Source: huggingface/transformers#25199
RoPE: rotary positional embedding
考虑处于序列中位置为 i 的特征
对于一对特征
其中的
然后计算:
最后将分成对的特征再拼接回去即可。
具体地,对于
在实现上,huggingface transformers
这样实现 LLaMA 的 RoPE 操作:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
其计算中用到了一个 rotate_half
功能是将 x 切两半,但拼接的时候是拼成 [-x2,x1]
计算的时候,上半部结果是:
可是,这只对于 d=2 的情况下成立,当 d 不是 2 的时候,rotate_half
会将整个特征从中间切分,而非按顺序两两为一对地切分,例如,会把[x1, x2, x3, x4, x5, x6]
切分为 [x1, x2, x3], [x4, x5, x6]
这样再计算 embedding 的时候,会得到:
那么问题来了,huggingface transformers 的模型又是可以正常使用的,说明 RoPE 的实现并没有问题。原因在哪呢?
huggingface/transformers#25199 (comment)
这个人的回答帮我弄明白了:
原来,HF transformers 在加载 llama 参数的时候,会对 wq, wk 先进行一个置换操作:
# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
直接看代码不一定好理解,我们通过了一个具体例子来看一下这个 permute 到底做了什么:
>>> w = torch.randn(6,6)
>>> w
tensor([[ 0.0351, -1.8382, -0.4659, -0.6392, -1.4064, 2.5892],
[ 0.1871, -1.6733, -0.1340, 0.1229, -0.0832, 0.8563],
[-1.4261, 0.1210, -0.7404, -0.7363, 0.2171, -0.5006],
[ 1.1344, 0.9882, 0.5771, 1.6343, -0.5803, -0.6329],
[ 0.5153, -0.4251, 0.2446, 0.8374, -1.2831, 0.0325],
[-0.5279, -0.5472, -0.2414, 0.1889, 1.3524, -0.7277]])
>>> w1 = w.view(1, 3, 2, 6)
>>> w2 = w1.transpose(1,2).reshape(6,6)
>>> w2
tensor([[ 0.0351, -1.8382, -0.4659, -0.6392, -1.4064, 2.5892],
[-1.4261, 0.1210, -0.7404, -0.7363, 0.2171, -0.5006],
[ 0.5153, -0.4251, 0.2446, 0.8374, -1.2831, 0.0325],
[ 0.1871, -1.6733, -0.1340, 0.1229, -0.0832, 0.8563],
[ 1.1344, 0.9882, 0.5771, 1.6343, -0.5803, -0.6329],
[-0.5279, -0.5472, -0.2414, 0.1889, 1.3524, -0.7277]])
可以看到,这个 permute 操作实际上就是把原本的参数的第 [2k-1, 2k] 行,置换到第 [k, d/2+k] 行去了。那么,我们将得到特征 [x1, x3, x5, x2, x4, x6]
,这样的话,再进行 rotate_half 操作之后,我们会切分为 [x1, x3, x5], [x2, x4, x6]
,再计算 embedding,会得到:
至此,我们弄明白了 RoPE 的原理和公式,以及 HF 的实现。
但是,Meta 的官方 LLaMA 实现是这样做的:
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
首先,可以看到用到了 view_as_complex
函数,其作用是将输入的最后一维(必须为2)分别视作复数的实部和虚部,具体地,根据 torch 的例子:
>>> x=torch.randn(4, 2)
>>> x
tensor([[ 1.6116, -0.5772],
[-1.4606, -0.9120],
[ 0.0786, -1.7497],
[-0.6561, -1.6623]])
>>>torch.view_as_complex(x)
tensor([(1.6116-0.5772j), (-1.4606-0.9120j), (0.0786-1.7497j), (-0.6561-1.6623j)])
接下来解释 reshape
到底在做什么:
-
首先,假设 xq.shape 是 (batch_size, seq_len, hidden_dim)
-
那么
xq.shape[:-1]
即 (batch_size, seq_len) 这样一个 tuple -
*
操作表示将 tuple 展开,那么 reshape 的变量是:reshape(batch_size, seq_len, -1, 2)
-
这里的 -1 表达的意思是:不指定这一维的值,具体值可以推理得到。其实是这些库为了方便开发提供的自动化功能,实际上这一维的值就是
hidden_dim/2
,即将原本 hidden_dim 那一维的值,按顺序分为两两一组 -
给一个具体例子:
-
>>> x = torch.randn(1,2,4) >>> x tensor([[[-0.9001, 0.5472, -0.2143, 1.7586], [ 0.3527, -2.4083, 0.0380, -2.0296]]]) >>> x1 = x.reshape(1, 2, -1, 2) >>> x1 tensor([[[[-0.9001, 0.5472], [-0.2143, 1.7586]], [[ 0.3527, -2.4083], [ 0.0380, -2.0296]]]]) >>> x1.shape torch.Size([1, 2, 2, 2])
那么,这个 reshape 操作实际上就是将特征维度按顺序两两组成一对,之后再和每一对特征的encoding freqs_cis
相乘。
首先,要知道复数的乘法如何计算。对于 x1=a+bj, x2=c+dj 而言,x1*x2 = (ac-bd) + (ad+bc)j
那么,对于复数特征
即
最后,再将复数形式转化为实数形式,完成 RoPE。
总结一下,Meta 是按照标准的 RoPE 方式来实现,即先将特征维度按顺序两两配对,与 [cos, sin] encoding 进行计算后再恢复为正常结构。只是在实现的时候,使用复数作为中间形式,简化了向量操作。实际上,从 RoFormer 也就是提出 RoPE 的工作来看,其设计的出发点应该就是一种复数的表达,所以使用复数反而是更自然的。只是这确实是我第一次看到在 DL 的实现中使用。