fix: typo of torch dimension (#151)
This commit was merged in pull request #151.
This commit is contained in:
@@ -99,7 +99,7 @@ class RMSNorm(nn.Module):
|
||||
return output * self.weight
|
||||
```
|
||||
|
||||
并且,我们可以用下面的代码来对`RMSNorm`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 288])`,与我们输入的形状一致,说明模块的实现是正确的,归一化并不会改变输入的形状。
|
||||
并且,我们可以用下面的代码来对`RMSNorm`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的,归一化并不会改变输入的形状。
|
||||
|
||||
```python
|
||||
norm = RMSNorm(args.dim, args.norm_eps)
|
||||
|
||||
Reference in New Issue
Block a user