手把手实现 Transformer

Transformer 架构是现代 NLP 的基石。本文将通过 PyTorch 从零实现一个简化版的 Transformer。

核心组件

  • 多头注意力机制(Multi-Head Attention)
  • 位置编码(Positional Encoding)
  • 前馈神经网络(Feed-Forward Network)

代码实现

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)