defcreate_padding_mask(seq: torch.Tensor, pad_token_id: int = 0): """ Create mask for padding tokens. The mask indicates positions that are padding (True/1) and should be masked. Args: seq: Input sequence tensor of token IDs (batch_size, seq_len). pad_token_id: The ID used for padding tokens. Defaults to 0. Returns: mask: Padding mask (batch_size, 1, 1, seq_len). Positions with True (or 1) will be masked. Positions with False (or 0) will be kept. """ # 结果是一个布尔张量 mask,形状与 seq 相同 (batch_size, seq_len) # 在 mask 中,值为 True 的位置表示该 token 是一个 padding token,应该被屏蔽;值为 False 的位置表示是非 padding token,应该保留 mask = seq == pad_token_id return mask.unsqueeze(1).unsqueeze(2)
下面这个方法用来掩盖掉未来的 Token。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
defcreate_future_mask(size: int): """ Create mask to prevent attention to future positions (look-ahead mask). The mask indicates positions that are future (True/1) and should be masked. Args: size: Size of the square mask, typically target_seq_len. Returns: mask: Future mask (1, 1, size, size). Positions with True (or 1) will be masked. Positions with False (or 0) will be kept. """ # torch.ones((size, size), device=device) 创建一个全 1 的方阵 # torch.triu(..., diagonal=1) 获取这个方阵的上三角部分(不包括对角线)。 # - diagonal=0: 包括主对角线 # - diagonal=1: 从主对角线的上一条对角线开始 # - diagonal=-1: 从主对角线的下一条对角线开始 mask = torch.triu(torch.ones(size, size, device=device), diagonal=1) mask = mask.bool() return mask.unsqueeze(0).unsqueeze(0)
defcreate_masks(src: torch.Tensor, tgt: torch.Tensor, src_pad_id: int = 0, tgt_pad_id: int = 0): """ Create all masks needed for the encoder and decoder. Args: src: Source sequence tensor of token IDs (batch_size, src_len). tgt: Target sequence tensor of token IDs (batch_size, tgt_len). src_pad_id: Token ID for padding in the source sequence. tgt_pad_id: Token ID for padding in the target sequence. Returns: src_mask: Padding mask for the encoder's self-attention and decoder's cross-attention. Shape: (batch_size, 1, 1, src_len). True means mask. tgt_mask: Combined look-ahead and padding mask for the decoder's self-attention. Shape: (batch_size, 1, tgt_len, tgt_len). True means mask. """ # 1. 创建源序列的填充掩码 (src_mask) # 这个掩码将用于: # a) Encoder 的自注意力层 (屏蔽源序列中的填充) # b) Decoder 的交叉注意力层 (屏蔽编码器输出中对应源序列填充的部分) src_padding_mask = create_padding_mask(src) # (batch_size, 1, 1, src_len)