chapter 6 use swanlab

This commit is contained in:
ZeYi Lin
2025-07-03 18:18:44 +08:00
parent 0d2471d3ee
commit db3a162cd8

View File

@@ -289,7 +289,7 @@ from transformers import (
import datetime
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
import wandb
import swanlab
```
首先需要定义几个超参的类型,用于处理 sh 脚本中设定的超参值。由于 transformers 本身有 TraingingArguments 类,其中包括了训练的一些必备超参数。我们这里只需定义 TrainingArguments 中未包含的超参即可,主要包括模型相关的超参(定义在 ModelArguments和数据相关的超参定义在 DataTrainingArguments
@@ -456,14 +456,14 @@ trainer.save_model()
```
注意,由于上文检测了是否存在 checkpoint此处使用 resume_from_checkpoint 来实现从 checkpoint 恢复训练的功能。
由于在大规模训练中监测训练进度、loss 下降趋势尤为重要,在脚本中,我们使用了 wandb 作为训练检测的工具。在脚本开始进行了 wandb 的初始化:
由于在大规模训练中监测训练进度、loss 下降趋势尤为重要,在脚本中,我们使用了 swanlab 作为训练检测的工具。在脚本开始进行了 swanlab 的初始化:
```python
# 初始化 WandB
wandb.init(project="pretrain", name="from_scrach")
# 初始化 SwanLab
swanlab.init(project="pretrain", name="from_scrach")
```
在启动训练后,终端会输出 wandb 监测的 url点击即可观察训练进度。此处不再赘述 wandb 的使用细节,欢迎读者查阅相关的资料说明。
在启动训练后,终端会输出 swanlab 监测的 url点击即可观察训练进度。此处不再赘述 swanlab 的使用细节,欢迎读者查阅相关的资料说明。
完成上述代码后,我们使用一个 sh 脚本(`./code/pretrain.sh`)定义超参数的值,并通过 Deepspeed 启动训练,从而实现高效的多卡分布式训练:
@@ -495,7 +495,7 @@ deepspeed pretrain.py \
--bf16 \
--gradient_checkpointing \
--deepspeed ./ds_config_zero2.json \
--report_to wandb
--report_to swanlab
# --resume_from_checkpoint ${output_model}/checkpoint-20400 \
```
在安装了 Deepspeed 第三方库后,可以直接通过 Deepspeed 命令来启动多卡训练。上述脚本命令主要是定义了各种超参数的值,可参考使用。在第四章中,我们介绍了 DeepSpeed 分布式训练的原理和 ZeRO 阶段设置,在这里,我们使用 ZeRO-2 进行训练。此处加载了 `ds_config_zero.json` 作为 DeepSpeed 的配置参数:
@@ -690,8 +690,8 @@ class SupervisedDataset(Dataset):
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# 初始化 WandB
wandb.init(project="sft", name="qwen-1.5b")
# 初始化 SwanLab
swanlab.init(project="sft", name="qwen-1.5b")
# 设置日志
logging.basicConfig(