Update transformer.py
This commit is contained in:
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
class ModelArgs:
|
||||
n_embd: int # 嵌入维度
|
||||
n_heads: int # 头数
|
||||
n_local_heads: int # 本地计算头数
|
||||
dim: int # 模型维度
|
||||
dropout: float
|
||||
max_seq_len: int
|
||||
@@ -36,11 +37,11 @@ class MultiHeadAttention(nn.Module):
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
|
||||
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,
|
||||
# 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接
|
||||
self.wq = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wq = nn.Linear(args.n_embd, self.n_local_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.n_embd, self.n_local_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.n_embd, self.n_local_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads)
|
||||
self.wo = nn.Linear(args.n_local_heads * self.head_dim, args.dim, bias=False)
|
||||
self.wo = nn.Linear(self.n_local_heads * self.head_dim, args.dim, bias=False)
|
||||
# 注意力的 dropout
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
# 残差连接的 dropout
|
||||
|
||||
Reference in New Issue
Block a user