Update 第二章 Transformer架构.md
This commit is contained in:
@@ -253,54 +253,51 @@ class MultiHeadAttention(nn.Module):
|
||||
super().__init__()
|
||||
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
||||
assert args.dim % args.n_heads == 0
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
# 每个头的维度,等于模型维度除以头的总数。
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.n_heads = args.n_heads
|
||||
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x dim
|
||||
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,
|
||||
# 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接
|
||||
self.wq = nn.Linear(args.dim, self.n_local_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, self.n_local_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, self.n_local_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads)
|
||||
self.wo = nn.Linear(self.n_local_heads * self.head_dim, args.dim, bias=False)
|
||||
self.wq = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x dim(head_dim = dim / n_heads)
|
||||
self.wo = nn.Linear(self.n_heads * self.head_dim, args.dim, bias=False)
|
||||
# 注意力的 dropout
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
# 残差连接的 dropout
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
|
||||
self.is_causal = is_causal
|
||||
|
||||
# 创建一个上三角矩阵,用于遮蔽未来信息
|
||||
# 注意,因为是多头注意力,Mask 矩阵比之前我们定义的多一个维度
|
||||
if is_causal:
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
# 注册为模型的缓冲区
|
||||
self.register_buffer("mask", mask)
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
# 注册为模型的缓冲区
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
|
||||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||||
bsz, seqlen, _ = q.shape
|
||||
|
||||
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, n_embed) -> (B, T, n_embed)
|
||||
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, dim) -> (B, T, dim)
|
||||
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)
|
||||
|
||||
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, C // n_head),然后交换维度,变成 (B, n_head, T, C // n_head)
|
||||
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, dim // n_head),然后交换维度,变成 (B, n_head, T, dim // n_head)
|
||||
# 因为在注意力计算中我们是取了后两个维度参与计算
|
||||
# 为什么要先按B*T*n_head*C//n_head展开再互换1、2维度而不是直接按注意力输入展开,是因为view的展开方式是直接把输入全部排开,
|
||||
# 然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
|
||||
# 注意力计算
|
||||
# 计算 QK^T / sqrt(d_k),维度为 (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
@@ -317,7 +314,7 @@ class MultiHeadAttention(nn.Module):
|
||||
output = torch.matmul(scores, xv)
|
||||
|
||||
# 恢复时间维度并合并头。
|
||||
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, C // n_head),再拼接成 (B, T, n_head * C // n_head)
|
||||
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, dim // n_head),再拼接成 (B, T, n_head * dim // n_head)
|
||||
# contiguous 函数用于重新开辟一块新内存存储,因为Pytorch设置先transpose再view会报错,
|
||||
# 因为view直接基于底层存储得到,然而transpose并不会改变底层存储,因此需要额外存储
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
@@ -326,7 +323,6 @@ class MultiHeadAttention(nn.Module):
|
||||
output = self.wo(output)
|
||||
output = self.resid_dropout(output)
|
||||
return output
|
||||
|
||||
```
|
||||
|
||||
## 2.2 Encoder-Decoder
|
||||
|
||||
Reference in New Issue
Block a user