fix(chapter5): align labels/attention_mask semantics and add padding-aware batch generation (#170)

This commit was merged in pull request #170.
This commit is contained in:
Founce
2026-02-26 15:34:10 +08:00
committed by GitHub
parent 827808c1e3
commit 5dd78a0fe8
7 changed files with 184 additions and 42 deletions

View File

@@ -13,7 +13,7 @@ class PretrainDataset(Dataset):
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 预计算每行的起始字节偏移量
self._offsets = []
with open(data_path, 'rb') as f:
@@ -51,7 +51,7 @@ class SFTDataset(Dataset):
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
@@ -116,4 +116,4 @@ class SFTDataset(Dataset):
X = np.array(input_id[:-1]).astype(np.int64)
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)

View File

@@ -196,6 +196,8 @@ def init_model():
# 从本地路径加载预训练的分词器
tokenizer = AutoTokenizer.from_pretrained('./tokenizer_k/')
if tokenizer.pad_token_id is not None:
lm_config.pad_token_id = tokenizer.pad_token_id
# 根据配置创建Transformer模型
model = Transformer(lm_config)
@@ -320,4 +322,4 @@ if __name__ == "__main__":
# 开始训练循环
for epoch in range(args.epochs):
train_epoch(epoch)
train_epoch(epoch)

View File

@@ -127,6 +127,8 @@ def init_model():
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained('./tokenizer_k/')
if tokenizer.pad_token_id is not None:
lm_config.pad_token_id = tokenizer.pad_token_id
# 初始化模型
model = Transformer(lm_config)
@@ -224,4 +226,4 @@ if __name__ == "__main__":
# 开始训练
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch)
train_epoch(epoch)

View File

@@ -15,6 +15,15 @@ def export_model(tokenizer_path, model_config, model_ckpt_path, save_directory):
ModelConfig.register_for_auto_class()
Transformer.register_for_auto_class("AutoModelForCausalLM")
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True,
use_fast=False
)
if tokenizer.pad_token_id is not None:
model_config.pad_token_id = tokenizer.pad_token_id
# 初始化模型
model = Transformer(model_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -31,13 +40,6 @@ def export_model(tokenizer_path, model_config, model_ckpt_path, save_directory):
model.load_state_dict(state_dict, strict=False)
print(f'模型参数: {count_parameters(model)/1e6:.2f}M = {count_parameters(model)/1e9:.2f}B')
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True,
use_fast=False
)
# 保存完整模型和tokenizer
model.save_pretrained(save_directory, safe_serialization=False)
tokenizer.save_pretrained(save_directory)
@@ -56,4 +58,4 @@ if __name__ == '__main__':
model_config=config,
model_ckpt_path='./BeelGroup_sft_model_215M/sft_dim1024_layers18_vocab_size6144.pth',
save_directory="k-model-215M"
)
)

View File

@@ -26,6 +26,7 @@ class ModelConfig(PretrainedConfig):
max_seq_len: int = 512,
dropout: float = 0.0,
flash_attn: bool = True,
pad_token_id: int = 0,
**kwargs,
):
self.dim = dim
@@ -39,6 +40,7 @@ class ModelConfig(PretrainedConfig):
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
self.pad_token_id = pad_token_id
super().__init__(**kwargs)
class RMSNorm(nn.Module):
@@ -177,7 +179,7 @@ class Attention(nn.Module):
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = x.shape
@@ -199,16 +201,40 @@ class Attention(nn.Module):
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
# 根据是否支持Flash Attention选择实现方式。
if self.flash:
# 使用Flash Attention。
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
if key_padding_mask is not None:
causal_mask = torch.ones((seqlen, seqlen), dtype=torch.bool, device=x.device).tril()
full_attn_mask = causal_mask[None, None, :, :] & key_padding_mask
output = torch.nn.functional.scaled_dot_product_attention(
xq,
xk,
xv,
attn_mask=full_attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
else:
output = torch.nn.functional.scaled_dot_product_attention(
xq,
xk,
xv,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
else:
# 使用手动实现的注意力机制。
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen]
if key_padding_mask is not None:
scores = scores.masked_fill(~key_padding_mask, float("-inf"))
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
@@ -272,11 +298,11 @@ class DecoderLayer(nn.Module):
# 定义前馈神经网络计算的归一化层
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin):
def forward(self, x, freqs_cos, freqs_sin, attention_mask: Optional[torch.Tensor] = None):
# 前向传播函数
# 首先输入x经过注意力归一化层然后进行注意力计算结果与输入x相加得到h
# 然后h经过前馈神经网络归一化层然后进行前馈神经网络计算结果与h相加得到输出
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin, attention_mask=attention_mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
@@ -334,6 +360,45 @@ class Transformer(PreTrainedModel):
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _prepare_attention_mask(self, attention_mask: Optional[torch.Tensor], tokens: torch.Tensor) -> Optional[torch.Tensor]:
if attention_mask is None:
return None
if attention_mask.dim() == 4:
attention_mask = attention_mask[:, 0, 0, :]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, 0, :]
attention_mask = attention_mask.to(tokens.device)
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask > 0
if attention_mask.shape != tokens.shape:
raise ValueError(f"attention_mask shape {attention_mask.shape} must match input_ids shape {tokens.shape}")
return attention_mask
def _left_pad_by_attention_mask(
self,
idx: torch.Tensor,
attention_mask: Optional[torch.Tensor],
pad_token_id: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if attention_mask is None or attention_mask.all():
return idx, attention_mask
bsz = idx.size(0)
lengths = attention_mask.long().sum(dim=1)
max_len = max(int(lengths.max().item()), 1)
packed_idx = idx.new_full((bsz, max_len), pad_token_id)
packed_mask = attention_mask.new_zeros((bsz, max_len), dtype=torch.bool)
for row in range(bsz):
valid_len = int(lengths[row].item())
if valid_len <= 0:
continue
valid_tokens = idx[row][attention_mask[row]]
packed_idx[row, max_len - valid_len:] = valid_tokens
packed_mask[row, max_len - valid_len:] = True
return packed_idx, packed_mask
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
@@ -347,8 +412,9 @@ class Transformer(PreTrainedModel):
if 'input_ids' in kwargs:
tokens = kwargs['input_ids']
if 'attention_mask' in kwargs:
targets = kwargs['attention_mask']
if 'labels' in kwargs:
targets = kwargs['labels']
attention_mask = self._prepare_attention_mask(kwargs.get('attention_mask'), tokens)
# 前向传播函数
_bsz, seqlen = tokens.shape
@@ -361,17 +427,30 @@ class Transformer(PreTrainedModel):
# 通过Decoder层
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin)
h = layer(h, freqs_cos, freqs_sin, attention_mask=attention_mask)
# 通过归一化层
h = self.norm(h)
if targets is not None:
# 如果给定了目标,计算损失
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0, reduction='none')
ignore_index = self.args.pad_token_id if self.args.pad_token_id is not None else 0
if torch.any(targets == -100):
ignore_index = -100
self.last_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=ignore_index,
reduction='none'
)
else:
# 推理时的小优化:只对最后一个位置的输出进行前向传播
logits = self.output(h[:, [-1], :])
if attention_mask is None:
logits = self.output(h[:, [-1], :])
else:
full_logits = self.output(h)
last_token_pos = attention_mask.long().sum(dim=1).clamp(min=1) - 1
logits = full_logits[torch.arange(_bsz, device=tokens.device), last_token_pos].unsqueeze(1)
self.last_loss = None
# 设置输出
@@ -381,18 +460,36 @@ class Transformer(PreTrainedModel):
@torch.inference_mode()
def generate(self, idx, stop_id=None, max_new_tokens=256, temperature=1.0, top_k=None):
def generate(
self,
idx,
stop_id=None,
max_new_tokens=256,
temperature=1.0,
top_k=None,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None
):
"""
给定输入序列 idx形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
在 model.eval() 模式下运行。效率较低的采样版本没有使用键k/v cache。
"""
if pad_token_id is None:
pad_token_id = self.args.pad_token_id if self.args.pad_token_id is not None else 0
attention_mask = self._prepare_attention_mask(attention_mask, idx)
idx, attention_mask = self._left_pad_by_attention_mask(idx, attention_mask, pad_token_id)
finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device)
index = idx.shape[1]
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
mask_cond = None
if attention_mask is not None:
mask_cond = attention_mask if attention_mask.size(1) <= self.args.max_seq_len else attention_mask[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond).logits
logits = self(idx_cond, attention_mask=mask_cond).logits
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
if temperature == 0.0:
@@ -406,13 +503,24 @@ class Transformer(PreTrainedModel):
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
if idx_next == stop_id:
break
prev_finished = finished.clone()
if stop_id is not None:
if prev_finished.any():
fill_token = pad_token_id if pad_token_id is not None else stop_id
idx_next = torch.where(prev_finished[:, None], torch.full_like(idx_next, fill_token), idx_next)
finished = prev_finished | idx_next[:, 0].eq(stop_id)
# 将采样的索引添加到序列中并继续
idx = torch.cat((idx, idx_next), dim=1)
if attention_mask is not None:
next_mask = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)
if prev_finished.any():
next_mask[prev_finished] = False
attention_mask = torch.cat((attention_mask, next_mask), dim=1)
if stop_id is not None and finished.all():
break
return idx[:, index:] # 只返回生成的token
@@ -573,7 +681,9 @@ class Transformer(PreTrainedModel):
temperature=1.0,
top_k=None,
do_sample=False,
num_beams=1
num_beams=1,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None
):
"""
高级文本生成函数,支持三种解码策略:
@@ -609,19 +719,27 @@ class Transformer(PreTrainedModel):
num_beams = 1
if top_k is not None and top_k < 1:
top_k = None
if pad_token_id is None:
pad_token_id = self.args.pad_token_id if self.args.pad_token_id is not None else 0
attention_mask = self._prepare_attention_mask(attention_mask, idx)
idx, attention_mask = self._left_pad_by_attention_mask(idx, attention_mask, pad_token_id)
# 束搜索逻辑
if not do_sample and num_beams > 1:
return self._beam_search(idx, max_new_tokens, num_beams, temperature, top_k, stop_id)
# 贪婪解码和随机采样逻辑
finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device)
index = idx.shape[1]
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
mask_cond = None
if attention_mask is not None:
mask_cond = attention_mask if attention_mask.size(1) <= self.args.max_seq_len else attention_mask[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond).logits
logits = self(idx_cond, attention_mask=mask_cond).logits
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
# 根据参数选择解码策略
@@ -635,12 +753,23 @@ class Transformer(PreTrainedModel):
# 低温度下的随机采样(接近贪婪)
idx_next = self._random_sample(logits, temperature, top_k)
# 检查停止条件
if stop_id is not None and idx_next[0, 0] == stop_id:
break
prev_finished = finished.clone()
if stop_id is not None:
if prev_finished.any():
fill_token = pad_token_id if pad_token_id is not None else stop_id
idx_next = torch.where(prev_finished[:, None], torch.full_like(idx_next, fill_token), idx_next)
finished = prev_finished | idx_next[:, 0].eq(stop_id)
# 将选择的token添加到序列中
idx = torch.cat((idx, idx_next), dim=1)
if attention_mask is not None:
next_mask = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)
if prev_finished.any():
next_mask[prev_finished] = False
attention_mask = torch.cat((attention_mask, next_mask), dim=1)
if stop_id is not None and finished.all():
break
return idx[:, index:] # 只返回生成的token
@@ -649,6 +778,7 @@ if __name__ == '__main__':
args = ModelConfig(
dim=1024,
n_layers=18,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0,
)
# 实例化LLaMA2Model
model = Transformer(args=args)
@@ -670,4 +800,4 @@ if __name__ == '__main__':
print("Y shape :", Y.shape)
# 将输入张量传入模型
output = model(X, Y)
output = model(X, Y)

View File

@@ -33,10 +33,18 @@ class TextGenerator:
# 根据 dtype 选择适当的自动混合精度上下文
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[self.dtype]
self.ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype)
# 初始化分词器
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_model_path) # 根据指定的路径加载分词器
# 加载模型检查点文件
checkpoint_dict = torch.load(self.checkpoint, map_location=self.device) # 加载模型参数 # 初始化模型参数
self.model = Transformer(ModelConfig(dim=1024, n_layers=18)) # 实例化 Transformer 模型
self.model = Transformer(
ModelConfig(
dim=1024,
n_layers=18,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
)
) # 实例化 Transformer 模型
sunwanted_prefix = '_orig_mod.'
for k, v in list(checkpoint_dict.items()):
if k.startswith(sunwanted_prefix):
@@ -50,8 +58,6 @@ class TextGenerator:
self.model.eval()
# 将模型放置到正确的设备上GPU 或 CPU
self.model.to(self.device)
# 初始化分词器
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_model_path) # 根据指定的路径加载分词器
def chat_template(self, prompt):
message = [

View File

@@ -568,8 +568,8 @@ class Transformer(PreTrainedModel):
if 'input_ids' in kwargs:
tokens = kwargs['input_ids']
if 'attention_mask' in kwargs:
targets = kwargs['attention_mask']
if 'labels' in kwargs:
targets = kwargs['labels']
# 前向传播函数
_bsz, seqlen = tokens.shape
@@ -1306,7 +1306,7 @@ class PretrainDataset(Dataset):
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# 预计算每行的起始字节偏移量
self._offsets = []
with open(data_path, 'rb') as f:
@@ -1365,7 +1365,7 @@ class SFTDataset(Dataset):
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
@@ -2249,4 +2249,4 @@ Sample 2:
[6] Jingyao Gong. (2023). *minimind: Minimalist LLM implementation*. GitHub repository. https://github.com/jingyaogong/minimind
[7] Mobvoi. (2023). *seq-monkey-data: Llama2 training/inference data*. GitHub repository. https://github.com/mobvoi/seq-monkey-data
[7] Mobvoi. (2023). *seq-monkey-data: Llama2 training/inference data*. GitHub repository. https://github.com/mobvoi/seq-monkey-data