refactor: 将参数名从keyargs改为kwargs以符合惯例
修改forward方法的参数命名,使其更符合Python常用命名规范
This commit is contained in:
@@ -335,20 +335,20 @@ class Transformer(PreTrainedModel):
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
- tokens: Optional[torch.Tensor], 输入 token 张量。
|
||||
- targets: Optional[torch.Tensor], 目标 token 张量。
|
||||
- kv_cache: bool, 是否使用键值缓存。
|
||||
- keyargs: 其他关键字参数。
|
||||
- kwargs: 其他关键字参数。
|
||||
|
||||
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
|
||||
"""
|
||||
|
||||
if 'input_ids' in keyargs:
|
||||
tokens = keyargs['input_ids']
|
||||
if 'attention_mask' in keyargs:
|
||||
targets = keyargs['attention_mask']
|
||||
if 'input_ids' in kwargs:
|
||||
tokens = kwargs['input_ids']
|
||||
if 'attention_mask' in kwargs:
|
||||
targets = kwargs['attention_mask']
|
||||
|
||||
# 前向传播函数
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
@@ -554,20 +554,20 @@ class Transformer(PreTrainedModel):
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
- tokens: Optional[torch.Tensor], 输入 token 张量。
|
||||
- targets: Optional[torch.Tensor], 目标 token 张量。
|
||||
- kv_cache: bool, 是否使用键值缓存。
|
||||
- keyargs: 其他关键字参数。
|
||||
- kwargs: 其他关键字参数。
|
||||
|
||||
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
|
||||
"""
|
||||
|
||||
if 'input_ids' in keyargs:
|
||||
tokens = keyargs['input_ids']
|
||||
if 'attention_mask' in keyargs:
|
||||
targets = keyargs['attention_mask']
|
||||
if 'input_ids' in kwargs:
|
||||
tokens = kwargs['input_ids']
|
||||
if 'attention_mask' in kwargs:
|
||||
targets = kwargs['attention_mask']
|
||||
|
||||
# 前向传播函数
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
Reference in New Issue
Block a user