Update transformer.py

This commit is contained in:
Logan Zou
2025-07-28 17:39:34 +08:00
committed by GitHub
parent 9bdf9ed202
commit a110181cf8

View File

@@ -9,6 +9,7 @@ import torch.nn.functional as F
class ModelArgs:
n_embd: int # 嵌入维度
n_heads: int # 头数
n_local_heads: int # 本地计算头数
dim: int # 模型维度
dropout: float
max_seq_len: int
@@ -36,11 +37,11 @@ class MultiHeadAttention(nn.Module):
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积
# 不理解的读者可以自行模拟一下每一个线性层其实相当于n个参数矩阵的拼接
self.wq = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
self.wq = nn.Linear(args.n_embd, self.n_local_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.n_embd, self.n_local_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.n_embd, self.n_local_heads * self.head_dim, bias=False)
# 输出权重矩阵,维度为 dim x n_embdhead_dim = n_embeds / n_heads
self.wo = nn.Linear(args.n_local_heads * self.head_dim, args.dim, bias=False)
self.wo = nn.Linear(self.n_local_heads * self.head_dim, args.dim, bias=False)
# 注意力的 dropout
self.attn_dropout = nn.Dropout(args.dropout)
# 残差连接的 dropout