PyTorch 构建 Transformer 模型(手把手讲解)

PyTorch 构建 Transformer 模型

使用 PyTorch 构建 Transformer 模型是当前自然语言处理(NLP)和计算机视觉(CV)领域最热门的技术之一。本文将围绕"PyTorch 构建 Transformer 模型"这一主题,提供清晰、简洁的实现路径与最佳实践,帮助开发者快速构建自己的 Transformer 模型。

快速解决

直接使用 PyTorch 提供的 nn.Transformer 模块可以快速构建 Transformer 模型。这个方法适合已经了解 Transformer 架构的开发者,能够快速搭建一个基础模型结构并训练。

示例代码如下:

import torch
import torch.nn as nn

model = nn.Transformer(nhead=8, num_encoder_layers=3, num_decoder_layers=3)

src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))

output = model(src, tgt)
print(output.shape)  # 输出: (20, 32, 512)

该代码创建了一个包含 3 个编码层和 3 个解码层的 Transformer 模型,适用于序列到序列任务,例如机器翻译。

常用方法

以下是构建 Transformer 模型时常用的 PyTorch 方法和模块:

方法/模块 功能说明 示例代码
nn.Transformer 构建完整的 Transformer 模型 model = nn.Transformer(nhead=8)
nn.TransformerEncoder 构建 Transformer 编码器部分 encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
nn.TransformerDecoder 构建 Transformer 解码器部分 decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)
nn.TransformerEncoderLayer 定义编码器层的结构 encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
nn.TransformerDecoderLayer 定义解码器层的结构 decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
nn.Embedding 将输入序列转换为嵌入向量 emb = nn.Embedding(1000, 512)
nn.Linear 输出层,用于生成最终预测 output_layer = nn.Linear(512, 1000)
nn.Dropout 添加 dropout 层防止过拟合 dropout = nn.Dropout(p=0.1)

详细说明

使用 nn.TransformerEncoderLayer 构建编码器

TransformerEncoderLayer 是构建 Transformer 编码器的基础,它包含了多头自注意力机制和前馈网络。

import torch
import torch.nn as nn

encoder_layer = nn.TransformerEncoderLayer(
    d_model=512,        # 嵌入维度
    nhead=8,          # 多头注意力头数
    dim_feedforward=2048,  # 前馈网络隐藏层维度
    dropout=0.1       # dropout 概率
)

encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)

src = torch.rand((10, 32, 512))
encoded = encoder(src)
print(encoded.shape)  # 输出: (10, 32, 512)

使用 nn.TransformerDecoderLayer 构建解码器

与编码器类似,解码器由 TransformerDecoderLayer 构成,它支持自注意力和编码器-解码器注意力。

decoder_layer = nn.TransformerDecoderLayer(
    d_model=512,        # 嵌入维度
    nhead=8,          # 多头注意力头数
    dim_feedforward=2048,  # 前馈网络隐藏层维度
    dropout=0.1       # dropout 概率
)

decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)

tgt = torch.rand((20, 32, 512))
decoded = decoder(tgt, encoded)
print(decoded.shape)  # 输出: (20, 32, 512)

使用 nn.Transformer 构建完整模型

Transformer 模块封装了编码器和解码器,适合快速构建端到端的 Transformer 模型。

model = nn.Transformer(
    d_model=512,        # 嵌入维度
    nhead=8,          # 多头注意力头数
    num_encoder_layers=3,  # 编码器层数
    num_decoder_layers=3,  # 解码器层数
    dim_feedforward=2048,  # 前馈网络维度
    dropout=0.1       # dropout 概率
)

src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
output = model(src, tgt)
print(output.shape)  # 输出: (20, 32, 512)

高级技巧

自定义嵌入层

在实际应用中,输入通常是 token ID 形式,需要通过嵌入层将其转换为向量表示。可以自定义嵌入层以支持位置编码或词向量加载。

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len=5000):
        super(TransformerEmbedding, self).__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(0, seq_len).expand(x.size(0), seq_len).to(x.device)
        x = self.token_emb(x) + self.pos_emb(positions)
        return self.dropout(x)

emb = TransformerEmbedding(vocab_size=10000, d_model=512)
src = emb(torch.randint(0, 10000, (32, 10)))  # 输入 (batch_size, seq_len)
src = src.permute(1, 0, 2)  # 转换为 (seq_len, batch_size, embed_dim)

动态生成解码器输入

在解码阶段,解码器输入通常是逐步生成的,可以使用一个循环机制,将每一步的输出作为下一步的输入。

start_token = torch.tensor([1] * 32).unsqueeze(0)  # shape: (1, 32)
emb = TransformerEmbedding(vocab_size=10000, d_model=512)
decoder_input = emb(start_token).permute(1, 0, 2)  # shape: (1, 32, 512)

outputs = []
for i in range(20):  # 生成 20 个 token
    decoder_output = decoder(decoder_input, encoded)
    decoder_input = decoder_output[-1, :, :].unsqueeze(0)  # 取最后一个 token 作为下一步输入
    outputs.append(decoder_output)

使用自定义的 Transformer 模型

在实际任务中,可能需要对 Transformer 进行修改,例如添加自定义层或处理不同任务需求。可以继承 nn.Transformer 并覆盖 forward 方法。

import torch
import torch.nn as nn

class CustomTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=3):
        super(CustomTransformer, self).__init__()
        self.emb = TransformerEmbedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.emb(src)
        src = src.permute(1, 0, 2)
        tgt = self.emb(tgt)
        tgt = tgt.permute(1, 0, 2)
        memory = self.encoder(src)
        output = self.decoder(tgt, memory)
        output = self.fc(output)
        return output.permute(1, 0, 2)  # 返回 (batch_size, seq_len, vocab_size)

model = CustomTransformer(vocab_size=10000)
src = torch.randint(0, 10000, (32, 10))
tgt = torch.randint(0, 10000, (32, 20))
output = model(src, tgt)
print(output.shape)  # 输出: (32, 20, 10000)

常见问题

Q1:为什么 Transformer 的输入顺序是 (seq_len, batch_size, embed_dim)?

PyTorch 的 Transformer 模块要求输入格式为 (seq_len, batch_size, embed_dim),这是为了与内部的注意力机制和缓存机制保持一致。如果使用的是 (batch_size, seq_len, embed_dim),需要通过 .permute(1, 0, 2) 转换。

Q2:如何处理 Transformer 的位置编码?

位置编码通常通过 nn.Embedding 或直接计算的正弦/余弦形式添加到词嵌入中。在自定义嵌入层中,可以将位置信息与词嵌入相加或拼接。

Q3:TransformerDecoder 的输入为什么需要是 decoder_input?

TransformerDecoder 在每一步生成时都需要知道当前生成的 token 序列,因此输入通常包括前面生成的部分序列。在训练中,可以通过 teacher forcing 提供完整的目标序列;在推理中,则逐步生成。

Q4:如何选择 nhead 和 d_model 的值?

nhead 是注意力头数,一般建议是 d_model 的因数。d_model 是嵌入维度,通常设为 256、512 或 1024,取决于任务复杂度和计算资源。

总结

PyTorch 构建 Transformer 模型可以通过内置模块快速实现,也可以通过自定义嵌入和解码逻辑适应特定任务需求。