Update 第二章 Transformer架构.md

This commit is contained in:
Logan Zou
2025-07-30 20:52:28 +08:00
committed by GitHub
parent 18d1f56840
commit d278182a90

View File

@@ -478,7 +478,7 @@ class EncoderLayer(nn.Module):
# Encoder 不需要掩码,传入 is_causal=False
self.attention = MultiHeadAttention(args, is_causal=False)
self.fnn_norm = LayerNorm(args.n_embd)
self.feed_forward = MLP(args)
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
def forward(self, x):
# Layer Norm
@@ -528,7 +528,7 @@ class DecoderLayer(nn.Module):
self.attention = MultiHeadAttention(args, is_causal=False)
self.ffn_norm = LayerNorm(args.n_embd)
# 第三个部分是 MLP
self.feed_forward = MLP(args)
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
def forward(self, x, enc_out):
# Layer Norm