fix: typo of torch dimension (#151)

This commit was merged in pull request #151.
This commit is contained in:
Genghong Hu
2026-02-25 20:56:19 +08:00
committed by GitHub
parent 77aff4b66a
commit 8908c4a8c3

View File

@@ -99,7 +99,7 @@ class RMSNorm(nn.Module):
return output * self.weight
```
并且,我们可以用下面的代码来对`RMSNorm`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 288])`,与我们输入的形状一致,说明模块的实现是正确的,归一化并不会改变输入的形状。
并且,我们可以用下面的代码来对`RMSNorm`模块进行测试,可以看到代码最终输出的形状为`torch.Size([1, 50, 768])`,与我们输入的形状一致,说明模块的实现是正确的,归一化并不会改变输入的形状。
```python
norm = RMSNorm(args.dim, args.norm_eps)