当前位置:   article > 正文

PaLM中ROPE位置编码实现源码解析_rope 实现

rope 实现

1、源码

import torch
from einops import rearrange
from torch import einsum, nn

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # [d/2]
        # inv_freq形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        # 计算m
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) # [length]

        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_d/2],其中 m=0,1,...,length-1

        # return结果形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1
        return torch.cat((freqs, freqs), dim=-1) # [length, d]


def rotate_half(x):
    # x为q或k, 形式化为[q0, q1, .., qd-1]
    # x: [bs, head, length, d]--> [bs, head, length, 2, d/2]
    # 下式将x进行划分,前半部分形式化为[q0, q1, .., q(d/2-1)]
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    # x1形式化为[q0, q1, .., q(d/2-1)]
    # x2形式化为[q(d/2), q(d/2+1), .., q(d-1)]
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1) # [-q(d/2), -q(d/2+1), .., -q(d-1), q0, q1, .., q(d/2-1)]

def apply_rotary_pos_emb(pos, t):
    # t: [bs, head, length, d], [q0, q1, .., qd-1]
    # pos: [length, d], [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)]
    rotate_half(t)
    # 以第一个为例,q0*cos(m*theta_0) - q(d/2)*sin(m*theta_0)
    #  第二个,q1*cos(m*theta_1) - q(d/2+1)*sin(m*theta_1)
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())


if __name__=='__main__':
    # (bs, head, length, d)
    q = torch.randn((2, 12, 10, 32))  # q=[q0, q1, .., qd-1]
    k = torch.randn((2, 12, 10, 32))
    v = torch.randn((2, 12, 10, 32))
    print('q:', q[0][0][0])
    print('k:', k[0][0][0])
    rotary_emb = RotaryEmbedding(dim=32)
    pos_emb = rotary_emb(max_seq_len=10, device=torch.device('cpu'))  # [length, d]
    q_new, k_new = map(lambda t: apply_rotary_pos_emb(pos_emb, t), (q, k))
    print()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

2、公式
( q 0 q 1 . . . q d / 2 − 1 q d / 2 . . . q d − 2 q d − 1 ) ∗ ( c o s ( m θ 0 ) c o s ( m θ 1 ) . . . c o s ( m θ d / 2 − 1 ) c o s ( m θ 0 ) . . . c o s ( m θ d / 2 − 2 ) c o s ( m θ d / 2 − 1 ) ) + ( − q d / 2 − q d / 2 + 1 . . . − q d − 1 q 0 . . . q d / 2 − 2 q d / 2 − 1 ) ( s i n ( m θ 0 ) s i n ( m θ 1 ) . . . s i n ( m θ d / 2 − 1 ) s i n ( m θ 0 ) . . . s i n ( m θ d / 2 − 2 ) s i n ( m θ d / 2 − 1 ) ) \left(

q0q1...qd/21qd/2...qd2qd1
\right)* \left(
cos(mθ0)cos(mθ1)...cos(mθd/21)cos(mθ0)...cos(mθd/22)cos(mθd/21)
\right)+ \left(
qd/2qd/2+1...qd1q0...qd/22qd/21
\right) \left(
sin(mθ0)sin(mθ1)...sin(mθd/21)sin(mθ0)...sin(mθd/22)sin(mθd/21)
\right) q0q1...qd/21qd/2...qd2qd1 cos(mθ0)cos(mθ1)...cos(mθd/21)cos(mθ0)...cos(mθd/22)cos(mθd/21) + qd/2qd/2+1...qd1q0...qd/22qd/21 sin(mθ0)sin(mθ1)...sin(mθd/21)sin(mθ0)...sin(mθd/22)sin(mθd/21)
3、图形
观察上图,可以发现 q 0 q_0 q0 q d / 2 q_{d/2} qd/2相互作用,生成新的 q 0 n e w q^{new}_0 q0new q d / 2 n e w q^{new}_{d/2} qd/2new,拆解后可以得到下式

q 0 n e w = q 0 ∗ c o s ( m θ 0 ) − q d / 2 ∗ s i n ( m θ 0 ) q^{new}_0=q_0*cos(m\theta_0)-q_{d/2}*sin(m\theta_0) q0new=q0cos(mθ0)qd/2sin(mθ0)
q d / 2 n e w = q 0 ∗ s i n ( m θ 0 ) + q d / 2 ∗ c o s ( m θ 0 ) q^{new}_{d/2}=q_0*sin(m\theta_0)+q_{d/2}*cos(m\theta_0) qd/2new=q0sin(mθ0)+qd/2cos(mθ0)
也即向量 ( q 0 n e w , q d / 2 n e w ) (q^{new}_0,q^{new}_{d/2}) (q0new,qd/2new)由向量 ( q 0 , q d / 2 ) (q_0,q_{d/2}) (q0,qd/2)逆时针旋转 m θ 0 m\theta_0 mθ0得到
可于下面链接中LLaMA中ROPE实现做对比
LLaMA中ROPE位置编码实现源码解析

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/371424
推荐阅读
  

闽ICP备14008679号