Files
happy-llm/docs/chapter5/code/k_model.py

804 lines
35 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import math
import inspect
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
class ModelConfig(PretrainedConfig):
model_type = "Tiny-K"
def __init__(
self,
dim: int = 768,
n_layers: int = 12,
n_heads: int = 16,
n_kv_heads: int = 8,
vocab_size: int = 6144,
hidden_dim: int = None,
multiple_of: int = 64,
norm_eps: float = 1e-5,
max_seq_len: int = 512,
dropout: float = 0.0,
flash_attn: bool = True,
pad_token_id: int = 0,
**kwargs,
):
self.dim = dim
self.n_layers = n_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.max_seq_len = max_seq_len
self.dropout = dropout
self.flash_attn = flash_attn
self.pad_token_id = pad_token_id
super().__init__(**kwargs)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
# eps是为了防止除以0的情况
self.eps = eps
# weight是一个可学习的参数全部初始化为1
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# 计算RMSNorm的核心部分
# x.pow(2).mean(-1, keepdim=True)计算了输入x的平方的均值
# torch.rsqrt是平方根的倒数这样就得到了RMSNorm的分母部分再加上eps防止分母为0
# 最后乘以x得到RMSNorm的结果
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# forward函数是模型的前向传播
# 首先将输入x转为float类型然后进行RMSNorm最后再转回原来的数据类型
# 最后乘以weight这是RMSNorm的一个可学习的缩放因子
output = self._norm(x.float()).type_as(x)
return output * self.weight
# 获得旋转嵌入的实部和虚部
# 注意此处的dim应为 dim//n_head因为我们是对每个head进行旋转嵌入
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# torch.arange(0, dim, 2)[: (dim // 2)].float()生成了一个从0开始步长为2的序列长度为dim的一半
# 然后每个元素除以dim再取theta的倒数得到频率
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成一个从0到end的序列长度为end
t = torch.arange(end, device=freqs.device)
# 计算外积得到一个二维矩阵每一行是t的元素乘以freqs的元素
freqs = torch.outer(t, freqs).float()
# 计算频率的余弦值,得到实部
freqs_cos = torch.cos(freqs)
# 计算频率的正弦值,得到虚部
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
# 此函数的作用是将freqs_cis调整为与x的形状相同以便能够与x进行广播操作
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# 获取x的维度数
ndim = x.ndim
# 断言确保1在x的维度范围内
assert 0 <= 1 < ndim
# 断言确保freqs_cis的形状与x的第二维和最后一维相同
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
# 构造一个新的形状除了第二维和最后一维其他维度都为1这样做是为了能够将freqs_cis与x进行广播操作
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# 将freqs_cis调整为新的形状并返回
return freqs_cis.view(shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 将查询和键张量转换为浮点数,并重塑形状以分离实部和虚部
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# 重新塑形频率张量以进行广播
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
# 应用旋转,分别计算旋转后的实部和虚部
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 将最后两个维度合并,并还原为原始张量的形状
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
# 获取输入张量的形状:批量大小、序列长度、键/值对头的数量、每个头的维度大小
bs, slen, n_kv_heads, head_dim = x.shape
# 如果重复次数为1则不需要重复直接返回原始张量
if n_rep == 1:
return x
# 对张量进行扩展和重塑操作以重复键值对
return (
x[:, :, :, None, :] # 在第四个维度(头的维度前)添加一个新的维度
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 将新添加的维度扩展到n_rep大小实现重复的效果
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # 重新塑形,合并键/值对头的数量和重复次数的维度
)
class Attention(nn.Module):
def __init__(self, args: ModelConfig):
super().__init__()
# 根据是否指定n_kv_heads确定用于键key和值value的头的数量。
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# 确保总头数可以被键值头数整除。
assert args.n_heads % self.n_kv_heads == 0
# 模型并行处理大小默认为1。
model_parallel_size = 1
# 本地计算头数,等于总头数除以模型并行处理大小。
self.n_local_heads = args.n_heads // model_parallel_size
# 本地键值头数,等于键值头数除以模型并行处理大小。
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
# 重复次数,用于扩展键和值的尺寸。
self.n_rep = self.n_local_heads // self.n_local_kv_heads
# 每个头的维度,等于模型维度除以头的总数。
self.head_dim = args.dim // args.n_heads
# 定义权重矩阵。
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
# 输出权重矩阵。
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# 定义dropout。
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
# 保存dropout概率。
self.dropout = args.dropout
# 检查是否使用Flash Attention需要PyTorch >= 2.0)。
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
# 若不支持Flash Attention则使用手动实现的注意力机制并设置mask。
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# 创建一个上三角矩阵,用于遮蔽未来信息。
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
# 注册为模型的缓冲区
self.register_buffer("mask", mask)
def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
bsz, seqlen, _ = x.shape
# 计算查询Q、键K、值V
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# 调整形状以适应头的维度。
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 应用旋转位置嵌入RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# 对键和值进行扩展以适应重复次数。
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# 将头作为批次维度处理。
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
# 根据是否支持Flash Attention选择实现方式。
if self.flash:
# 使用Flash Attention。
if key_padding_mask is not None:
causal_mask = torch.ones((seqlen, seqlen), dtype=torch.bool, device=x.device).tril()
full_attn_mask = causal_mask[None, None, :, :] & key_padding_mask
output = torch.nn.functional.scaled_dot_product_attention(
xq,
xk,
xv,
attn_mask=full_attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
else:
output = torch.nn.functional.scaled_dot_product_attention(
xq,
xk,
xv,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
else:
# 使用手动实现的注意力机制。
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen]
if key_padding_mask is not None:
scores = scores.masked_fill(~key_padding_mask, float("-inf"))
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
# 恢复时间维度并合并头。
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# 最终投影回残差流。
output = self.wo(output)
output = self.resid_dropout(output)
return output
class MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
# 如果没有指定隐藏层的维度我们将其设置为输入维度的4倍
# 然后将其减少到2/3最后确保它是multiple_of的倍数
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# 定义第一层线性变换,从输入维度到隐藏维度
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
# 定义第二层线性变换,从隐藏维度到输入维度
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
# 定义第三层线性变换,从输入维度到隐藏维度
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
# 定义dropout层用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 前向传播函数
# 首先输入x通过第一层线性变换和SILU激活函数
# 然后结果乘以输入x通过第三层线性变换的结果
# 最后通过第二层线性变换和dropout层
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class DecoderLayer(nn.Module):
def __init__(self, layer_id: int, args: ModelConfig):
super().__init__()
# 定义多头注意力的头数
self.n_heads = args.n_heads
# 定义输入维度
self.dim = args.dim
# 定义每个头的维度,等于输入维度除以头数
self.head_dim = args.dim // args.n_heads
# 定义LLaMA2Attention对象用于进行多头注意力计算
self.attention = Attention(args)
# 定义LLaMAMLP对象用于进行前馈神经网络计算
self.feed_forward = MLP(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
# 定义层的ID
self.layer_id = layer_id
# 定义注意力计算的归一化层
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
# 定义前馈神经网络计算的归一化层
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x, freqs_cos, freqs_sin, attention_mask: Optional[torch.Tensor] = None):
# 前向传播函数
# 首先输入x经过注意力归一化层然后进行注意力计算结果与输入x相加得到h
# 然后h经过前馈神经网络归一化层然后进行前馈神经网络计算结果与h相加得到输出
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin, attention_mask=attention_mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
class Transformer(PreTrainedModel):
config_class = ModelConfig # 配置类
last_loss: Optional[torch.Tensor] # 记录最后一次计算的损失
def __init__(self, args: ModelConfig = None):
super().__init__(args)
# 初始化模型参数
self.args = args
# 词汇表大小
self.vocab_size = args.vocab_size
# 层数
self.n_layers = args.n_layers
# 词嵌入层
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
# Dropout层
self.dropout = nn.Dropout(args.dropout)
# Decoder层
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(DecoderLayer(layer_id, args))
# 归一化层
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
# 输出层
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
# 将词嵌入层的权重与输出层的权重共享
self.tok_embeddings.weight = self.output.weight
# 预计算相对位置嵌入的频率
freqs_cos, freqs_sin = precompute_freqs_cis(self.args.dim // self.args.n_heads, self.args.max_seq_len)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
# 初始化所有权重
self.apply(self._init_weights)
# 对残差投影进行特殊的缩放初始化
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * args.n_layers))
# 初始化最后一次前向传播的损失属性
self.last_loss = None
self.OUT = CausalLMOutputWithPast() # 输出容器
self._no_split_modules = [name for name, _ in self.named_modules()] # 不分割的模块列表
def _init_weights(self, module):
# 初始化权重的函数
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _prepare_attention_mask(self, attention_mask: Optional[torch.Tensor], tokens: torch.Tensor) -> Optional[torch.Tensor]:
if attention_mask is None:
return None
if attention_mask.dim() == 4:
attention_mask = attention_mask[:, 0, 0, :]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, 0, :]
attention_mask = attention_mask.to(tokens.device)
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask > 0
if attention_mask.shape != tokens.shape:
raise ValueError(f"attention_mask shape {attention_mask.shape} must match input_ids shape {tokens.shape}")
return attention_mask
def _left_pad_by_attention_mask(
self,
idx: torch.Tensor,
attention_mask: Optional[torch.Tensor],
pad_token_id: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if attention_mask is None or attention_mask.all():
return idx, attention_mask
bsz = idx.size(0)
lengths = attention_mask.long().sum(dim=1)
max_len = max(int(lengths.max().item()), 1)
packed_idx = idx.new_full((bsz, max_len), pad_token_id)
packed_mask = attention_mask.new_zeros((bsz, max_len), dtype=torch.bool)
for row in range(bsz):
valid_len = int(lengths[row].item())
if valid_len <= 0:
continue
valid_tokens = idx[row][attention_mask[row]]
packed_idx[row, max_len - valid_len:] = valid_tokens
packed_mask[row, max_len - valid_len:] = True
return packed_idx, packed_mask
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
"""
- tokens: Optional[torch.Tensor], 输入 token 张量。
- targets: Optional[torch.Tensor], 目标 token 张量。
- kv_cache: bool, 是否使用键值缓存。
- kwargs: 其他关键字参数。
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
"""
if 'input_ids' in kwargs:
tokens = kwargs['input_ids']
if 'labels' in kwargs:
targets = kwargs['labels']
attention_mask = self._prepare_attention_mask(kwargs.get('attention_mask'), tokens)
# 前向传播函数
_bsz, seqlen = tokens.shape
# 通过词嵌入层和Dropout层
h = self.tok_embeddings(tokens)
h = self.dropout(h)
# 获取相对位置嵌入的频率
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]
# 通过Decoder层
for layer in self.layers:
h = layer(h, freqs_cos, freqs_sin, attention_mask=attention_mask)
# 通过归一化层
h = self.norm(h)
if targets is not None:
# 如果给定了目标,计算损失
logits = self.output(h)
ignore_index = self.args.pad_token_id if self.args.pad_token_id is not None else 0
if torch.any(targets == -100):
ignore_index = -100
self.last_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=ignore_index,
reduction='none'
)
else:
# 推理时的小优化:只对最后一个位置的输出进行前向传播
if attention_mask is None:
logits = self.output(h[:, [-1], :])
else:
full_logits = self.output(h)
last_token_pos = attention_mask.long().sum(dim=1).clamp(min=1) - 1
logits = full_logits[torch.arange(_bsz, device=tokens.device), last_token_pos].unsqueeze(1)
self.last_loss = None
# 设置输出
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('last_loss', self.last_loss)
return self.OUT
@torch.inference_mode()
def generate(
self,
idx,
stop_id=None,
max_new_tokens=256,
temperature=1.0,
top_k=None,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None
):
"""
给定输入序列 idx形状为 (bz,seq_len) 的长整型张量),通过多次生成新 token 来完成序列。
在 model.eval() 模式下运行。效率较低的采样版本没有使用键k/v cache。
"""
if pad_token_id is None:
pad_token_id = self.args.pad_token_id if self.args.pad_token_id is not None else 0
attention_mask = self._prepare_attention_mask(attention_mask, idx)
idx, attention_mask = self._left_pad_by_attention_mask(idx, attention_mask, pad_token_id)
finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device)
index = idx.shape[1]
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
mask_cond = None
if attention_mask is not None:
mask_cond = attention_mask if attention_mask.size(1) <= self.args.max_seq_len else attention_mask[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond, attention_mask=mask_cond).logits
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
if temperature == 0.0:
# 选择最有可能的索引
_, idx_next = torch.topk(logits, k=1, dim=-1)
else:
# 缩放 logits 并应用 softmax
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
prev_finished = finished.clone()
if stop_id is not None:
if prev_finished.any():
fill_token = pad_token_id if pad_token_id is not None else stop_id
idx_next = torch.where(prev_finished[:, None], torch.full_like(idx_next, fill_token), idx_next)
finished = prev_finished | idx_next[:, 0].eq(stop_id)
# 将采样的索引添加到序列中并继续
idx = torch.cat((idx, idx_next), dim=1)
if attention_mask is not None:
next_mask = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)
if prev_finished.any():
next_mask[prev_finished] = False
attention_mask = torch.cat((attention_mask, next_mask), dim=1)
if stop_id is not None and finished.all():
break
return idx[:, index:] # 只返回生成的token
def _greedy_decode(self, logits: torch.Tensor) -> torch.Tensor:
"""
贪婪解码选择概率最大的token
Args:
logits: 模型输出的logits形状为 (batch_size, vocab_size)
Returns:
选择的token索引形状为 (batch_size, 1)
"""
_, idx_next = torch.topk(logits, k=1, dim=-1)
return idx_next
def _random_sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
"""
随机采样基于概率分布随机选择token
Args:
logits: 模型输出的logits形状为 (batch_size, vocab_size)
temperature: 温度参数,控制随机性
top_k: 只考虑概率最高的k个token
Returns:
选择的token索引形状为 (batch_size, 1)
"""
# 缩放 logits
logits = logits / temperature
# 应用top-k过滤
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
# 将不在 top-k 内的 logits 设为负无穷
logits[logits < v[:, [-1]]] = -float('Inf')
# 计算概率并采样
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
return idx_next
def _beam_search(self, idx: torch.Tensor, max_new_tokens: int, num_beams: int,
temperature: float = 1.0, top_k: int = None, stop_id: int = None) -> torch.Tensor:
"""
束搜索:维护多个候选序列,选择最优路径
束搜索的核心思想在每一步生成时不是只选择一个最佳token
而是保留多个候选路径,最终选择累积概率最高的完整序列。
Args:
idx: 输入序列,形状为 (batch_size, seq_len)
max_new_tokens: 最大生成token数量
num_beams: 束宽度,表示保留的候选路径数量
temperature: 温度参数,控制分布的平滑程度
top_k: top-k过滤参数限制候选token范围
stop_id: 停止生成的token ID遇到则停止
Returns:
生成的token序列形状为 (batch_size, generated_length)
只返回新生成的部分,不包含原始输入序列
"""
# 获取输入序列的基本信息
batch_size = idx.shape[0] # 批次大小通常为1
seq_len = idx.shape[1] # 输入序列长度
# 初始化束:创建 num_beams 个候选序列
beams = [idx.clone() for _ in range(num_beams)]
# 初始化每个候选序列的累积对数概率分数
beam_scores = torch.zeros(num_beams, device=idx.device)
# 第一个候选是原始输入序列分数为0
beam_scores[0] = 0.0
# 其他候选初始分数设为负无穷,表示尚未生成
beam_scores[1:] = float('-inf')
# 主循环逐步生成新的token最多生成 max_new_tokens 个
for step in range(max_new_tokens):
# 每轮迭代收集新的候选序列和分数
new_beams = [] # 新的候选序列列表
new_scores = [] # 对应的分数列表
# 遍历当前的所有候选序列
for beam_idx, beam in enumerate(beams):
# 跳过无效候选(分数为负无穷的序列)
if beam_scores[beam_idx] == float('-inf'):
continue
# 序列长度检查:如果超过最大长度,截取最后的部分
beam_cond = beam if beam.size(1) <= self.args.max_seq_len else beam[:, -self.args.max_seq_len:]
# 前向传播:获取模型对当前序列的预测
output = self(beam_cond)
# 提取最后一个位置的logits用于预测下一个token
logits = output.logits[:, -1, :] # 形状: (1, vocab_size)
# 温度缩放调整logits的分布
if temperature != 1.0:
logits = logits / temperature
# 温度 > 1分布更平滑增加随机性
# 温度 < 1分布更尖锐更确定
# Top-k过滤限制候选token的范围提高质量
if top_k is not None:
# 找到logits中前top_k个最大的值
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
# 将不在前top_k内的logits设为负无穷
logits[logits < v[:, [-1]]] = -float('Inf')
# 这样采样时只会考虑前top_k个token
# 计算对数概率使用log_softmax避免数值不稳定
log_probs = F.log_softmax(logits, dim=-1)
# 获取前 num_beams 个最可能的候选token
# 注意这里的top-k与上面的top-k不同
# 上面的top-k是全局过滤这里是束搜索的分支选择
top_log_probs, top_indices = torch.topk(log_probs, k=num_beams, dim=-1)
# 为当前候选序列生成 num_beams 个扩展序列
for k in range(num_beams):
# 选择第k个候选token
token = top_indices[:, k:k+1] # token ID
log_prob = top_log_probs[:, k] # 对应的对数概率
# 扩展序列将新token添加到当前序列末尾
new_beam = torch.cat([beam, token], dim=1)
# 更新累积分数:原序列分数 + 新token的对数概率
new_score = beam_scores[beam_idx] + log_prob.item()
# 保存新的候选序列和分数
new_beams.append(new_beam)
new_scores.append(new_score)
# 安全检查:如果没有生成任何有效候选,提前结束
if not new_beams:
break
# 筛选最佳候选:从所有新生成的候选中选择分数最高的 num_beams 个
# 按分数降序排序,获取索引
sorted_indices = sorted(range(len(new_scores)), key=lambda i: new_scores[i], reverse=True)
# 选择前 num_beams 个最佳候选
beams = [new_beams[i] for i in sorted_indices[:num_beams]]
beam_scores = [new_scores[i] for i in sorted_indices[:num_beams]]
# 停止条件检查检查最佳序列是否以停止token结尾
if stop_id is not None and beams[0][0, -1] == stop_id:
break
# 返回得分最高的序列,只返回新生成的部分(去掉原始输入)
# beams[0] 是最终得分最高的完整序列
# [:, seq_len:] 切片只保留生成部分
return beams[0][:, seq_len:]
@torch.inference_mode()
def generate_super(self,
idx,
stop_id=None,
max_new_tokens=256,
temperature=1.0,
top_k=None,
do_sample=False,
num_beams=1,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None
):
"""
高级文本生成函数,支持三种解码策略:
1. 贪婪解码Greedy Search
- 参数do_sample=False, num_beams=1
- 特点每步选择概率最大的token速度快、结果确定
2. 随机采样Random Sampling
- 参数do_sample=True, num_beams=1
- 特点基于概率分布随机采样可配合temperature和top-k控制多样性
3. 束搜索Beam Search
- 参数do_sample=False, num_beams>1
- 特点:维护多条候选路径,选择总概率最高的序列,质量更高但速度较慢
Args:
idx: 输入序列张量,形状为 (batch_size, seq_len)
stop_id: 停止生成的token ID
max_new_tokens: 最大生成token数量
temperature: 温度参数,控制随机性,越高越随机
top_k: 只考虑概率最高的k个tokenNone表示不考虑
do_sample: 是否使用随机采样False时使用确定性解码
num_beams: 束搜索的束宽度1表示不使用束搜索
Returns:
生成的token序列形状为 (batch_size, generated_length)
"""
# 参数验证
if temperature <= 0:
temperature = 0.001 # 避免除零错误
if num_beams < 1:
num_beams = 1
if top_k is not None and top_k < 1:
top_k = None
if pad_token_id is None:
pad_token_id = self.args.pad_token_id if self.args.pad_token_id is not None else 0
attention_mask = self._prepare_attention_mask(attention_mask, idx)
idx, attention_mask = self._left_pad_by_attention_mask(idx, attention_mask, pad_token_id)
# 束搜索逻辑
if not do_sample and num_beams > 1:
return self._beam_search(idx, max_new_tokens, num_beams, temperature, top_k, stop_id)
# 贪婪解码和随机采样逻辑
finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device)
index = idx.shape[1]
for _ in range(max_new_tokens):
# 如果序列上下文过长,截断它到最大长度
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
mask_cond = None
if attention_mask is not None:
mask_cond = attention_mask if attention_mask.size(1) <= self.args.max_seq_len else attention_mask[:, -self.args.max_seq_len:]
# 前向传播获取序列中最后一个位置的 logits
logits = self(idx_cond, attention_mask=mask_cond).logits
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
# 根据参数选择解码策略
if do_sample:
idx_next = self._random_sample(logits, temperature, top_k)
else:
# 当temperature=0时使用贪婪解码
if temperature < 0.1:
idx_next = self._greedy_decode(logits)
else:
# 低温度下的随机采样(接近贪婪)
idx_next = self._random_sample(logits, temperature, top_k)
prev_finished = finished.clone()
if stop_id is not None:
if prev_finished.any():
fill_token = pad_token_id if pad_token_id is not None else stop_id
idx_next = torch.where(prev_finished[:, None], torch.full_like(idx_next, fill_token), idx_next)
finished = prev_finished | idx_next[:, 0].eq(stop_id)
# 将选择的token添加到序列中
idx = torch.cat((idx, idx_next), dim=1)
if attention_mask is not None:
next_mask = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device)
if prev_finished.any():
next_mask[prev_finished] = False
attention_mask = torch.cat((attention_mask, next_mask), dim=1)
if stop_id is not None and finished.all():
break
return idx[:, index:] # 只返回生成的token
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("tokenizer_k")
args = ModelConfig(
dim=1024,
n_layers=18,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0,
)
# 实例化LLaMA2Model
model = Transformer(args=args)
# 计算model的全部参数
num_params = sum(p.numel() for p in model.parameters())
print(f'LLM总参数量{num_params / 1e6:.3f} 百万')
prompt = "你好呀,今天吃什么呢?你过得怎么样嘞?"
text = f"{tokenizer.bos_token}{prompt}{tokenizer.eos_token}"
print(f"Input text: {text}")
input_id = tokenizer(text).data['input_ids']
print("input_ids :", input_id)
print("dcode_str :", tokenizer.decode(input_id))
X = torch.tensor(input_id[:-1]).unsqueeze(0)
Y = torch.tensor(input_id[1:]).unsqueeze(0)
print("X shape :", X.shape)
print("Y shape :", Y.shape)
# 将输入张量传入模型
output = model(X, Y)