Compare commits
41 Commits
v1.0.1
...
osquerkkzl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b38c8cb261 | ||
|
|
47164fcca5 | ||
|
|
7b83aa6118 | ||
|
|
de9d9e0048 | ||
|
|
7b091acc64 | ||
|
|
88f31c0d14 | ||
|
|
21bac613c0 | ||
|
|
63e88022f3 | ||
|
|
1c0a0c22e1 | ||
|
|
3afabec1a8 | ||
|
|
72b41341e1 | ||
|
|
b9172031c8 | ||
|
|
46b509c9c1 | ||
|
|
4ed47f3918 | ||
|
|
fc6c8c81ee | ||
|
|
9c461438c7 | ||
|
|
50bd19efb4 | ||
|
|
712415e0a7 | ||
|
|
9098d6527f | ||
|
|
550d9bd40c | ||
|
|
59ea8f65ad | ||
|
|
edbcd3ad38 | ||
|
|
76b3cb848f | ||
|
|
6ce019cb2e | ||
|
|
0e09304c88 | ||
|
|
5ab392358e | ||
|
|
f30ddbcd1a | ||
|
|
d35df306ed | ||
|
|
ebe52dc086 | ||
|
|
0428271b7f | ||
|
|
590363587c | ||
|
|
b7e1a26255 | ||
|
|
9a882a92ed | ||
|
|
d278182a90 | ||
|
|
18d1f56840 | ||
|
|
3a8eb17848 | ||
|
|
f192a4ecd4 | ||
|
|
c889b864a9 | ||
|
|
b7d3e0678e | ||
|
|
a110181cf8 | ||
|
|
9bdf9ed202 |
1375
Extra-Chapter/CDDRS/CDDRS.ipynb
Normal file
BIN
Extra-Chapter/CDDRS/images/pic1.png
Normal file
|
After Width: | Height: | Size: 438 KiB |
BIN
Extra-Chapter/CDDRS/images/pic2.png
Normal file
|
After Width: | Height: | Size: 307 KiB |
BIN
Extra-Chapter/CDDRS/images/pic3.png
Normal file
|
After Width: | Height: | Size: 927 KiB |
879
Extra-Chapter/CDDRS/readme.md
Normal file
@@ -0,0 +1,879 @@
|
||||
# 建筑文档智能RAG审查系统
|
||||
|
||||
一个从零开始实现的建筑文档智能审查系统,旨在帮助开发者理解知识引导检索在专业领域文档审查中的核心原理和实现细节。
|
||||
|
||||
## 项目动机
|
||||
|
||||
建筑施工交底文档的合规性审查是保障施工项目安全性、经济性的关键环节。在施工项目全周期中,各项操作必须符合相关规范条文要求,才能确保建设项目的安全性与可持续性。然而,相关查询参考往往分散在各个项目文件中,传统基于人工的审查方法难以处理庞大复杂的建筑条文,其审查过程需要基于审查人员的经验与专业知识,具有主观性强,耗时长且易出错等弊端。
|
||||
|
||||
随着大语言模型技术的发展,LLM为自动化建筑文档审查带来了新的希望。然而,大语言模型通常使用通用语料进行训练,缺乏建筑相关背景知识,在处理建造背景下的复杂推理问题中会产生严重的幻觉现象。通过使用基于向量相似匹配的RAG方法,可以为LLMs提供初步的相似参考知识,从而减轻基于人工或规则的审查方法难以处理庞大建筑文本所带来的错误率高的问题。
|
||||
|
||||
然而,传统RAG方法在建筑专业文档审查中存在关键局限:由于固定的分块设计,使得文本块之间面临知识信息缺失问题;在检索过程中,使用整句问询嵌入的方法进行相似性匹配,缺少对问询细粒度特征的识别与考量,检索效率低下。在建筑施工交底文档中,这类文档详细阐述了施工工艺特点和方法、质量规格、操作程序以及安全协议,包含大量知识细节且专业性极强。因此需要一个能够精准理解和检索建筑领域专业知识的智能系统。
|
||||
|
||||
因此,本项目提出了一个生成式知识引导的建筑文档审查系统,旨在提升审查的可靠性和准确性。系统具有两大核心创新:首先提出动态语义知识分块策略,构建具有更优语义连贯性和完整性的知识库;其次基于增强的知识表示,提出生成式知识引导检索框架,在语义嵌入检索过程中增强对细粒度信息的关注,从而提高知识参考检索的准确性和建筑文档审查任务中修正的可靠性。
|
||||
|
||||
需要注意的是,由于篇幅限制,我们无法展示完整的整个实现过程,但是,我们将在文档中讲解每个必要的实现步骤以及背后的思考,您可以通过这些内容快速理解如何实现一个建筑文档智能审查系统。
|
||||
|
||||
## 前置实现
|
||||
|
||||
接下来,我们将带领大家,从0开始,实现一个建筑文档智能审查系统。首先,我们将完成一些基本的准备过程。
|
||||
|
||||
### 1. 实现 LLM 模块
|
||||
|
||||
首先我们需要实现 LLM 模块,这是系统中最基本的模块,我们将利用大模型完成文档的清洗,信息提取等工作,可以说本系统的一部分精髓即为使用大模型预先处理文档信息,方便后续进行检索,这里我们使用 DeepSeek 的 api 来实现。
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""Interface for large language models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_params: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_params = model_params or {}
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, input: str) -> str:
|
||||
"""Sends a text input to the LLM and retrieves a response."""
|
||||
```
|
||||
|
||||
如上是一个调用大模型的抽象接口,这可以帮助我们统一调用大模型的格式,我们继承这个基类,实现调用大模型的接口。
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
from typing import Any, Optional
|
||||
from .base import BaseLLM
|
||||
|
||||
class DeepSeekLLM(BaseLLM):
|
||||
"""Implementation of the BaseLLM interface using DeepSeek API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.deepseek.com/v1",
|
||||
model_params: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(model_name, model_params, **kwargs)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def predict(self, input: str) -> str:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[{"role": "user", "content": input}],
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
```
|
||||
|
||||
完成搭建后,我们可以通过尝试调用 predict 方法来测试是否成功。
|
||||
|
||||
```python
|
||||
llm = DeepSeekLLM(
|
||||
model_name="deepseek-chat",
|
||||
api_key="your-api-key-here",
|
||||
base_url="https://api.deepseek.com/v1"
|
||||
)
|
||||
print(llm.predict("你好,你能帮助我进行建筑文档审查吗?"))
|
||||
```
|
||||
|
||||
当观察到 LLM 正确回复后,我们这一模块的构建就完成了。
|
||||
|
||||
### 2. 实现 Embedding 模块
|
||||
|
||||
除了调用大模型,我们还需要实现 Embedding 模块,Embedding 模块用于将文本转换为向量,我们将使用向量来表示文档中的信息,这样的好处是,我们可以通过向量的相似度来衡量文档与查询之间的相似度,从而召回对回复用户问题最有帮助的文档。
|
||||
|
||||
构建 Embedding 模块的方法与构建 LLM 模块类似。
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Any, Optional
|
||||
|
||||
class BaseEmb(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_params: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_params = model_params or {}
|
||||
|
||||
@abstractmethod
|
||||
def get_emb(self, input: str) -> List[float]:
|
||||
"""Sends a text input to the embedding model and retrieves the embedding."""
|
||||
pass
|
||||
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from .base import BaseEmb
|
||||
|
||||
class BGEEmbedding(BaseEmb):
|
||||
def __init__(self, model_name: str = "BAAI/bge-m3", **kwargs):
|
||||
super().__init__(model_name=model_name, **kwargs)
|
||||
self.embed_model = HuggingFaceEmbedding(
|
||||
model_name=model_name,
|
||||
trust_remote_code=True,
|
||||
cache_folder="./model_cache"
|
||||
)
|
||||
|
||||
def get_emb(self, text: str) -> List[float]:
|
||||
embedding = self.embed_model.get_text_embedding(text)
|
||||
return embedding
|
||||
```
|
||||
|
||||
完成搭建后,我们可以通过尝试调用 get_emb 方法来测试是否成功。
|
||||
|
||||
```python
|
||||
emb = BGEEmbedding(model_name="BAAI/bge-m3")
|
||||
print(emb.get_emb("建筑结构的安全性检查包括哪些方面?"))
|
||||
```
|
||||
|
||||
当观察到 Embedding 正确给出了编码后的向量,我们这一模块的构建就完成了。
|
||||
|
||||
### 3. 实现文档预处理模块
|
||||
|
||||
为了处理建筑文档,我们需要预先准备好文档读取模块。本系统假设所有建筑规范和标准已经转换为Markdown格式,便于后续的文本处理和分析。
|
||||
|
||||
```python
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
class DocumentProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_documents(self, directory_path: str) -> List[str]:
|
||||
documents = []
|
||||
|
||||
for file_path in Path(directory_path).rglob('*.md'):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
documents.append(content)
|
||||
except Exception as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
|
||||
return documents
|
||||
```
|
||||
|
||||
完成文档预处理模块的设置后,我们就可以采用下面的方法来加载建筑规范文档了。
|
||||
|
||||
```python
|
||||
processor = DocumentProcessor()
|
||||
documents = processor.load_documents("./construction_standards")
|
||||
print(f"加载了 {len(documents)} 个建筑规范文档")
|
||||
```
|
||||
|
||||
## 核心实现
|
||||
|
||||
建筑文档审查系统的主要流程如下。首先,让我们来梳理一下建筑文档审查的工作流程,系统的一个核心思想在于,我们需要把用户提供的文档内容通过智能化的问询生成和知识引导检索来识别潜在的合规性问题。与传统RAG方法不同,我们的系统专门针对建筑领域的专业特点进行了优化,能够更准确地理解建筑规范要求,提供更可靠的审查建议。
|
||||
|
||||
### 动态语义知识分块
|
||||
|
||||
在传统RAG流程中,文本通过设置固定的token数量划分文本区块。然而,固定token数量会在句子中间截断,导致信息缺失。为此,本系统使用基于建筑文本语义动态划分的方式,通过双重语义聚类的方式,完成考虑建筑语义连贯性的知识chunk划分。
|
||||
|
||||
首先,将整个文档内容处理成单独句子序列 $S = \{s_0, s_1, \ldots, s_a\}$。通过计算相邻句子间的语义差异度来识别潜在的语义边界:
|
||||
|
||||
$$\gamma_i = 1 - \frac{s_{i-1} \cdot s_i}{\|s_{i-1}\| \|s_i\|}$$
|
||||
|
||||
基于语义差异度分布自动确定动态阈值:
|
||||
|
||||
$$\psi = \text{Quantile}(\Gamma, \frac{a-p}{a})$$
|
||||
|
||||
确保最终的分块既保持语义连贯性又满足长度约束:
|
||||
|
||||
$$\mathbb{E}[\gamma_{\text{intra}}] < \mathbb{E}[\gamma_{\text{inter}}]$$
|
||||
|
||||
```python
|
||||
import re
|
||||
import numpy as np
|
||||
from typing import List, Dict, Tuple
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class DynamicSemanticChunker:
|
||||
def __init__(self,
|
||||
embedding_model: str = "BAAI/bge-m3",
|
||||
max_chunk_length: int = 512,
|
||||
min_chunk_length: int = 50):
|
||||
self.embedding_model = SentenceTransformer(embedding_model)
|
||||
self.max_chunk_length = max_chunk_length
|
||||
self.min_chunk_length = min_chunk_length
|
||||
|
||||
def split_text(self, text: str) -> Dict[str, str]:
|
||||
sentences = self._split_into_sentences(text)
|
||||
if len(sentences) == 0:
|
||||
return {}
|
||||
|
||||
sentence_embeddings = self.embedding_model.encode(sentences)
|
||||
gamma_values = self._compute_semantic_discrepancy(sentence_embeddings)
|
||||
|
||||
total_tokens = sum(len(s.split()) for s in sentences)
|
||||
baseline_chunks = max(1, total_tokens // self.max_chunk_length)
|
||||
alpha = max(0.1, (len(sentences) - baseline_chunks) / len(sentences))
|
||||
threshold = np.quantile(gamma_values, alpha) if len(gamma_values) > 0 else 0.5
|
||||
|
||||
boundaries = self._identify_boundaries(gamma_values, threshold)
|
||||
initial_chunks = self._create_initial_chunks(sentences, boundaries)
|
||||
final_chunks = self._enforce_length_constraints(initial_chunks)
|
||||
|
||||
chunks_dict = {}
|
||||
for i, chunk in enumerate(final_chunks):
|
||||
chunk_id = f"chunk-{i+1:03d}"
|
||||
chunks_dict[chunk_id] = chunk
|
||||
|
||||
return chunks_dict
|
||||
|
||||
def _split_into_sentences(self, text: str) -> List[str]:
|
||||
sentence_pattern = r'[。!?;\n]+'
|
||||
sentences = re.split(sentence_pattern, text)
|
||||
|
||||
cleaned_sentences = []
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if len(sentence) > 5:
|
||||
cleaned_sentences.append(sentence)
|
||||
|
||||
return cleaned_sentences
|
||||
|
||||
def _compute_semantic_discrepancy(self, embeddings: np.ndarray) -> List[float]:
|
||||
gamma_values = []
|
||||
|
||||
for i in range(1, len(embeddings)):
|
||||
similarity = cosine_similarity(
|
||||
embeddings[i-1].reshape(1, -1),
|
||||
embeddings[i].reshape(1, -1)
|
||||
)[0][0]
|
||||
|
||||
gamma = 1 - similarity
|
||||
gamma_values.append(gamma)
|
||||
|
||||
return gamma_values
|
||||
|
||||
def _identify_boundaries(self, gamma_values: List[float], threshold: float) -> List[int]:
|
||||
boundaries = [0]
|
||||
|
||||
for i, gamma in enumerate(gamma_values):
|
||||
if gamma > threshold:
|
||||
boundaries.append(i + 1)
|
||||
|
||||
boundaries.append(len(gamma_values) + 1)
|
||||
return sorted(set(boundaries))
|
||||
|
||||
def _create_initial_chunks(self, sentences: List[str], boundaries: List[int]) -> List[str]:
|
||||
chunks = []
|
||||
|
||||
for i in range(len(boundaries) - 1):
|
||||
start = boundaries[i]
|
||||
end = boundaries[i + 1]
|
||||
|
||||
chunk_sentences = sentences[start:end]
|
||||
chunk_text = ' '.join(chunk_sentences)
|
||||
chunks.append(chunk_text)
|
||||
|
||||
return chunks
|
||||
|
||||
def _enforce_length_constraints(self, chunks: List[str]) -> List[str]:
|
||||
final_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_length = len(chunk.split())
|
||||
|
||||
if chunk_length <= self.max_chunk_length:
|
||||
if chunk_length >= self.min_chunk_length:
|
||||
final_chunks.append(chunk)
|
||||
else:
|
||||
sub_chunks = self._split_long_chunk(chunk)
|
||||
final_chunks.extend(sub_chunks)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def _split_long_chunk(self, chunk: str) -> List[str]:
|
||||
sentences = chunk.split('。')
|
||||
sub_chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
test_chunk = current_chunk + sentence + "。"
|
||||
if len(test_chunk.split()) <= self.max_chunk_length:
|
||||
current_chunk = test_chunk
|
||||
else:
|
||||
if current_chunk:
|
||||
sub_chunks.append(current_chunk.strip())
|
||||
current_chunk = sentence + "。"
|
||||
|
||||
if current_chunk:
|
||||
sub_chunks.append(current_chunk.strip())
|
||||
|
||||
return sub_chunks
|
||||
```
|
||||
|
||||
### 建筑文档审查系统
|
||||
|
||||
整体的审查过程如下图所示。系统获取需要审查的区域后,依据提示生成审查问题推荐,此部分也可供工程师进行相关问题输入或推荐问题选择,生成待审查问题。随后,系统通过生成式知识引导检索框架,依据审查问题在所建文本知识库中检索出相应的知识参考。最终,依据检索的部分与审查原文,进行问题分析与审查修正,完成最终的审查流程。
|
||||
|
||||

|
||||
|
||||
#### 审查问题生成
|
||||
|
||||
在文档审查流程中,系统引入了双阶段Prompt工程驱动的智能化问询生成机制,旨在对建筑施工交底文档进行预见性分析与风险挖掘,实现对文档潜在问题的高效、精准定位。
|
||||
|
||||
阶段1为待查文档主旨目标解构,模型被指示从文本中提炼核心事件、关键技术、工艺流程等要素,结构化地总结文档的核心内容,由此界定本次审查的靶向目标,为后续的精细化问询奠定基础。阶段2为多维度风险探测与定制化问询生成,基于第一阶段提炼的核心要素,通过few-shot等方式引导 LLM 从合规性、安全性、可操作性等多维度对文档进行风险探测。Prompt 指示模型围绕潜在的限制条件、操作流程、以及可能存在的合规性隐患等方面,进行细粒度、多角度的审查提问。
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
CORE_COMPONENTS_PROMPT = """
|
||||
Task: Your task involves the extraction of crucial information components from a designated text segment. The purpose of this extraction is to assist in uncovering hidden descriptions indicative of regulatory non-compliance. Key information components encompass, but are not limited to, core descriptive events, essential construction techniques, technologies, and associated limitations and restrictions.
|
||||
|
||||
Input: {document_chunk}
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REVIEW_QUERIES_PROMPT = """
|
||||
Task: Your task is to generate relevant search queries based on the text under review and provided core descriptive references. These queries should target potential areas of non-compliance within the text, facilitating the subsequent retrieval of original regulatory documents for detailed examination.
|
||||
|
||||
Input: {document_chunk}
|
||||
Core components: {core_components}
|
||||
Queries:
|
||||
"""
|
||||
|
||||
def generate_review_queries(llm, document_chunk: str) -> List[str]:
|
||||
core_prompt = CORE_COMPONENTS_PROMPT.format(document_chunk=document_chunk)
|
||||
core_response = llm.predict(core_prompt)
|
||||
|
||||
# 生成审查查询
|
||||
queries_prompt = REVIEW_QUERIES_PROMPT.format(
|
||||
document_chunk=document_chunk,
|
||||
core_components=core_response
|
||||
)
|
||||
queries_response = llm.predict(queries_prompt)
|
||||
|
||||
# 从响应中提取查询列表
|
||||
queries = re.findall(r"'([^']*)'", queries_response)
|
||||
|
||||
return queries[:5]
|
||||
```
|
||||
|
||||
#### 知识引导生成式检索
|
||||
|
||||
系统的核心创新在于知识引导的检索框架,整个过程分为三个关键步骤。步骤1为句子级编码,主要负责输入查询句子的初始表示学习,计算查询与知识库chunks间的句子级相似度分数。步骤2为知识引导检索,进一步从查询中提取关键信息,利用这些信息结合文档长度自适应加权等机制,对每个知识库chunk进行更详细的评分。步骤3为重排序与增强,使用大语言模型对步骤2检索的结果进行进一步重排序,并利用精炼的知识来增强原始查询。
|
||||

|
||||
|
||||
首先建立专门针对建筑领域文本分析的深度提取模块,集成领域预训练BERT进行上下文编码,结合双向LSTM进行建筑法规依赖建模。建立三级重要性分类层次:max(最高)、mid(中等)、lit(字面)优先级。本项目直接通过大语言模型进行关键信息提取,如果需要更精准的效果,可以自行训练BERT模型进行专门的关键信息提取。
|
||||

|
||||
|
||||
```python
|
||||
import re
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
KEY_INFO_EXTRACTION_PROMPT = """
|
||||
Your task is to extract key information from the query with three different priority levels:
|
||||
|
||||
Maximum priority (max): The most important core concepts or entities
|
||||
Medium priority (mid): Important modifiers or qualifying conditions
|
||||
Literal priority (lit): Specific values, standards or specifications
|
||||
|
||||
Query: {query}
|
||||
max:
|
||||
mid:
|
||||
lit:
|
||||
"""
|
||||
|
||||
class KeyInfoExtractor:
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
|
||||
def extract_key_info(self, query: str) -> Dict[str, Tuple[str, float]]:
|
||||
prompt = KEY_INFO_EXTRACTION_PROMPT.format(query=query)
|
||||
response = self.llm.predict(prompt)
|
||||
|
||||
lines = response.strip().split('\n')
|
||||
key_info = {}
|
||||
weights = {'max': 0.5, 'mid': 0.3, 'lit': 0.2}
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('max:'):
|
||||
key_info['max'] = (line[4:].strip(), weights['max'])
|
||||
elif line.startswith('mid:'):
|
||||
key_info['mid'] = (line[4:].strip(), weights['mid'])
|
||||
elif line.startswith('lit:'):
|
||||
key_info['lit'] = (line[4:].strip(), weights['lit'])
|
||||
|
||||
return key_info
|
||||
```
|
||||
|
||||
#### 文档长度自适应因子
|
||||
|
||||
在知识引导检索过程中,文档长度自适应因子用于调整不同长度文档的权重分配,确保长短文档都能得到公平的评分机会。该因子的计算考虑了当前文档chunk的长度与平均文档长度的关系。
|
||||
|
||||
$$\Lambda_{\text{DL}} = \frac{\overline{|k|} + |k_j|}{2\overline{|k|}}$$
|
||||
|
||||
其中 $|k_j|$ 表示当前文档chunk的长度,$\overline{|k|}$ 表示平均文档长度。通过这种归一化处理,可以避免因文档长度差异导致的评分偏差。
|
||||
|
||||
```python
|
||||
def compute_document_length_factor(chunk_length: int, avg_length: int = 100) -> float:
|
||||
lambda_dl = (avg_length + chunk_length) / (2 * avg_length)
|
||||
return lambda_dl
|
||||
```
|
||||
|
||||
#### 术语重要性计算
|
||||
|
||||
术语重要性指标衡量术语在文档中的显著程度,结合术语频率和文档长度自适应因子,能够更准确地评估术语在当前文档中的重要性。计算公式考虑了术语频率的非线性增长特性。
|
||||
|
||||
$$\text{Sign}(t_{e_i}^\tau, k_j) = \frac{2 \cdot f(t_{e_i}^\tau, k_j) \cdot \Lambda_{\text{DL}}}{f(t_{e_i}^\tau, k_j) + 1}$$
|
||||
|
||||
其中 $f(t_{e_i}^\tau, k_j)$ 表示术语在文档chunk中的出现频率,$\Lambda_{\text{DL}}$ 为文档长度自适应因子。这种计算方式能够防止高频术语过度影响评分。
|
||||
|
||||
```python
|
||||
def compute_term_significance(term_freq: int, doc_length_factor: float) -> float:
|
||||
significance = (2 * term_freq * doc_length_factor) / (term_freq + 1)
|
||||
return significance
|
||||
```
|
||||
|
||||
#### 术语稀有度计算
|
||||
|
||||
术语稀有度用于衡量术语在整个知识库中的稀缺程度,稀有度越高的术语在检索中的权重越大。计算采用了改进的IDF公式,增加了平滑处理以避免零除问题。
|
||||
|
||||
$\text{Rarity}(t_{e_i}^\tau) = \log\left(\frac{D - \text{df}(t_{e_i}^\tau) + 0.5}{\text{df}(t_{e_i}^\tau) + 0.5} + 1\right)$
|
||||
|
||||
其中 $D$ 表示文档总数,$\text{df}(t_{e_i}^\tau)$ 表示包含该术语的文档数量。加一操作确保了对数值始终为正数。
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
def compute_term_rarity(doc_freq: int, total_docs: int) -> float:
|
||||
rarity = np.log((total_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
|
||||
return rarity
|
||||
```
|
||||
|
||||
#### 连贯性指数评估
|
||||
|
||||
连贯性指数反映术语在文档中的分布连贯性,通过滑动窗口技术分析术语在文档中的局部分布情况。连贯性高的术语往往在文档的特定区域集中出现,表明其与文档主题的强相关性。
|
||||
|
||||
$$\text{CI}(t_{e_i}^\tau, k_j) = \max_{w \in W, \, t \in w} \frac{\sum I(t = t_{e_i}^\tau) \cdot |w|}{|k_j|}$$
|
||||
|
||||
其中 $W$ 表示文档中的滑动窗口集合,$I(t = t_{e_i}^\tau)$ 为指示函数,当窗口中包含该术语时为1,否则为0。
|
||||
|
||||
```python
|
||||
def compute_coherence_index(term: str, chunk: str, window_size: int = 50) -> float:
|
||||
chunk_tokens = chunk.lower().split()
|
||||
chunk_length = len(chunk_tokens)
|
||||
|
||||
if chunk_length == 0:
|
||||
return 0.0
|
||||
|
||||
max_coherence = 0.0
|
||||
|
||||
for i in range(0, chunk_length - window_size + 1, 10):
|
||||
window = chunk_tokens[i:i + window_size]
|
||||
term_count = window.count(term.lower())
|
||||
|
||||
if term_count > 0:
|
||||
coherence = (term_count * window_size) / chunk_length
|
||||
max_coherence = max(max_coherence, coherence)
|
||||
|
||||
return max_coherence
|
||||
```
|
||||
|
||||
#### 评分融合与检索
|
||||
|
||||
将句子级相似度评分与知识级评分进行融合,形成最终的文档相关性评分。融合过程采用加权平均的方式,平衡参数λ控制两种评分方式的重要性。
|
||||
|
||||
$\Phi = \lambda \Phi(\mathcal{K}) + (1 - \lambda) \Phi(\mathcal{S})$
|
||||
|
||||
其中 $\lambda$ 为平衡参数,$\Phi(\mathcal{K})$ 为知识级评分,$\Phi(\mathcal{S})$ 为句子级评分。通过调整λ值,可以控制系统更偏向语义相似还是知识匹配。当λ=0时,系统完全依赖句子级语义相似度;当λ=1时,系统完全依赖知识匹配评分;λ=0.5时,两种评分方式权重相等。在建筑文档审查场景中,通常设置λ=0.5以平衡专业知识匹配和语义理解。
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
class GKGRRetriever:
|
||||
def __init__(self,
|
||||
knowledge_base: List[str],
|
||||
embedding_model,
|
||||
key_info_extractor: KeyInfoExtractor,
|
||||
llm,
|
||||
config: Dict[str, Any] = None):
|
||||
self.knowledge_base = knowledge_base
|
||||
self.embedding_model = embedding_model
|
||||
self.key_info_extractor = key_info_extractor
|
||||
self.llm = llm
|
||||
|
||||
default_config = {
|
||||
"lambda_param": 0.5,
|
||||
"top_k": 5,
|
||||
"rerank_enabled": True,
|
||||
"query_expansion": True,
|
||||
"similarity_threshold": 0.1
|
||||
}
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
self.kb_embeddings = self._precompute_embeddings()
|
||||
|
||||
def _precompute_embeddings(self) -> np.ndarray:
|
||||
embeddings = self.embedding_model.encode(self.knowledge_base, show_progress_bar=True)
|
||||
return embeddings
|
||||
|
||||
def retrieve_with_scores(self, query: str) -> List[Tuple[str, float, Dict[str, float]]]:
|
||||
query_embedding = self.embedding_model.encode([query])[0]
|
||||
sentence_scores = cosine_similarity(
|
||||
query_embedding.reshape(1, -1),
|
||||
self.kb_embeddings
|
||||
)[0]
|
||||
|
||||
key_info = self.key_info_extractor.extract_key_info(query)
|
||||
knowledge_scores = self._compute_knowledge_scores(key_info)
|
||||
|
||||
final_scores = []
|
||||
for i in range(len(self.knowledge_base)):
|
||||
norm_sent = sentence_scores[i]
|
||||
norm_know = knowledge_scores[i] / max(knowledge_scores) if max(knowledge_scores) > 0 else 0
|
||||
|
||||
final_score = (self.config["lambda_param"] * norm_know +
|
||||
(1 - self.config["lambda_param"]) * norm_sent)
|
||||
final_scores.append(final_score)
|
||||
|
||||
results_with_scores = []
|
||||
for i, final_score in enumerate(final_scores):
|
||||
if final_score > self.config["similarity_threshold"]:
|
||||
score_details = {
|
||||
"sentence_score": float(sentence_scores[i]),
|
||||
"knowledge_score": float(knowledge_scores[i]),
|
||||
"final_score": float(final_score)
|
||||
}
|
||||
results_with_scores.append((self.knowledge_base[i], final_score, score_details))
|
||||
|
||||
results_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return results_with_scores[:self.config["top_k"]]
|
||||
|
||||
def _compute_knowledge_scores(self, key_info: Dict[str, Tuple[str, float]]) -> List[float]:
|
||||
scores = []
|
||||
avg_length = sum(len(chunk.split()) for chunk in self.knowledge_base) / len(self.knowledge_base)
|
||||
|
||||
for chunk in self.knowledge_base:
|
||||
chunk_score = 0.0
|
||||
chunk_tokens = chunk.lower().split()
|
||||
chunk_length = len(chunk_tokens)
|
||||
|
||||
lambda_dl = compute_document_length_factor(chunk_length, avg_length)
|
||||
|
||||
for priority, (info_text, weight) in key_info.items():
|
||||
if not info_text.strip():
|
||||
continue
|
||||
|
||||
terms = info_text.lower().split()
|
||||
for term in terms:
|
||||
if term in chunk_tokens:
|
||||
tf = chunk_tokens.count(term)
|
||||
|
||||
significance = compute_term_significance(tf, lambda_dl)
|
||||
|
||||
segments_with_term = sum(1 for kb_chunk in self.knowledge_base
|
||||
if term in kb_chunk.lower())
|
||||
rarity = compute_term_rarity(segments_with_term, len(self.knowledge_base))
|
||||
|
||||
coherence = compute_coherence_index(term, chunk)
|
||||
|
||||
term_score = significance * rarity * (1 + coherence) * weight
|
||||
chunk_score += term_score
|
||||
|
||||
scores.append(chunk_score)
|
||||
|
||||
return scores
|
||||
|
||||
def retrieve(self, query: str) -> Tuple[List[str], str]:
|
||||
results_with_scores = self.retrieve_with_scores(query)
|
||||
|
||||
documents = [doc for doc, _, _ in results_with_scores]
|
||||
|
||||
if self.config["rerank_enabled"] and len(documents) > 1:
|
||||
documents = self._llm_rerank(query, documents)
|
||||
|
||||
augmented_query = query
|
||||
if self.config["query_expansion"]:
|
||||
augmented_query = self._augment_query(query, documents[:3])
|
||||
|
||||
return documents, augmented_query
|
||||
```
|
||||
|
||||
#### 重排序优化
|
||||
|
||||
系统使用大语言模型对检索结果进行进一步重排序,通过LLM的语义理解能力优化文档的相关性排序。重排序过程中,系统会构造包含查询和候选文档的提示,要求LLM根据相关性对文档进行重新排序。
|
||||
|
||||
```python
|
||||
def _llm_rerank(self, query: str, documents: List[str]) -> List[str]:
|
||||
if len(documents) <= 1:
|
||||
return documents
|
||||
|
||||
rerank_prompt = f"""
|
||||
Task: A list of documents is shown below. Each document has a number next to it. A question is also provided. Your task is to return the numbers of ALL documents in order of relevance from MOST to LEAST relevant. MUST include EVERY document number exactly once.
|
||||
|
||||
Example format:
|
||||
Document 1: <document 1>
|
||||
Document 2: <document 2>
|
||||
Document 3: <document 3>
|
||||
Question: <question>
|
||||
Answer: 3,1,2
|
||||
|
||||
Now here are the actual documents and question.
|
||||
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
rerank_prompt += f"Document {i+1}: {doc[:150]}...\n"
|
||||
|
||||
rerank_prompt += f"Question: {query}\nAnswer:"
|
||||
|
||||
try:
|
||||
response = self.llm.predict(rerank_prompt)
|
||||
order_nums = [int(x.strip()) - 1 for x in response.split(',')
|
||||
if x.strip().isdigit() and 0 <= int(x.strip()) - 1 < len(documents)]
|
||||
|
||||
reranked = [documents[i] for i in order_nums if i < len(documents)]
|
||||
|
||||
# 添加遗漏的文档
|
||||
used_indices = set(order_nums)
|
||||
for i, doc in enumerate(documents):
|
||||
if i not in used_indices:
|
||||
reranked.append(doc)
|
||||
|
||||
return reranked[:len(documents)]
|
||||
except:
|
||||
return documents
|
||||
```
|
||||
|
||||
#### 查询增强
|
||||
|
||||
同时系统还会利用检索到的知识来增强原始查询,生成更具体、更详细的查询用于进一步检索。查询增强通过分析检索结果的上下文信息,识别查询中可能遗漏的关键概念和术语。
|
||||
|
||||
```python
|
||||
def _augment_query(self, original_query: str, top_results: List[str]) -> str:
|
||||
if not top_results:
|
||||
return original_query
|
||||
|
||||
document_list = ""
|
||||
for i, doc in enumerate(top_results):
|
||||
document_list += f"Document {i+1}: {doc[:100]}...\n"
|
||||
|
||||
augment_prompt = f"""
|
||||
Task: Your task is to generate a detailed answer to the question by synthesizing information from ALL provided documents. Prioritize relevance, cite document numbers, and structure your response as follows:
|
||||
|
||||
Question: {original_query}
|
||||
{document_list}
|
||||
Answer:
|
||||
"""
|
||||
|
||||
try:
|
||||
augmented = self.llm.predict(augment_prompt)
|
||||
return augmented.strip()
|
||||
except:
|
||||
return original_query
|
||||
```
|
||||
|
||||
#### 偏差检测分析
|
||||
|
||||
在先期知识增强检索阶段获取领域知识后,系统随即进入误差辨析模块。该模块基于检索得到的知识参考,并结合预设的审阅问题,对原文进行细致的偏差检测与评估。
|
||||
|
||||
```python
|
||||
class ErrorAnalyzer:
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
|
||||
def analyze_errors(self, document_chunk: str, query: str, retrieved_knowledge: List[str]) -> Dict[str, Any]:
|
||||
|
||||
analysis_prompt = f"""
|
||||
Task: Your task is to conduct an error analysis on a given review document, based on a provided review query and relevant reference specifications. This analysis MUST strictly adhere to the provided reference and focus specifically on reviewing and analyzing the original descriptive sections within the review document.
|
||||
|
||||
Review document: {document_chunk}
|
||||
Query: {query}
|
||||
Reference: {chr(10).join([f"{i+1}. {ref}" for i, ref in enumerate(retrieved_knowledge)])}
|
||||
Analysis:
|
||||
"""
|
||||
|
||||
analysis = self.llm.predict(analysis_prompt)
|
||||
|
||||
return {
|
||||
"analysis": analysis,
|
||||
"reference_support": retrieved_knowledge
|
||||
}
|
||||
```
|
||||
|
||||
#### 修订建议生成
|
||||
|
||||
误差辨析模块完成后,系统将输出标记偏差区域以及相关知识佐证。随后,系统进入修订策略生成模块。该模块依据误差分析结果和知识参考,对标记区域进行针对性的修订建议生成,最终实现对原文的知识驱动型自动修正。
|
||||
|
||||
```python
|
||||
class RevisionGenerator:
|
||||
def __init__(self, llm):
|
||||
self.llm = llm
|
||||
|
||||
def generate_revisions(self, document_chunk: str, analysis: Dict[str, Any]) -> Dict[str, str]:
|
||||
revision_prompt = f"""
|
||||
Task: Your task is to review and revise the provided document based on the given analysis and corresponding reference specifications. STRICT adherence to the provided reference specifications is required. If the review document aligns with the analysis and reference specifications WITHOUT discrepancies, revision is not necessary.
|
||||
|
||||
Review document: {document_chunk}
|
||||
Analysis: {analysis['analysis']}
|
||||
Reference: {chr(10).join([f"- {ref}" for ref in analysis['reference_support']])}
|
||||
Revision:
|
||||
"""
|
||||
|
||||
revision = self.llm.predict(revision_prompt)
|
||||
|
||||
return {
|
||||
"original_text": document_chunk,
|
||||
"revision_suggestions": revision,
|
||||
"modified_regions": analysis.get("error_regions", []),
|
||||
"confidence": self._calculate_confidence(analysis)
|
||||
}
|
||||
|
||||
def _calculate_confidence(self, analysis: Dict[str, Any]) -> float:
|
||||
ref_count = len(analysis.get("reference_support", []))
|
||||
error_count = len(analysis.get("error_regions", []))
|
||||
|
||||
confidence = min(0.9, 0.5 + (ref_count * 0.1) + (error_count * 0.05))
|
||||
return confidence
|
||||
```
|
||||
|
||||
#### 完整审查流程
|
||||
|
||||
将上述所有模块整合,形成完整的文档审查流程。系统首先生成审查问题,然后进行知识引导检索,接着执行错误分析,最后生成修订建议。
|
||||
|
||||
```python
|
||||
def complete_review_process(document_chunk: str,
|
||||
gkgr_framework: GKGRRetriever,
|
||||
error_analyzer: ErrorAnalyzer,
|
||||
revision_generator: RevisionGenerator) -> Dict[str, Any]:
|
||||
review_queries = generate_review_queries(gkgr_framework.llm, document_chunk)
|
||||
|
||||
results = {}
|
||||
for query in review_queries[:3]:
|
||||
retrieved_docs, augmented_query = gkgr_framework.retrieve(query)
|
||||
|
||||
knowledge_refs = retrieved_docs
|
||||
analysis = error_analyzer.analyze_errors(document_chunk, query, knowledge_refs)
|
||||
|
||||
revision = revision_generator.generate_revisions(document_chunk, analysis)
|
||||
|
||||
results[query] = {
|
||||
"retrieved_knowledge": retrieved_docs,
|
||||
"augmented_query": augmented_query,
|
||||
"analysis": analysis,
|
||||
"revision": revision
|
||||
}
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
至此,我们就完成了建筑文档智能审查系统的核心实现。
|
||||
|
||||
## 实际应用示例
|
||||
|
||||
让我们通过一个完整的示例来展示系统的使用:
|
||||
|
||||
```python
|
||||
# 初始化系统组件
|
||||
llm = DeepSeekLLM(
|
||||
model_name='deepseek-chat',
|
||||
api_key='your-api-key',
|
||||
base_url='https://api.deepseek.com/v1'
|
||||
)
|
||||
|
||||
embedding = BGEEmbedding(model_name="BAAI/bge-m3")
|
||||
key_extractor = KeyInfoExtractor(llm)
|
||||
|
||||
# 从markdown文档构建知识库
|
||||
processor = DocumentProcessor()
|
||||
documents = processor.load_documents("./construction_standards")
|
||||
|
||||
# 对文档进行动态语义分块
|
||||
chunker = DynamicSemanticChunker()
|
||||
knowledge_base = []
|
||||
for doc in documents:
|
||||
chunks = chunker.split_text(doc)
|
||||
knowledge_base.extend(chunks.values())
|
||||
|
||||
# 初始化检索器
|
||||
gkgr_retriever = GKGRRetriever(
|
||||
knowledge_base=knowledge_base,
|
||||
embedding_model=embedding,
|
||||
key_info_extractor=key_extractor,
|
||||
llm=llm
|
||||
)
|
||||
|
||||
# 初始化分析器
|
||||
error_analyzer = ErrorAnalyzer(llm)
|
||||
revision_generator = RevisionGenerator(llm)
|
||||
|
||||
# 待审查的文档内容
|
||||
sample_document = """
|
||||
钢筋混凝土柱的施工应符合以下要求:
|
||||
1. 混凝土强度等级不低于C25
|
||||
2. 钢筋保护层厚度为25mm
|
||||
3. 混凝土浇筑应连续进行,间歇时间不超过1小时
|
||||
4. 养护期间应保持混凝土表面湿润
|
||||
"""
|
||||
|
||||
# 执行审查
|
||||
result = complete_review_process(
|
||||
sample_document,
|
||||
gkgr_retriever,
|
||||
error_analyzer,
|
||||
revision_generator
|
||||
)
|
||||
|
||||
# 查看审查结果
|
||||
for query, analysis in result.items():
|
||||
print(f"审查问题: {query}")
|
||||
print(f"修订建议: {analysis['revision']['revision_suggestions']}")
|
||||
print("-" * 50)
|
||||
```
|
||||
|
||||
## 扩展性说明
|
||||
|
||||
系统可以通过更换知识库轻松适应其他领域。对于特定企业或项目,可以通过微调关键信息提取模型来提升准确性。在性能优化方面,使用动态语义分块可以提升检索质量,预计算并缓存知识库嵌入以提升检索速度,对于大量文档可使用批量处理模式,根据具体应用场景调整λ参数和top-k值。
|
||||
|
||||
## 写在最后
|
||||
|
||||
恭喜你阅读完此文,你已经充分了解了如何实现一个建筑文档智能审查系统以及其背后的思考。这个系统展示了如何将动态语义分块、知识引导检索和大语言模型有机结合,为建筑行业的文档审查提供了一个实用的解决方案。
|
||||
|
||||
虽然当前系统已经取得了不错的效果,但仍有改进空间。全局关联增强方面,当前基于文本块的检索可以进一步结合知识图谱等技术。多模态支持方面,未来可以扩展支持CAD图纸、施工图等视觉信息。实时更新方面,支持知识库的增量更新和动态维护。个性化定制方面,根据不同企业和项目特点进行系统定制。
|
||||
|
||||
读者们可以运行项目中的示例代码,体验完整的建筑文档智能审查流程。我们相信这个系统不仅能够提升审查效率,更能为建筑行业的数字化转型贡献力量。
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目的开发过程中,我们深入研究了建筑工程领域的专业知识和最新的自然语言处理技术。特别感谢建筑行业专家提供的宝贵建议,以及开源社区在技术实现方面的支持。项目代码实现参考了LlamaIndex、Transformers等优秀开源项目的设计理念。
|
||||
|
||||
需要说明的是,本项目专门针对建筑施工领域的文档审查场景进行了深度优化。如果您需要处理其他领域的文档,建议根据具体需求对系统进行相应调整。
|
||||
|
||||
## 源码获取
|
||||
|
||||
本项目的源码以及实例数据存放在 [GitHub 仓库](https://github.com/Hongru0306/CDDRS)。
|
||||
|
||||
## 引用
|
||||
|
||||
如果您在研究中使用了本项目的成果,请按如下方式引用:
|
||||
|
||||
```bibtex
|
||||
@article{XIAO2025103618,
|
||||
title = {Generative knowledge-guided review system for construction disclosure documents},
|
||||
journal = {Advanced Engineering Informatics},
|
||||
volume = {68},
|
||||
pages = {103618},
|
||||
year = {2025},
|
||||
issn = {1474-0346},
|
||||
doi = {https://doi.org/10.1016/j.aei.2025.103618},
|
||||
url = {https://www.sciencedirect.com/science/article/pii/S1474034625005117},
|
||||
author = {Hongru Xiao and Jiankun Zhuang and Bin Yang and Jiale Han and Yantao Yu and Songning Lai},
|
||||
keywords = {Construction documents review, Large language model (LLM), Knowledge-guided retrieval, Natural Language Processing (NLP)}
|
||||
}
|
||||
```
|
||||
152
Extra-Chapter/generation-method/llm_generation.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
def test_decoding_strategies():
|
||||
"""
|
||||
测试三种解码策略:贪婪解码、随机采样、束搜索
|
||||
"""
|
||||
model_id = "../model/kmno4zx/happy-llm-215M-sft/"
|
||||
|
||||
print("正在加载模型和tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="cpu").eval()
|
||||
|
||||
# 测试prompt
|
||||
test_prompt = "请介绍一下自己"
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个AI助手"},
|
||||
{"role": "user", "content": test_prompt}
|
||||
]
|
||||
|
||||
# 准备输入
|
||||
input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
input_ids = tokenizer(input_ids).data['input_ids']
|
||||
x = (torch.tensor(input_ids, dtype=torch.long)[None, ...]).to(model.device)
|
||||
|
||||
print(f"测试prompt: {test_prompt}")
|
||||
print(f"输入token数量: {len(input_ids)}")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试1: 贪婪解码 (Greedy Search)
|
||||
print("🔍 测试1: 贪婪解码 (Greedy Search)")
|
||||
print("参数: do_sample=False, num_beams=1, temperature=0.0")
|
||||
print("特点: 每步选择概率最大的token,结果确定,速度快")
|
||||
|
||||
with torch.no_grad():
|
||||
greedy_output = model.generate_super(
|
||||
x,
|
||||
stop_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=50,
|
||||
temperature=0.0,
|
||||
do_sample=False,
|
||||
num_beams=1
|
||||
)
|
||||
greedy_response = tokenizer.decode(greedy_output[0].tolist(), skip_special_tokens=True)
|
||||
|
||||
print(f"贪婪解码结果: {greedy_response}")
|
||||
print()
|
||||
|
||||
# 测试2: 随机采样 (Random Sampling)
|
||||
print("🎲 测试2: 随机采样 (Random Sampling)")
|
||||
print("参数: do_sample=True, num_beams=1, temperature=0.8, top_k=50")
|
||||
print("特点: 基于概率分布随机采样,结果多样,创造性高")
|
||||
|
||||
with torch.no_grad():
|
||||
# 运行多次以展示随机性
|
||||
for i in range(3):
|
||||
sampling_output = model.generate_super(
|
||||
x,
|
||||
stop_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=50,
|
||||
temperature=0.8,
|
||||
top_k=50,
|
||||
do_sample=True,
|
||||
num_beams=1
|
||||
)
|
||||
sampling_response = tokenizer.decode(sampling_output[0].tolist(), skip_special_tokens=True)
|
||||
print(f"随机采样结果 {i+1}: {sampling_response}")
|
||||
|
||||
print()
|
||||
|
||||
# 测试3: 束搜索 (Beam Search)
|
||||
print("🔦 测试3: 束搜索 (Beam Search)")
|
||||
print("参数: do_sample=False, num_beams=3, temperature=1.0")
|
||||
print("特点: 维护多条候选路径,选择总概率最高的序列,质量更高")
|
||||
|
||||
with torch.no_grad():
|
||||
beam_output = model.generate_super(
|
||||
x,
|
||||
stop_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=50,
|
||||
temperature=1.0,
|
||||
do_sample=False,
|
||||
num_beams=3
|
||||
)
|
||||
beam_response = tokenizer.decode(beam_output[0].tolist(), skip_special_tokens=True)
|
||||
|
||||
print(f"束搜索结果: {beam_response}")
|
||||
print()
|
||||
|
||||
# 测试4: 不同的温度参数对随机采样的影响
|
||||
print("🌡️ 测试4: 不同温度参数对随机采样的影响")
|
||||
print("参数: do_sample=True, num_beams=1, 测试不同temperature值")
|
||||
|
||||
temperatures = [0.2, 0.8, 1.5]
|
||||
for temp in temperatures:
|
||||
with torch.no_grad():
|
||||
temp_output = model.generate_super(
|
||||
x,
|
||||
stop_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=30,
|
||||
temperature=temp,
|
||||
do_sample=True,
|
||||
num_beams=1
|
||||
)
|
||||
temp_response = tokenizer.decode(temp_output[0].tolist(), skip_special_tokens=True)
|
||||
print(f"温度 {temp}: {temp_response}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("✅ 三种解码策略测试完成!")
|
||||
print()
|
||||
print("📊 总结对比:")
|
||||
print("• 贪婪解码: 速度快,结果确定,适合确定性任务")
|
||||
print("• 随机采样: 创造性强,结果多样,适合创意生成")
|
||||
print("• 束搜索: 质量较高,平衡速度和质量,适合一般对话")
|
||||
|
||||
def test_original_generation():
|
||||
"""
|
||||
原始的生成代码作为对比
|
||||
"""
|
||||
model_id = "../model/kmno4zx/happy-llm-215M-sft/"
|
||||
|
||||
print("运行原始生成代码...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="cpu").eval()
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个AI助手"},
|
||||
{"role": "user", "content": "你好,请介绍一下自己。"}
|
||||
]
|
||||
|
||||
input_ids = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
|
||||
input_ids = tokenizer(input_ids).data['input_ids']
|
||||
|
||||
x = (torch.tensor(input_ids, dtype=torch.long)[None, ...]).to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
y = model.generate_super(x, stop_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=0.6)
|
||||
response = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
|
||||
|
||||
print(f"Assistant: {response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始测试三种解码策略...")
|
||||
print()
|
||||
|
||||
try:
|
||||
test_decoding_strategies()
|
||||
except Exception as e:
|
||||
print(f"测试过程中出现错误: {e}")
|
||||
print("运行原始生成代码...")
|
||||
test_original_generation()
|
||||
3
Extra-Chapter/generation-method/model_down.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from modelscope import snapshot_download
|
||||
|
||||
model_dir = snapshot_download('kmno4zx/happy-llm-215M-sft', cache_dir='your/cache/dir', revision='master')
|
||||
511
Extra-Chapter/generation-method/readme.md
Normal file
@@ -0,0 +1,511 @@
|
||||
# 大模型生成Token的方式
|
||||
|
||||
> 代码已更新到 Happy-LLM 仓库第五章的代码中。
|
||||
|
||||
## 贪婪解码(Greedy Decoding)
|
||||
|
||||
### 原理说明
|
||||
贪婪解码是最简单直接的文本生成策略。在每一步生成时,它总是选择概率最大的那个token作为下一个token,然后继续生成,直到遇到停止条件或达到最大长度。
|
||||
|
||||
**核心思想**:局部最优选择 → 希望全局最优
|
||||
|
||||
**数学表达**:
|
||||
```
|
||||
token_t = argmax P(token_t | token_1, token_2, ..., token_{t-1})
|
||||
```
|
||||
|
||||
### 代码实现
|
||||
基于我们实现的 `_greedy_decode` 方法:
|
||||
|
||||
```python
|
||||
def _greedy_decode(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
贪婪解码:选择概率最大的token
|
||||
|
||||
Args:
|
||||
logits: 模型输出的logits,形状为 (batch_size, vocab_size)
|
||||
|
||||
Returns:
|
||||
选择的token索引,形状为 (batch_size, 1)
|
||||
"""
|
||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
return idx_next
|
||||
```
|
||||
|
||||
**关键步骤解析**:
|
||||
1. `torch.topk(logits, k=1, dim=-1)`:找到logits中最大值的位置
|
||||
2. 返回最大概率token的索引
|
||||
3. 该token被添加到序列中,继续下一轮生成
|
||||
|
||||
### 使用示例
|
||||
```python
|
||||
# 在 generate_super 函数中调用贪婪解码
|
||||
output = model.generate_super(
|
||||
input_ids,
|
||||
do_sample=False, # 不使用采样
|
||||
num_beams=1, # 不使用束搜索
|
||||
temperature=0.0, # 温度为0确保确定性
|
||||
max_new_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
### 优缺点分析
|
||||
|
||||
**优点**:
|
||||
- ✅ **速度快**:每步只需要一次前向传播和简单的argmax操作
|
||||
- ✅ **结果确定**:相同的输入总是产生相同的输出
|
||||
- ✅ **内存效率高**:不需要维护多个候选序列
|
||||
- ✅ **实现简单**:算法逻辑直观易懂
|
||||
|
||||
**缺点**:
|
||||
- ❌ **容易陷入局部最优**:每步的局部最优不一定等于全局最优
|
||||
- ❌ **缺乏多样性**:总是产生相同的序列,缺乏创造性
|
||||
- ❌ **可能产生重复内容**:容易陷入重复循环
|
||||
- ❌ **忽略长程依赖**:不考虑序列的整体连贯性
|
||||
|
||||
### 典型例子
|
||||
假设模型生成了以下概率分布:
|
||||
|
||||
```
|
||||
输入: "今天天气"
|
||||
下一token概率:
|
||||
- "很" (0.4)
|
||||
- "不错" (0.3)
|
||||
- "真好" (0.2)
|
||||
- "不太好" (0.1)
|
||||
```
|
||||
|
||||
贪婪解码会选择"很",生成"今天天气很",然后继续这个过程。
|
||||
|
||||
### 使用场景
|
||||
- **确定性任务**:如数学计算、代码生成
|
||||
- **需要一致性的应用**:如API服务、自动化脚本
|
||||
- **计算资源受限的环境**:需要快速生成结果
|
||||
- **基准测试**:作为其他算法的对比基准
|
||||
|
||||
## 采样解码(Sampling Decoding)
|
||||
|
||||
### 原理说明
|
||||
采样解码不是选择概率最大的token,而是基于模型的概率分布进行随机采样。这样可以在每次生成时产生不同的结果,增加文本的多样性和创造性。
|
||||
|
||||
**核心思想**:基于概率分布随机选择 → 增加多样性
|
||||
|
||||
**数学表达**:
|
||||
|
||||
```
|
||||
token_t ~ P(token_t | token_1, token_2, ..., token_{t-1})
|
||||
```
|
||||
|
||||
### 关键参数
|
||||
|
||||
#### 1. Temperature(温度)
|
||||
- **作用**:控制概率分布的平滑程度
|
||||
- **原理**:将logits除以temperature,然后进行softmax
|
||||
- **效果**:
|
||||
- `temperature > 1`:分布更平滑,增加随机性
|
||||
- `temperature < 1`:分布更尖锐,更接近贪婪解码
|
||||
- `temperature → 0`:等价于贪婪解码
|
||||
|
||||
#### 2. Top-k Sampling
|
||||
- **作用**:限制候选token的范围
|
||||
- **原理**:只考虑概率最高的k个token,其他token概率设为0
|
||||
- **效果**:避免选择概率很低的"奇怪"token,提高质量
|
||||
|
||||
### 代码实现
|
||||
基于我们实现的 `_random_sample` 方法:
|
||||
|
||||
```python
|
||||
def _random_sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
|
||||
"""
|
||||
随机采样:基于概率分布随机选择token
|
||||
|
||||
Args:
|
||||
logits: 模型输出的logits,形状为 (batch_size, vocab_size)
|
||||
temperature: 温度参数,控制随机性
|
||||
top_k: 只考虑概率最高的k个token
|
||||
|
||||
Returns:
|
||||
选择的token索引,形状为 (batch_size, 1)
|
||||
"""
|
||||
# 1. 温度缩放
|
||||
logits = logits / temperature
|
||||
|
||||
# 2. Top-k过滤
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
||||
# 3. 计算概率并采样
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
return idx_next
|
||||
```
|
||||
|
||||
**关键步骤解析**:
|
||||
1. **温度缩放**:调整概率分布的平滑程度
|
||||
2. **Top-k过滤**:移除低概率候选,提高质量
|
||||
3. **概率归一化**:使用softmax得到概率分布
|
||||
4. **随机采样**:根据概率分布随机选择token
|
||||
|
||||
### 使用示例
|
||||
```python
|
||||
# 基本采样
|
||||
output = model.generate_super(
|
||||
input_ids,
|
||||
do_sample=True, # 启用采样
|
||||
num_beams=1, # 不使用束搜索
|
||||
temperature=0.8, # 中等温度
|
||||
max_new_tokens=100
|
||||
)
|
||||
|
||||
# 带top-k的采样
|
||||
output = model.generate_super(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
temperature=1.0, # 较高温度增加随机性
|
||||
top_k=50, # 只考虑前50个候选
|
||||
max_new_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
### 温度参数详解
|
||||
|
||||
**不同温度的效果对比**:
|
||||
|
||||
```python
|
||||
# 示例概率分布
|
||||
original_probs = [0.6, 0.2, 0.1, 0.05, 0.05]
|
||||
|
||||
# Temperature = 0.1 (低温度,接近贪婪)
|
||||
scaled_probs = [0.85, 0.08, 0.04, 0.015, 0.015]
|
||||
# 结果:很可能选择第一个token
|
||||
|
||||
# Temperature = 1.0 (标准温度)
|
||||
scaled_probs = [0.6, 0.2, 0.1, 0.05, 0.05]
|
||||
# 结果:按原始概率采样
|
||||
|
||||
# Temperature = 2.0 (高温度,增加随机性)
|
||||
scaled_probs = [0.35, 0.25, 0.18, 0.11, 0.11]
|
||||
# 结果:各个token都有机会被选中
|
||||
```
|
||||
|
||||
### Top-k机制详解
|
||||
|
||||
**Top-k过滤过程**:
|
||||
|
||||
```python
|
||||
# 假设词汇表大小为1000,top_k=50
|
||||
logits = [0.1, 2.3, 0.5, 1.8, 0.3, 3.2, 0.9, 0.2, 1.5, 0.7, ...] # 1000个值
|
||||
|
||||
# 步骤1:找到前50个最大值
|
||||
v, _ = torch.topk(logits, 50)
|
||||
threshold = v[-1] # 第50大的值
|
||||
|
||||
# 步骤2:过滤
|
||||
logits[logits < threshold] = -float('Inf')
|
||||
# 结果:只有50个token有非零概率,其他950个token概率为0
|
||||
```
|
||||
|
||||
### 优缺点分析
|
||||
|
||||
**优点**:
|
||||
- ✅ **多样性好**:每次生成可能产生不同的结果
|
||||
- ✅ **创造性高**:能产生意想不到的内容
|
||||
- ✅ **避免重复**:不容易陷入重复循环
|
||||
- ✅ **可调性强**:通过参数控制随机程度
|
||||
|
||||
**缺点**:
|
||||
- ❌ **结果不确定**:相同输入可能产生不同输出
|
||||
- ❌ **质量不稳定**:可能产生低质量或不连贯的内容
|
||||
- ❌ **需要调参**:temperature和top_k需要仔细调节
|
||||
- ❌ **计算开销**:需要计算完整的概率分布
|
||||
|
||||
### 使用场景
|
||||
- **创意写作**:故事生成、诗歌创作
|
||||
- **对话系统**:让对话更加自然和有趣
|
||||
- **数据增强**:生成多样化的训练数据
|
||||
- **探索性任务**:需要探索多种可能性的场景
|
||||
|
||||
## 束搜索(Beam Search)
|
||||
|
||||
### 原理说明
|
||||
束搜索是一种启发式搜索算法,它在每一步生成时保留多个候选序列(束),而不是只选择一个最佳序列。通过维护多条路径,它能够在计算效率和生成质量之间取得平衡。
|
||||
|
||||
**核心思想**:维护多条候选路径 → 选择累积概率最高的序列
|
||||
|
||||
**算法流程**:
|
||||
1. **初始化**:从输入序列开始
|
||||
2. **扩展**:为每个候选序列生成多个扩展
|
||||
3. **评分**:计算每个新序列的累积概率
|
||||
4. **筛选**:保留分数最高的N个候选
|
||||
5. **重复**:继续扩展直到结束条件
|
||||
|
||||
### 关键概念
|
||||
|
||||
#### 束宽度(Beam Width)
|
||||
- **定义**:每步保留的候选序列数量
|
||||
- **权衡**:
|
||||
- 宽度=1:等价于贪婪解码
|
||||
- 宽度越大:搜索空间越大,质量越高,但计算成本也越大
|
||||
|
||||
#### 累积概率
|
||||
- **计算方式**:序列概率 = 各个token概率的乘积
|
||||
- **数值稳定性**:通常使用对数概率求和
|
||||
- **公式**:`log P(sequence) = Σ log P(token_i | context)`
|
||||
|
||||
### 代码实现
|
||||
基于我们实现的 `_beam_search` 方法:
|
||||
|
||||
```python
|
||||
def _beam_search(self, idx: torch.Tensor, max_new_tokens: int, num_beams: int,
|
||||
temperature: float = 1.0, top_k: int = None, stop_id: int = None) -> torch.Tensor:
|
||||
"""
|
||||
束搜索:维护多个候选序列,选择最优路径
|
||||
|
||||
Args:
|
||||
idx: 输入序列,形状为 (batch_size, seq_len)
|
||||
max_new_tokens: 最大生成token数量
|
||||
num_beams: 束宽度,表示保留的候选路径数量
|
||||
temperature: 温度参数,控制分布的平滑程度
|
||||
top_k: top-k过滤参数,限制候选token范围
|
||||
stop_id: 停止生成的token ID,遇到则停止
|
||||
|
||||
Returns:
|
||||
生成的token序列,形状为 (batch_size, generated_length)
|
||||
"""
|
||||
# 1. 初始化束
|
||||
beams = [idx.clone() for _ in range(num_beams)]
|
||||
beam_scores = torch.zeros(num_beams, device=idx.device)
|
||||
beam_scores[0] = 0.0 # 第一个候选是原始序列
|
||||
beam_scores[1:] = float('-inf') # 其他候选初始分数为负无穷
|
||||
|
||||
# 2. 主循环:逐步生成token
|
||||
for step in range(max_new_tokens):
|
||||
new_beams = []
|
||||
new_scores = []
|
||||
|
||||
# 3. 扩展每个候选序列
|
||||
for beam_idx, beam in enumerate(beams):
|
||||
if beam_scores[beam_idx] == float('-inf'):
|
||||
continue # 跳过无效候选
|
||||
|
||||
# 前向传播获取logits
|
||||
output = self(beam)
|
||||
logits = output.logits[:, -1, :]
|
||||
|
||||
# 应用温度和top-k
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
||||
# 计算对数概率
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
|
||||
# 获取前num_beams个候选token
|
||||
top_log_probs, top_indices = torch.topk(log_probs, k=num_beams, dim=-1)
|
||||
|
||||
# 4. 为当前候选生成多个扩展
|
||||
for k in range(num_beams):
|
||||
token = top_indices[:, k:k+1]
|
||||
log_prob = top_log_probs[:, k]
|
||||
|
||||
new_beam = torch.cat([beam, token], dim=1)
|
||||
new_score = beam_scores[beam_idx] + log_prob.item()
|
||||
|
||||
new_beams.append(new_beam)
|
||||
new_scores.append(new_score)
|
||||
|
||||
# 5. 筛选最佳候选
|
||||
if not new_beams:
|
||||
break
|
||||
|
||||
# 按分数排序,选择前num_beams个
|
||||
sorted_indices = sorted(range(len(new_scores)), key=lambda i: new_scores[i], reverse=True)
|
||||
beams = [new_beams[i] for i in sorted_indices[:num_beams]]
|
||||
beam_scores = [new_scores[i] for i in sorted_indices[:num_beams]]
|
||||
|
||||
# 检查停止条件
|
||||
if stop_id is not None and beams[0][0, -1] == stop_id:
|
||||
break
|
||||
|
||||
# 6. 返回最佳序列
|
||||
return beams[0][:, idx.shape[1]:] # 只返回生成部分
|
||||
```
|
||||
|
||||
### 束搜索过程示例
|
||||
|
||||
假设束宽度=3,输入="今天天气":
|
||||
|
||||
**第1步扩展**:
|
||||
```
|
||||
候选1: "今天天气很好" (分数: 0.4)
|
||||
候选2: "今天天气不错" (分数: 0.3)
|
||||
候选3: "今天天气真好" (分数: 0.2)
|
||||
```
|
||||
|
||||
**第2步扩展**(每个候选再扩展3个):
|
||||
```
|
||||
候选1.1: "今天天气很好啊" (分数: 0.4 + 0.1 = 0.5)
|
||||
候选1.2: "今天天气很好。" (分数: 0.4 + 0.2 = 0.6) ← 保留
|
||||
候选1.3: "今天天气很好," (分数: 0.4 + 0.05 = 0.45)
|
||||
|
||||
候选2.1: "今天天气不错啊" (分数: 0.3 + 0.15 = 0.45)
|
||||
候选2.2: "今天天气不错。" (分数: 0.3 + 0.1 = 0.4) ← 保留
|
||||
候选2.3: "今天天气不错," (分数: 0.3 + 0.08 = 0.38)
|
||||
|
||||
候选3.1: "今天天气真好啊" (分数: 0.2 + 0.12 = 0.32)
|
||||
候选3.2: "今天天气真好。" (分数: 0.2 + 0.25 = 0.45) ← 保留
|
||||
候选3.3: "今天天气真好," (分数: 0.2 + 0.1 = 0.3)
|
||||
```
|
||||
|
||||
**筛选结果**(保留分数最高的3个):
|
||||
```
|
||||
最佳候选: "今天天气很好。" (分数: 0.6)
|
||||
次佳候选: "今天天气不错。" (分数: 0.4)
|
||||
第三候选: "今天天气真好。" (分数: 0.45)
|
||||
```
|
||||
|
||||
### 使用示例
|
||||
```python
|
||||
# 基本束搜索
|
||||
output = model.generate_super(
|
||||
input_ids,
|
||||
do_sample=False, # 不使用采样
|
||||
num_beams=3, # 束宽度为3
|
||||
temperature=1.0, # 标准温度
|
||||
max_new_tokens=100
|
||||
)
|
||||
|
||||
# 带top-k的束搜索
|
||||
output = model.generate_super(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=5, # 更大的束宽度
|
||||
temperature=0.8, # 稍微降低温度
|
||||
top_k=50, # 限制候选范围
|
||||
max_new_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
### 优缺点分析
|
||||
|
||||
**优点**:
|
||||
- ✅ **质量较高**:比贪婪解码质量更好
|
||||
- ✅ **确定性**:结果相对稳定(相同输入产生相同输出)
|
||||
- ✅ **平衡性好**:在质量和效率之间取得平衡
|
||||
- ✅ **避免明显错误**:不容易选择明显不合适的token
|
||||
|
||||
**缺点**:
|
||||
- ❌ **计算开销大**:需要维护多个候选序列
|
||||
- ❌ **内存占用高**:存储多个候选序列和分数
|
||||
- ❌ **仍可能局部最优**:虽然比贪婪好,但仍可能错过全局最优
|
||||
- ❌ **多样性有限**:仍然偏向高概率路径,创造性不如采样
|
||||
|
||||
### 束宽度选择建议
|
||||
|
||||
| 束宽度 | 适用场景 | 优点 | 缺点 |
|
||||
|--------|----------|------|------|
|
||||
| 1-2 | 实时应用、计算资源有限 | 速度快、资源占用少 | 质量相对较低 |
|
||||
| 3-5 | 一般对话、文本生成 | 质量较好、速度适中 | 资源占用中等 |
|
||||
| 6-10 | 高质量生成、翻译 | 质量很高 | 计算开销大 |
|
||||
| 10+ | 专业应用、研究 | 最高质量 | 开销很大 |
|
||||
|
||||
### 使用场景
|
||||
- **机器翻译**:需要准确性和流畅性的平衡
|
||||
- **文本摘要**:生成连贯的摘要内容
|
||||
- **对话系统**:生成有逻辑的回复
|
||||
- **代码生成**:需要语法正确和逻辑合理
|
||||
- **长文本生成**:如文章写作、报告生成
|
||||
|
||||
## 辅助模型投机解码(Assisted Decoding)
|
||||
|
||||
### 原理说明
|
||||
投机解码是一种**用小模型加速大模型推理**的技术。它通过"草稿-验证"的方式,让小先生成候选token,然后大家模型快速验证,减少大模型的前向传播次数。
|
||||
|
||||
**核心思想**:小模型投机生成 → 大模型批量验证 → 减少大模型计算负担
|
||||
|
||||
### 工作流程
|
||||
|
||||
#### 1. 草稿生成阶段
|
||||
```
|
||||
输入: "今天天气"
|
||||
小模型快速生成草稿: "今天天气很好,适合出门散步"
|
||||
```
|
||||
|
||||
#### 2. 验证阶段
|
||||
大模型一次性验证整个草稿序列:
|
||||
- ✅ 接受的token:"今天天气很好,"
|
||||
- ❌ 拒绝的token:从"适合"开始拒绝
|
||||
- 🔧 大模型重新生成:"适合在家休息"
|
||||
|
||||
#### 3. 最终结果
|
||||
```
|
||||
输出: "今天天气很好,适合在家休息"
|
||||
```
|
||||
|
||||
### 关键优势
|
||||
|
||||
**速度提升**:
|
||||
- 小模型推理快 → 生成多个候选token
|
||||
- 大模型批量验证 → 一次处理多个token
|
||||
- 减少90%+的大模型前向传播
|
||||
|
||||
**质量保证**:
|
||||
- 大模型有最终否决权
|
||||
- 只有大模型认可的token才会被保留
|
||||
- 不会降低生成质量
|
||||
|
||||
### 具体例子对比
|
||||
|
||||
**传统方式**(大模型逐个生成):
|
||||
```
|
||||
第1步: 大模型 → "今天"
|
||||
第2步: 大模型 → "今天天气"
|
||||
第3步: 大模型 → "今天天气很"
|
||||
第4步: 大模型 → "今天天气很好"
|
||||
第5步: 大模型 → "今天天气很好,"
|
||||
第6步: 大模型 → "今天天气很好,适合"
|
||||
... (每步都需要大模型前向传播)
|
||||
```
|
||||
|
||||
**投机解码**:
|
||||
```
|
||||
第1步: 小模型快速草稿 → "今天天气很好,适合出门散步"
|
||||
第2步: 大模型批量验证 → 接受"今天天气很好,",拒绝"适合出门散步"
|
||||
第3步: 大模型重新生成 → "适合在家休息"
|
||||
```
|
||||
|
||||
这样原本需要6次大模型推理的过程,现在只需要2次!
|
||||
|
||||
### 技术实现要点
|
||||
|
||||
#### 1. 草稿长度控制
|
||||
- **草稿不宜过长**:通常2-10个token
|
||||
- **接受率平衡**:太长接受率低,太短加速效果不明显
|
||||
- **动态调整**:根据接受率调整草稿长度
|
||||
|
||||
#### 2. 验证机制
|
||||
```python
|
||||
# 伪代码
|
||||
def assisted_decoding(input_ids, assistant_model, main_model):
|
||||
# 小模型生成草稿
|
||||
draft_tokens = assistant_model.generate_draft(input_ids, max_draft_len=5)
|
||||
|
||||
# 大模型验证
|
||||
accepted_count = main_model.verify_draft(input_ids, draft_tokens)
|
||||
|
||||
# 构建最终结果
|
||||
if accepted_count == len(draft_tokens):
|
||||
return draft_tokens # 全部接受
|
||||
else:
|
||||
# 部分接受,大模型重新生成剩余部分
|
||||
accepted_part = draft_tokens[:accepted_count]
|
||||
remaining_part = main_model.generate_remaining(input_ids + accepted_part)
|
||||
return accepted_part + remaining_part
|
||||
```
|
||||
|
||||
### 总结
|
||||
投机解码本质上是用**计算资源换时间**,通过小模型的"投机"来减少大模型的计算负担。它是一种聪明的工程优化,在不牺牲质量的前提下显著提升推理速度。
|
||||
BIN
Extra-Chapter/s1-vllm-thinking-budget/images/image-1.png
Normal file
|
After Width: | Height: | Size: 115 KiB |
BIN
Extra-Chapter/s1-vllm-thinking-budget/images/image-2.png
Normal file
|
After Width: | Height: | Size: 272 KiB |
BIN
Extra-Chapter/s1-vllm-thinking-budget/images/image-3.png
Normal file
|
After Width: | Height: | Size: 157 KiB |
BIN
Extra-Chapter/s1-vllm-thinking-budget/images/image-4.png
Normal file
|
After Width: | Height: | Size: 289 KiB |
BIN
Extra-Chapter/s1-vllm-thinking-budget/images/thinking-budget.png
Normal file
|
After Width: | Height: | Size: 146 KiB |
2044
Extra-Chapter/s1-vllm-thinking-budget/output/output_1754208752.txt
Normal file
1978
Extra-Chapter/s1-vllm-thinking-budget/output/output_1754209653.txt
Normal file
179
Extra-Chapter/s1-vllm-thinking-budget/readme.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# S1: Thinking Budget with vLLM
|
||||
|
||||
首先,我们来了解一下AI教母李飞飞教授关于 Test-time scaling 的论文:[*《s1: Simple test-time scaling》*](http://arxiv.org/abs/2501.19393)
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/image-1.png" alt="alt text" width="50%">
|
||||
</div>
|
||||
|
||||
论文大致讲了个什么事情呢?简单来说,提出了一种新的测试时间缩放方法,旨在提高模型在推理阶段的效率和准确性。通过调整模型的思考预算,可以在不同的任务和数据集上实现更好的性能。
|
||||
|
||||
就是说对于一些复杂问题,需要用推理链来解决的问题,我们可以通过调整模型的思考预算来提高推理效率和准确性。上图也可以看到当思考预算增加时,模型的性能会有明显提升。
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/image-2.png" alt="alt text" width="50%">
|
||||
</div>
|
||||
|
||||
插一句题外话,论文中判断问题难易程度的方式是通过让 Qwen2.5-32B-Instruct 模型回答问题,答对的问题就是简单问题,答错的就是复杂问题。
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/image-3.png" alt="alt text" width="50%">
|
||||
</div>
|
||||
|
||||
|
||||
论文也做了消融实验来探讨,在未满足思考预算时插入一些不同的特定词(如:Wait!)对模型最终性能的影响。结果表明,插入特定词可以有效地引导模型进行更深入的思考,并且“Wait,Wait”的效果最好。
|
||||
|
||||
## 代码实现
|
||||
|
||||
我们使用 vLLM 来实现模型的思考预算。vLLM 是一个高性能的推理引擎,支持大规模语言模型的高效推理。以下为代码实现的步骤:
|
||||
|
||||
> 考虑到部分同学配置环境可能会遇到一些问题,我们在 ucloud 平台准备了环境镜像,点击下方链接并直接创建 ucloud 示例即可。 https://www.compshare.cn/images/8gfTTB5y0ql6?referral_code=ELukJdQS3vvCwYIfgsQf2C
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/thinking-budget.png" alt="alt text" width="80%">
|
||||
</div>
|
||||
|
||||
左侧为不使用思考预算的推理过程,右侧为使用思考预算的推理过程。可以看到,使用思考预算后,模型会在推理过程中插入特定词来引导模型进行更深入的思考。
|
||||
|
||||
以下为核心代码实现,完整代码请参考 [*s1.py*](./s1.py)
|
||||
|
||||
```python
|
||||
def run_thinking_budget_sample(llm_model, tokenizer, user_input, thinking_budget):
|
||||
input_text = build_input(user_input, tokenizer)
|
||||
input_token_count = count_token(input_text, tokenizer)
|
||||
|
||||
iteration_count= 0
|
||||
max_token = input_token_count + thinking_budget
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=4096,
|
||||
skip_special_tokens=False
|
||||
)
|
||||
|
||||
think_token_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
wait_sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=thinking_budget - think_token_count,
|
||||
stop='</think>',
|
||||
skip_special_tokens=False
|
||||
)
|
||||
|
||||
outputs = llm_model.generate(
|
||||
input_text,
|
||||
wait_sampling_params
|
||||
)
|
||||
total_token, think_token_count = count_thinking_token(outputs, tokenizer)
|
||||
|
||||
print(f'第{iteration_count}次迭代,思考token数:{think_token_count}')
|
||||
|
||||
if think_token_count > thinking_budget:
|
||||
break
|
||||
input_text = total_token + "\nWait!\n"
|
||||
|
||||
# \nWait a moment. Was there any loophole in my thought just now?!\n
|
||||
# \nWait!\n
|
||||
|
||||
iteration_count += 1
|
||||
|
||||
final_outputs = llm_model.generate(
|
||||
outputs[0].prompt + outputs[0].outputs[0].text + "\n</think>\n",
|
||||
sampling_params
|
||||
)
|
||||
|
||||
total_content = final_outputs[0].prompt + final_outputs[0].outputs[0].text
|
||||
thinking_content = total_content.split("<think>")[-1].split("</think>")[0]
|
||||
|
||||
print(total_content)
|
||||
|
||||
print(f"迭代次数:{iteration_count}, 输入token数:{input_token_count}, 思考token数:{count_token(thinking_content, tokenizer)}, 总token数:{count_token(total_content, tokenizer)}")
|
||||
```
|
||||
|
||||
首先是要定义一个函数 `run_thinking_budget_sample`,该函数接收模型、tokenizer、用户输入和思考预算作为参数。然后构建输入文本并计算输入的 token 数量。
|
||||
|
||||
因为`max_tokens` 参数是指生成的最大 token 数量,所以我们需要计算输入文本的 token 数量,并将其与思考预算相加,得到 `max_token = thinking_budget - think_token_count`。如果思考 token 数量超过了思考预算,就停止迭代。另外还需要在 `SamplingParams` 中设置 `stop` 参数为 `</think>`,这样模型在生成文本时会在遇到 `</think>` 时停止生成。
|
||||
|
||||
```python
|
||||
wait_sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=thinking_budget - think_token_count,
|
||||
stop='</think>',
|
||||
skip_special_tokens=False
|
||||
)
|
||||
```
|
||||
|
||||
另外还需要在每次迭代中,使用 `llm_model.generate` 方法生成文本,并计算思考 token 数量。如果思考 token 数量超过了思考预算,就停止迭代。否则,将生成的文本添加到输入文本中,并在文本末尾添加 `\nWait!\n`,以引导模型进行更深入的思考。
|
||||
|
||||
```python
|
||||
while True:
|
||||
wait_sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=thinking_budget - think_token_count,
|
||||
stop='</think>',
|
||||
skip_special_tokens=False
|
||||
)
|
||||
|
||||
outputs = llm_model.generate(
|
||||
input_text,
|
||||
wait_sampling_params
|
||||
)
|
||||
total_token, think_token_count = count_thinking_token(outputs, tokenizer)
|
||||
|
||||
print(f'第{iteration_count}次迭代,思考token数:{think_token_count}')
|
||||
|
||||
if think_token_count > thinking_budget:
|
||||
break
|
||||
input_text = total_token + "\nWait!\n"
|
||||
|
||||
# \nWait a moment. Was there any loophole in my thought just now?!\n
|
||||
# \nWait!\n
|
||||
|
||||
iteration_count += 1
|
||||
```
|
||||
|
||||
当达到思考预算后,使用 `llm_model.generate` 方法生成最终的输出文本,并将其打印出来。最后输出迭代次数、输入 token 数量、思考 token 数量和总 token 数量。
|
||||
|
||||
```python
|
||||
final_outputs = llm_model.generate(
|
||||
outputs[0].prompt + outputs[0].outputs[0].text + "\n</think>\n",
|
||||
sampling_params
|
||||
)
|
||||
|
||||
total_content = final_outputs[0].prompt + final_outputs[0].outputs[0].text
|
||||
thinking_content = total_content.split("<think>")[-1].split("</think>")[0]
|
||||
|
||||
print(total_content)
|
||||
|
||||
print(f"迭代次数:{iteration_count}, 输入token数:{input_token_count}, 思考token数:{count_token(thinking_content, tokenizer)}, 总token数:{count_token(total_content, tokenizer)}")
|
||||
```
|
||||
|
||||
此时我们还需要另外一个 `SamplingParams` 对象来设置最终生成文本的参数。`max_tokens` 参数设置为 4096,模型根据思考过程进行总结最后得出答案,这个过程也需要很多tokn,此时设置为多少都可以,通常设置为一个较大的值即可。
|
||||
|
||||
```python
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=4096,
|
||||
skip_special_tokens=False
|
||||
)
|
||||
```
|
||||
|
||||
以上为核心代码实现,完整代码请参考 [*s1.py*](./s1.py)。在实际使用中,可以根据具体的任务和数据集调整思考预算和其他参数,以获得更好的性能。
|
||||
|
||||
## 结果分析
|
||||
|
||||
使用思考预算后,模型在推理过程中能够更深入地思考问题,从而提高了推理效率和准确性。但是也发现了一些有趣的现象。
|
||||
|
||||
例如,在某些情况下,就算插入了`Wait!`,模型并不会按照论文中所示进行多种不同方式尝试解答,或是反思之前的思考过程是否正确。而且还会出现模型在思考过程中重复生成相同的内容,导致思考 token 数量超过思考预算的情况。
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/image-4.png" alt="alt text" width="70%">
|
||||
</div>
|
||||
|
||||
当然,也有可能本身测试的模型只有14B参数,导致其在思考过程中的能力受到限制。
|
||||
|
||||
经过测试下来,有可能强行使用特定词(如:Wait!)来引导模型进行更深入的思考,可能会促使模型产生 “一条道走到黑” 的想法。
|
||||
|
||||
部分实验测试记录在 [*output*](./output/) 中。
|
||||
131
Extra-Chapter/s1-vllm-thinking-budget/s1.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
import time
|
||||
|
||||
def build_input(prompt, tokenizer):
|
||||
messages = [
|
||||
{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{{}}."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
input_text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True
|
||||
)
|
||||
return input_text
|
||||
|
||||
def count_thinking_token(outputs, tokenizer):
|
||||
total_token = outputs[0].prompt + outputs[0].outputs[0].text
|
||||
thinking_token = total_token.split("<think>\n")[-1]
|
||||
thinking_token_id = tokenizer(thinking_token)["input_ids"]
|
||||
return total_token, len(thinking_token_id)
|
||||
|
||||
def count_token(string, tokenizer):
|
||||
return len(tokenizer(string)["input_ids"])
|
||||
|
||||
|
||||
def run_thinking_budget_sample(llm_model, tokenizer, user_input, thinking_budget):
|
||||
input_text = build_input(user_input, tokenizer)
|
||||
input_token_count = count_token(input_text, tokenizer)
|
||||
|
||||
iteration_count= 0
|
||||
max_token = input_token_count + thinking_budget
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=4096,
|
||||
skip_special_tokens=False
|
||||
)
|
||||
|
||||
think_token_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
wait_sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=thinking_budget - think_token_count,
|
||||
stop='</think>',
|
||||
skip_special_tokens=False
|
||||
)
|
||||
|
||||
outputs = llm_model.generate(
|
||||
input_text,
|
||||
wait_sampling_params
|
||||
)
|
||||
total_token, think_token_count = count_thinking_token(outputs, tokenizer)
|
||||
|
||||
print(f'第{iteration_count}次迭代,思考token数:{think_token_count}')
|
||||
|
||||
if think_token_count > thinking_budget:
|
||||
break
|
||||
input_text = total_token + "\nWait!\n"
|
||||
|
||||
# \nWait a moment. Was there any loophole in my thought just now?!\n
|
||||
# \nWait!\n
|
||||
|
||||
iteration_count += 1
|
||||
|
||||
final_outputs = llm_model.generate(
|
||||
outputs[0].prompt + outputs[0].outputs[0].text + "\n</think>\n",
|
||||
sampling_params
|
||||
)
|
||||
|
||||
total_content = final_outputs[0].prompt + final_outputs[0].outputs[0].text
|
||||
thinking_content = total_content.split("<think>")[-1].split("</think>")[0]
|
||||
|
||||
print(total_content)
|
||||
|
||||
print(f"迭代次数:{iteration_count}, 输入token数:{input_token_count}, 思考token数:{count_token(thinking_content, tokenizer)}, 总token数:{count_token(total_content, tokenizer)}")
|
||||
|
||||
# 保存输出到文件
|
||||
with open(f"output_{int(time.time())}.txt", "w") as f:
|
||||
f.write(total_content)
|
||||
f.write(f"\n迭代次数:{iteration_count}, 输入token数:{input_token_count}, 思考token数:{count_token(thinking_content, tokenizer)}, 总token数:{count_token(total_content, tokenizer)}")
|
||||
|
||||
|
||||
def run_sample(llm_model, tokenizer, user_input):
|
||||
input_text = build_input(user_input, tokenizer)
|
||||
input_token_count = count_token(input_text, tokenizer)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
max_tokens=32768,
|
||||
skip_special_tokens=False
|
||||
)
|
||||
|
||||
final_outputs = llm_model.generate(
|
||||
input_text,
|
||||
sampling_params
|
||||
)
|
||||
|
||||
total_content = final_outputs[0].prompt + final_outputs[0].outputs[0].text
|
||||
thinking_content = total_content.split("<think>")[-1].split("</think>")[0]
|
||||
print(total_content)
|
||||
|
||||
print(f"输入token数:{input_token_count}, 思考token数:{count_token(thinking_content, tokenizer)}, 总token数:{count_token(total_content, tokenizer)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_path = "/model/ModelScope/Qwen/Qwen3-14B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
gpu_memory_utilization=0.9,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
print("=================================== 思考预算采样 ===================================")
|
||||
run_thinking_budget_sample(
|
||||
llm_model=llm,
|
||||
tokenizer=tokenizer,
|
||||
user_input="There are exactly three positive real numbers $ k $ such that the function\n$ f(x) = \\frac{(x - 18)(x - 72)(x - 98)(x - k)}{x} $\ndefined over the positive real numbers achieves its minimum value at exactly two positive real numbers $ x $. Find the sum of these three values of $ k $.",
|
||||
thinking_budget=32768
|
||||
)
|
||||
|
||||
# print("=================================== 无思考预算采样 ===================================")
|
||||
# run_sample(
|
||||
# llm_model=llm,
|
||||
# tokenizer=tokenizer,
|
||||
# user_input="There are exactly three positive real numbers $ k $ such that the function\n$ f(x) = \\frac{(x - 18)(x - 72)(x - 98)(x - k)}{x} $\ndefined over the positive real numbers achieves its minimum value at exactly two positive real numbers $ x $. Find the sum of these three values of $ k $."
|
||||
# )
|
||||
742
Extra-Chapter/vlm-concatenation-finetune/README.md
Normal file
@@ -0,0 +1,742 @@
|
||||
# Qwen3-"VL"——超小中文多模态模型的“拼接微调”之路1(附代码和SwanLab记录)
|
||||
|
||||
* 作者:情感机器实验室——陈少宏
|
||||
|
||||
* 邮箱:<shaohon_chen@115lab.club>
|
||||
|
||||
* GitHub:[https://github.com/ShaohonChen/Qwen3-SmVL](https://github.com/ShaohonChen/Qwen3-SmVL)
|
||||
* SwanLab:[https://swanlab.cn/@ShaohonChen/Qwen3-SmVL/overview](https://swanlab.cn/@ShaohonChen/Qwen3-SmVL/overview)
|
||||
* 数据集:[https://huggingface.co/datasets/HuggingFaceM4/the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron)
|
||||
|
||||
> 💚 **特别感谢**
|
||||
> 感谢 [@zhihuazhao-bit](https://github.com/zhihuazhao-bit) 帮我审阅和修复了提交代码中众多的小 bug,并在 NV 上完成了测试。
|
||||
> 感谢 [@KMnO4-zx](https://github.com/KMnO4-zx) 对教程文章内容的审核与修正。
|
||||
|
||||
## 摘要
|
||||
|
||||
最近Huggingface团队发布了超小多模态模型SmolVLM2,可以做到端侧1GB显存推理。在怀着惊喜试用后发现,虽然模型有极其强大的视觉文本理解能力,但是模型却无法理解中文。这对一个“四六级压线过”的笔者来说十分不友好。刚好前段时间做SwanLab硬件检测适配时有一台未到期的沐曦曦云C500服务器,因此萌生了使用**沐曦GPU芯片**微调、把当前中文小模型扛把子Qwen3与SmolVLM2直接微调拼接的想法。
|
||||
|
||||
本教程将介绍一种模型拼接的思路,将SmolVLM2的视觉模块(0.09B)与Qwen3最小的模型(0.6B)进行对齐微调,最终使得Qwen模型具备一定的视觉理解能力。由于笔者时间有限且考虑到文章篇幅的原因,因此该系列预计将以系列的方式放出。篇幅规划如下:
|
||||
|
||||
* **第一篇**:如何构建和微调一个拼接模型(**本篇博客**)
|
||||
* **第二篇**:模型测评、数据集优化、回答人类对齐
|
||||
* **第三篇**:微调技巧介绍、视觉位置编码改动与模型结构优化
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/PPAP.png" alt="PPAP" width="400" />
|
||||
<figcaption>I have a Qwen, I have a SmolVLM...</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
<div style="background-color:#fff3cd; color:black; padding:10px; border-radius:4px; border:1px solid #fbe5b0; width: 90%; max-width: 100%; margin: auto;">
|
||||
⚠️关于算力的注意:本教程涉及VLM微调训练,对算力要求较高,需要40G及以上的GPU显存才能运行本教程的训练代码。
|
||||
</div>
|
||||
|
||||
## 目录
|
||||
|
||||
* [SmolVLM2的背景知识](#SmolVLM2的背景知识)
|
||||
* [模型拼接和微调思路简介](#模型拼接和微调思路简介)
|
||||
* [模型拼接实现和关键代码讲解](#模型拼接实现和关键代码讲解)
|
||||
* [微调数据集构建](#微调数据集构建)
|
||||
* [微调方法与代码实现](#微调方法与代码实现)
|
||||
* [微调训练&结果展示](#微调训练&结果展示)
|
||||
* [代码及数据集链接汇总](#代码及数据集链接汇总)
|
||||
|
||||
## SmolVLM2的背景知识
|
||||
|
||||
首先,我们先回顾一下SmolVLM2模型的构建方案,SmolVLM2模型的整体包括三大块:视觉模型层,特征映射层和大语言模型层,见下图:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/smolvlm2.png" alt="smolvlm2" width="400" />
|
||||
<figcaption>SmolVLM2的架构图</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
这个设计是现在比较常见的VLM方案。核心设计思想就是让视觉模型的输出特征与经过embedding的文本特征直接拼接后输入到语言模型(LLM)当中,没有交叉注意力等模块。相比于早期LLaVA等架构,这种最大的优点就是可以最大程度复用已有的语言模型。以Qwen2.5-VL为例,其3B、7B、72B模型大小指的只是LLM部分,并没有包含Vision模块,实际上3B模型的参数量接近4B,视觉模块大概0.4B左右,三个不同大小的VLM使用的是统一的视觉模型。对于一些较大的VLM来说,构建视觉模型时绝大多数的训练都集中在特征映射模块和视觉模块,只在最后阶段为了最终效果进行整体微调时才会调整语言模块。保证了VLM的语言能力。
|
||||
|
||||
下面简述一下各个模块的细节:
|
||||
|
||||
* 视觉模型层:SmolVLM2-256M版本用的是Google的SigLip模型,一个基于ViT的视觉模型,选用的是最小的SigLip-93M的版本,HF论文里没具体写是直接用的SigLip的参数还是他们从零构建的(有注意到的读者可以评论留言下)。在SmolVLM2代码中对应的是`SmolVLMVisionTransformer`类
|
||||
|
||||
* 特征映射层:就是一个简单的MLP,不过SmolVLM2中为了降低图像分辨率还做了一个Pixel shuffle来降低图像分辨率,进一步减少视觉的Token占用,减少了文本长度。HF团队在论文里提到对于参数量较小的VLM来说使用Pixel shuffle还能提升性能。但可训练参数其实就是一个单层的神经网络,这个模块的核心作用就是做特征对齐,将视觉特征从768维(SigLip的维度)映射到576维(SmolLLM2的维度)
|
||||
|
||||
* 大语言模型:SmolVLM2-256M模型使用的文本模型是SmolLM-135M版本。可能是由于模型较小,HF团队在论文中说到训练时仅采用两阶段训练:大规模图文训练+针对视频任务的专门微调。为了保障模型的文本能力HF团队在训练数据中参杂了大概14%的纯文本微调数据。不过考虑到视觉模块本身参数量(93M)大小接近于文本模型(135M),因此笔者推测相比于冻结文本模型,数据平衡在这之中会起到更关键的作用。
|
||||
|
||||
HF团队在原文中还提到了许多影像小模型VLM性能的trick,感兴趣的读者可以进一步参考SmolVLM2的论文
|
||||
|
||||
## 模型拼接和微调思路简介
|
||||
|
||||
正所谓顶级食材(模型)只需要最简单的烹饪。模型拼接的思路非常简单直接,基本就三步:
|
||||
|
||||
1. 调整SmolVLM2的“上下文控制格式”,使得其与Qwen3兼容。
|
||||
|
||||
2. 将模型的文本部分直接从SmolLM2换成Qwen3-0.6B,包括其文本tokenizer和词嵌入、文本模型、以及模型最后输出的语言模型头(LM Head)。
|
||||
|
||||
3. 需要重新初始化特征映射层的MLP,从768->576的单层神经网络改成768->1024的单层神经网络即可。
|
||||
|
||||
整体架构和对图文对前后处理依旧保持SmolVLM2的流程不变,具体改动见下图:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/concatation.png" alt="concatation" width="400" />
|
||||
<figcaption>将Qwen3-0.6B替换SmolVLM2的语言模型部分</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
笔者接下来详细介绍下为了实现“拼接”,具体改动的地方,供之后有类似的任务的读者参考。
|
||||
|
||||
## 模型拼接实现和关键代码讲解
|
||||
|
||||
### 第一处改动:SmolVLM2的Tokenizers部分
|
||||
|
||||
首先需要改动的就是需要改动的是SmolVLM2的Tokenizers部分,这里面主要是涉及两个问题:
|
||||
|
||||
* 第一个问题是要将SmolVLM2用于指示图像位置的特殊令牌(Special Token)加入到Qwen3的Tokenizer当中,这么做的目的是防止SmolVLM2的图像Token`<image>`被切分为`<`、`image`、`>`三块。幸运的是,Qwen3本身在Tokenizers中预留了未来用于多模态的特殊特殊令牌`<|image_pad|>`。因此读者直接使用了`<|image_pad|>`代替了`<image>`。用于在文本中预留图像特征的插入点。
|
||||
|
||||
* 第二个问题是:SmolVLM2的chat_template和Qwen3的chat_template差别极大。chat_template的作用是通过格式化文本让模型清楚知道不同Token所代表的背景信息。用最近比较流行的话来说就是“上下文工程”(Context Engineering)。
|
||||
|
||||
这里我列举了一下Qwen3、SmolVLM2、Qwen2.5-VL在聊天场景下的上下文,供读者参考。
|
||||
|
||||
**Qwen3聊天上下文格式**
|
||||
|
||||
以给一张图片,问题是“你的名字是什么?”,模型回答是“我的名字是Qwen”为例子。模型的上下文如下:
|
||||
|
||||
```txt
|
||||
<|im_start|>user
|
||||
你的名字是什么?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
|
||||
</think>
|
||||
|
||||
我的名字是Qwen<|im_end|>
|
||||
|
||||
```
|
||||
|
||||
注意Qwen3上下文是没有预留图像位置的,但相比于一般的LLM和VLM多了一个用于插入模型思考过程的`<think><\think>`,以及包含额外的函数调用控制文本。为了便于读者理解,读者在在下面举了一个函数调用的例子。这些函数调用上下文用于控制模型调用外部函数、API或者MCP接口和接收其返回的信息。
|
||||
|
||||
考虑到篇幅限制,本文就不粘贴带函数调用、推理、思考等一系列上下文的信息了(笔者打印了下发现实在太长了)。感兴趣的读者可以在Qwen3的官方文处了解详细设计
|
||||
|
||||
* [Qwen3函数调用案例](https://qwen.readthedocs.io/zh-cn/latest/framework/function_call.html#the-example-case)
|
||||
|
||||
可以说正是这些复杂的上下文信息让模型有可能实现推理、调用函数等多样化的能力。包括多模态理解任务也需要先对上下文进行设计。
|
||||
|
||||
**SmdwadwdoVLM2聊天上下文格式:**
|
||||
|
||||
以给一张图片,问题是“How many dog in there.”,模型回答是“There are Three dogs.”为例子。三种不同模型的上下文如下:
|
||||
|
||||
```txt
|
||||
<|im_start|>User:<fake_token_around_image><row_1_col_1><image>...<image><fake_token_around_image><row_1_col_2><image>...<image><fake_token_around_image><row_1_col_3><image>...<image>...<fake_token_around_image><row_4_col_4><image>...<image>
|
||||
|
||||
<fake_token_around_image><global-img><image>...<image><fake_token_around_image>How many dog in there.<end_of_utterance>
|
||||
Assistant: There are Three dogs.<end_of_utterance>
|
||||
Assistant:
|
||||
```
|
||||
|
||||
看起来非常乱,是因为有大量的`<image>`占位符。`<image>...<image>`之间是许多的`<image>`,笔者为了文章观感删掉了大量的占位符。注意模型的回车、空格均为上下文的一部分,在进行推理时需要严格遵守缩进关系。
|
||||
|
||||
但是我们仍能找到熟悉的内容,如`User:`,`Assistant:`等用于提示模型用户的输入与模型应当输出的位置。这些关键词和Qwen类似。
|
||||
|
||||
读者注意到了除了`<fake_token_around_image>`,`<image>`等用于指示图像的词,还出现了<row_1_col_1>这种位置指示符,这是因为SmolVLM2为了防止降采样对图像分辨率影响,专门使用了`image splitting`技术,简单来说就是将全局图和高清的局部图共同输入到模型当中(见下图`image splitting`模块),感兴趣的读者可在文末找到HF的技术报告了解详细技术。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/image-split.png" alt="image-split" width="400" />
|
||||
<figcaption>SmolVLM2的完整推理流程,可以看到在图像输入前使用`image splitting`进行了预切分</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
**本博文的拼接模型Qwen3-SmVL模型**
|
||||
|
||||
相比于Qwen3,SmolVLM2少了很多上下控制的
|
||||
|
||||
为了尽可能保存或者说预留Qwen3的思考、函数调用等能力,笔者最终选择将SmolVLM2对于图像特征的排列插入到Qwen3的上下文格式当中。最终上下文格式如下:
|
||||
|
||||
```txt
|
||||
<|im_start|>user
|
||||
<vision_start><row_1_col_1><|image_pad|>(图像插入的地方)<|image_pad|><vision_start>
|
||||
(用户提问的地方)
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
|
||||
</think>
|
||||
|
||||
(模型回答的地方)<|im_end|>
|
||||
<|endoftext|>
|
||||
```
|
||||
|
||||
可以看到读者尽量保持了与Qwen3的风格和复用特殊令牌。这样能够使得后续拼接的Qwen3-0.6B模型不至于受到上下文差异过大带来的性能损耗。实际上在设计微调上下文时应尽量与模型先前训练的任务接近,以减少微调带来的性能损失。
|
||||
|
||||
transformers实现模型上下文格式控制的代码并非python语言,而是一种前端文本格式控制的语言Jinja。这个语言的变量作用域设计简直可以说是有魔法在里面。配合上Qwen3功能丰富且复杂的上下文策略,让笔者花了2个小时用于修改chat_teamplate。这里笔者不赘述如何修改chat_template,感兴趣的读者可以去文末代码链接寻找`chat_template.jinja`文件,笔者专门将chat_template模版拿出来,并且做了格式化方便读者阅读。未来有时间了笔者专门写一篇模型上下文控制与jinja语言的博客。
|
||||
|
||||
### 第二处改动:替换SmolVLM2的SmolLM2模型为Qwen3-0.6B
|
||||
|
||||
替换模型这块没什么复杂的,主要是需要处理Transformers比较复杂的嵌套逻辑。Tranformers通常建议模型将预训练模型backbone和下游任务分开来。改动逻辑图如下:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/change_model.png" alt="change_model" width="400" />
|
||||
<figcaption>替换smolvlm2的文本模块和语言模型头</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
以Qwen3为例,预训练Backbone模型为`Qwen3Model`,仅仅包含embedding层、各个Decoder层,最后输出的是所有输入token的hidden state。负责下游任务的Qwen3提供了包括:用于因果语言序列生成的`Qwen3ForCausalLM`,也就是大家常用的语言生成。负责句子分类`Qwen3ForSequenceClassification`,使用最后一个生成的token输入到一个单层MLP做序列级分类,做句子情绪分类等可以用这个下游模型;`Qwen3ForTokenClassification`用于做Token级分类,比如语言实体抽取任务可以使用这个下游模型。`Qwen3ForQuestionAnswering`则是专门做抽取式问答任务的模型,核心思想是输入(问题,参考文本)让模型从参考文本中找到与问题最相关的一段,这类任务由于RAG系统的出现没那么流行了,未来笔者专门出一个系列的教程阐述除了因果语言序列生成以外的任务则怎么微调。
|
||||
|
||||
**关键代码如下**
|
||||
|
||||
```python
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoModelForImageTextToText,
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM
|
||||
)
|
||||
|
||||
# 替换text模型和head
|
||||
smolvlm2_02B_model = AutoModelForImageTextToText.from_pretrained(
|
||||
"model/SmolVLM2-256M-Video-Instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
_attn_implementation="eager",
|
||||
).to(device)
|
||||
|
||||
qwen3_06b_model = AutoModelForCausalLM.from_pretrained(
|
||||
"model/Qwen3-0.6B", torch_dtype=torch.bfloat16
|
||||
).to(device)
|
||||
|
||||
smolvlm2_02B_model.model.text_model = qwen3_06b_model.model
|
||||
smolvlm2_02B_model.lm_head = qwen3_06b_model.lm_head
|
||||
...
|
||||
```
|
||||
|
||||
接下来比较复杂的是替换所有的关键变量,比如模型内用于在文本序列中为图像特征预留的占位符`image_token_id`,用于指示停止生成的`eos_token_id`,和计算loss值会用到的`vocab_size`,Qwen的词表大小为151936,远远大过SmolVLM2的词表49280。具体代码如下:
|
||||
|
||||
```python
|
||||
...
|
||||
# 替换词表大小
|
||||
smolvlm2_02B_model.vocab_size = qwen3_06b_model.vocab_size
|
||||
smolvlm2_02B_model.model.vocab_size = qwen3_06b_model.vocab_size
|
||||
smolvlm2_02B_model.config.vocab_size = qwen3_06b_model.vocab_size
|
||||
smolvlm2_02B_model.config.text_config.vocab_size = qwen3_06b_model.vocab_size
|
||||
smolvlm2_02B_model.model.config.vocab_siz = qwen3_06b_model.vocab_size
|
||||
smolvlm2_02B_model.model.config.text_config.vocab_size = qwen3_06b_model.vocab_size
|
||||
# 替换图像token
|
||||
smolvlm2_02B_model.image_token_id = 151655
|
||||
smolvlm2_02B_model.model.image_token_id = 151655
|
||||
smolvlm2_02B_model.config.image_token_id = 151655
|
||||
smolvlm2_02B_model.model.config.image_token_id = 151655
|
||||
# 替换模型生成停止符
|
||||
smolvlm2_02B_model.generation_config.eos_token_id = 151645
|
||||
···
|
||||
```
|
||||
|
||||
上面的代码可以看到在替换各个变量时需要将嵌套模型的变量一起替换掉,笔者之前训练时就因为仅仅替换了`SmolVLMForConditionalGeneration`而忘记替换`SmolVLMModel`中的`image_token_id`,导致语言模型接收不到图像特征,最后表现出来就是loss下降的极快且低,grad_norm看起来也学到位了,一推理效果特别差,附上错误训练的损失图:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/fail_train.png" alt="fail_train" width="800" />
|
||||
<figcaption>SwanLab记录训练结果展示:蓝色为错误训练的完整微调loss图,可以看到损失下降很快,然而实际推理会发现模型并没有图像理解能力。冻结语言模型头(红色)后发现grad_norm为零且loss不收敛,正确的应该是黄色</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
笔者最早没发现改动错误,先做完整微调(蓝色曲线)后发现损失下降很快达到了0.1以下,结果实际一推理发现模型完全没有图像理解能力,就补了一个冻结语言模型只微调视觉模型的实验(红色曲线),结果发现损失完全没下降,才定位到了视觉特征传入有问题。后续修复后正确的损失下降过程见黄色图像。
|
||||
|
||||
### 第三处改动:构建和替换特征映射层
|
||||
|
||||
这个相对较简单,只需要重新构建一个维度对齐的`SmolVLMConnector`即可。Qwen3的hidden_dim是1024,SigLip的hidden_dim是768,因此构建一个768➡️1024映射的`SmolVLMConnector`即可。代码如下:
|
||||
|
||||
```python
|
||||
···
|
||||
# 构建配置并且创建连接器
|
||||
@dataclass
|
||||
class VisionConfig:
|
||||
hidden_size: int = 768
|
||||
|
||||
@dataclass
|
||||
class TextConfig:
|
||||
hidden_size: int = 1024
|
||||
|
||||
@dataclass
|
||||
class ConnectConfig:
|
||||
scale_factor: int = 4
|
||||
vision_config: VisionConfig = VisionConfig()
|
||||
text_config: TextConfig = TextConfig()
|
||||
|
||||
new_connector_config = ConnectConfig()
|
||||
|
||||
# 替换 SigLit 到 LLM 的 connector 层
|
||||
new_connector = SmolVLMConnector(new_connector_config).to(device).to(torch.bfloat16)
|
||||
smolvlm2_02B_model.model.connector = new_connector
|
||||
···
|
||||
```
|
||||
|
||||
## 微调数据集构建
|
||||
|
||||
笔者最初计划寻找中文多模态数据集,但发现相关的资料比较少。因此决定先用英文的多模态数据集凑合一下。之后再考虑通过数据合成的方式将部分数据翻译为中文。关于数据合成和配比的问题将在之后的博客讨论。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/the_cauldron.png" alt="the_cauldron" width="400" />
|
||||
<figcaption>the_cauldron数据集logo</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
这里为了方便本项目直接使用HuggingFace团队整合的多模态数据集the Cauldron数据集,Cauldron翻译成中文类似于煮东西的“釜”,不知道HF团队是不是玩“炼丹”的梗。这个数据集整合了50个视觉微调任务数据集的训练集,用于微调Huggingface发布的多模态模型Idefics2模型。这50多个数据集都被处理成了一致的格式(见下图),共有1,880,992条数据,完整下载约169G,非常方便使用。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/data_show.png" alt="data_show" width="800" />
|
||||
<figcaption>数据集样本展示</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
不过可惜数据集的文本都是英文内容,且绝大多数数据集的回复非常短,只有一个词,这也给后面模型训练带来了麻烦。本篇博客暂时不讨论关于数据构建和配比的问题,后续有时间了专门做相关的实验。本博客先以为Qwen3模型带来视觉能力为核心目标。
|
||||
|
||||
数据集的下载链接如下,国内推荐用modelscope下载:
|
||||
|
||||
* [HuggingFace Hub](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron)
|
||||
* [ModelScope](https://modelscope.cn/datasets/AI-ModelScope/the_cauldron)
|
||||
|
||||
笔者在实际测试时发现"mimic_cgd","localized_narratives","okvqa","ocrvqa","clevr_math"这几个子数据集加载有点异常,建议使用此数据集训练的读者手动处理下,社区也有用户反馈这几个数据可以在原始来源处额外下载,未来笔者将会补全这几个数据集重新上传一次完整版的the Cauldron数据集。
|
||||
|
||||
## 微调方法与代码实现
|
||||
|
||||
### 冻结模型参数微调
|
||||
|
||||
整体微调方法采用了CLM模型通常的Teacher Forcing的学习方法,损失就是标准的交叉熵损失。考虑到此次本教程的目标是先确保模型具备中文多模态能力(优化模型性能等之后撰写其他博客),因此为了实验效率,在对齐微调阶段**采用冻结视觉模型与文本模型,仅微调特征映射器和语言模型头**的方法。
|
||||
|
||||
冻结模型参数的核心代码如下:
|
||||
|
||||
```python
|
||||
def freeze_model(qwen_smvl):
|
||||
for _, param in qwen_smvl.model.text_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
for _, param in qwen_smvl.model.vision_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
return qwen_smvl
|
||||
```
|
||||
|
||||
冻结后训练参数、模型总参数、与占比如下:
|
||||
|
||||
```txt
|
||||
trainable params: 12.00M || all params: 662.87M || trainable%: 1.81
|
||||
```
|
||||
|
||||
### 文本长度,损失掩码和截断策略
|
||||
|
||||
**文本长度**
|
||||
|
||||
由于视觉特征需要占据大量的文本长度,笔者简单测试了下the_cauldron图像占0.8K到1.3K左右的token。而数据集中大多数文本token数在200-500左右,极少情况会有3-4K的情况。因此笔者统一采用2K的文本长度,超出部分截断处理。
|
||||
|
||||
这里有一个不同于文本微调的细节要注意,文本截断长度不能小于图像token,否则会导致模型在进行特征拼接时报错(当然图像特征如果被截断了,这条训练数据也就没意义了)。因此对于显存不足64G的同学如果需要适当缩短文本长度(不建议低于1.5K),最好连同图像分辨率也缩小些。在后面的博客我们会专门增加对减少图片token占用的研究。
|
||||
|
||||
同样由于文本长度受限,且图像特征没法截断,我们也没使用“packing dataset”的方法提升模型的训练效率。
|
||||
|
||||
考虑到部分数据集存在多张图片的情况,考虑到本次训练仅采用2k的文本长度(与之对比HF在训练SmolVLM-256M版本采用的是8K的文本长度,2.2B版使用了16K的文本长度)。针对单条数据中存在多张图片的情况仅仅选用第一张。
|
||||
|
||||
**损失掩码**
|
||||
|
||||
在采用Teacher Forcing的学习方法时,文本微调中损失掩码有两种策略:
|
||||
|
||||
* 对包含“用户问题”和“模型回复”的完整文本进行微调优化
|
||||
* 仅对“模型回复”部分进行微调优化
|
||||
|
||||
这两种策略的对比如下图:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/mask.png" alt="mask" width="800" />
|
||||
<figcaption>两种微调掩码策略的差异,通常建议选择“仅微调模型回答部分”以增强泛化性</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
通常来说使用“仅微调模型回复部分”的策略模型更容易泛化(这点与HF在SmolVLM2的论文提到的trick)。然而笔者为了提高训练效率选择了完整文本微调。可以在后续博客中增加消融实验做进一步对比。
|
||||
|
||||
值得注意的是,在进行完整文本微调时,需要单独屏蔽Image Token以防止对图像占位token计算损失,影响模型表现。
|
||||
|
||||
**关键代码如下:**
|
||||
|
||||
```python
|
||||
def data_collate_fix2k(examples, processor, device, max_length=2048):
|
||||
batch_text = []
|
||||
batch_image = []
|
||||
for example in examples:
|
||||
images = example["images"][:1] # 只允许一张图,不然显存压力太大
|
||||
batch_image.append(images)
|
||||
image_num = len(images)
|
||||
chat_texts = example["texts"][0]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}] * image_num
|
||||
+ [{"type": "text", "text": chat_texts["user"]}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": chat_texts["assistant"]}],
|
||||
},
|
||||
]
|
||||
text = processor.apply_chat_template(
|
||||
messages, enable_thinking=False, add_generation_prompt=False
|
||||
)
|
||||
|
||||
batch_text.append(text)
|
||||
|
||||
batch = processor(
|
||||
text=batch_text,
|
||||
images=batch_image,
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
)
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == processor.image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
return batch.to(device, dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
### 微调超参数设置
|
||||
|
||||
**学习率**
|
||||
|
||||
由于仅仅针对特征映射层(connector)进行训练,且conntector由于要对齐Qwen3的维度因此参数为随机初始化(理论上可以采用一些独特的初始化策略提升性能,但考虑到模型较小因此笔者没关注初始化策略)。因此学习率设置为lora中较为流行的1e-4学习率策略。
|
||||
|
||||
为了保障有效收敛,学习率衰减基本是必备的trick,采用的是社区比较流行的cosine学习率衰减,衰减至0。warm up为整体步长的10%(在超过1000k step的情况下固定为50)。
|
||||
|
||||
**batch size**
|
||||
|
||||
Batch size通常来说越大越好,然而由于VLM模型的文本长度太大,因此采用每卡1 batch和4梯度累加(grad accelerate),在8卡训练中等效32 Batch size。
|
||||
|
||||
**训练参数设置代码**
|
||||
|
||||
```python
|
||||
training_args = TrainingArguments(
|
||||
seed=42,
|
||||
data_seed=42,
|
||||
max_steps=200,
|
||||
# num_train_epochs=1, # 训练1个epoch 约1k steps
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
dataloader_pin_memory=False,
|
||||
warmup_ratio=0.1,
|
||||
learning_rate=1e-4,
|
||||
lr_scheduler_type="cosine",
|
||||
weight_decay=0.01,
|
||||
logging_steps=5,
|
||||
eval_strategy="steps",
|
||||
eval_steps=0.125,
|
||||
save_strategy="steps",
|
||||
save_steps=0.125,
|
||||
save_total_limit=8,
|
||||
optim="adamw_torch",
|
||||
bf16=True,
|
||||
output_dir=f"./model/freeze_except_connector_cocovqa",
|
||||
overwrite_output_dir=False,
|
||||
report_to="swanlab",
|
||||
run_name="freeze_except_connector_cocovqa",
|
||||
remove_unused_columns=False,
|
||||
gradient_checkpointing=False,
|
||||
)
|
||||
```
|
||||
|
||||
### 训练环境
|
||||
|
||||
微调代码基于沐曦的C500国产通用计算GPU实现,显存为64G。沐曦的AI芯片基本完全兼容pytorch和huggingface transformers场景,并且在做多模态训练时相比较其他国产AI芯片罕见的没有兼容性问题。读者在尝试本项目代码时可以采用Nvidia显存40G以上的显卡运行本教程。
|
||||
|
||||
**笔者个人感觉沐曦的GPU整体适配效果还是非常好的,没遇到适配性的问题。体验上和用NV的GPU做训练没什么区别**。笔者自己也用过好几款国产GPU,沐曦的体验肯定是名列前茅的,包括代码中有指定flash attention在沐曦GPU上都能成功迁移,这点非常值得给沐曦团队点个赞。希望国产GPU生态能越发展越好,造福广大炼丹师;)。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/muxi-gpu.jpg" alt="muxi-gpu" width="400" />
|
||||
<figcaption>沐曦国产GPU,笔者用的云端服务器没见过真机,因此找了张网图</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
训练环境的话除了安装GPU对应的驱动和pytorch外,本教程需要额外安装Huggingface全家桶,如下:
|
||||
|
||||
```txt
|
||||
torch # 推荐版本>=6.0
|
||||
torchvision
|
||||
transformers>=4.53.0
|
||||
accelerate
|
||||
datasets
|
||||
num2words # SmolVLM2需要
|
||||
```
|
||||
|
||||
额外补充一句,如果采用沐曦GPU训练的话,需要在沐曦官方文档处寻找[沐曦版torch](https://developer.metax-tech.com/softnova/index)的安装方式进行下载。其他HF环境和NV基本一样。附赠一个沐曦查看GPU的命令:
|
||||
|
||||
```bash
|
||||
mx-smi
|
||||
```
|
||||
|
||||
效果如下:
|
||||
|
||||
```bash
|
||||
=================== MetaX System Management Interface Log ===================
|
||||
Timestamp : Sat Jul 12 14:58:51 2025
|
||||
|
||||
Attached GPUs : 8
|
||||
+---------------------------------------------------------------------------------+
|
||||
| MX-SMI 2.1.12 Kernel Mode Driver Version: 2.12.13 |
|
||||
| MACA Version: 2.29.0.19 BIOS Version: 1.22.3.0 |
|
||||
|------------------------------------+---------------------+----------------------+
|
||||
| GPU NAME | Bus-id | GPU-Util |
|
||||
| Temp Pwr:Usage/Cap | Memory-Usage | |
|
||||
|====================================+=====================+======================|
|
||||
| 0 MetaX C500 | 0000:0e:00.0 | 0% |
|
||||
| 36C 69W / 350W | 5680/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 1 MetaX C500 | 0000:0f:00.0 | 0% |
|
||||
| 38C 70W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 2 MetaX C500 | 0000:10:00.0 | 0% |
|
||||
| 37C 69W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 3 MetaX C500 | 0000:12:00.0 | 1% |
|
||||
| 37C 71W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 4 MetaX C500 | 0000:35:00.0 | 0% |
|
||||
| 37C 70W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 5 MetaX C500 | 0000:36:00.0 | 1% |
|
||||
| 36C 68W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 6 MetaX C500 | 0000:37:00.0 | 0% |
|
||||
| 39C 73W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
| 7 MetaX C500 | 0000:38:00.0 | 0% |
|
||||
| 38C 71W / 350W | 4986/65536 MiB | |
|
||||
+------------------------------------+---------------------+----------------------+
|
||||
|
||||
+---------------------------------------------------------------------------------+
|
||||
| Process: |
|
||||
| GPU PID Process Name GPU Memory |
|
||||
| Usage(MiB) |
|
||||
|=================================================================================|
|
||||
| 0 3496691 python3.10 4066 |
|
||||
| 0 3496692 python3.10 102 |
|
||||
| 0 3496693 python3.10 102 |
|
||||
| 0 3496694 python3.10 102 |
|
||||
| 0 3496695 python3.10 102 |
|
||||
| 0 3496696 python3.10 102 |
|
||||
| 0 3496697 python3.10 102 |
|
||||
| 0 3496698 python3.10 170 |
|
||||
| 1 3496692 python3.10 4154 |
|
||||
| 2 3496693 python3.10 4154 |
|
||||
| 3 3496694 python3.10 4154 |
|
||||
| 4 3496695 python3.10 4154 |
|
||||
| 5 3496696 python3.10 4154 |
|
||||
| 6 3496697 python3.10 4154 |
|
||||
| 7 3496698 python3.10 4154 |
|
||||
+---------------------------------------------------------------------------------+
|
||||
```
|
||||
|
||||
### 训练代码实现
|
||||
|
||||
在构建训练代码时,笔者使用HuggingFace Transfomers框架的Trainer类来完成训练代码。Trainer类实现的训练逻辑基本能完成大部分微调任务。这里唯一需要提到的是笔者使用了Qwen3-0.6B而非通常此类任务该使用的Qwen3-0.6B-Base模型,Qwen3-0.6B相比于Qwen3-0.6B-Base模型经过了指令遵从微调、对齐等,能实现聊天问答功能。
|
||||
|
||||
通常来说对经过微调的模型进行持续训练会一定程度带来性能损失,然而此次微调时笔者冻结了LLM参数,因此需要选用经过微调的模型来实现多模态问答能力。
|
||||
|
||||
笔者在训练过程中使用的是bfloat16精度,相比于float16来说bfloat16增加了尾数位数,训练过程中精度会更高些。
|
||||
|
||||
在前期进行方案验证阶段笔者采用的是cocoqa数据集,并且进行200steps的微调训练。在确定方案可行后笔者计划使用完整数据集进行微调训练,然而考虑到训练数据量仅仅只有整个模型的12M,因此笔者按参数量与训练Token的比值为1:10采样数据集,即总共从数据集中采样出60K条数据用于实际训练(文本长度按照2k计算,实际上有padding部分因此实际参与token数小于120M)。笔者认为参与训练的数量是足以令模型收敛的,后续实验也证明了模型确实能达到我们所期望的效果。
|
||||
|
||||
**训练关键代码实现**
|
||||
|
||||
代码比较长是因为增加了断点续训的能力
|
||||
|
||||
```python
|
||||
################
|
||||
# 开启训练
|
||||
################
|
||||
last_checkpoint = None # load last checkpoint if available
|
||||
if (
|
||||
os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists"
|
||||
)
|
||||
print(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}."
|
||||
)
|
||||
# Init Trainer
|
||||
trainer = Trainer(
|
||||
model=qwen_smvl,
|
||||
args=training_args,
|
||||
train_dataset=raw_data["train"],
|
||||
eval_dataset=raw_data["test"],
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
trainer.train(resume_from_checkpoint=last_checkpoint)
|
||||
qwen_smvl.save_pretrained(training_args.output_dir)
|
||||
```
|
||||
|
||||
完整代码见[代码及数据集链接汇总](#代码及数据集链接汇总)
|
||||
|
||||
或者直接由[完整项目GitHub地址]()
|
||||
|
||||
## 微调训练&结果展示
|
||||
|
||||
### 环境安装与微调代码执行
|
||||
|
||||
**代码准备与环境安装**
|
||||
|
||||
可以在[GitHub仓库地址](https://github.com/ShaohonChen/Qwen3-SmVL)处找到实验的完整代码。使用git clone后使用如下命令安装环境
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**数据集和模型下载**
|
||||
|
||||
笔者附上自动下载脚本,注意该脚本使用[魔塔社区](https://modelscope.cn/)完成模型与数据集的下载
|
||||
|
||||
```bash
|
||||
bash download_resource.sh
|
||||
```
|
||||
|
||||
### 小批量微调训练
|
||||
|
||||
为了进行快速验证,笔者首先使用cocoqa数据集并且进行了200steps的训练,所有参数与前文所述一致。通过
|
||||
|
||||
运行实验命令如下,推荐使用8卡进行训练,在8张沐曦GPU卡上预计需要使用20min
|
||||
|
||||
```bash
|
||||
# 单GPU训练
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py ./cocoqa_train.yaml
|
||||
# 8GPU训练
|
||||
accelerate launch --num_process 8 train.py ./cocoqa_train.yaml
|
||||
```
|
||||
|
||||
注意,本项目使用SwanLab进行训练日志记录与分析,如果未登陆SwanLab需要使用`swanlab login`进行登陆。运行后看到如下结果即代表实验成功开启:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/run.png" alt="run" width="800" />
|
||||
<figcaption>成功训练后可以看到SwanLab链接</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
下面是笔者完成小批量微调训练的训练损失、测试损失结果图
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/cocoqa_swanlab.png" alt="cocoqa_swanlab" width="800" />
|
||||
<figcaption>SwanLab训练可视化分析结果,可以看到最后训练损失和测试损失都收敛在0.65左右</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
模型在完成训练后会自动使用一张狗狗图片配合问题“图中有什么动物?”让模型根据图片进行推理,推理结果如下:
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/bad_case.png" alt="bad_case" width="800" />
|
||||
<figcaption>SwanLab记录了模型训练好后的推理结果,可以看到模型能正常理解和回复中文</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
当时看到模型对着三只狗的图片回答“兔子”时笔者一时认为炼丹失败了,当然如果实际炼丹失败后模型是不会输出动物类型的,而是输出一些乱码或者告诉用户并没有看到图片。识别错误的原因实际上是由于训练步数过少导致的。后续加大训练步数与数据量后模型能正常识别出狗狗并且能准确的说出有三只狗。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/dog.png" alt="dog" width="250" />
|
||||
<figcaption>附上三只眼神忧伤的狗子,难道长得很像兔子吗?</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
PS: 作者公开了在[SwanLab上的训练结果](https://swanlab.cn/@ShaohonChen/Qwen3-SmVL/overview),感兴趣的读者可以自己查看,SwanLab也支持Clone作者的训练日志,大家可以在自己训练时clone笔者的项目去做对照。
|
||||
|
||||
### 完整微调训练结果展示
|
||||
|
||||
运行实验命令如下,推荐使用8卡进行训练,在8片沐曦C500芯片上预计需要使用1.5h
|
||||
|
||||
```bash
|
||||
# 单GPU训练
|
||||
CUDA_VISIBLE_DEVICES=0 python train.py ./full_train.yaml
|
||||
# 8GPU训练
|
||||
accelerate launch --num_processes 8 train.py ./full_train.yaml
|
||||
```
|
||||
|
||||
下图展示了使用完整微调数据对比于小批量训练,可以看到全量数据微调时loss变得更为抖动,这是由于数据类型的丰富给模型的学习带来了一定的挑战。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/fulldata_swanlab.png" alt="fulldata_swanlab" width="800" />
|
||||
<figcaption>红色为完整训练loss,黄色为小批量训练结果</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
进一步对比完整训练和小批量训练的训练和测试损失,可以看到完整训练的模型训练损失达到了0.61,远低于仅仅使用cocoqa模型的效果,评估损失也远低于前者,维持在0.58左右。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/evalloss.png" alt="evalloss" width="800" />
|
||||
<figcaption>红色为完整训练loss,黄色为小批量训练结果</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
这里值得一提的是,由于我们选用的测试集比较小(仅有64条数据),因此训练损失和测试损失的差距并不能直接理解为过拟合的证据。实际上在大模型训练上,如果数据集足够大的情况下,通常可以认为训练损失等同于评估损失。
|
||||
|
||||
此外,模型通过分析1k步之后的训练损失、平均梯度范数(Grad Norm)变化。此时训练任务已过半,且学习率开始快速衰减。如下图,可以看到学习率快速衰减的情况下模型损失并没有明显的进一步下降,这说明模型已经实现了充分训练。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/1kstep.png" alt="1kstep" width="800" />
|
||||
<figcaption>1k step之后模型的训练损失变化</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
在训练效率方面,可以看到我们仍没有充分榨干沐曦GPU的性能,当然这也是由于多模态任务的网络本身架构上比较复杂,其中包含许多对图像、文本的拼接工作,这也导致了GPU性能没法完全利用。
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/mx-gpu-use.png" alt="mx-gpu-use" width="800" />
|
||||
<figcaption>SwanLab对沐曦C500训效率自动记录</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
同样在完成训练后使用狗狗图进行了测试,这次模型能理解图片、中文以及给出正确的回复。更为关键的是模型完全保留了Qwen3-0.6B原有的全部能力,包括函数调用、推理等。在此基础上,仅仅增加了0.09B参数量的情况下为模型带来了图像理解能力!
|
||||
|
||||
<div align="center">
|
||||
<figure>
|
||||
<img src="./images/good_case.png" alt="good_case" width="800" />
|
||||
<figcaption>同样的图片与问题,更大的数据量和更充足的数据使得模型能够正确给出回复</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
### 模型推理与效果分析
|
||||
|
||||
等笔者下完数据集后未来补一下测试环节 ; )
|
||||
|
||||
可以关注[swanlab教程集合](https://docs.swanlab.cn/examples/qwen3_smolvlm_muxi.html)获取最新更新教程!
|
||||
|
||||
## 代码及数据集链接汇总
|
||||
|
||||
微调用The Cauldron数据集下载链接:
|
||||
|
||||
* HuggingFace Hub: [https://huggingface.co/datasets/HuggingFaceM4/the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron)
|
||||
* ModelScope: [https://modelscope.cn/datasets/AI-ModelScope/the_cauldron](https://modelscope.cn/datasets/AI-ModelScope/the_cauldron)
|
||||
|
||||
Qwen3-0.6B模型下载:
|
||||
|
||||
* HuggingFace Hub: [https://huggingface.co/Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)
|
||||
* ModelScope: [https://modelscope.cn/Qwen/Qwen3-0.6B](https://modelscope.cn/Qwen/Qwen3-0.6B)
|
||||
|
||||
本实验完整代码GitHub链接:
|
||||
|
||||
* 完整项目GitHub地址:[https://github.com/ShaohonChen/Qwen3-SmVL](https://github.com/ShaohonChen/Qwen3-SmVL)
|
||||
|
||||
本实验SwanLab日志:
|
||||
|
||||
* SwanLab训练过程查看:[https://swanlab.cn/@ShaohonChen/Qwen3-SmVL/overview](https://swanlab.cn/@ShaohonChen/Qwen3-SmVL/overview)
|
||||
|
||||
## 参考资料
|
||||
|
||||
* Huggingface SmolVLM2技术报告:[https://arxiv.org/pdf/2504.05299](https://arxiv.org/pdf/2504.05299)
|
||||
BIN
Extra-Chapter/vlm-concatenation-finetune/images/1kstep.png
Normal file
|
After Width: | Height: | Size: 273 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/PPAP.png
Normal file
|
After Width: | Height: | Size: 438 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/bad_case.png
Normal file
|
After Width: | Height: | Size: 458 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/change_model.png
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/cocoqa.png
Normal file
|
After Width: | Height: | Size: 707 KiB |
|
After Width: | Height: | Size: 532 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/concatation.png
Normal file
|
After Width: | Height: | Size: 269 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/data_show.png
Normal file
|
After Width: | Height: | Size: 836 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/dog.png
Normal file
|
After Width: | Height: | Size: 836 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/evalloss.png
Normal file
|
After Width: | Height: | Size: 121 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/fail_train.png
Normal file
|
After Width: | Height: | Size: 136 KiB |
|
After Width: | Height: | Size: 212 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/good_case.png
Normal file
|
After Width: | Height: | Size: 368 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/image-split.png
Normal file
|
After Width: | Height: | Size: 293 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/mask.png
Normal file
|
After Width: | Height: | Size: 49 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/muxi-gpu.jpg
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/mx-gpu-use.png
Normal file
|
After Width: | Height: | Size: 718 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/run.png
Normal file
|
After Width: | Height: | Size: 764 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/smolvlm2.png
Normal file
|
After Width: | Height: | Size: 196 KiB |
BIN
Extra-Chapter/vlm-concatenation-finetune/images/the_cauldron.png
Normal file
|
After Width: | Height: | Size: 321 KiB |
21
README.md
@@ -67,6 +67,16 @@
|
||||
|
||||
- [文本数据处理详解](./Extra-Chapter/text-data-processing/readme.md) @[蔡鋆捷](https://github.com/xinala-781) 2025-7-14
|
||||
|
||||
- [Qwen3-"VL"——超小中文多模态模型的“拼接微调”之路](./Extra-Chapter/vlm-concatenation-finetune/README.md) @[ShaohonChen](https://github.com/ShaohonChen) 2025-7-30
|
||||
|
||||
- [S1: Thinking Budget with vLLM](./Extra-Chapter/s1-vllm-thinking-budget/readme.md) @[不要葱姜蒜](https://github.com/kmno4-zx) 2025-8-03
|
||||
|
||||
|
||||
- [CDDRS: 使用细粒度语义信息指导增强的RAG检索方法](./Extra-Chapter/CDDRS/readme.md) @[Hongru0306](https://github.com/Hongru0306) 2025-8-21
|
||||
|
||||
- [大模型生成 Token 的方式有哪些?](./Extra-Chapter/generation-method/readme.md) @[不要葱姜蒜](https://github.com/kmno4-zx) 2025-10-17
|
||||
|
||||
|
||||
>   *如果大家在学习 Happy-LLM 项目或 LLM 相关知识中有自己独到的见解、认知、实践,欢迎大家 PR 在 [Extra Chapter LLM Blog](./Extra-Chapter/) 中。请遵守 Extra Chapter LLM Blog 的 [PR 规范](./Extra-Chapter/Readme.md),我们会视 PR 内容的质量和价值来决定是否合并或补充到 Happy-LLM 正文中来。*
|
||||
|
||||
### 模型下载
|
||||
@@ -83,8 +93,7 @@
|
||||
|
||||
  ***本 Happy-LLM PDF 教程完全开源免费。为防止各类营销号加水印后贩卖给大模型初学者,我们特地在 PDF 文件中预先添加了不影响阅读的 Datawhale 开源标志水印,敬请谅解~***
|
||||
|
||||
> *Happy-LLM PDF : https://github.com/datawhalechina/happy-llm/releases/tag/PDF*
|
||||
> *Happy-LLM PDF 国内下载地址 : https://www.datawhale.cn/learn/summary/179*
|
||||
> *Happy-LLM PDF : https://github.com/datawhalechina/happy-llm/releases/tag/v1.0.1*
|
||||
|
||||
## 💡 如何学习
|
||||
|
||||
@@ -96,6 +105,8 @@
|
||||
|
||||
  最后,欢迎每一位读者在学习完本项目后加入到 LLM 开发者的行列。作为国内 AI 开源社区,我们希望充分聚集共创者,一起丰富这个开源 LLM 的世界,打造更多、更全面特色 LLM 的教程。星火点点,汇聚成海。我们希望成为 LLM 与普罗大众的阶梯,以自由、平等的开源精神,拥抱更恢弘而辽阔的 LLM 世界。
|
||||
|
||||
> - 中国计算机学会(CCF) × Datawhale × GitLink开源平台联合推出AI普惠课程,免费算力报名参加 [【报名地址】](https://mp.weixin.qq.com/s/P03f3e2vUUh7OxDP40Ra6w)[【GitLink 地址】](https://gitlink.org.cn/datawhalechina/happy-llm)
|
||||
|
||||
## 🤝 如何贡献
|
||||
|
||||
我们欢迎任何形式的贡献!
|
||||
@@ -108,7 +119,7 @@
|
||||
## 🙏 致谢
|
||||
|
||||
### 核心贡献者
|
||||
- [宋志学-项目负责人](https://github.com/KMnO4-zx) (Datawhale成员-中国矿业大学(北京))
|
||||
- [宋志学-项目负责人](https://github.com/KMnO4-zx) (Datawhale成员)
|
||||
- [邹雨衡-项目负责人](https://github.com/logan-zou) (Datawhale成员-对外经济贸易大学)
|
||||
- [朱信忠-指导专家](https://xinzhongzhu.github.io/)(Datawhale首席科学家-浙江师范大学杭州人工智能研究院教授)
|
||||
|
||||
@@ -116,6 +127,8 @@
|
||||
|
||||
- [ditingdapeng](https://github.com/ditingdapeng)(内容贡献者-云原生基础架构工程师)
|
||||
- [蔡鋆捷](https://github.com/xinala-781)(内容贡献者-福州大学)
|
||||
- [ShaohonChen](https://github.com/ShaohonChen) (情感机器实验室研究员-西安电子科技大学在读硕士)
|
||||
- [肖鸿儒, 庄健琨](https://github.com/Hongru0306) (内容贡献者-同济大学)
|
||||
|
||||
### 特别感谢
|
||||
- 感谢 [@Sm1les](https://github.com/Sm1les) 对本项目的帮助与支持
|
||||
@@ -130,7 +143,7 @@
|
||||
## Star History
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/star-history-2025710.png" alt="Datawhale" width="90%">
|
||||
<img src="./images/star-history-20251017.png" alt="Datawhale" width="90%">
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
23
README_en.md
@@ -52,6 +52,25 @@
|
||||
| [Chapter 5: Building Large Models from Scratch](./docs/chapter5/第五章%20动手搭建大模型.md) | Implementing LLaMA2, training Tokenizer, pre-training small LLM | ✅ |
|
||||
| [Chapter 6: Large Model Training Practice](./docs/chapter6/第六章%20大模型训练流程实践.md) | Pre-training, supervised fine-tuning, LoRA/QLoRA efficient fine-tuning | 🚧 |
|
||||
| [Chapter 7: Large Model Applications](./docs/chapter7/第七章%20大模型应用.md) | Model evaluation, RAG retrieval enhancement, Agent intelligent agents | ✅ |
|
||||
| [Extra Chapter LLM Blog](./Extra-Chapter/) | Excellent Learning Notes/Blog on LLMs ,Welcome PR !| 🚧 |
|
||||
|
||||
### Extra Chapter LLM Blog
|
||||
|
||||
- [With large models becoming so powerful, what’s the significance of fine-tuning a 0.6B small model?](./Extra-Chapter/why-fine-tune-small-large-language-models/readme.md) @[不要葱姜蒜](https://github.com/KMnO4-zx) 2025-7-11
|
||||
|
||||
- [Details of the Transformer modules](./Extra-Chapter/transformer-architecture/) @[ditingdapeng](https://github.com/ditingdapeng) 2025-7-14
|
||||
|
||||
- [Detailed Explanation of Text Data Processing](./Extra-Chapter/text-data-processing/readme.md) @[蔡鋆捷](https://github.com/xinala-781) 2025-7-14
|
||||
|
||||
- [Qwen3-"VL"——Path to 'Concatenation Fine-tuning' for Ultra-small Chinese Multimodal Models](./Extra-Chapter/vlm-concatenation-finetune/README.md) @[ShaohonChen](https://github.com/ShaohonChen) 2025-7-30
|
||||
|
||||
- [S1: Thinking Budget with vLLM](./Extra-Chapter/s1-vllm-thinking-budget/readme.md) @[kmno4-zx](https://github.com/kmno4-zx) 2025-8-03
|
||||
|
||||
- [CDDRS: Key elements guided Enhancement for RAG-based Retrieval Methods](./Extra-Chapter/CDDRS/readme.md) @[Hongru0306](https://github.com/Hongru0306) 2025-8-21
|
||||
|
||||
|
||||
>   * If anyone has unique insights, knowledge, or practices related to the Happy-LLM project or LLMs in general, you are welcome to submit a PR to the [Extra Chapter LLM Blog](./Extra-Chapter/). Please adhere to the [PR Guidances](./Extra-Chapter/Readme.md). We will decide whether to merge or supplement the content into the main Happy-LLM text based on the quality and value of the PR.*
|
||||
|
||||
|
||||
### Model Downloads
|
||||
|
||||
@@ -92,7 +111,7 @@ We welcome any form of contribution!
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
### Core Contributors
|
||||
- [Song Zhixue - Project Leader](https://github.com/KMnO4-zx) (Datawhale Member - China University of Mining and Technology, Beijing)
|
||||
- [Song Zhixue - Project Leader](https://github.com/KMnO4-zx) (Datawhale Member)
|
||||
- [Zou Yuheng - Project Leader](https://github.com/logan-zou) (Datawhale Member - University of International Business and Economics)
|
||||
- [Zhu Xinzhong - Expert Advisor](https://xinzhongzhu.github.io/) (Datawhale Chief Scientist - Professor at Hangzhou Institute for Advanced Study, Zhejiang Normal University)
|
||||
|
||||
@@ -109,7 +128,7 @@ We welcome any form of contribution!
|
||||
## Star History
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/star-history-2025710.png" alt="Datawhale" width="90%">
|
||||
<img src="./images/star-history-20251017.png" alt="Datawhale" width="90%">
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
@@ -106,7 +106,7 @@
|
||||
## Star History
|
||||
|
||||
<div align='center'>
|
||||
<img src="./images/star-history-2025710.png" alt="Datawhale" width="90%">
|
||||
<img src="./images/star-history-20251017.png" alt="Datawhale" width="90%">
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
@@ -26,21 +26,18 @@ class MultiHeadAttention(nn.Module):
|
||||
super().__init__()
|
||||
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
||||
assert args.dim % args.n_heads == 0
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
# 每个头的维度,等于模型维度除以头的总数。
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.n_heads = args.n_heads
|
||||
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x dim
|
||||
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,
|
||||
# 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接
|
||||
self.wq = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.n_embd, args.n_local_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads)
|
||||
self.wo = nn.Linear(args.n_local_heads * self.head_dim, args.dim, bias=False)
|
||||
self.wq = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x dim(head_dim = dim / n_heads)
|
||||
self.wo = nn.Linear(self.n_heads * self.head_dim, args.dim, bias=False)
|
||||
# 注意力的 dropout
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
# 残差连接的 dropout
|
||||
@@ -60,16 +57,16 @@ class MultiHeadAttention(nn.Module):
|
||||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||||
bsz, seqlen, _ = q.shape
|
||||
|
||||
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, n_embed) -> (B, T, n_embed)
|
||||
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, dim) -> (B, T, dim)
|
||||
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)
|
||||
|
||||
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, C // n_head),然后交换维度,变成 (B, n_head, T, C // n_head)
|
||||
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, dim // n_head),然后交换维度,变成 (B, n_head, T, dim // n_head)
|
||||
# 因为在注意力计算中我们是取了后两个维度参与计算
|
||||
# 为什么要先按B*T*n_head*C//n_head展开再互换1、2维度而不是直接按注意力输入展开,是因为view的展开方式是直接把输入全部排开,
|
||||
# 然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
@@ -90,7 +87,7 @@ class MultiHeadAttention(nn.Module):
|
||||
output = torch.matmul(scores, xv)
|
||||
|
||||
# 恢复时间维度并合并头。
|
||||
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, C // n_head),再拼接成 (B, T, n_head * C // n_head)
|
||||
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, dim // n_head),再拼接成 (B, T, n_head * dim // n_head)
|
||||
# contiguous 函数用于重新开辟一块新内存存储,因为Pytorch设置先transpose再view会报错,
|
||||
# 因为view直接基于底层存储得到,然而transpose并不会改变底层存储,因此需要额外存储
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
|
||||
@@ -253,54 +253,51 @@ class MultiHeadAttention(nn.Module):
|
||||
super().__init__()
|
||||
# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵
|
||||
assert args.dim % args.n_heads == 0
|
||||
# 模型并行处理大小,默认为1。
|
||||
model_parallel_size = 1
|
||||
# 本地计算头数,等于总头数除以模型并行处理大小。
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
# 每个头的维度,等于模型维度除以头的总数。
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.n_heads = args.n_heads
|
||||
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd
|
||||
# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x dim
|
||||
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,
|
||||
# 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接
|
||||
self.wq = nn.Linear(args.dim, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.dim, args.n_local_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.dim, args.n_local_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x n_embd(head_dim = n_embeds / n_heads)
|
||||
self.wo = nn.Linear(args.n_local_heads * self.head_dim, args.dim, bias=False)
|
||||
self.wq = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(args.n_embd, self.n_heads * self.head_dim, bias=False)
|
||||
# 输出权重矩阵,维度为 dim x dim(head_dim = dim / n_heads)
|
||||
self.wo = nn.Linear(self.n_heads * self.head_dim, args.dim, bias=False)
|
||||
# 注意力的 dropout
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
# 残差连接的 dropout
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
|
||||
self.is_causal = is_causal
|
||||
|
||||
# 创建一个上三角矩阵,用于遮蔽未来信息
|
||||
# 注意,因为是多头注意力,Mask 矩阵比之前我们定义的多一个维度
|
||||
if is_causal:
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
# 注册为模型的缓冲区
|
||||
self.register_buffer("mask", mask)
|
||||
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
# 注册为模型的缓冲区
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
|
||||
# 获取批次大小和序列长度,[batch_size, seq_len, dim]
|
||||
bsz, seqlen, _ = q.shape
|
||||
|
||||
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, n_embed) -> (B, T, n_embed)
|
||||
# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, dim) -> (B, T, dim)
|
||||
xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)
|
||||
|
||||
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, C // n_head),然后交换维度,变成 (B, n_head, T, C // n_head)
|
||||
# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, dim // n_head),然后交换维度,变成 (B, n_head, T, dim // n_head)
|
||||
# 因为在注意力计算中我们是取了后两个维度参与计算
|
||||
# 为什么要先按B*T*n_head*C//n_head展开再互换1、2维度而不是直接按注意力输入展开,是因为view的展开方式是直接把输入全部排开,
|
||||
# 然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xq = xq.transpose(1, 2)
|
||||
xk = xk.transpose(1, 2)
|
||||
xv = xv.transpose(1, 2)
|
||||
|
||||
|
||||
# 注意力计算
|
||||
# 计算 QK^T / sqrt(d_k),维度为 (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
@@ -317,7 +314,7 @@ class MultiHeadAttention(nn.Module):
|
||||
output = torch.matmul(scores, xv)
|
||||
|
||||
# 恢复时间维度并合并头。
|
||||
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, C // n_head),再拼接成 (B, T, n_head * C // n_head)
|
||||
# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, dim // n_head),再拼接成 (B, T, n_head * dim // n_head)
|
||||
# contiguous 函数用于重新开辟一块新内存存储,因为Pytorch设置先transpose再view会报错,
|
||||
# 因为view直接基于底层存储得到,然而transpose并不会改变底层存储,因此需要额外存储
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
@@ -326,7 +323,6 @@ class MultiHeadAttention(nn.Module):
|
||||
output = self.wo(output)
|
||||
output = self.resid_dropout(output)
|
||||
return output
|
||||
|
||||
```
|
||||
|
||||
## 2.2 Encoder-Decoder
|
||||
@@ -356,7 +352,7 @@ Transformer 由 Encoder 和 Decoder 组成,每一个 Encoder(Decoder)又
|
||||
|
||||
### 2.2.2 前馈神经网络
|
||||
|
||||
前馈神经网络(Feed Forward Neural Network,下简称 FFN),也就是我们在上一节提过的每一层的神经元都和上下两层的每一个神经元完全连接的网络结构。每一个 Encoder Layer 都包含一个上文讲的注意力机制和一个前馈神经网络。前馈神经网络的实现是较为简单的:
|
||||
前馈神经网络(Feed Forward Neural Network,下简称 FNN),也就是我们在上一节提过的每一层的神经元都和上下两层的每一个神经元完全连接的网络结构。每一个 Encoder Layer 都包含一个上文讲的注意力机制和一个前馈神经网络。前馈神经网络的实现是较为简单的:
|
||||
|
||||
```python
|
||||
class MLP(nn.Module):
|
||||
@@ -392,7 +388,7 @@ $$
|
||||
\mu_j = \frac{1}{m}\sum^{m}_{i=1}Z_j^{i}
|
||||
$$
|
||||
|
||||
其中,$Z_j^{i}$ 是样本 i 在第 j 个维度上的值,m 就是 mini-batch 的大小。
|
||||
其中, $Z_j^{i}$ 是样本 i 在第 j 个维度上的值,m 就是 mini-batch 的大小。
|
||||
|
||||
再计算样本的方差:
|
||||
|
||||
@@ -478,7 +474,7 @@ class EncoderLayer(nn.Module):
|
||||
# Encoder 不需要掩码,传入 is_causal=False
|
||||
self.attention = MultiHeadAttention(args, is_causal=False)
|
||||
self.fnn_norm = LayerNorm(args.n_embd)
|
||||
self.feed_forward = MLP(args)
|
||||
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
# Layer Norm
|
||||
@@ -528,7 +524,7 @@ class DecoderLayer(nn.Module):
|
||||
self.attention = MultiHeadAttention(args, is_causal=False)
|
||||
self.ffn_norm = LayerNorm(args.n_embd)
|
||||
# 第三个部分是 MLP
|
||||
self.feed_forward = MLP(args)
|
||||
self.feed_forward = MLP(args.dim, args.dim, args.dropout)
|
||||
|
||||
def forward(self, x, enc_out):
|
||||
# Layer Norm
|
||||
@@ -620,13 +616,33 @@ $$
|
||||
我们以一个简单的例子来说明位置编码的计算过程:假如我们输入的是一个长度为 4 的句子"I like to code",我们可以得到下面的词向量矩阵 $\rm x$ ,其中每一行代表的就是一个词向量, $\rm x_0=[0.1,0.2,0.3,0.4]$ 对应的就是“I”的词向量,它的pos就是为0,以此类推,第二行代表的是“like”的词向量,它的pos就是1:
|
||||
|
||||
$$
|
||||
\rm x = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.2 & 0.3 & 0.4 & 0.5 \\ 0.3 & 0.4 & 0.5 & 0.6 \\ 0.4 & 0.5 & 0.6 & 0.7 \end{bmatrix}
|
||||
\rm x = \begin{bmatrix}
|
||||
0.1 & 0.2 & 0.3 & 0.4 \\
|
||||
0.2 & 0.3 & 0.4 & 0.5 \\
|
||||
0.3 & 0.4 & 0.5 & 0.6 \\
|
||||
0.4 & 0.5 & 0.6 & 0.7
|
||||
\end{bmatrix}
|
||||
$$
|
||||
|
||||
则经过位置编码后的词向量为:
|
||||
|
||||
$$
|
||||
\rm x_{PE} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.2 & 0.3 & 0.4 & 0.5 \\ 0.3 & 0.4 & 0.5 & 0.6 \\ 0.4 & 0.5 & 0.6 & 0.7 \end{bmatrix} + \begin{bmatrix} \sin(\frac{0}{10000^0}) & \cos(\frac{0}{10000^0}) & \sin(\frac{0}{10000^{2/4}}) & \cos(\frac{0}{10000^{2/4}}) \\ \sin(\frac{1}{10000^0}) & \cos(\frac{1}{10000^0}) & \sin(\frac{1}{10000^{2/4}}) & \cos(\frac{1}{10000^{2/4}}) \\ \sin(\frac{2}{10000^0}) & \cos(\frac{2}{10000^0}) & \sin(\frac{2}{10000^{2/4}}) & \cos(\frac{2}{10000^{2/4}}) \\ \sin(\frac{3}{10000^0}) & \cos(\frac{3}{10000^0}) & \sin(\frac{3}{10000^{2/4}}) & \cos(\frac{3}{10000^{2/4}}) \end{bmatrix} = \begin{bmatrix} 0.1 & 1.2 & 0.3 & 1.4 \\ 1.041 & 0.84 & 0.41 & 1.49 \\ 1.209 & -0.016 & 0.52 & 1.59 \\ 0.541 & -0.489 & 0.895 & 1.655 \end{bmatrix}
|
||||
\rm x_{PE} = \begin{bmatrix}
|
||||
0.1 & 0.2 & 0.3 & 0.4 \\
|
||||
0.2 & 0.3 & 0.4 & 0.5 \\
|
||||
0.3 & 0.4 & 0.5 & 0.6 \\
|
||||
0.4 & 0.5 & 0.6 & 0.7
|
||||
\end{bmatrix} + \begin{bmatrix}
|
||||
\sin(\frac{0}{10000^0}) & \cos(\frac{0}{10000^0}) & \sin(\frac{0}{10000^{2/4}}) & \cos(\frac{0}{10000^{2/4}}) \\
|
||||
\sin(\frac{1}{10000^0}) & \cos(\frac{1}{10000^0}) & \sin(\frac{1}{10000^{2/4}}) & \cos(\frac{1}{10000^{2/4}}) \\
|
||||
\sin(\frac{2}{10000^0}) & \cos(\frac{2}{10000^0}) & \sin(\frac{2}{10000^{2/4}}) & \cos(\frac{2}{10000^{2/4}}) \\
|
||||
\sin(\frac{3}{10000^0}) & \cos(\frac{3}{10000^0}) & \sin(\frac{3}{10000^{2/4}}) & \cos(\frac{3}{10000^{2/4}})
|
||||
\end{bmatrix} = \begin{bmatrix}
|
||||
0.1 & 1.2 & 0.3 & 1.4 \\
|
||||
1.041 & 0.84 & 0.41 & 1.49 \\
|
||||
1.209 & -0.016 & 0.52 & 1.59 \\
|
||||
0.541 & -0.489 & 0.895 & 1.655
|
||||
\end{bmatrix}
|
||||
$$
|
||||
|
||||
我们可以使用如下的代码来获取上述例子的位置编码:
|
||||
|
||||
@@ -39,7 +39,7 @@ BERT 是针对于 NLU 任务打造的预训练模型,其输入一般是文本
|
||||
<p>图3.2 BERT 模型简略结构</p>
|
||||
</div>
|
||||
|
||||
输入的文本序列会首先通过 tokenizer(分词器) 转化成 input_ids(基本每一个模型在 tokenizer 的操作都类似,可以参考 Transformer 的 tokenizer 机制,后文不再赘述),然后进入 Embedding 层转化为特定维度的 hidden_states,再经过 Encoder 块。Encoder 块中是对叠起来的 N 层 Encoder Layer,BERT 有两种规模的模型,分别是 base 版本(12层 Encoder Layer,768 的隐藏层维度,总参数量 110M),large 版本(24层 Encoder Layer,1024 的隐藏层维度,总参数量 340M)。通过Encoder 编码之后的最顶层 hidden_states 最后经过 prediction_heads 就得到了最后的类别概率,经过 Softmax 计算就可以计算出模型预测的类别。
|
||||
输入的文本序列会首先通过 tokenizer(分词器) 转化成 input_ids(基本每一个模型在 tokenizer 的操作都类似,可以参考 Transformer 的 tokenizer 机制,后文不再赘述),然后进入 Embedding 层转化为特定维度的 hidden_states,再经过 Encoder 块。Encoder 块中是堆叠起来的 N 层 Encoder Layer,BERT 有两种规模的模型,分别是 base 版本(12层 Encoder Layer,768 的隐藏层维度,总参数量 110M),large 版本(24层 Encoder Layer,1024 的隐藏层维度,总参数量 340M)。通过Encoder 编码之后的最顶层 hidden_states 最后经过 prediction_heads 就得到了最后的类别概率,经过 Softmax 计算就可以计算出模型预测的类别。
|
||||
|
||||
|
||||
> BERT 采用 WordPiece 作为分词方法。WordPiece 是一种基于统计的子词切分算法,其核心在于将单词拆解为子词(例如,"playing" -> ["play", "##ing"])。其合并操作的依据是最大化语言模型的似然度。对于中文等非空格分隔的语言,通常将单个汉字作为原子分词单位(token)处理。
|
||||
@@ -80,6 +80,8 @@ BERT 的 注意力机制和 Transformer 中 Encoder 的 自注意力机制几乎
|
||||
|
||||
如图,BERT 的注意力计算过程和 Transformer 的唯一差异在于,在完成注意力分数的计算之后,先通过 Position Embedding 层来融入相对位置信息。这里的 Position Embedding 层,其实就是一层线性矩阵。通过可训练的参数来拟合相对位置,相对而言比 Transformer 使用的绝对位置编码 Sinusoidal 能够拟合更丰富的相对位置信息,但是,这样也增加了不少模型参数,同时完全无法处理超过模型训练长度的输入(例如,对 BERT 而言能处理的最大上下文长度是 512 个 token)。
|
||||
|
||||
注:原始 BERT(即论文提出)使用和 Transformer 一致的绝对位置编码,后续改进(包括 BERT 的各种变体)使用了上述相对位置编码,为帮助读者了解更全面的模型结构设计,此处选择了改进版 BERT。
|
||||
|
||||
可以看出,BERT 的模型架构既是建立在 Transformer 的 Encoder 之上的,这也是为什么说 BERT 沿承了 Transformer 的思想。
|
||||
|
||||
#### (3)预训练任务——MLM + NSP
|
||||
|
||||
@@ -335,20 +335,20 @@ class Transformer(PreTrainedModel):
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
- tokens: Optional[torch.Tensor], 输入 token 张量。
|
||||
- targets: Optional[torch.Tensor], 目标 token 张量。
|
||||
- kv_cache: bool, 是否使用键值缓存。
|
||||
- keyargs: 其他关键字参数。
|
||||
- kwargs: 其他关键字参数。
|
||||
|
||||
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
|
||||
"""
|
||||
|
||||
if 'input_ids' in keyargs:
|
||||
tokens = keyargs['input_ids']
|
||||
if 'attention_mask' in keyargs:
|
||||
targets = keyargs['attention_mask']
|
||||
if 'input_ids' in kwargs:
|
||||
tokens = kwargs['input_ids']
|
||||
if 'attention_mask' in kwargs:
|
||||
targets = kwargs['attention_mask']
|
||||
|
||||
# 前向传播函数
|
||||
_bsz, seqlen = tokens.shape
|
||||
@@ -415,6 +415,234 @@ class Transformer(PreTrainedModel):
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx[:, index:] # 只返回生成的token
|
||||
|
||||
def _greedy_decode(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
贪婪解码:选择概率最大的token
|
||||
|
||||
Args:
|
||||
logits: 模型输出的logits,形状为 (batch_size, vocab_size)
|
||||
|
||||
Returns:
|
||||
选择的token索引,形状为 (batch_size, 1)
|
||||
"""
|
||||
_, idx_next = torch.topk(logits, k=1, dim=-1)
|
||||
return idx_next
|
||||
|
||||
def _random_sample(self, logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
|
||||
"""
|
||||
随机采样:基于概率分布随机选择token
|
||||
|
||||
Args:
|
||||
logits: 模型输出的logits,形状为 (batch_size, vocab_size)
|
||||
temperature: 温度参数,控制随机性
|
||||
top_k: 只考虑概率最高的k个token
|
||||
|
||||
Returns:
|
||||
选择的token索引,形状为 (batch_size, 1)
|
||||
"""
|
||||
# 缩放 logits
|
||||
logits = logits / temperature
|
||||
|
||||
# 应用top-k过滤
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
# 将不在 top-k 内的 logits 设为负无穷
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
||||
# 计算概率并采样
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
return idx_next
|
||||
|
||||
def _beam_search(self, idx: torch.Tensor, max_new_tokens: int, num_beams: int,
|
||||
temperature: float = 1.0, top_k: int = None, stop_id: int = None) -> torch.Tensor:
|
||||
"""
|
||||
束搜索:维护多个候选序列,选择最优路径
|
||||
|
||||
束搜索的核心思想:在每一步生成时,不是只选择一个最佳token,
|
||||
而是保留多个候选路径,最终选择累积概率最高的完整序列。
|
||||
|
||||
Args:
|
||||
idx: 输入序列,形状为 (batch_size, seq_len)
|
||||
max_new_tokens: 最大生成token数量
|
||||
num_beams: 束宽度,表示保留的候选路径数量
|
||||
temperature: 温度参数,控制分布的平滑程度
|
||||
top_k: top-k过滤参数,限制候选token范围
|
||||
stop_id: 停止生成的token ID,遇到则停止
|
||||
|
||||
Returns:
|
||||
生成的token序列,形状为 (batch_size, generated_length)
|
||||
只返回新生成的部分,不包含原始输入序列
|
||||
"""
|
||||
# 获取输入序列的基本信息
|
||||
batch_size = idx.shape[0] # 批次大小,通常为1
|
||||
seq_len = idx.shape[1] # 输入序列长度
|
||||
|
||||
# 初始化束:创建 num_beams 个候选序列
|
||||
beams = [idx.clone() for _ in range(num_beams)]
|
||||
# 初始化每个候选序列的累积对数概率分数
|
||||
beam_scores = torch.zeros(num_beams, device=idx.device)
|
||||
# 第一个候选是原始输入序列,分数为0
|
||||
beam_scores[0] = 0.0
|
||||
# 其他候选初始分数设为负无穷,表示尚未生成
|
||||
beam_scores[1:] = float('-inf')
|
||||
|
||||
# 主循环:逐步生成新的token,最多生成 max_new_tokens 个
|
||||
for step in range(max_new_tokens):
|
||||
# 每轮迭代收集新的候选序列和分数
|
||||
new_beams = [] # 新的候选序列列表
|
||||
new_scores = [] # 对应的分数列表
|
||||
|
||||
# 遍历当前的所有候选序列
|
||||
for beam_idx, beam in enumerate(beams):
|
||||
# 跳过无效候选(分数为负无穷的序列)
|
||||
if beam_scores[beam_idx] == float('-inf'):
|
||||
continue
|
||||
|
||||
# 序列长度检查:如果超过最大长度,截取最后的部分
|
||||
beam_cond = beam if beam.size(1) <= self.args.max_seq_len else beam[:, -self.args.max_seq_len:]
|
||||
|
||||
# 前向传播:获取模型对当前序列的预测
|
||||
output = self(beam_cond)
|
||||
# 提取最后一个位置的logits,用于预测下一个token
|
||||
logits = output.logits[:, -1, :] # 形状: (1, vocab_size)
|
||||
|
||||
# 温度缩放:调整logits的分布
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
# 温度 > 1:分布更平滑,增加随机性
|
||||
# 温度 < 1:分布更尖锐,更确定
|
||||
|
||||
# Top-k过滤:限制候选token的范围,提高质量
|
||||
if top_k is not None:
|
||||
# 找到logits中前top_k个最大的值
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
# 将不在前top_k内的logits设为负无穷
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# 这样采样时只会考虑前top_k个token
|
||||
|
||||
# 计算对数概率:使用log_softmax避免数值不稳定
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
|
||||
# 获取前 num_beams 个最可能的候选token
|
||||
# 注意:这里的top-k与上面的top-k不同
|
||||
# 上面的top-k是全局过滤,这里是束搜索的分支选择
|
||||
top_log_probs, top_indices = torch.topk(log_probs, k=num_beams, dim=-1)
|
||||
|
||||
# 为当前候选序列生成 num_beams 个扩展序列
|
||||
for k in range(num_beams):
|
||||
# 选择第k个候选token
|
||||
token = top_indices[:, k:k+1] # token ID
|
||||
log_prob = top_log_probs[:, k] # 对应的对数概率
|
||||
|
||||
# 扩展序列:将新token添加到当前序列末尾
|
||||
new_beam = torch.cat([beam, token], dim=1)
|
||||
# 更新累积分数:原序列分数 + 新token的对数概率
|
||||
new_score = beam_scores[beam_idx] + log_prob.item()
|
||||
|
||||
# 保存新的候选序列和分数
|
||||
new_beams.append(new_beam)
|
||||
new_scores.append(new_score)
|
||||
|
||||
# 安全检查:如果没有生成任何有效候选,提前结束
|
||||
if not new_beams:
|
||||
break
|
||||
|
||||
# 筛选最佳候选:从所有新生成的候选中选择分数最高的 num_beams 个
|
||||
# 按分数降序排序,获取索引
|
||||
sorted_indices = sorted(range(len(new_scores)), key=lambda i: new_scores[i], reverse=True)
|
||||
# 选择前 num_beams 个最佳候选
|
||||
beams = [new_beams[i] for i in sorted_indices[:num_beams]]
|
||||
beam_scores = [new_scores[i] for i in sorted_indices[:num_beams]]
|
||||
|
||||
# 停止条件检查:检查最佳序列是否以停止token结尾
|
||||
if stop_id is not None and beams[0][0, -1] == stop_id:
|
||||
break
|
||||
|
||||
# 返回得分最高的序列,只返回新生成的部分(去掉原始输入)
|
||||
# beams[0] 是最终得分最高的完整序列
|
||||
# [:, seq_len:] 切片只保留生成部分
|
||||
return beams[0][:, seq_len:]
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_super(self,
|
||||
idx,
|
||||
stop_id=None,
|
||||
max_new_tokens=256,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
do_sample=False,
|
||||
num_beams=1
|
||||
):
|
||||
"""
|
||||
高级文本生成函数,支持三种解码策略:
|
||||
|
||||
1. 贪婪解码(Greedy Search):
|
||||
- 参数:do_sample=False, num_beams=1
|
||||
- 特点:每步选择概率最大的token,速度快、结果确定
|
||||
|
||||
2. 随机采样(Random Sampling):
|
||||
- 参数:do_sample=True, num_beams=1
|
||||
- 特点:基于概率分布随机采样,可配合temperature和top-k控制多样性
|
||||
|
||||
3. 束搜索(Beam Search):
|
||||
- 参数:do_sample=False, num_beams>1
|
||||
- 特点:维护多条候选路径,选择总概率最高的序列,质量更高但速度较慢
|
||||
|
||||
Args:
|
||||
idx: 输入序列张量,形状为 (batch_size, seq_len)
|
||||
stop_id: 停止生成的token ID
|
||||
max_new_tokens: 最大生成token数量
|
||||
temperature: 温度参数,控制随机性,越高越随机
|
||||
top_k: 只考虑概率最高的k个token,None表示不考虑
|
||||
do_sample: 是否使用随机采样,False时使用确定性解码
|
||||
num_beams: 束搜索的束宽度,1表示不使用束搜索
|
||||
|
||||
Returns:
|
||||
生成的token序列,形状为 (batch_size, generated_length)
|
||||
"""
|
||||
# 参数验证
|
||||
if temperature <= 0:
|
||||
temperature = 0.001 # 避免除零错误
|
||||
if num_beams < 1:
|
||||
num_beams = 1
|
||||
if top_k is not None and top_k < 1:
|
||||
top_k = None
|
||||
|
||||
# 束搜索逻辑
|
||||
if not do_sample and num_beams > 1:
|
||||
return self._beam_search(idx, max_new_tokens, num_beams, temperature, top_k, stop_id)
|
||||
|
||||
# 贪婪解码和随机采样逻辑
|
||||
index = idx.shape[1]
|
||||
for _ in range(max_new_tokens):
|
||||
# 如果序列上下文过长,截断它到最大长度
|
||||
idx_cond = idx if idx.size(1) <= self.args.max_seq_len else idx[:, -self.args.max_seq_len:]
|
||||
|
||||
# 前向传播获取序列中最后一个位置的 logits
|
||||
logits = self(idx_cond).logits
|
||||
logits = logits[:, -1, :] # 只保留最后一个时间步的输出
|
||||
|
||||
# 根据参数选择解码策略
|
||||
if do_sample:
|
||||
idx_next = self._random_sample(logits, temperature, top_k)
|
||||
else:
|
||||
# 当temperature=0时使用贪婪解码
|
||||
if temperature < 0.1:
|
||||
idx_next = self._greedy_decode(logits)
|
||||
else:
|
||||
# 低温度下的随机采样(接近贪婪)
|
||||
idx_next = self._random_sample(logits, temperature, top_k)
|
||||
|
||||
# 检查停止条件
|
||||
if stop_id is not None and idx_next[0, 0] == stop_id:
|
||||
break
|
||||
|
||||
# 将选择的token添加到序列中
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx[:, index:] # 只返回生成的token
|
||||
|
||||
if __name__ == '__main__':
|
||||
tokenizer = AutoTokenizer.from_pretrained("tokenizer_k")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# 第五章 动手搭建大模型
|
||||
|
||||
|
||||
|
||||
## 5.1 动手实现一个 LLaMA2 大模型
|
||||
|
||||
Meta(原Facebook)于2023年2月发布第一款基于Transformer结构的大型语言模型LLaMA,并于同年7月发布同系列模型LLaMA2。我们在第四章已经学习和了解了LLM,以及如何训练LLM等内容。本小节我们就来学习如何动手实现一个LLaMA2模型。
|
||||
@@ -554,20 +556,20 @@ class Transformer(PreTrainedModel):
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **keyargs) -> torch.Tensor:
|
||||
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
- tokens: Optional[torch.Tensor], 输入 token 张量。
|
||||
- targets: Optional[torch.Tensor], 目标 token 张量。
|
||||
- kv_cache: bool, 是否使用键值缓存。
|
||||
- keyargs: 其他关键字参数。
|
||||
- kwargs: 其他关键字参数。
|
||||
|
||||
- self.OUT: CausalLMOutputWithPast, 包含 logits 和损失。
|
||||
"""
|
||||
|
||||
if 'input_ids' in keyargs:
|
||||
tokens = keyargs['input_ids']
|
||||
if 'attention_mask' in keyargs:
|
||||
targets = keyargs['attention_mask']
|
||||
if 'input_ids' in kwargs:
|
||||
tokens = kwargs['input_ids']
|
||||
if 'attention_mask' in kwargs:
|
||||
targets = kwargs['attention_mask']
|
||||
|
||||
# 前向传播函数
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
transformers
|
||||
datasets
|
||||
torch
|
||||
torchdata
|
||||
torchdata==0.9.0
|
||||
deepspeed
|
||||
pandas
|
||||
swanlab
|
||||
swanlab
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# 第六章 大模型训练流程实践
|
||||
|
||||
第五章中,我们从零开始动手搭建了 LLaMA2 模型,并完整实现了其预训练和微调的全流程。在本章中,我们将深入探讨大模型的训练流程实践,重点介绍如何利用主流的大模型框架高效地进行模型训练和性能优化。
|
||||
|
||||
## 6.1 模型预训练
|
||||
|
||||
在上一章,我们逐步拆解了 LLM 的模型结构及训练过程,从零手写实现了 LLaMA 模型结构及 Pretrain、SFT 全流程,更深入地理解了 LLM 的模型原理及训练细节。但是,在实际应用中,手写实现的 LLM 训练存在以下问题:
|
||||
|
||||
@@ -51,7 +51,24 @@ class Agent:
|
||||
stream=False,
|
||||
)
|
||||
if response.choices[0].message.tool_calls:
|
||||
self.messages.append({"role": "assistant", "content": response.choices[0].message.content})
|
||||
# 将包含 tool_calls 的完整 assistant 消息添加到历史中
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
for tool_call in response.choices[0].message.tool_calls
|
||||
]
|
||||
}
|
||||
self.messages.append(assistant_message)
|
||||
|
||||
# 处理工具调用
|
||||
tool_list = []
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
|
||||
@@ -13,7 +13,7 @@ st.set_page_config(
|
||||
|
||||
# --- OpenAI客户端初始化 ---
|
||||
client = OpenAI(
|
||||
api_key="sk-quovvfgjdmmrvwiljusggiwvxfiekzicwjgtdvpfqhpmbpqu",
|
||||
api_key="your siliconflow api key",
|
||||
base_url="https://api.siliconflow.cn/v1",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# 此处默认使用国内可访问的轨迹流动平台 https://cloud.siliconflow.cn/
|
||||
# 此处默认使用国内可访问的硅基流动平台 https://cloud.siliconflow.cn/
|
||||
|
||||
OPENAI_API_KEY='your api key'
|
||||
OPENAI_BASE_URL='https://api.siliconflow.cn/v1'
|
||||
@@ -32,20 +32,12 @@ class ReadFiles:
|
||||
self.file_list = self.get_files()
|
||||
|
||||
def get_files(self):
|
||||
# args:dir_path,目标文件夹路径
|
||||
file_list = []
|
||||
for filepath, dirnames, filenames in os.walk(self._path):
|
||||
# os.walk 函数将递归遍历指定文件夹
|
||||
for filename in filenames:
|
||||
# 通过后缀名判断文件类型是否满足要求
|
||||
if filename.endswith(".md"):
|
||||
# 如果满足要求,将其绝对路径加入到结果列表
|
||||
file_list.append(os.path.join(filepath, filename))
|
||||
elif filename.endswith(".txt"):
|
||||
file_list.append(os.path.join(filepath, filename))
|
||||
elif filename.endswith(".pdf"):
|
||||
file_list.append(os.path.join(filepath, filename))
|
||||
return file_list
|
||||
file_list=[]
|
||||
for file_path,dir_names,file_names in os.walk(self.path):
|
||||
for file_name in file_names:
|
||||
if any([file_name.endswith(suffix) for suffix in [".md",".pdf",".txt"]]):
|
||||
file_list.append(os.path.join(file_path,file_name))
|
||||
return file_list
|
||||
|
||||
def get_content(self, max_token_len: int = 600, cover_content: int = 150):
|
||||
docs = []
|
||||
@@ -146,13 +138,10 @@ class ReadFiles:
|
||||
|
||||
@classmethod
|
||||
def read_pdf(cls, file_path: str):
|
||||
# 读取PDF文件
|
||||
with open(file_path, 'rb') as file:
|
||||
reader = PyPDF2.PdfReader(file)
|
||||
text = ""
|
||||
for page_num in range(len(reader.pages)):
|
||||
text += reader.pages[page_num].extract_text()
|
||||
return text
|
||||
with open(file_path,"rb") as file:
|
||||
reader=PyPDF2.PdfReader(file)
|
||||
return "".join([page.extract_text() for page in reader.pages])
|
||||
|
||||
|
||||
@classmethod
|
||||
def read_markdown(cls, file_path: str):
|
||||
@@ -185,3 +174,4 @@ class Documents:
|
||||
with open(self.path, mode='r', encoding='utf-8') as f:
|
||||
content = json.load(f)
|
||||
return content
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# 大模型应用
|
||||
|
||||
在前面的章节中,我们系统地介绍了大模型的基础知识、训练方法和微调技术。本章将重点探讨大模型在实际应用中的关键技术和框架,涵盖大模型评测、RAG(检索增强生成)以及Agent(智能体)等核心内容,帮助读者深入理解大模型的实际应用场景和实现方法。
|
||||
|
||||
## 7.1 LLM 的评测
|
||||
|
||||
近年来,随着人工智能领域的迅猛发展,大规模预训练语言模型(简称大模型)成为了推动技术进步的核心力量。这些大模型在自然语言处理等任务中展现出了令人惊叹的能力。然而,要准确衡量一个大模型的性能,必须依靠科学而合理的评测。
|
||||
@@ -320,7 +322,7 @@ class OpenAIEmbedding(BaseEmbeddings):
|
||||
|
||||
def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> List[float]:
|
||||
"""
|
||||
此处默认使用轨迹流动的免费嵌入模型 BAAI/bge-m3
|
||||
此处默认使用硅基流动的免费嵌入模型 BAAI/bge-m3
|
||||
"""
|
||||
if self.is_api:
|
||||
text = text.replace("\n", " ")
|
||||
@@ -817,4 +819,4 @@ User: exit
|
||||
|
||||
[11] Zhiruo Wang, Jun Araki, Zhengbao Jiang, Md Rizwan Parvez, 和 Graham Neubig. (2023). *Learning to Filter Context for Retrieval-Augmented Generation.* arXiv preprint arXiv:2311.08377.
|
||||
|
||||
[12] Ori Ram, Yoav Levine, Itay Dalmedigos, Dor Muhlgay, Amnon Shashua, Kevin Leyton-Brown 和 Yoav Shoham. (2023). *In-Context Retrieval-Augmented Language Models.* arXiv preprint arXiv:2302.00083.
|
||||
[12] Ori Ram, Yoav Levine, Itay Dalmedigos, Dor Muhlgay, Amnon Shashua, Kevin Leyton-Brown 和 Yoav Shoham. (2023). *In-Context Retrieval-Augmented Language Models.* arXiv preprint arXiv:2302.00083.
|
||||
|
||||
|
Before Width: | Height: | Size: 907 KiB After Width: | Height: | Size: 959 KiB |
|
Before Width: | Height: | Size: 907 KiB After Width: | Height: | Size: 959 KiB |
|
Before Width: | Height: | Size: 321 KiB After Width: | Height: | Size: 318 KiB |
BIN
docs/images/star-history-20251017.png
Normal file
|
After Width: | Height: | Size: 373 KiB |
|
Before Width: | Height: | Size: 138 KiB |
202
docs/index.html
@@ -8,20 +8,146 @@
|
||||
<meta name="description" content="Description">
|
||||
<meta name="viewport"
|
||||
content="width=device-width, user-scalable=no, initial-scale=1.0, maximum-scale=1.0, minimum-scale=1.0">
|
||||
|
||||
<!-- 1. 回归原始主题:保证列宽和排版是你最熟悉的样子 -->
|
||||
<link rel="stylesheet" href="//cdn.jsdelivr.net/npm/docsify@latest/lib/themes/vue.css">
|
||||
|
||||
<!-- 2. 手动定义暗黑模式样式 (精准覆盖,不影响布局) -->
|
||||
<style>
|
||||
/* --- 核心变量 --- */
|
||||
:root {
|
||||
--dark-bg: #1a1a1a;
|
||||
--dark-text: #c4c4c4;
|
||||
--dark-sidebar: #141414;
|
||||
--dark-code-bg: #2b2b2b;
|
||||
--dark-border: #333;
|
||||
--theme-color: #42b983; /* 你的主题绿 */
|
||||
}
|
||||
|
||||
/* --- 暗黑模式激活时的样式 --- */
|
||||
body.dark-mode {
|
||||
background-color: var(--dark-bg);
|
||||
color: var(--dark-text);
|
||||
}
|
||||
|
||||
/* 侧边栏变黑 */
|
||||
body.dark-mode .sidebar {
|
||||
background-color: var(--dark-sidebar);
|
||||
border-right: 1px solid var(--dark-border);
|
||||
color: var(--dark-text);
|
||||
}
|
||||
body.dark-mode .sidebar-nav li a {
|
||||
color: #999;
|
||||
}
|
||||
body.dark-mode .sidebar-nav li.active > a {
|
||||
color: var(--theme-color);
|
||||
border-right: 2px solid var(--theme-color);
|
||||
}
|
||||
|
||||
/* 正文内容适配 */
|
||||
body.dark-mode section.content {
|
||||
padding-top: 20px; /* 避免顶栏遮挡 */
|
||||
}
|
||||
|
||||
/* --- 重点:修复代码块颜色 --- */
|
||||
/* 背景变深灰,文字变亮 */
|
||||
body.dark-mode pre {
|
||||
background-color: var(--dark-code-bg) !important;
|
||||
}
|
||||
body.dark-mode code {
|
||||
background-color: var(--dark-code-bg) !important;
|
||||
color: #e0e0e0 !important;
|
||||
}
|
||||
/* 行内代码 (`code`) */
|
||||
body.dark-mode .markdown-section code {
|
||||
color: #f08d49; /* 醒目的橙色 */
|
||||
background-color: rgba(255,255,255,0.1);
|
||||
}
|
||||
/* 代码块内的文字强制变亮,防止原本的黑色字看不清 */
|
||||
body.dark-mode .token.comment,
|
||||
body.dark-mode .token.prolog,
|
||||
body.dark-mode .token.doctype,
|
||||
body.dark-mode .token.cdata {
|
||||
color: #777;
|
||||
}
|
||||
body.dark-mode .token.punctuation {
|
||||
color: #ccc;
|
||||
}
|
||||
body.dark-mode .token.operator,
|
||||
body.dark-mode .token.entity,
|
||||
body.dark-mode .token.url,
|
||||
body.dark-mode .language-css .token.string,
|
||||
body.dark-mode .style .token.string {
|
||||
color: #c4c4c4;
|
||||
}
|
||||
|
||||
/* 标题和引用 */
|
||||
body.dark-mode h1, body.dark-mode h2, body.dark-mode h3, body.dark-mode h4, body.dark-mode h5 {
|
||||
color: #e0e0e0;
|
||||
}
|
||||
body.dark-mode blockquote {
|
||||
color: #999;
|
||||
background: rgba(255,255,255,0.05);
|
||||
}
|
||||
|
||||
/* 表格修复 */
|
||||
body.dark-mode .markdown-section tr:nth-child(2n) {
|
||||
background-color: rgba(255,255,255,0.03);
|
||||
}
|
||||
body.dark-mode .markdown-section td,
|
||||
body.dark-mode .markdown-section th {
|
||||
border-color: var(--dark-border);
|
||||
}
|
||||
|
||||
/* Mermaid 图表反色 */
|
||||
body.dark-mode .mermaid {
|
||||
filter: invert(1) hue-rotate(180deg);
|
||||
}
|
||||
|
||||
/* --- 切换按钮样式 (嵌入侧边栏) --- */
|
||||
.sidebar-toggle-btn {
|
||||
cursor: pointer;
|
||||
display: block;
|
||||
text-align: center;
|
||||
padding: 10px 0;
|
||||
margin: 0 15px 10px 15px;
|
||||
font-weight: bold;
|
||||
font-size: 14px;
|
||||
border-radius: 4px;
|
||||
background-color: rgba(0,0,0,0.05);
|
||||
color: #505d6b;
|
||||
border: 1px solid rgba(0,0,0,0.05);
|
||||
transition: all 0.3s;
|
||||
}
|
||||
body.dark-mode .sidebar-toggle-btn {
|
||||
background-color: rgba(255,255,255,0.1);
|
||||
color: #ccc;
|
||||
border: 1px solid #444;
|
||||
}
|
||||
.sidebar-toggle-btn:hover {
|
||||
background-color: var(--theme-color);
|
||||
color: white;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
|
||||
<!-- Mermaid JS -->
|
||||
<script src="//cdn.jsdelivr.net/npm/mermaid@8.0.0-rc.8/dist/mermaid.min.js"></script>
|
||||
|
||||
<script>
|
||||
var num = 0;
|
||||
mermaid.initialize({ startOnLoad: false });
|
||||
|
||||
window.$docsify = {
|
||||
name: 'Happy-LLM',
|
||||
repo: 'https://github.com/datawhalechina/happy-llm',
|
||||
loadSidebar: true,
|
||||
auto2top: true,
|
||||
subMaxLevel: 2,
|
||||
relativePath: false, // 启用相对路径支持
|
||||
relativePath: false,
|
||||
alias: {
|
||||
'/.*/_sidebar.md': '/_sidebar.md'
|
||||
},
|
||||
@@ -34,24 +160,80 @@
|
||||
fontsize: '0.9em',
|
||||
color: 'rgb(90,90,90)',
|
||||
language: 'chinese'
|
||||
}
|
||||
},
|
||||
// Mermaid 渲染配置
|
||||
markdown: {
|
||||
renderer: {
|
||||
code: function(code, lang) {
|
||||
if (lang === "mermaid") {
|
||||
return (
|
||||
'<div class="mermaid">' + mermaid.render('mermaid-svg-' + num++, code) + "</div>"
|
||||
);
|
||||
}
|
||||
return this.origin.code.apply(this, arguments);
|
||||
}
|
||||
}
|
||||
},
|
||||
// --- 插件逻辑:插入按钮 + 图片放大 ---
|
||||
plugins: [
|
||||
function(hook, vm) {
|
||||
// 每次路由切换完成后执行
|
||||
hook.doneEach(function() {
|
||||
// 1. 获取侧边栏元素
|
||||
const sidebar = document.querySelector('.sidebar-nav');
|
||||
// 2. 如果已存在按钮则不重复添加
|
||||
if (!sidebar || document.querySelector('.sidebar-toggle-btn')) return;
|
||||
|
||||
// 3. 创建按钮
|
||||
const btn = document.createElement('div');
|
||||
btn.className = 'sidebar-toggle-btn';
|
||||
|
||||
// 4. 初始化状态
|
||||
const savedTheme = localStorage.getItem('theme-mode');
|
||||
if (savedTheme === 'dark') {
|
||||
document.body.classList.add('dark-mode');
|
||||
btn.textContent = '🌙 Switch to Light';
|
||||
} else {
|
||||
btn.textContent = '☀️ Switch to Dark';
|
||||
}
|
||||
|
||||
// 5. 点击事件
|
||||
btn.onclick = function() {
|
||||
document.body.classList.toggle('dark-mode');
|
||||
const isDark = document.body.classList.contains('dark-mode');
|
||||
// 保存状态
|
||||
localStorage.setItem('theme-mode', isDark ? 'dark' : 'light');
|
||||
// 更新按钮文字
|
||||
btn.textContent = isDark ? '🌙 Switch to Light' : '☀️ Switch to Dark';
|
||||
};
|
||||
|
||||
// 6. 插入到侧边栏最顶部
|
||||
sidebar.insertBefore(btn, sidebar.firstChild);
|
||||
});
|
||||
}
|
||||
]
|
||||
}
|
||||
</script>
|
||||
<!-- Put them above docsify.min.js -->
|
||||
|
||||
<!-- Docsify 核心 -->
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify@latest/lib/docsify.min.js"></script>
|
||||
<!-- code render-->
|
||||
|
||||
<!-- 图片放大 (官方插件,最稳) -->
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify/lib/plugins/zoom-image.min.js"></script>
|
||||
|
||||
<!-- 复制代码 -->
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify-copy-code"></script>
|
||||
|
||||
<!-- 代码高亮支持 -->
|
||||
<script src="//cdn.jsdelivr.net/npm/prismjs@latest/components/prism-bash.js"></script>
|
||||
<script src="//cdn.jsdelivr.net/npm/prismjs@latest/components/prism-python.js"></script>
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify-pagination@latest/dist/docsify-pagination.min.js"></script>
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify-copy-code"></script>
|
||||
|
||||
<!-- 其他插件 -->
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify-pagination@latest/dist/docsify-pagination.min.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@latest/dist/katex.min.js"></script>
|
||||
<link rel="stylesheet" href="//cdn.jsdelivr.net/npm/katex@latest/dist/katex.min.css" />
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked@3"></script>
|
||||
<!-- CDN files for docsify-katex -->
|
||||
<script src="//cdn.jsdelivr.net/npm/docsify-katex@latest/dist/docsify-katex.js"></script>
|
||||
<!-- 字数统计 -->
|
||||
<script src="//unpkg.com/docsify-count/dist/countable.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
|
||||
BIN
images/star-history-20251017.png
Normal file
|
After Width: | Height: | Size: 373 KiB |
|
Before Width: | Height: | Size: 138 KiB |