diff --git a/docs/chapter5/code/k_model.py b/docs/chapter5/code/k_model.py index bc4cef9..576a861 100644 --- a/docs/chapter5/code/k_model.py +++ b/docs/chapter5/code/k_model.py @@ -335,20 +335,20 @@ class Transformer(PreTrainedModel): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor: + def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ - tokens: Optional[torch.Tensor], 输入 token 张量。 - targets: Optional[torch.Tensor], 目标 token 张量。 - kv_cache: bool, 是否使用键值缓存。 - - keyargs: 其他关键字参数。 + - kwargs: 其他关键字参数。 - self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。 """ - if 'input_ids' in keyargs: - tokens = keyargs['input_ids'] - if 'attention_mask' in keyargs: - targets = keyargs['attention_mask'] + if 'input_ids' in kwargs: + tokens = kwargs['input_ids'] + if 'attention_mask' in kwargs: + targets = kwargs['attention_mask'] # 前向传播函数 _bsz, seqlen = tokens.shape diff --git a/docs/chapter5/第五章 动手搭建大模型.md b/docs/chapter5/第五章 动手搭建大模型.md index 43c3455..ad456b4 100644 --- a/docs/chapter5/第五章 动手搭建大模型.md +++ b/docs/chapter5/第五章 动手搭建大模型.md @@ -554,20 +554,20 @@ class Transformer(PreTrainedModel): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor: + def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ - tokens: Optional[torch.Tensor], 输入 token 张量。 - targets: Optional[torch.Tensor], 目标 token 张量。 - kv_cache: bool, 是否使用键值缓存。 - - keyargs: 其他关键字参数。 + - kwargs: 其他关键字参数。 - self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。 """ - if 'input_ids' in keyargs: - tokens = keyargs['input_ids'] - if 'attention_mask' in keyargs: - targets = keyargs['attention_mask'] + if 'input_ids' in kwargs: + tokens = kwargs['input_ids'] + if 'attention_mask' in kwargs: + targets = kwargs['attention_mask'] # 前向传播函数 _bsz, seqlen = tokens.shape