手把手实现 Transformer
Transformer 架构是现代 NLP 的基石。本文将通过 PyTorch 从零实现一个简化版的 Transformer。
核心组件
- 多头注意力机制(Multi-Head Attention)
- 位置编码(Positional Encoding)
- 前馈神经网络(Feed-Forward Network)
代码实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)