MENU

Transformer 中的 Positional Encoding

July 7, 2020 • Read: 38649 • Deep Learning阅读设置

最近我在学习 Transformer 结构的时候,发现其中的 positional encoding 很不好理解,尤其是其中的公式,为什么要这样设计,后来上网收集各种资料,方才理解,遂于此写一篇文章进行记录

首先你需要知道,Transformer 是以字作为输入,将字进行字嵌入之后,再与位置嵌入进行相加(不是拼接,就是单纯的对应位置上的数值进行加和)

需要使用位置嵌入的原因也很简单,因为 Transformer 摈弃了 RNN 的结构,因此需要一个东西来标记各个字之间的时序 or 位置关系,而这个东西,就是位置嵌入

One possible solution to give the model some sense of order is to add a piece of information to each word about its position in the sentence. We call this “piece of information”, the positional encoding.

如果让我们从 0 开始设计一个 Positional Encoding,比较容易想到的第一个方法是取 [0,1] 之间的数分配给每个字,其中 0 给第一个字,1 给最后一个字,具体公式就是 $PE=\frac {pos}{T-1}$。这样做的问题在于,假设在较短文本中任意两个字位置编码的差为 0.0333,同时在某一个较长文本中也有两个字的位置编码的差是 0.0333。假设较短文本总共 30 个字,那么较短文本中的这两个字其实是相邻的;假设较长文本总共 90 个字,那么较长文本中这两个字中间实际上隔了两个字。这显然是不合适的,因为相同的差值,在不同的句子中却不是同一个含义

另一个想法是线性的给每个时间步分配一个数字,也就是说,第一个单词被赋予 1,第二个单词被赋予 2,依此类推。这种方式也有很大的问题:1. 它比一般的字嵌入的数值要大,难免会抢了字嵌入的「风头」,对模型可能有一定的干扰;2. 最后一个字比第一个字大太多,和字嵌入合并后难免会出现特征在数值上的倾斜

理想的设计

理想情况下,位置嵌入的设计应该满足以下条件:

  • 它应该为每个字输出唯一的编码
  • 不同长度的句子之间,任何两个字之间的差值应该保持一致
  • 它的值应该是有界的

作者设计的位置嵌入满足以上的要求。首先,它不是一个数字,而是一个包含句子中特定位置信息的 $d$ 维向量。其次,这种嵌入方式没有集成到模型中,相反,这个向量是用来给句子中的每个字提供位置信息的,换句话说,我们通过注入每个字位置信息的方式,增强了模型的输入(其实说白了就是将位置嵌入和字嵌入相加,然后作为输入

设 $t$ 为一句话中的某个字的位置,$\vec {p_t} \in \mathbb {R}^d$ 表示位置 $t$ 时刻这个词位置嵌入的向量,$\vec {p_t}$ 的定义如下

$$ \begin{align} \vec{p_t}^{(i)} = f(t)^{(i)} & := \begin{cases} \sin({\omega_k} . t), & \text{if}\ i = 2k \\ \cos({\omega_k} . t), & \text{if}\ i = 2k + 1 \end{cases} \end{align} $$

其中

$$ \omega_k = \frac{1}{10000^{2k / d}} $$

$k$​指的是位置嵌入中维度的下标,为了使得位置嵌入和字嵌入能够相加,因此位置嵌入维度和字嵌入的维度必须相同,所以 $i\in [0, d)$,因此就有 $k\in [0, \frac {d-1}{2}]$

对于三角函数 $y=A\sin (Bx+C)+D$ 来说,周期是 $\frac {2\pi}{B}$,频率为 $\frac {B}{2\pi}$,因此 B 越大,频率值越大,一个周期内函数图像重复次数越多,波长越短(如果这里的数学知识忘了,可以看这篇文章

回到 $\vec {p_t}$ 的定义中,$k$ 是越来越大的,因此 $w_k$ 越来越小,所以 $\frac {w_k}{2\pi}$ 也越来越小,于是频率随着向量维度下标的递增而递减,频率递减 = 周期变长。我们计算一下周期最小是 $2\pi$($k=0$ 时),周期最大是 $10000・2\pi$(假设 $k=\frac {d}{2}$ 时)

你可以想象下 $t$ 时刻字的位置编码 $\vec {p_t}$ 是一个包含 $\sin$ 和 $\cos$ 函数的向量(假设 $d$ 可以被 2 整除)

$$ \vec{p_t} = \begin{bmatrix} \sin({\omega_0}.t)\\ \cos({\omega_0}.t)\\ \\ \sin({\omega_1}.t)\\ \cos({\omega_1}.t)\\ \\ \vdots\\ \\ \sin({\omega_{\frac{d}{2}-1}}.t)\\ \cos({\omega_{\frac{d}{2}-1}}.t) \end{bmatrix}_{d \times 1} $$

直观展示

你可能想知道 $\sin$ 和 $\cos$ 的组合是如何表示位置信息的?这其实很简单,假设你想用二进制表示一个数字,你会怎么做?

$$ \begin{align} 0: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} & & 8: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} \\ 1: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} & & 9: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} \\ 2: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} & & 2: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} \\ 3: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} & & 11: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{0}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} \\ 4: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} & & 12: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{0}} \\ 5: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} & & 13: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{0}} \ \ \color{red}{\texttt{1}} \\ 6: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} & & 14: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{0}} \\ 7: \ \ \ \ \color{orange}{\texttt{0}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} & & 15: \ \ \ \ \color{orange}{\texttt{1}} \ \ \color{green}{\texttt{1}} \ \ \color{blue}{\texttt{1}} \ \ \color{red}{\texttt{1}} \\ \end{align} $$

用二进制表示一个数字太浪费空间了,因此我们可以使用与之对应的连续函数 —— 正弦函数

参考文献

Last Modified: August 6, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

26 Comments
  1. 一只小白 一只小白

    解开了我对 position encoding 的困惑。有一点小建议:补充一下 position encoding 如何解决开篇提出的两个问题,逻辑感觉更严密

    1. mathor mathor

      @一只小白好的,这两天我再改改

  2. chrisL chrisL

    在任意时刻下的位置通过 WT 乘积送到 sin/cos,在不同 W 下的频率实现不同数字编码,T 记录时间步。也有一种解释,不同的频率本事就含有时间信息。博主好像最后没解释清楚最后如何表示这个 position 的。

  3. gq gq

    你好,请问这是怎么做到不同长度的句子之间,任何两个字之间的差值应该保持一致的?

  4. 李白二 李白二

    你好,博主解析的很漂亮!@(大拇指)
    按照你说的理想条件: 1. 它应该为每个字输出唯一的编码;2. 不同长度的句子之间,任何两个字之间的差值应该保持一致; 3. 它的值应该是有界的
    注释:评论过长,只好分两次写了

  5. 李白二 李白二

    我想这样的函数应该都能满足理想条件:
    向量 V =(pos/C)*f (t) 其中,pos 为字符的位置字符的位置索引,t 为 t 时刻字的位置编码(同你博客里边的 t 时刻意思一样), C 是一个比较大的常数;
    博主,你怎么看?

    1. mathor mathor

      @李白二是的,不过后来从 bert 开始,所有用到 positional encoding 的部分都不再是这种单纯的公式计算得到的,而是使用模型训练的,即 nn.Embedding ()

  6. ctrlplayer ctrlplayer

    我有个问题,positional encoding 层应该被训练吗?

    1. mathor mathor

      @ctrlplayer 在最原始的 transformer 的论文以及源码实现中,是不训练的

      到了 bert 及其之后的论文里,是训练的

    2. AI_magician AI_magician

      @mathor 据说是因为在 convolution 中的尝试中有用过但效果不好

    3. AI_magician AI_magician

      @mathor 据说是因为在 convolution 中的尝试中有用过但效果不好

  7. hhy hhy

    楼主好,我想请教一下,位置编码满足了不同长度的句子之间,任何两个字之间的差值应该保持一致,但是位置编码和嵌入向量相加之后,这些信息会被打乱,那么相加后的输入向量又是怎么体现出位置的前后关系的了

    1. mathor mathor

      @hhy 我的水平有限,无法解答你的问题,苏神最近有一篇关于位置编码的博客,或许可以帮到你
      https://spaces.ac.cn/archives/8231

  8. lp lp

    真的很强,受益匪浅,非常感谢博主的分享!!

  9. 张宇杰 张宇杰

    博主您好,我想问一下,我训练好了一个 transformer 用于做机器翻译,但是用您的 greedy_decoder 方法取实现的时候我发现得到的 decoder 是乱的。顺便说一下,我用的是 pytorch 封装的 transformer 就是 nn.TransformerEncoderLayer 之类的

  10. 亭亭玉立 亭亭玉立

    你好,在 k 的取值范围应该可以取到(d-1)/2 吧,这个上面写的是开区间

    1. mathor mathor

      @亭亭玉立感谢提醒,已修改

  11. Bryce Bryce

    感谢分享 @(真棒)

  12. 深度学习 - Transformer 详解 - StubbornHuang Blog

    [...] 如果不理解这里为何这么设计,可以看这篇文章 Transformer 中的 Positional Encoding。[...]

  13. 梦想就是摸鱼 梦想就是摸鱼

    赞啊,解释得太好了

  14. wildstone wildstone

    大神,​k 指的是位置嵌入中维度的下标是不是写错了,应该是 i 把?k 是成对的 2 个一组的组序号下标把?

  15. 小鱼 小鱼

    谢谢博主!让我更加明白了这个 transformer 的位置编码的概念。不过对于这个 t 的取值我不是很清楚它的取值,能够详细的讲解一下嘛?它的取值我知道肯定是跟句子中的词位置是有关系的。

  16. 小鱼 小鱼

    感谢博主,在你的另外的博客上面看到了,取值为 [0,max_sentence_length)。

  17. AI_magician AI_magician

    其实说白了的话也就是拼接,可以这么理解,先拼接一个独热编码向量,然后再乘以 W 矩阵,得到的结果也是这两个的相加,但是这个是可以 learn 的,效果不好才用一个公式得出数值相加

  18. 菜鸟 菜鸟

    很好的解释

  19. Vince Vince

    写的真好,但是怎么感觉戛然而止了。。。。。。。会写请您多写点谢谢 hhhh