refactor: 将参数名从keyargs改为kwargs以符合惯例

修改forward方法的参数命名,使其更符合Python常用命名规范
This commit is contained in:
KMnO4-zx
2025-08-07 19:37:01 +08:00
parent ebe52dc086
commit d35df306ed
2 changed files with 12 additions and 12 deletions

View File

@@ -335,20 +335,20 @@ class Transformer(PreTrainedModel):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
- tokens: Optional[torch.Tensor], 输入 token 张量。
- targets: Optional[torch.Tensor], 目标 token 张量。
- kv_cache: bool, 是否使用键值缓存。
- keyargs: 其他关键字参数。
- kwargs: 其他关键字参数。
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
"""
if 'input_ids' in keyargs:
tokens = keyargs['input_ids']
if 'attention_mask' in keyargs:
targets = keyargs['attention_mask']
if 'input_ids' in kwargs:
tokens = kwargs['input_ids']
if 'attention_mask' in kwargs:
targets = kwargs['attention_mask']
# 前向传播函数
_bsz, seqlen = tokens.shape

View File

@@ -554,20 +554,20 @@ class Transformer(PreTrainedModel):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
- tokens: Optional[torch.Tensor], 输入 token 张量。
- targets: Optional[torch.Tensor], 目标 token 张量。
- kv_cache: bool, 是否使用键值缓存。
- keyargs: 其他关键字参数。
- kwargs: 其他关键字参数。
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
"""
if 'input_ids' in keyargs:
tokens = keyargs['input_ids']
if 'attention_mask' in keyargs:
targets = keyargs['attention_mask']
if 'input_ids' in kwargs:
tokens = kwargs['input_ids']
if 'attention_mask' in kwargs:
targets = kwargs['attention_mask']
# 前向传播函数
_bsz, seqlen = tokens.shape