From 435661a5d54f58f25a64342a10e70457bcbf4445 Mon Sep 17 00:00:00 2001 From: Logan Zou <74288839+logan-zou@users.noreply.github.com> Date: Fri, 25 Jul 2025 16:15:16 +0800 Subject: [PATCH] Update transformer.py --- docs/chapter2/code/transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/chapter2/code/transformer.py b/docs/chapter2/code/transformer.py index 1d5814c..d1fccaf 100644 --- a/docs/chapter2/code/transformer.py +++ b/docs/chapter2/code/transformer.py @@ -36,11 +36,11 @@ class MultiHeadAttention(nn.Module): # Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd # 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积, # 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接 - self.wq = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.n_embd, args.n_heads * self.head_dim, bias=False) + self.wq = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False) + self.wk = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False) + self.wv = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False) # 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + self.wo = nn.Linear(args.n_local_heads * self.head_dim, args.dim, bias=False) # 注意力的 dropout self.attn_dropout = nn.Dropout(args.dropout) # 残差连接的 dropout