当前位置:   article > 正文

Self-attention与multi-head self-attention

Self-attention与multi-head self-attention

自注意力(self-attention)允许模型在处理序列数据时,根据输入中的其他位置来加权考虑每个位置的信息。这对于处理长文本或序列中的依赖关系非常有用。

多头自注意力(multi-head self-attention)建立在自注意力机制之上,它通过允许模型同时关注不同表示子空间的信息,来增强模型捕捉不同类型的关系和依赖性的能力。

虽然自注意力专注于捕捉序列内部的依赖关系,但多头自注意力通过利用多个注意力头来捕捉不同类型的关系和依赖关系,提高了这种捕捉能力。

自注意力机制的实现:

  1. 计算注意力分数(Attention Scores)
    给定一个输入序列 X X X,我们首先将其投影到查询 Q Q Q,键 K K K,和值 V V V 的向量空间中,这是通过学习得到的权重矩阵 W Q W_Q WQ W K W_K WK W V W_V WV 实现的。然后,我们计算查询与键的点积,最后通过 softmax 函数进行标准化,得到注意力分数:

    Attention Scores = softmax ( Q K T d k ) \text{Attention Scores} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) Attention Scores=softmax(dk QKT)
    其中, d k d_k dk 是键的维度。

  2. 计算加权和(Weighted Sum)
    注意力分数与值 V V V 相乘,得到每个位置的加权和。这一步捕捉了每个位置对其他所有位置的相对重要性。具体计算公式为:

    Self-Attention ( X ) = Attention Scores ⋅ V \text{Self-Attention}(X) = \text{Attention Scores} \cdot V Self-Attention(X)=Attention ScoresV

多头自注意力机制的实现

多头自注意力不只是单一地关注输入序列的全局信息,而是同时从不同的角度去看待输入序列。这就好比你在看待一个问题时,可以从不同的角度去思考,这样就能够更全面地理解问题了。

具体来说,多头自注意力包括了几个步骤:

  1. 投影(Projection)
    对于给定的输入序列 X X X,我们首先将其分别投影为查询 Q Q Q,键 K K K,和值 V V V 的向量。这可以通过线性变换实现:

    Q = X ⋅ W Q , K = X ⋅ W K , V = X ⋅ W V Q = X \cdot W_Q, \quad K = X \cdot W_K, \quad V = X \cdot W_V Q=XWQ,K=XWK,V=XWV
    这里, W Q W_Q WQ W K W_K WK W V W_V WV 是可学习的权重矩阵。

  2. 头的拆分(Splitting Heads)
    将每个投影后的向量 Q , K , V Q, K, V Q,K,V 分成 h h h 个头(通常 h h h 是一个超参数)。即,我们得到 h h h 组查询 Q i Q_i Qi,键 K i K_i Ki,值 V i V_i Vi,每组都是原始投影向量的子集。

  3. 注意力计算(Attention Computation)
    对于每个头 i i i,我们计算注意力分数:

    Attention i ( Q i , K i ) = softmax ( Q i K i T d k ) V i \text{Attention}_i(Q_i, K_i) = \text{softmax}\left(\frac{Q_iK_i^T}{\sqrt{d_k}}\right)V_i Attentioni(Qi,Ki)=softmax(dk QiKiT)Vi
    其中, d k d_k dk 是键的维度。通过这个步骤,每个头都得到了一组注意力加权的值。

  4. 头的合并(Concatenation of Heads)
    将每个头的注意力加权值拼接在一起,形成多头注意力的输出:
    MultiHead ( Q , K , V ) = Concat ( Attention 1 , Attention 2 , . . . , Attention h ) \text{MultiHead}(Q, K, V) = \text{Concat}(\text{Attention}_1, \text{Attention}_2, ..., \text{Attention}_h) MultiHead(Q,K,V)=Concat(Attention1,Attention2,...,Attentionh)

  5. 线性变换(Linear Transformation)
    将拼接后的输出通过另一个线性变换:
    MultiHead ( Q , K , V ) ⋅ W O \text{MultiHead}(Q, K, V) \cdot W_O MultiHead(Q,K,V)WO
    其中, W O W_O WO 是输出层的权重矩阵。

通过这些步骤,我们就得到了多头自注意力机制的最终输出,该输出保留了输入序列的各个部分,并且在多个注意力头的帮助下,能够更好地捕捉序列中的长程依赖关系。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/379845
推荐阅读
相关标签
  

闽ICP备14008679号