fix(chapter5): 修复 labels/attention_mask 语义并补齐 padding-aware 批量推理 #170
Reference in New Issue
Block a user
Delete Branch "refs/pull/170/head"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
PR Description
BG
当前
chapter5示例代码中存在 4 个一致性问题:forward中将attention_mask误赋值给targets。loss_mask,模型注意力层未使用padding_mask。-1位置 logits,batch 场景下含 padding 会取到错误位置。pad_token_id不一致(实现写死为0)。本次改动
修复
forward函数,使用labels作为监督信号(不再把attention_mask当 targets)。增加 attention 对
padding_mask的支持Attention.forward新增attention_mask参数。causal + key padding组合 mask。masked_fill(-inf)。attention_mask时按“最后有效 token”取 logits,不再固定[:, -1, :]。generate/generate_super新增attention_mask和pad_token_id参数。stop_id)后的生成。dataset.py改为使用tokenizer.pad_token_id(fallback 为 0)。lm_config.pad_token_id。影响范围
docs/chapter5/code/k_model.pydocs/chapter5/code/dataset.pydocs/chapter5/code/ddp_pretrain.pydocs/chapter5/code/ddp_sft_full.pydocs/chapter5/code/model_sample.pydocs/chapter5/code/export_model.pydocs/chapter5/第五章 动手搭建大模型.md兼容性
model(X, Y)训练调用兼容。labels + attention_mask语义。感谢这个PR!改动审查完毕:
✅ 正确修复了
forward中attention_mask被错误赋值给targets的 bug✅ attention_mask 实现逻辑正确
✅ 批量生成支持正确
代码实现正确,逻辑清晰。LGTM 👍