refactor(RAG): 改进文本分块逻辑以正确处理长行和空格

重构文本分块算法,保留空格并优化长行处理
使用token级别分割避免跨单词分割问题
添加覆盖内容逻辑以保持上下文连贯性
This commit is contained in:
KMnO4-zx
2025-07-04 09:07:52 +08:00
parent 5c474e4730
commit f50df92095
2 changed files with 106 additions and 31 deletions

View File

@@ -68,37 +68,65 @@ class ReadFiles:
lines = text.splitlines() # 假设以换行符分割文本为行
for line in lines:
line = line.replace(' ', '')
# 保留空格,只移除行首行尾空格
line = line.strip()
line_len = len(enc.encode(line))
if line_len > max_token_len:
# 如果单行长度就超过限制,则将其分割成多个块
num_chunks = (line_len + token_len - 1) // token_len
for i in range(num_chunks):
start = i * token_len
end = start + token_len
# 避免跨单词分割
while not line[start:end].rstrip().isspace():
start += 1
end += 1
if start >= line_len:
break
curr_chunk = curr_chunk[-cover_content:] + line[start:end]
# 先保存当前块(如果有内容)
if curr_chunk:
chunk_text.append(curr_chunk)
# 处理最后一个块
start = (num_chunks - 1) * token_len
curr_chunk = curr_chunk[-cover_content:] + line[start:end]
chunk_text.append(curr_chunk)
curr_chunk = ''
curr_len = 0
if curr_len + line_len <= token_len:
# 将长行按token长度分割
line_tokens = enc.encode(line)
num_chunks = (len(line_tokens) + token_len - 1) // token_len
for i in range(num_chunks):
start_token = i * token_len
end_token = min(start_token + token_len, len(line_tokens))
# 解码token片段回文本
chunk_tokens = line_tokens[start_token:end_token]
chunk_part = enc.decode(chunk_tokens)
# 添加覆盖内容(除了第一个块)
if i > 0 and chunk_text:
prev_chunk = chunk_text[-1]
cover_part = prev_chunk[-cover_content:] if len(prev_chunk) > cover_content else prev_chunk
chunk_part = cover_part + chunk_part
chunk_text.append(chunk_part)
# 重置当前块状态
curr_chunk = ''
curr_len = 0
elif curr_len + line_len + 1 <= token_len: # +1 for newline
# 当前行可以加入当前块
if curr_chunk:
curr_chunk += '\n'
curr_len += 1
curr_chunk += line
curr_chunk += '\n'
curr_len += line_len
curr_len += 1
else:
chunk_text.append(curr_chunk)
curr_chunk = curr_chunk[-cover_content:]+line
curr_len = line_len + cover_content
# 当前行无法加入当前块,开始新块
if curr_chunk:
chunk_text.append(curr_chunk)
# 开始新块,添加覆盖内容
if chunk_text:
prev_chunk = chunk_text[-1]
cover_part = prev_chunk[-cover_content:] if len(prev_chunk) > cover_content else prev_chunk
curr_chunk = cover_part + '\n' + line
curr_len = len(enc.encode(cover_part)) + 1 + line_len
else:
curr_chunk = line
curr_len = line_len
# 添加最后一个块(如果有内容)
if curr_chunk:
chunk_text.append(curr_chunk)

View File

@@ -261,27 +261,74 @@ def get_chunk(cls, text: str, max_token_len: int = 600, cover_content: int = 150
curr_len = 0
curr_chunk = ''
lines = text.split('\n')
token_len = max_token_len - cover_content
lines = text.splitlines() # 假设以换行符分割文本为行
for line in lines:
line = line.replace(' ', '')
# 保留空格,只移除行首行尾空格
line = line.strip()
line_len = len(enc.encode(line))
if line_len > max_token_len:
print('warning line_len = ', line_len)
if curr_len + line_len <= max_token_len:
# 如果单行长度就超过限制,则将其分割成多个块
# 先保存当前块(如果有内容)
if curr_chunk:
chunk_text.append(curr_chunk)
curr_chunk = ''
curr_len = 0
# 将长行按token长度分割
line_tokens = enc.encode(line)
num_chunks = (len(line_tokens) + token_len - 1) // token_len
for i in range(num_chunks):
start_token = i * token_len
end_token = min(start_token + token_len, len(line_tokens))
# 解码token片段回文本
chunk_tokens = line_tokens[start_token:end_token]
chunk_part = enc.decode(chunk_tokens)
# 添加覆盖内容(除了第一个块)
if i > 0 and chunk_text:
prev_chunk = chunk_text[-1]
cover_part = prev_chunk[-cover_content:] if len(prev_chunk) > cover_content else prev_chunk
chunk_part = cover_part + chunk_part
chunk_text.append(chunk_part)
# 重置当前块状态
curr_chunk = ''
curr_len = 0
elif curr_len + line_len + 1 <= token_len: # +1 for newline
# 当前行可以加入当前块
if curr_chunk:
curr_chunk += '\n'
curr_len += 1
curr_chunk += line
curr_chunk += '\n'
curr_len += line_len
curr_len += 1
else:
chunk_text.append(curr_chunk)
curr_chunk = curr_chunk[-cover_content:] + line
curr_len = line_len + cover_content
# 当前行无法加入当前块,开始新块
if curr_chunk:
chunk_text.append(curr_chunk)
# 开始新块,添加覆盖内容
if chunk_text:
prev_chunk = chunk_text[-1]
cover_part = prev_chunk[-cover_content:] if len(prev_chunk) > cover_content else prev_chunk
curr_chunk = cover_part + '\n' + line
curr_len = len(enc.encode(cover_part)) + 1 + line_len
else:
curr_chunk = line
curr_len = line_len
# 添加最后一个块(如果有内容)
if curr_chunk:
chunk_text.append(curr_chunk)
return chunk_text
```
#### Step 4: 数据库与向量检索