使用 PyTorch 实现 Transformer 结构(一):多头注意力机制

我计划通过几篇文章来整理一下如何使用 PyTorch 实现 2017 年《Attention is all your need》论文中的 Transformer 模型结构。关于 Transformer 的理论部分网络上已经有非常多的文章或视频讲解,如果你想要了解这部分的内容,我个人比较推荐的是来自图灵出版社公众号的一篇文章:《一文读懂 Transformer,工作原理与实现全解析》

准备

在使用 PyTorch 编写模型结构前,通常会先初始化一些基本环境配置。

1
2
3
4
5
6
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

此外让我们先初始化 Transformer 模型的一些结构参数。

1
2
3
4
5
6
7
d_model = 512  # embedding size
max_len = 1024 # max length of sequence
d_ff = 4 * d_model # feedforward nerual network dimension
d_q = d_k = d_v = 64 # dimension of k(same as q) and v
n_layers = 6 # number of encoder and decoder layers
n_heads = 8 # number of heads in multihead attention
p_drop = 0.1 # propability of dropout

Scaled Dot-Product Attention

Transformer 的核心是多头注意力机制,在实现多头注意力机制之前,需要先实现缩放点积注意力(Scaled Dot-Product Attention)。

实现缩放点积注意力其实就是实现公式 $\text{Attention(Q, K, V)}=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$,也就是下面这张图:

Scaled Dot-Product Attention

我们可以把缩放点积注意力当成一个方法来实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor | None = None,
):
"""
Args:
Q: (batch_size, n_heads, seq_len_q, d_q)
K: (batch_size, n_heads, seq_len_k, d_k)
V: (batch_size, n_heads, seq_len_v, d_v)
mask: Optional mask.
Its shape should be broadcastable to (batch_size, n_heads, seq_len_q, seq_len_k).
Positions with True (or 1) will be masked (set to a very small number before softmax).
Positions with False (or 0) will be kept.
Returns:
output: (batch_size, n_heads, seq_len_q, d_v)
attention_weights: (batch_size, n_heads, seq_len_q, seq_len_k)
"""
assert Q.dim() == 4, f"Query should be 4-dim but got {Q.dim()}-dim"
assert K.dim() == 4, f"Key should be 4-dim but got {K.dim()}-dim"
assert V.dim() == 4, f"Value should be 4-dim but got {V.dim()}-dim"
assert Q.size(-1) == K.size(-1), "Dimension of query (d_q) and key (d_k) must be the same."
assert K.size(-2) == V.size(-2), "Sequence length of key (seq_len_k) and value (seq_len_k) must be the same."

scores = Q @ K.transpose(-2, -1) / torch.sqrt(torch.tensor(d_k, dtype=Q.dtype, device=Q.device))
# (batch_size, n_heads, seq_len_q, seq_len_k)

if mask is not None:
scores.masked_fill_(mask.bool(), -torch.inf)

attention_weights = F.softmax(scores, dim=-1)
# (batch_size, n_heads, seq_len_q, seq_len_k)

output = attention_weights @ V
# (batch_size, n_heads, seq_len_q, d_v)

return output, attention_weights

上述代码中:

  1. @ 等价于 torch.matmul()
  2. 除以 $\sqrt{d_k}$ 是为了缩放点积的结果,防止其数值过大导致梯度消失,从而稳定训练过程
  3. 使用负无穷大作为掩码值是因为:
    • softmax 函数形式为 $\text{softmax}(z)_i = \frac{e^{z_i}}{\sum^K_{j=1} e^{z_j}}$
    • 当 $x$ 趋向负无穷大时,指数函数 $e^x$ 趋向于 0
    • 那么想要掩盖掉第 $k$ 个位置,就将值设置为无穷大,从而在计算 softmax 时第 $k$ 个位置对应的分子会趋近于 0 ,因此 softmax 输出为 0

Multi-Head Attention

上述缩放点积注意力是对单个头进行计算,接下去让我们来实现多头注意力机制。这部分代码主要就是将模型拆分成多个头,例如 $d_{model}$ 为 512,拆分成 8 个头,则每个头维度为 64。

Multi-Head Attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class MultiHeadAttention(nn.Module):
def __init__(self, bias=False):
super().__init__()

self.W_q = nn.Linear(d_model, d_q * n_heads, bias=bias)
self.W_k = nn.Linear(d_model, d_k * n_heads, bias=bias)
self.W_v = nn.Linear(d_model, d_v * n_heads, bias=bias)
self.W_o = nn.Linear(d_v * n_heads, d_model, bias=bias)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None = None,
):
"""
Args:
query: (batch_size, seq_len_q, d_model)
key: (batch_size, seq_len_k, d_model)
value: (batch_size, seq_len_v, d_model)
mask: Optional mask.
"""
batch_size = query.size(0)

# 1. Linear projections
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# (batch_size, seq_len_qkv, d_qkv * n_heads)

# 2. Split into heads
# view(): (batch_size, seq_len_qkv, d_qkv * n_heads) -> (batch_size, seq_len_qkv, n_heads, d_qkv)
# transpose(): (batch_size, seq_len_qkv, n_heads, d_qkv) -> (batch_size, n_heads, seq_len_qkv, d_qkv)
Q = Q.view(batch_size, -1, n_heads, d_q).transpose(1, 2)
K = K.view(batch_size, -1, n_heads, d_k).transpose(1, 2)
V = V.view(batch_size, -1, n_heads, d_v).transpose(1, 2)

# 3. Apply attention
output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
# (batch_size, n_heads, seq_len_q, d_v), (batch_size, n_heads, seq_len_q, seq_len_k)

# 4. Concatenate heads
# transpose(): (batch_size, n_heads, seq_len_q, d_v) -> (batch_size, seq_len_q, n_heads, d_v)
# contiguous(): transpose() 操作可能会导致张量在内存中变得不连续,contiguous() 方法会返回一个内存连续的张量。
# 如果原张量已经是连续的,contiguous() 不会做任何事情(也不会复制数据)。
# 如果原张量不连续,contiguous() 会复制数据,使其在内存中连续存储。
# view(): (batch_size, seq_len_q, n_heads, d_v) -> (batch_size, seq_len_q, dv * n_heads = d_model)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, d_v * n_heads)

# 5. Final projection
output = self.W_o(output)
# (batch_size, seq_len_q, dv * n_heads = d_model)

return output, attention_weights

上述代码中:

  1. bias 并不一定要全为 False,另外几种常见的配置是把 bias 全部设置为 True 或者 W_obiasTrue
  2. 可以把 scaled_dot_product_attention() 作为静态方法合并到 MultiHeadAttention 类中

FeedForward Network

前馈神经网络这部分实现起来比较简单,它是一个典型的 MLP。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class FeedForwardNetwork(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(p_drop),
nn.Linear(d_ff, d_model),
nn.Dropout(p_drop),
)

def forward(self, x: torch.Tensor):
"""
Args:
x: (batch_size, seq_len_q, d_model)
"""
output = self.model(x)
# (batch_size, seq_len_q, d_model)

return output