From 94e6e4a5be40cc259bcbe9b06ac3f809898444f1 Mon Sep 17 00:00:00 2001 From: sjjjoaps Date: Fri, 2 Jan 2026 11:30:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86=E5=A4=A7=E8=A7=84?= =?UTF-8?q?=E6=A8=A1=E6=95=B0=E6=8D=AE=E8=AF=BB=E5=8F=96=E9=80=BB=E8=BE=91?= =?UTF-8?q?,=E8=A7=A3=E5=86=B3=E4=BA=86=E4=B8=80=E6=AC=A1=E6=80=A7?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E6=89=80=E6=9C=89=E6=95=B0=E6=8D=AE=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=E5=86=85=E5=AD=98=E5=8D=A0=E7=94=A8=E8=BF=87=E5=A4=A7?= =?UTF-8?q?=E4=BB=A5=E5=8F=8A=E8=AE=AD=E7=BB=83=E8=BF=87=E7=A8=8B=E4=B8=AD?= =?UTF-8?q?=E5=86=85=E5=AD=98=E5=8D=A0=E7=94=A8=E6=8C=81=E7=BB=AD=E4=B8=8A?= =?UTF-8?q?=E5=8D=87=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/chapter5/code/dataset.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/docs/chapter5/code/dataset.py b/docs/chapter5/code/dataset.py index b021443..25b3ee6 100644 --- a/docs/chapter5/code/dataset.py +++ b/docs/chapter5/code/dataset.py @@ -1,7 +1,6 @@ import json import random import re - import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader @@ -15,14 +14,22 @@ class PretrainDataset(Dataset): self.tokenizer = tokenizer self.max_length = max_length self.padding = 0 - with open(data_path, 'r', encoding='utf-8') as f: - self.data = f.readlines() + # 预计算每行的起始字节偏移量 + self._offsets = [] + with open(data_path, 'rb') as f: + self._offsets.append(0) + while f.readline(): + self._offsets.append(f.tell()) + self._total_lines = len(self._offsets) - 1 # 最后一个 tell() 是 EOF def __len__(self): - return len(self.data) + return self._total_lines def __getitem__(self, index: int): - sample = json.loads(self.data[index]) + with open(self.data_path, 'rb') as f: + f.seek(self._offsets[index]) + line = f.readline().decode('utf-8') + sample = json.loads(line) text = f"{self.tokenizer.bos_token}{sample['text']}" input_id = self.tokenizer(text).data['input_ids'][:self.max_length] text_len = len(input_id) @@ -37,8 +44,7 @@ class PretrainDataset(Dataset): Y = np.array(input_id[1:]).astype(np.int64) loss_mask = np.array(loss_mask[1:]).astype(np.int64) return torch.from_numpy(X), torch.from_numpy(Y), torch.from_numpy(loss_mask) - - + class SFTDataset(Dataset): def __init__(self, data_path, tokenizer, max_length=512): super().__init__() @@ -46,11 +52,15 @@ class SFTDataset(Dataset): self.tokenizer = tokenizer self.max_length = max_length self.padding = 0 - with open(data_path, 'r', encoding='utf-8') as f: - self.data = f.readlines() + self._offsets = [] + with open(data_path, 'rb') as f: + self._offsets.append(0) + while f.readline(): + self._offsets.append(f.tell()) + self._total_lines = len(self._offsets) - 1 def __len__(self): - return len(self.data) + return self._total_lines def generate_loss_mask(self, input_ids): # 生成 loss mask, 0 表示不计算损失, 1 表示计算损失 @@ -89,7 +99,10 @@ class SFTDataset(Dataset): return mask def __getitem__(self, index: int): - sample = json.loads(self.data[index]) + with open(self.data_path, 'rb') as f: + f.seek(self._offsets[index]) + line = f.readline().decode('utf-8') + sample = json.loads(line) text = self.tokenizer.apply_chat_template(sample, tokenize=False, add_generation_prompt=False) input_id = self.tokenizer(text).data['input_ids'][:self.max_length] text_len = len(input_id)