Merge pull request #155 from sjjjoaps/main

优化了大规模数据读取逻辑,解决了一次性加载所有数据导致内存占用过大以及训练过程中内存占用持续上升的问题
This commit was merged in pull request #155.
This commit is contained in:
不要葱姜蒜
2026-01-03 11:28:55 +08:00
committed by GitHub
2 changed files with 50 additions and 22 deletions

View File

@@ -1,7 +1,6 @@
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
@@ -15,14 +14,22 @@ class PretrainDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
# 预计算每行的起始字节偏移量
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
while f.readline():
self._offsets.append(f.tell())
self._total_lines = len(self._offsets) - 1 # 最后一个 tell() 是 EOF
def __len__(self):
return len(self.data)
return self._total_lines
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
with open(self.data_path, 'rb') as f:
f.seek(self._offsets[index])
line = f.readline().decode('utf-8')
sample = json.loads(line)
text = f"{self.tokenizer.bos_token}{sample['text']}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
@@ -37,8 +44,7 @@ class PretrainDataset(Dataset):
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SFTDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
@@ -46,11 +52,15 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
while f.readline():
self._offsets.append(f.tell())
self._total_lines = len(self._offsets) - 1
def __len__(self):
return len(self.data)
return self._total_lines
def generate_loss_mask(self, input_ids):
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
@@ -89,7 +99,10 @@ class SFTDataset(Dataset):
return mask
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
with open(self.data_path, 'rb') as f:
f.seek(self._offsets[index])
line = f.readline().decode('utf-8')
sample = json.loads(line)
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)

View File

@@ -1307,14 +1307,22 @@ class PretrainDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
# 预计算每行的起始字节偏移量
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
while f.readline():
self._offsets.append(f.tell())
self._total_lines = len(self._offsets) - 1 # 最后一个 tell() 是 EOF
def __len__(self):
return len(self.data)
return self._total_lines
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
with open(self.data_path, 'rb') as f:
f.seek(self._offsets[index])
line = f.readline().decode('utf-8')
sample = json.loads(line)
text = f"{self.tokenizer.bos_token}{sample['text']}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
@@ -1358,16 +1366,20 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
while f.readline():
self._offsets.append(f.tell())
self._total_lines = len(self._offsets) - 1
def __len__(self):
return len(self.data)
return self._total_lines
def generate_loss_mask(self, input_ids):
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
mask = [0] * len(input_ids)
a_sequence = [3, 1074, 537, 500, 203] # <|im_start|>assistant\n
a_sequence = self.tokenizer("<|im_start|>assistant\n")['input_ids'] # <|im_start|>assistant\n
a_length = len(a_sequence)
n = len(input_ids)
i = 0
@@ -1380,10 +1392,10 @@ class SFTDataset(Dataset):
match = False
break
if match:
# 从子序列结束的位置开始查找第一个4, 4 为 <|im_end|> EOS id
# 从子序列结束的位置开始查找第一个 4 (eos_token_id)
j = None
for idx in range(i + a_length, n):
if input_ids[idx] == 4:
if input_ids[idx] == self.tokenizer.eos_token_id:
j = idx
break
if j is not None:
@@ -1401,7 +1413,10 @@ class SFTDataset(Dataset):
return mask
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
with open(self.data_path, 'rb') as f:
f.seek(self._offsets[index])
line = f.readline().decode('utf-8')
sample = json.loads(line)
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)