优化了大规模数据读取逻辑,解决了一次性加载所有数据导致内存占用过大以及训练过程中内存占用持续上升的问题

This commit is contained in:
sjjjoaps
2026-01-02 11:30:42 +08:00
parent 47164fcca5
commit 94e6e4a5be

View File

@@ -1,7 +1,6 @@
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
@@ -15,14 +14,22 @@ class PretrainDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
# 预计算每行的起始字节偏移量
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
while f.readline():
self._offsets.append(f.tell())
self._total_lines = len(self._offsets) - 1 # 最后一个 tell() 是 EOF
def __len__(self):
return len(self.data)
return self._total_lines
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
with open(self.data_path, 'rb') as f:
f.seek(self._offsets[index])
line = f.readline().decode('utf-8')
sample = json.loads(line)
text = f"{self.tokenizer.bos_token}{sample['text']}"
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)
@@ -37,8 +44,7 @@ class PretrainDataset(Dataset):
Y = np.array(input_id[1:]).astype(np.int64)
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
class SFTDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
@@ -46,11 +52,15 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = 0
with open(data_path, 'r', encoding='utf-8') as f:
self.data = f.readlines()
self._offsets = []
with open(data_path, 'rb') as f:
self._offsets.append(0)
while f.readline():
self._offsets.append(f.tell())
self._total_lines = len(self._offsets) - 1
def __len__(self):
return len(self.data)
return self._total_lines
def generate_loss_mask(self, input_ids):
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
@@ -89,7 +99,10 @@ class SFTDataset(Dataset):
return mask
def __getitem__(self, index: int):
sample = json.loads(self.data[index])
with open(self.data_path, 'rb') as f:
f.seek(self._offsets[index])
line = f.readline().decode('utf-8')
sample = json.loads(line)
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
text_len = len(input_id)