Update 第二章 Transformer架构.md

This commit is contained in:
Logan Zou
2025-07-25 16:14:58 +08:00
committed by GitHub
parent 139ffd84b2
commit 1c8ce38bb9

View File

@@ -263,11 +263,11 @@ class MultiHeadAttention(nn.Module):
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积
# 不理解的读者可以自行模拟一下每一个线性层其实相当于n个参数矩阵的拼接
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wq = nn.Linear(args.dim, args.n_local_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_local_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_local_heads * self.head_dim, bias=False)
# 输出权重矩阵,维度为 dim x n_embdhead_dim = n_embeds / n_heads
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.wo = nn.Linear(args.n_local_heads * self.head_dim, args.dim, bias=False)
# 注意力的 dropout
self.attn_dropout = nn.Dropout(args.dropout)
# 残差连接的 dropout