Update 第二章 Transformer架构.md
This commit is contained in:
@@ -747,7 +747,7 @@ class PositionalEncoding(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
# Dropout 层
|
||||
self.dropout = nn.Dropout(p=args.dropout)
|
||||
# self.dropout = nn.Dropout(p=args.dropout)
|
||||
|
||||
# block size 是序列的最大长度
|
||||
pe = torch.zeros(args.block_size, args.n_embd)
|
||||
@@ -765,7 +765,7 @@ class PositionalEncoding(nn.Module):
|
||||
def forward(self, x):
|
||||
# 将位置编码加到 Embedding 结果上
|
||||
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
|
||||
return self.dropout(x)
|
||||
return x
|
||||
```
|
||||
|
||||
### 2.3.3 一个完整的 Transformer
|
||||
|
||||
Reference in New Issue
Block a user