优化了大规模数据读取逻辑,解决了一次性加载所有数据导致内存占用过大以及训练过程中内存占用持续上升的问题
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@@ -15,14 +14,22 @@ class PretrainDataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = 0
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
self.data = f.readlines()
|
||||
# 预计算每行的起始字节偏移量
|
||||
self._offsets = []
|
||||
with open(data_path, 'rb') as f:
|
||||
self._offsets.append(0)
|
||||
while f.readline():
|
||||
self._offsets.append(f.tell())
|
||||
self._total_lines = len(self._offsets) - 1 # 最后一个 tell() 是 EOF
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
return self._total_lines
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
sample = json.loads(self.data[index])
|
||||
with open(self.data_path, 'rb') as f:
|
||||
f.seek(self._offsets[index])
|
||||
line = f.readline().decode('utf-8')
|
||||
sample = json.loads(line)
|
||||
text = f"{self.tokenizer.bos_token}{sample['text']}"
|
||||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||||
text_len = len(input_id)
|
||||
@@ -37,8 +44,7 @@ class PretrainDataset(Dataset):
|
||||
Y = np.array(input_id[1:]).astype(np.int64)
|
||||
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
||||
return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask)
|
||||
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
@@ -46,11 +52,15 @@ class SFTDataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = 0
|
||||
with open(data_path, 'r', encoding='utf-8') as f:
|
||||
self.data = f.readlines()
|
||||
self._offsets = []
|
||||
with open(data_path, 'rb') as f:
|
||||
self._offsets.append(0)
|
||||
while f.readline():
|
||||
self._offsets.append(f.tell())
|
||||
self._total_lines = len(self._offsets) - 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
return self._total_lines
|
||||
|
||||
def generate_loss_mask(self, input_ids):
|
||||
# 生成 loss mask, 0 表示不计算损失, 1 表示计算损失
|
||||
@@ -89,7 +99,10 @@ class SFTDataset(Dataset):
|
||||
return mask
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
sample = json.loads(self.data[index])
|
||||
with open(self.data_path, 'rb') as f:
|
||||
f.seek(self._offsets[index])
|
||||
line = f.readline().decode('utf-8')
|
||||
sample = json.loads(line)
|
||||
text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False)
|
||||
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
||||
text_len = len(input_id)
|
||||
|
||||
Reference in New Issue
Block a user