赞
踩
论文: LLaMA: Open and Efficient Foundation Language Models
作者: Meta AI
代码: LLaMA
特点: 该方法在Transformer的基础上增加了Pre-normalization (RMSNorm)、SwiGLU activation function (SwiGLU)、Rotary Embeddings (RoPE)、FlashAttention。
⚠️ 在学习该方法前,建议补充BatchNorm、LayerNorm、位置编码、Attention的相关知识。
Transformer和LLaMA的结构图如下:
可见,其结构差异主要体现在如下方面:
- Transformer采用了左编码器+右解码器(Encoder+Decoder)的结构,LLaMA采用了仅解码器(Decoder-only)的结构。由于仅包含解码器不需要与编码器输出交互,故LLaMA去掉了Transformer中Decoder中间的交叉Multi-Head Attention和Add & Norm。
- LLaMA采用了归一化前置(Pre-normalization)的策略,将归一化操作放在了注意力、FFN前并在线性映射前增加了一个归一化。此外,LLaMA还将LayerNorm替换为了
RMSNorm
。- LLaMA将绝对位置编码替换为了旋转位置编码,即
RoPE
,这是一种只对Q和K进行位置编码的方式。- 为加速训练,LLaMA引入了
FlashAttention
。- LLaMA将ReLU替换为了
SwiGLU
。
均方根归一化RMSNorm
简化了LayerNorm
的计算。
要了解RMSNorm
,首先需回顾LayerNorm
的公式:
其中, x \boldsymbol{x} x为输入的token序列, E [ x ] = 1 n ∑ i = 1 n x i {\bf E}\boldsymbol{[x]}=\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i E[x]=n1∑i=1nxi和 V a r [ x ] = 1 n ∑ i = 1 n ( x i − E [ x ] ) 2 {\bf Var}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^n(\boldsymbol{x}_i-{\bf E}\boldsymbol{[x]})^2} Var[x]=n1∑i=1n(xi−E[x])2 为 x \boldsymbol{x} x的均值和有偏方差, ϵ \boldsymbol{\epsilon} ϵ用来防止分母为0, γ \boldsymbol{\gamma} γ和 β \boldsymbol{\beta} β是可学习的参数用来缩放和平移。
RMSNorm
简化了LayerNorm
的计算,其公式如下:
其中, R M S [ x ] = 1 n ∑ i = 1 n x i 2 {\bf RMS}\boldsymbol{[x]}=\sqrt{\frac{1}{n}\sum_{i=1}^{n}\boldsymbol{x}_i^2} RMS[x]=n1∑i=1nxi2 是均方根。
可见,RMSNorm
与LayerNorm
主要有如下差别:
RMSNorm
无需计算均值 E [ x ] {\bf E}[\boldsymbol{x}] E[x]。RMSNorm
将有偏方差 V a r [ x ] {\bf Var[\boldsymbol{x}]} Var[x]替换为了均方根 R M S [ x ] {\bf RMS[\boldsymbol{x}]} RMS[x]。RMSNorm
无需平移项 γ \boldsymbol{\gamma} γ。
与LayerNorm
一样,RMSNorm
也能以句子或单词(token)为单位进行归一化,如下给出了以token为单位的代码示例。
import torch import torch.nn as nn class MyRMSNorm(nn.Module): def __init__(self, hidden_dim, eps=1e-8): super().__init__() # 防止分母计算为0 self._eps = eps # 仿射变换参数,缩放norm后的数据分布 self._gamma = nn.Parameter(torch.ones(hidden_dim)) def forward(self, input): # input(N,L,C) ms = input.pow(2).mean(dim=-1, keepdim=True) # 计算均方,token-wise input = input / torch.sqrt(ms + self._eps) # 执行标准化 return input * self._gamma # 仿射变换 if __name__ == '__main__': batch_size = 4 length = 2 hidden_dim = 3 input = torch.rand(4, 2, 3) myRMSN = MyRMSNorm(hidden_dim=hidden_dim) MyO = myRMSN(input) pytorchRMSN = nn.RMSNorm(normalized_shape=hidden_dim, elementwise_affine=False) # 不使用可学习的gamma和beta pytorchO = pytorchRMSN(input) print(MyO == pytorchO)
旋转位置编码RoPE
使用绝对位置信息设计旋转规则,使旋转后的数据能够表达相对位置信息。
要了解RoPE
,首先我们来了解一下二维空间的旋转。如下图:
其中,
X
=
[
ρ
cos
ϕ
,
ρ
sin
ϕ
]
X=[\rho\cos\phi,\rho\sin\phi]
X=[ρcosϕ,ρsinϕ]是一个二维向量,逆时针旋转
θ
\theta
θ度变成
X
R
(
θ
)
XR(\theta)
XR(θ)。此时
R
(
θ
)
=
[
cos
θ
,
sin
θ
−
sin
θ
,
cos
θ
]
R(\theta)=\left[
X
R
(
θ
)
=
[
ρ
cos
ϕ
,
ρ
sin
ϕ
]
[
cos
θ
,
sin
θ
−
sin
θ
,
cos
θ
]
=
ρ
[
cos
ϕ
cos
θ
−
sin
ϕ
sin
θ
,
cos
ϕ
sin
θ
+
sin
ϕ
cos
θ
]
=
[
ρ
cos
(
ϕ
+
θ
)
,
ρ
sin
(
ϕ
+
θ
)
]
XR(\theta)=[\rho\cos\phi,\rho\sin\phi]\left[
可见, X X X与 X R ( θ ) XR(\theta) XR(θ)仅差一个 θ \theta θ,所以二维空间逆时针旋转 θ \theta θ度可通过 R ( θ ) R(\theta) R(θ)实现。
旋转只改变角度,不改变长度。
RoPE
将旋转应用在了注意力模块的查询
Q
Q
Q和
K
K
K上。它将第
i
i
i个查询
Q
i
Q_i
Qi旋转
i
θ
i\theta
iθ的角度,再将第
j
j
j个键
K
j
K_j
Kj旋转
j
θ
j\theta
jθ的角度,那么
Q
i
K
j
T
Q_iK_j^T
QiKjT就会变成一个与相对位置
i
−
j
i-j
i−j相关的值。推导过程如下:
i i i和 j j j是查询 Q i Q_i Qi和 K j K_j Kj的绝对位置, i − j i-j i−j是它们的相对位置。
然而, Q i Q_i Qi和 K j K_j Kj的维度通常都是大于2的,我们假设它是 D D D且 D D D是2的整数倍,于是我们可以将 Q i Q_i Qi和 K j K_j Kj分别划分为 d = D 2 d=\frac{D}{2} d=2D个子空间,每个子空间都是二维的。
下图给出了一个 D = 10 D=10 D=10的例子,我们将 Q i Q_i Qi和 K j K_j Kj分为5个子空间并分配1个包括5个角度的旋转序列 Θ = ( θ 1 , θ 2 , ⋯ , θ 5 ) \Theta=(\theta_1,\theta_2,\cdots,\theta_5) Θ=(θ1,θ2,⋯,θ5),每个子空间的旋转角度是在对应旋转序列的基础上乘以 i i i或 j j j。
将其扩展到 d d d个子空间,可以得到如下信息:
其中, X i X_i Xi代指 Q i Q_i Qi或 K j K_j Kj。此时,这种旋转仍然具有相对位置的表达能力,证明如下:
显然,上面的 R ( i Θ ) R(i\Theta) R(iΘ)过于稀疏,为了提升计算效率,通常 d d d个子空间的旋转使用下式表达:
为避免token数过多,
i
θ
k
i\theta_k
iθk和
j
θ
k
j\theta_k
jθk重叠导致相对位置得不到表达(同一个子空间
k
k
k,绝对位置
i
i
i和
j
j
j不同,
i
θ
k
−
j
θ
k
=
2
m
π
i\theta_k-j\theta_k=2m\pi
iθk−jθk=2mπ时重叠,
m
m
m是一个整数),RoPE
使用了一个递减的等比数列作为
θ
\theta
θ序列,如下:
θ k \theta_k θk是递减的,这表示token中前几个子空间的旋转角度较大,越往后旋转角度越小。
事实上,为了方便我们通常不是将相邻的两个值划分至同一子空间,而是将D分为前后两个部分,前后各取一个依次组成子空间,例如[q0,q1,q2,q3]被划分为[q0,q2], [q1,q3]而不是[q0,q1], [q2,q3]。以下为使用这种方式进行子空间划分的RoPE
代码:
from torch.nn import functional as F import torch.nn as nn import torch import math class Rotator: """根据hidden_dim,和position_ids 生成对应的旋转位置编码, 和论文中定义略有不同,一个个二维的子空间被 分割到了前后两部分,分别进行旋转,然后拼接起来 """ def __init__(self, D, position_ids): """ position_ids: [seq_len], D 和单个头的hidden_dim对应 """ base = 10000 d = D / 2 B = base ** (1/d) theta_base = 1.0 / (B ** (torch.arange(0, d))) # 等比数列, $\Theta$ thetas = position_ids.outer(theta_base) # [seq_len, D/2] # 这里的子空间划分与讲解不同,[q0,q1,q2,q3] -> [q0,q2],[q1,q3]是两个子空间而不是[q0,q1],[q2,q3] full_thetas = torch.cat((thetas, thetas), dim=-1) # [seq_len, D] self.cos = full_thetas.cos() self.sin = full_thetas.sin() def rotate(self, x): """ x: [bs, num_attention_heads, seq_len, D] q: [bs, num_attention_heads, seq_len, D] cos: [seq_len, D] [x,y] @ [[cos, sin], [-sin, cos]] = [x*cos-y*sin, ycos+x*sin] =[x,y]*cos+[-y, x]*sin """ return x * self.cos + Rotator.reverse_half(x) * self.sin @staticmethod def reverse_half(q): """ q: [bs, num_attention_heads, seq_len, D] trick2 """ u = q[..., :q.shape[-1] // 2] # 认为是各个二维子空间的第一维的向量集结 v = q[..., q.shape[-1] // 2:] # 认为是各个二维子空间的第二维的向量集结 return torch.cat((-v, u), dim=-1) if __name__ == "__main__": batch_size = 2 num_heads = 3 D = 6 # 单个头的token向量长度 hidden_dim = D * num_heads seq_len = 4 position_ids = torch.arange(seq_len) rotator = Rotator(D, position_ids) x = torch.randn((batch_size, seq_len, hidden_dim)) # 对每个头分别进行旋转,[batch_size,seq_len,hidden_dim] -> [batch_size,seq_len,num_heads,D] -> [batch_size,num_heads,seq_len,D] x = x.view(batch_size, seq_len, num_heads, D).transpose(1, 2) x = rotator.rotate(x)
FlashAttention
以分块的形式进行注意力计算,避免了SRAM和HBM之间频繁读写导致的时间浪费。
详情请参考我之前的博客FlashAttention in NeurIPS 2022。
激活函数SwiGLU
是门控线性单元(Gated Linear Units, GLU
)的变体,下图红框中表达了GLU
的计算过程:
可见,GLU
会先使用两个带偏执的线性层映射输入
x
\boldsymbol{x}
x,分别记为
x
W
1
+
b
1
\boldsymbol{xW_1+b_1}
xW1+b1和
x
W
2
+
b
2
\boldsymbol{xW_2+b_2}
xW2+b2;其中一个线性映射后会跟一个非线性激活函数sigmoid
,记为
σ
(
x
W
1
+
b
1
)
\sigma(\boldsymbol{xW_1+b_1})
σ(xW1+b1);然后将左右两边的结果对应元素相乘即完成了GLU
,记为
σ
(
x
W
1
+
b
1
)
⊗
(
x
W
2
+
b
2
)
\sigma(\boldsymbol{xW_1+b_1})\otimes(\boldsymbol{xW_2+b_2})
σ(xW1+b1)⊗(xW2+b2)。
SwiGLU
对GLU
做了两点改进:
- 去掉了两个线性映射的偏执项,此时公式变成 σ ( x W 1 ) ⊗ ( x W 2 ) \sigma(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) σ(xW1)⊗(xW2)。
- 将
sigmoid
替换为了Swish
,此时公式变成 Swish β ( x W 1 ) ⊗ ( x W 2 ) \text{Swish}_{\beta}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2}) Swishβ(xW1)⊗(xW2)。
Swish
的公式为
Swish
β
(
a
)
=
a
σ
(
β
a
)
=
a
1
+
e
−
β
a
\text{Swish}_{\beta}(a)=a\sigma(\beta a)=\frac{a}{1+e^{-\beta a}}
Swishβ(a)=aσ(βa)=1+e−βaa,在不同的
β
\beta
β下该非线性激活函数的曲线如下:
可见,当
β
\beta
β较大时,该曲线与ReLU
十分接近;当
β
=
1
\beta=1
β=1时,小于0但接近0的曲线变得更光滑且非单调。
SwiGLU
则选用了
β
=
1
\beta=1
β=1的Swish
,于是我们得到SwiGLU
的公式如下:
Swish
(
x
W
1
)
⊗
(
x
W
2
)
=
x
W
1
1
+
e
−
x
W
1
⊗
x
W
2
\text{Swish}(\boldsymbol{xW_1})\otimes(\boldsymbol{xW_2})=\frac{\boldsymbol{xW_1}}{1+e^{-\boldsymbol{xW_1}}}\otimes\boldsymbol{xW_2}
Swish(xW1)⊗(xW2)=1+e−xW1xW1⊗xW2
本博客仅做记录使用,无任何商业用途,参考内容如下:
解密旋转位置编码:数学基础、代码实现与绝对编码一体化探索
一文为你深度解析 LLaMA2 模型架构
Llama改进之——SwiGLU激活函数
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。