Update 第二章 Transformer架构.md
This commit is contained in:
@@ -478,7 +478,7 @@ class EncoderLayer(nn.Module):
|
||||
# Encoder 不需要掩码,传入 is_causal=False
|
||||
self.attention = MultiHeadAttention(args, is_causal=False)
|
||||
self.fnn_norm = LayerNorm(args.n_embd)
|
||||
self.feed_forward = MLP(args)
|
||||
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# Layer Norm
|
||||
@@ -528,7 +528,7 @@ class DecoderLayer(nn.Module):
|
||||
self.attention = MultiHeadAttention(args, is_causal=False)
|
||||
self.ffn_norm = LayerNorm(args.n_embd)
|
||||
# 第三个部分是 MLP
|
||||
self.feed_forward = MLP(args)
|
||||
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
|
||||
|
||||
def forward(self, x, enc_out):
|
||||
# Layer Norm
|
||||
|
||||
Reference in New Issue
Block a user