使用 PyTorch 实现 Transformer 结构(三):Transformer 封装与参数量计算

在上一篇文章中已经实现了完整的编解码器,那么接下去我们就可以实现完整的 Transformer 结构。不过在此之前,需要先为 Mask 编写几个辅助方法。

Mask

下面这个方法被用来创建 Padding Mask。例如,seq[1, 2, 3, 0, 0],那么 mask 则为 [False, False, False, True, True]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def create_padding_mask(seq: torch.Tensor, pad_token_id: int = 0):
"""
Create mask for padding tokens.
The mask indicates positions that are padding (True/1) and should be masked.
Args:
seq: Input sequence tensor of token IDs (batch_size, seq_len).
pad_token_id: The ID used for padding tokens. Defaults to 0.
Returns:
mask: Padding mask (batch_size, 1, 1, seq_len).
Positions with True (or 1) will be masked.
Positions with False (or 0) will be kept.
"""
# 结果是一个布尔张量 mask,形状与 seq 相同 (batch_size, seq_len)
# 在 mask 中,值为 True 的位置表示该 token 是一个 padding token,应该被屏蔽;值为 False 的位置表示是非 padding token,应该保留
mask = seq == pad_token_id
return mask.unsqueeze(1).unsqueeze(2)

下面这个方法用来掩盖掉未来的 Token。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def create_future_mask(size: int):
"""
Create mask to prevent attention to future positions (look-ahead mask).
The mask indicates positions that are future (True/1) and should be masked.
Args:
size: Size of the square mask, typically target_seq_len.
Returns:
mask: Future mask (1, 1, size, size).
Positions with True (or 1) will be masked.
Positions with False (or 0) will be kept.
"""
# torch.ones((size, size), device=device) 创建一个全 1 的方阵
# torch.triu(..., diagonal=1) 获取这个方阵的上三角部分(不包括对角线)。
# - diagonal=0: 包括主对角线
# - diagonal=1: 从主对角线的上一条对角线开始
# - diagonal=-1: 从主对角线的下一条对角线开始
mask = torch.triu(torch.ones(size, size, device=device), diagonal=1)
mask = mask.bool()
return mask.unsqueeze(0).unsqueeze(0)

下面这个方法将上述两个方法封装,用于为后续 Transformer 类提供编码器和解码器各自所需的完整掩码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def create_masks(src: torch.Tensor, tgt: torch.Tensor, src_pad_id: int = 0, tgt_pad_id: int = 0):
"""
Create all masks needed for the encoder and decoder.
Args:
src: Source sequence tensor of token IDs (batch_size, src_len).
tgt: Target sequence tensor of token IDs (batch_size, tgt_len).
src_pad_id: Token ID for padding in the source sequence.
tgt_pad_id: Token ID for padding in the target sequence.
Returns:
src_mask: Padding mask for the encoder's self-attention and decoder's cross-attention.
Shape: (batch_size, 1, 1, src_len). True means mask.
tgt_mask: Combined look-ahead and padding mask for the decoder's self-attention.
Shape: (batch_size, 1, tgt_len, tgt_len). True means mask.
"""
# 1. 创建源序列的填充掩码 (src_mask)
# 这个掩码将用于:
# a) Encoder 的自注意力层 (屏蔽源序列中的填充)
# b) Decoder 的交叉注意力层 (屏蔽编码器输出中对应源序列填充的部分)
src_padding_mask = create_padding_mask(src)
# (batch_size, 1, 1, src_len)

# 2. 创建目标序列的填充掩码 (tgt_padding_mask)
# 这个掩码用于屏蔽目标序列 (作为 Key 时) 的填充部分,在解码器的自注意力层中使用。
# 注意:这里的 tgt 是输入给解码器的序列 (例如,decoder_input)
tgt_padding_mask = create_padding_mask(tgt)
# (batch_size, 1, 1, tgt_len)

# 3. 创建目标序列的前瞻掩码 (tgt_future_mask / look-ahead mask)
# 这个掩码用于防止解码器的自注意力层关注到未来的位置。
tgt_len = tgt.size(1)
tgt_future_mask = create_future_mask(tgt_len)
# (1, 1, tgt_len, tgt_len)

# 4. 组合目标序列的填充掩码和前瞻掩码得到最终的 tgt_mask
# tgt_mask 将用于解码器的自注意力层。
tgt_mask = tgt_padding_mask | tgt_future_mask
# (batch_size, 1, tgt_len, tgt_len)

return src_padding_mask, tgt_mask

Transformer

最后,完整的 Transformer 结构如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
src_pad_id: int = 0,
tgt_pad_id: int = 0,
):
super().__init__()

# 保存填充 ID
self.src_pad_id = src_pad_id
self.tgt_pad_id = tgt_pad_id

# 实例化 Encoder
self.encoder = Encoder(src_vocab_size)
# 实例化 Decoder
self.decoder = Decoder(tgt_vocab_size)

# 最终的线性投影层,将解码器输出映射到目标词汇表大小
# 输入维度是 d_model,输出维度是 tgt_vocab_size
self.final_layer = nn.Linear(d_model, tgt_vocab_size)

def forward(self, src: torch.Tensor, tgt: torch.Tensor):
"""
Args:
src: Source sequence tensor of token IDs (batch_size, src_len).
tgt: Target sequence tensor of token IDs (batch_size, tgt_len).
Returns:
output: Logits over the target vocabulary.
Shape: (batch_size, tgt_len, tgt_vocab_size).
"""
# 1. 创建源序列和目标序列的掩码
src_mask, tgt_mask = create_masks(src, tgt, self.src_pad_id, self.tgt_pad_id)
# src_mask: (batch_size, 1, 1, src_len)
# tgt_mask: (batch_size, 1, tgt_len, tgt_len)

# 2. 将源序列传递给编码器
encoder_output = self.encoder(src, src_mask)
# (batch_size, src_len, d_model)

# 3. 将目标序列和编码器输出传递给解码器
decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
# (batch_size, tgt_len, d_model)

# 4. 将解码器的输出通过最后的线性层,得到每个位置的词汇表 logits
output = self.final_layer(decoder_output)
# (batch_size, tgt_len, tgt_vocab_size)

# 注意:通常在这里不应用 softmax。
# 如果使用 PyTorch 的 nn.CrossEntropyLoss,它内部会计算 log_softmax。
# 直接返回 logits (原始分数) 是标准做法。
return output

参数计算

最后来计算一下上述实现的 Transformer 参数量。设编码器和解码器各有 $N$ 层,$d_{model}$ 代码模型的隐藏层维度,$d_{ff}$ 代表前馈神经网络的中间层维度,$h$ 为多头注意力的头数,$V_{src}$ 为源语言词汇表大小,$V_{tgt}$ 为目标语言词汇表大小。

  1. 先从 Encoder Block/Layer 算起:
    1. 多头注意力:
      • Q、K、V 的线性投影权重:$3 \times d_{model} \times d_{model}$
      • 输出线性投影的权重:$d_{model} \times d_{model}$
    2. 前馈神经网络:$2 \times d_{model} \times d_{ff}$
    3. 每个 Encoder Layer 有 2 个 LayerNorm,每个 LayerNorm 有 $2 \times d_{model}$ 个参数($\gamma$ 和 $\beta$),总共 $2 \times 2 \times d_{model}$
  2. 对于 Decoder Block/Layer 而言,大致相似:
    1. 掩码多头注意力总共 $4 \times d_{model}^2$
    2. Cross Attention 同上。
    3. 前馈神经网络为 $2 \times d_{model} \times d_{ff}$
    4. 每个 Decoder Layer 有 3 个 LayerNorm,共 $3 \times 2 \times d_{model}$
  3. 词嵌入层:
    1. 源语言词嵌入:$V_{src} \times d_{model}$
    2. 目标语言词嵌入:$V_{tgt} \times d_{model}$
  4. 输出线性层的权重:$d_{model} \times V_{tgt}$

综上,当 $d_{model}$ 为 512,$h$ 为 8,$N$ 为 6,$d_{ff}$ 为 2048 时,参数量为 $44,070,912 + 1024 \times V_{tgt} + 512 \times V_{src}$。