赞
踩
1、源码
import torch
from einops import rearrange
from torch import einsum, nn
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # [d/2]
# inv_freq形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
# 计算m
seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) # [length]
freqs = einsum("i , j -> i j", seq, self.inv_freq)
# freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_d/2],其中 m=0,1,...,length-1
# return结果形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1
return torch.cat((freqs, freqs), dim=-1) # [length, d]
def rotate_half(x):
# x为q或k, 形式化为[q0, q1, .., qd-1]
# x: [bs, head, length, d]--> [bs, head, length, 2, d/2]
# 下式将x进行划分,前半部分形式化为[q0, q1, .., q(d/2-1)]
x = rearrange(x, "... (j d) -> ... j d", j=2)
# x1形式化为[q0, q1, .., q(d/2-1)]
# x2形式化为[q(d/2), q(d/2+1), .., q(d-1)]
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1) # [-q(d/2), -q(d/2+1), .., -q(d-1), q0, q1, .., q(d/2-1)]
def apply_rotary_pos_emb(pos, t):
# t: [bs, head, length, d], [q0, q1, .., qd-1]
# pos: [length, d], [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)]
rotate_half(t)
# 以第一个为例,q0*cos(m*theta_0) - q(d/2)*sin(m*theta_0)
# 第二个,q1*cos(m*theta_1) - q(d/2+1)*sin(m*theta_1)
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
if __name__=='__main__':
# (bs, head, length, d)
q = torch.randn((2, 12, 10, 32)) # q=[q0, q1, .., qd-1]
k = torch.randn((2, 12, 10, 32))
v = torch.randn((2, 12, 10, 32))
print('q:', q[0][0][0])
print('k:', k[0][0][0])
rotary_emb = RotaryEmbedding(dim=32)
pos_emb = rotary_emb(max_seq_len=10, device=torch.device('cpu')) # [length, d]
q_new, k_new = map(lambda t: apply_rotary_pos_emb(pos_emb, t), (q, k))
print()
2、公式
(
q
0
q
1
.
.
.
q
d
/
2
−
1
q
d
/
2
.
.
.
q
d
−
2
q
d
−
1
)
∗
(
c
o
s
(
m
θ
0
)
c
o
s
(
m
θ
1
)
.
.
.
c
o
s
(
m
θ
d
/
2
−
1
)
c
o
s
(
m
θ
0
)
.
.
.
c
o
s
(
m
θ
d
/
2
−
2
)
c
o
s
(
m
θ
d
/
2
−
1
)
)
+
(
−
q
d
/
2
−
q
d
/
2
+
1
.
.
.
−
q
d
−
1
q
0
.
.
.
q
d
/
2
−
2
q
d
/
2
−
1
)
(
s
i
n
(
m
θ
0
)
s
i
n
(
m
θ
1
)
.
.
.
s
i
n
(
m
θ
d
/
2
−
1
)
s
i
n
(
m
θ
0
)
.
.
.
s
i
n
(
m
θ
d
/
2
−
2
)
s
i
n
(
m
θ
d
/
2
−
1
)
)
\left(
3、图形
观察上图,可以发现
q
0
q_0
q0和
q
d
/
2
q_{d/2}
qd/2相互作用,生成新的
q
0
n
e
w
q^{new}_0
q0new和
q
d
/
2
n
e
w
q^{new}_{d/2}
qd/2new,拆解后可以得到下式
q
0
n
e
w
=
q
0
∗
c
o
s
(
m
θ
0
)
−
q
d
/
2
∗
s
i
n
(
m
θ
0
)
q^{new}_0=q_0*cos(m\theta_0)-q_{d/2}*sin(m\theta_0)
q0new=q0∗cos(mθ0)−qd/2∗sin(mθ0)
q
d
/
2
n
e
w
=
q
0
∗
s
i
n
(
m
θ
0
)
+
q
d
/
2
∗
c
o
s
(
m
θ
0
)
q^{new}_{d/2}=q_0*sin(m\theta_0)+q_{d/2}*cos(m\theta_0)
qd/2new=q0∗sin(mθ0)+qd/2∗cos(mθ0)
也即向量
(
q
0
n
e
w
,
q
d
/
2
n
e
w
)
(q^{new}_0,q^{new}_{d/2})
(q0new,qd/2new)由向量
(
q
0
,
q
d
/
2
)
(q_0,q_{d/2})
(q0,qd/2)逆时针旋转
m
θ
0
m\theta_0
mθ0得到
可于下面链接中LLaMA中ROPE实现做对比
LLaMA中ROPE位置编码实现源码解析
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。