Update transformer.py
This commit is contained in:
@@ -267,7 +267,7 @@ class Transformer(nn.Module):
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
# 如果不统计 embedding 的参数,就减去
|
||||
if non_embedding:
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
n_params -= self.transformer.wte.weight.numel()
|
||||
return n_params
|
||||
|
||||
'''初始化权重'''
|
||||
|
||||
Reference in New Issue
Block a user