
01
引言
Causal Self-Attention 通常又被称之为基于mask的Self-Attention , 是Transformer 模型的一个基本概念,尤其是在语言建模等自回归任务中。其目的是确保序列中的每个位置上的Token只关注它之前的位置(包括它自己),而不关注它之后位置上的Token 。这种机制可以防止未来位置的信息泄露,这对于预测句子中的下一个单词等任务至关重要。让我们深入了解Causal Self-Attention的工作原理以及如何实现它。
02
Self-Attention
在自注意力机制中,序列中的每个Token都可以关注其他每个位置上的Token。这需要计算一组注意力得分,用以表明每个位置对其他位置的关注程度。

注意力得分是通过Query 和Key 向量的点积,再乘以缩放因子经过softmax后计算得出的。这里,Q、K和 V均由输入序列经过线性层变换后得到。
03
Causal Self-Attention中的mask
为了加强因果关系的约束,我们对注意力得分进行了掩码处理。该掩码会将当前位置之后所有位置上的注意力得分设置为-∞(或一个非常大的负数),从而有效地在softmax操作后将其贡献值置为零。

其中,M是屏蔽mask矩阵,其中需要被屏蔽的位置处的值为-∞ ,其他位置的值为0。
04
掩码矩阵mask的作用
-
防止信息泄漏:确保模型在预测当前标记时不会使用未来信息。 -
支持自回归生成:允许模型一次生成一个文本标记,只关注已经生成的标记。
05
代码实现
让我们用 PyTorch 在 Python 中实现Causal Self-Attention 。
import torchimport torch.nn as nnimport torch.nn.functional as Fclass CausalSelfAttention(nn.Module):def __init__(self, embed_size, num_heads):super(CausalSelfAttention, self).__init__()self.num_heads = num_headsself.embed_size = embed_sizeself.head_dim = embed_size // num_headsassert self.head_dim * num_heads == embed_size, "Embedding size must be divisible by number of heads"self.values = nn.Linear(embed_size, embed_size, bias=False)self.keys = nn.Linear(embed_size, embed_size, bias=False)self.queries = nn.Linear(embed_size, embed_size, bias=False)self.fc_out = nn.Linear(embed_size, embed_size)def forward(self, x):N, seq_length, embed_size = x.shape# Split the embedding into num_heads different piecesvalues = self.values(x).view(N, seq_length, self.num_heads, self.head_dim)keys = self.keys(x).view(N, seq_length, self.num_heads, self.head_dim)queries = self.queries(x).view(N, seq_length, self.num_heads, self.head_dim)values = values.transpose(1, 2)keys = keys.transpose(1, 2)queries = queries.transpose(1, 2)# Scaled dot-product attentionenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])mask = torch.tril(torch.ones((seq_length, seq_length))).expand(N, 1, seq_length, seq_length)energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, seq_length, self.embed_size)out = self.fc_out(out)return out
测试代码如下:
# Example usageembed_size = 512num_heads = 8seq_length = 10x = torch.rand((1, seq_length, embed_size))causal_self_attention = CausalSelfAttention(embed_size, num_heads)output = causal_self_attention(x)print(output.shape) # Output: torch.Size([1, seq_length, embed_size])
上述过程可以总结如下:
-
初始化: CausalSelfAttention类使用embed_size和num_heads进行初始化。嵌入embeddings在各头之间平均分配。 线性层:我们为
key、value和query创建线性层。重塑:对输入张量进行
reshape操作,以分离头部。基于点积的注意力:计算注意力得分,应用掩码
mask以确保因果关系,并用 softmax 进行归一化处理。合并heads: 合并各头并通过最后的线性层后输出。
06
结论
Causal SelfAttention是一种强大的注意力机制,可确保模型在自回归任务中保证数据的顺序性。通过实施该机制,大家可以构建模型,在不窥探未来的情况下一步步生成新的序列。通过本文,希望大家可以对因果自注意力机制有更加扎实的了解,并知道如何在自己的模型中实现它。祝大家编码愉快!
点击上方小卡片关注我
添加个人微信,进专属粉丝群!


