使用 PyTorch 实现 Transformer 结构(二):编解码器

在上一篇文章中已经实现了 Transformer 中核心的缩放点积注意力、多头注意力以及前馈神经网络,那么这篇文章中我们先将这几部分封装成编码器层和解码器层,然后再去实现完整的编码器和解码器。

Transformer

Encoder Block/Layer

首先我们来实现一个基本的编码器模块(层),代码其实很简单,只要额外在 MultiHeadAttentionFeedFrowardNetwork 后面加上残差连接和 Layer Normalization 即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class EncoderBlock(nn.Module):
def __init__(self):
super().__init__()
# 1. Multi-head attention
self.mha = MultiHeadAttention()
# 2. Layer normalization
self.layer_norm_1 = nn.LayerNorm(d_model)
# 3. Feed forward
self.ffn = FeedForwardNetwork()
# 4. Another layer normalization
self.layer_norm_2 = nn.LayerNorm(d_model)
# 5. Dropout
self.dropout = nn.Dropout(p_drop)

def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
):
"""
Args:
x: (batch_size, seq_len, d_model)
mask: Optional mask.
"""
# 1. Multi-head attention with residual connection and layer norm
# 1a. Multi-Head Self-Attention (Q=K=V=x)
attention_output, _ = self.mha(x, x, x, mask)
# (batch_size, seq_len, d_model)
# 1b. Add & Norm
x = x + self.dropout(attention_output)
x = self.layer_norm_1(x)

# 2. Feed forward with residual connection and layer norm
# 2a. Feed Forward Network
feedforward_output = self.ffn(x)
# 2b. Add & Norm
x = x + self.dropout(feedforward_output)
x = self.layer_norm_2(x)

return x
# (batch_size, seq_len, d_model)

在上述代码中,mask 用于屏蔽输入序列中填充(padding)的部分,防止模型关注到这些无意义的标记。

Decoder Block/Layer

解码器模块(层)实现会比 EncoderBlock 略微复杂一些,因为多了一个带掩码的 Multi-Head Attention 以及需要处理 Cross Attention。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class DecoderBlock(nn.Module):
def __init__(self):
super().__init__()

# 1. Masked Multi-head attention
# Q, K, V from target sequence
self.mha_1 = MultiHeadAttention()
# 2. Layer norm for first sub-layer
self.layer_norm_1 = nn.LayerNorm(d_model)
# 3. Multi-head attention for cross attention with encoder output
# Query from target sequence (output of previous layer), Key & Value from encoder_output
self.mha_2 = MultiHeadAttention()
# 4. Layer norm for second sub-layer
self.layer_norm_2 = nn.LayerNorm(d_model)
# 5. Feed forward network
self.ffn = FeedForwardNetwork()
# 6. Layer norm for third sub-layer
self.layer_norm_3 = nn.LayerNorm(d_model)
# 7. Dropout (applied before Add & Norm for each sub-layer)
self.dropout = nn.Dropout(p_drop)

def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
):
"""
Args:
x: Target sequence embedding, (batch_size, target_seq_len, d_model)
encoder_output: Output from encoder, (batch_size, source_seq_len, d_model)
src_mask: Mask for source padding in encoder_output.
tgt_mask: Mask for target sequence (self-attention). Combines look-ahead and padding.
"""
# --- 1. 第一个子层: Masked Multi-Head Self-Attention + Add & Norm ---
# Q, K, V 都来自解码器的输入 x,tgt_mask 用于此处的自注意力
attention_output_1, _ = self.mha_1(x, x, x, tgt_mask)
x = x + self.dropout(attention_output_1)
x = self.layer_norm_1(x)

# --- 2. 第二个子层: Multi-Head Cross-Attention + Add & Norm ---
# Query 来自前一个子层的输出 (x),Key 和 Value 来自编码器的输出 (encoder_output)
# src_mask 用于此处的交叉注意力,屏蔽 encoder_output 中的填充部分
attention_output_2, _ = self.mha_2(x, encoder_output, encoder_output, src_mask)
x = x + self.dropout(attention_output_2)
x = self.layer_norm_2(x)

# --- 3. 第三个子层: Feed Forward Network + Add & Norm ---
feedforward_output = self.ffn(x)
x = x + self.dropout(feedforward_output)
x = self.layer_norm_3(x)

return x
# (batch_size, target_seq_len, d_model)

上述代码中:

  1. tgt_mask 作用于第一个多头注意力层,即解码器的 Masked Multi-Head Attention。
    • 它通过将注意力分数矩阵中对应于未来位置的部分设置为负无穷大(这通常是一个上三角矩阵,对角线以上为 True/1,表示屏蔽),从而使这些位置的注意力权重趋近于 0,来防止模型看到未来信息。
    • 同时也会屏蔽目标序列中的填充标记。
  2. src_mask 作用于第二个多头注意力层,即解码器的 Cross Attention,目的就是屏蔽编码器输出中的填充部分。

位置编码

在实现 Encoder 和 Decoder 之前,需要先实现位置编码。Transformer 中使用了正弦位置编码,我们需要实现如下两个公式:

$$
\text{PE}(pos, 2i) = \sin\left(\frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}}\right)
$$

$$
\text{PE}(pos, 2i+1) = \cos\left(\frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}}\right)
$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class PositionalEncoding(nn.Module):
def __init__(self):
super().__init__()

# 1. 创建一个形状为 (max_seq_length, d_model) 的零矩阵,用于存储位置编码
pe = torch.zeros(max_len, d_model, device=device)
# (max_seq_length, d_model)

# 2. 创建位置索引向量
# position 向量代表了序列中的每个位置索引 (0, 1, 2, ..., max_seq_length-1)
# .unsqueeze(1) 是为了后续与 div_term 进行广播乘法
position = torch.arange(0, max_len, device=device).float().unsqueeze(1)
# (max_seq_length, 1)

# 3. 创建除法项 (div_term)
# 创建表示 2i 的索引 (0, 2, 4, ..., d_model-2),这些索引对应于 d_model 中的偶数维度
indices_2i = torch.arange(0, d_model, 2, device=device).float()
# 计算 10000^(-2i/d_model)
div_term = torch.pow(10000.0, -indices_2i / d_model)
# (d_model/2, )

# 第 4 步 position * div_term 时
# 1. div_term 会升维到 (1, d_model/2)
# 2. position 会从 (max_len, 1) 扩展到 (max_len, d_model/2)
# 3. div_term 会从 (1, d_model/2) 扩展到 (max_len, d_model/2)

# 4. 计算位置编码
# 对偶数维度 (0, 2, 4, ...) 应用 sin 函数
pe[:, 0::2] = torch.sin(position * div_term)
# 对奇数维度 (1, 3, 5, ...) 应用 cos 函数
pe[:, 1::2] = torch.cos(position * div_term)

# 5. 注册为 buffer
# Buffer 是模型的状态的一部分,但不是模型的参数 (即在反向传播时不会被更新)。
# 它们会被保存在模型的 state_dict 中,并且在模型移动到 GPU 时也会一起移动。
# .unsqueeze(0) 是为了在 forward 方法中方便与批处理数据 (batch_size, seq_len, d_model) 相加。
self.register_buffer("pe", pe.unsqueeze(0))
# (1, max_seq_length, d_model)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch_size, seq_len, d_model)
"""
# 从预计算的 positional encoding 矩阵 (self.pe) 中,
# 取出与当前输入序列长度 (x.size(1)) 相匹配的部分。
# (batch_size, seq_len, d_model) + (1, seq_len, d_model) -> (batch_size, seq_len, d_model)
return x + self.pe[:, : x.size(1), :]

Encoder

在完成绝对位置编码之后,就可以实现 Encoder 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Encoder(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()

# 1. 输入嵌入层 (Input Embedding)
# 将输入的 token ID 映射为 d_model 维度的向量
self.embeddings = nn.Embedding(vocab_size, d_model)
# 2. 位置编码 (Positional Encoding)
self.pe = PositionalEncoding()
# 3. Dropout
self.dropout = nn.Dropout(p_drop)
# 4. N 个编码器层堆叠 (Stack of N encoder layers)
self.encoder_blocks = nn.ModuleList([EncoderBlock() for _ in range(n_layers)])

def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
"""
Args:
x: 输入的 token IDs, 形状为 (batch_size, seq_len)。
mask: Padding mask for the source sequence.
"""
# 1. 通过嵌入层
x = self.embeddings(x) # (batch_size, seq_len, d_model)

# 2. 添加位置编码并应用 Dropout
x = self.pe(x) # (batch_size, seq_len, d_model)
x = self.dropout(x) # (batch_size, seq_len, d_model)

# 3. 依次通过 N 个 EncoderBlock
# mask 会被传递给每一个 EncoderBlock,用于其内部的自注意力层
for block in self.encoder_blocks:
x = block(x, mask) # x 的形状保持 (batch_size, seq_len, d_model)

return x # 最终输出形状 (batch_size, seq_len, d_model)

Decoder

Decoder 的实现和 Encoder 几乎一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class Decoder(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()

# 1. 输出嵌入层 (Output Embedding for target sequence)
self.embeddings = nn.Embedding(vocab_size, d_model)
# 2. 位置编码 (Positional Encoding)
self.pe = PositionalEncoding()
# 3. Dropout
self.dropout = nn.Dropout(p_drop)
# 4. N 个解码器层堆叠 (Stack of N decoder layers)
self.decoder_blocks = nn.ModuleList([DecoderBlock() for _ in range(n_layers)])

def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
):
"""
Args:
x: Target tokens (IDs), (batch_size, target_seq_len)
encoder_output: Output from encoder, (batch_size, source_seq_len, d_model)
src_mask: Mask for source padding (encoder output).
tgt_mask: Mask for target padding and future positions (target sequence).
"""
# 1. 通过嵌入层
x = self.embeddings(x) # (batch_size, target_seq_len, d_model)

# 2. 添加位置编码并应用 Dropout
x = self.pe(x) # (batch_size, target_seq_len, d_model)
x = self.dropout(x) # (batch_size, target_seq_len, d_model)

# 3. 依次通过 N 个 DecoderBlock
for block in self.decoder_blocks:
x = block(x, encoder_output, src_mask, tgt_mask)
# x 的形状保持 (batch_size, target_seq_len, d_model)

return x # 最终输出形状 (batch_size, target_seq_len, d_model)
# 这个输出将送入最后的线性层和 softmax 来预测下一个 token