跳到主要内容

配置文件

5 分钟 等级: 入门

快速开始

创建一个配置文件 cfg.yaml

# 实验信息
run_name: 'your task' # 实验名称
run_description: "your custom tasks" # 实验描述

# 核心组件(需要你实现具体 type)
dataset:
type: "" # 数据集类型
model:
type: "" # 模型类型
metric:
type: "" # 评估指标类型

# 优化器,默认AdamW
optimizer:
type: 'AdamW'
lr: 3e-4 # 学习率
weight_decay: 0.01 # 权重衰减

# 学习率调度
scheduler:
type: "LinearWarmupCosineLRScheduler" # 调度器类型
min_lr: 3e-5 # 最小学习率
max_lr: 3e-4 # 最大学习率
warmup_rate: 0.05 # 预热比例(0-1)
warmup_start_lr: 1e-5 # 预热起始学习率

# 训练参数
training:
epochs: 20 # 训练轮数
ds_config: 'path/to/deepspeed.json' # DeepSpeed 配置(可选),若提供将并入配置
gradient_accumulation: 1 # 梯度累积步数
resume_from: null # 从某检查点继续训练
load_from: null # 预加载权重路径
grad_clip: null # 梯度裁剪阈值(null 表示不启用)
fp16: false # 是否使用半精度
print_model: false # 是否打印模型结构

# 训练/验证/测试进度展示与频率
progress_show:
loss: false # 训练进度中是否展示 loss 曲线/进度(false表示loss不是越大越好)

# 校验/测试触发频率(按 epoch 与 iter 两种粒度)
valid_begin_epoch: 2
valid_interval_epoch: 1
valid_begin_iter: 10
valid_interval_iter: 5
test_begin_epoch: 3
test_interval_epoch: 1
test_begin_iter: 60
test_interval_iter: 5
progress_every_n_epochs: 1 # 进度汇报间隔(按 epoch)
progress_every_n_batches: 1 # 进度汇报间隔(按 iter)

# 日志
log:
to_file: true # 是否写入文件
folder: "./assert/logs" # 日志目录
level: "INFO" # 控制台日志级别
rank_level: "WARNING" # 分布式时的非主进程日志级别

# 检查点
pt:
enable: true # 是否启用保存
dir: "./assert/checkpoints" # 保存目录
best_monitor:
loss: false # 以最低 loss 作为选择标准
topk: 5 # 保留前 K 个最佳模型
begin_epoch: 2 # 从第几轮开始按 epoch 保存
epoch_interval: 2 # 每隔多少个 epoch 保存一次
begin_batch: 10 # 从第几个 iter 开始按 batch 保存
batch_interval: 10 # 每隔多少个 iter 保存一次

# Wandb
wandb:
enable: true # 是否启用 Wandb
proj_name: "" # 项目名称
offline: true # 是否离线模式
dir: "./assert" # 日志/工件目录
tags: ["", ""] # 标签

train.py代码中加载配置:

from wall_e.config.load_config import load_cfg

# 加载配置
cfg = load_cfg('./cfg.yaml')

# 使用配置
lr = cfg.optimizer.lr
epochs = cfg.training.epochs
print(f"学习率: {lr}, 训练轮数: {epochs}")

命令行覆盖

你可以用命令行参数覆盖任何配置:

# 修改学习率
python train.py optimizer.lr=1e-4

# 修改训练轮数
python train.py training.epochs=100

# 同时修改多个参数
python train.py training.epochs=50 optimizer.lr=2e-4 training.fp16=true

# 修改嵌套配置
python train.py wandb.enable=true wandb.proj_name=my_experiment

配置优先级

配置按以下顺序合并(后面的会覆盖前面的):

  1. 基础配置文件
  2. DeepSpeed 配置
  3. 命令行参数

配置操作(增删改查与保存)

配置的操作与 OmegaConfDictConfig API 保持一致。以下示例涵盖常用的读取、修改、添加、删除和保存,更多能力请查阅OmegaConf文档

from omegaconf import OmegaConf

# 读取(含默认值与存在性判断)
lr = cfg.optimizer.lr
epochs = cfg.training.epochs
batch_size = cfg.training.get("batch_size", 32)
if "wandb" in cfg:
print("Wandb 已启用")

# 修改(单项与批量)
cfg.optimizer.lr = 1e-3
cfg.training.epochs = 100
cfg.update({
"optimizer.lr": 1e-3,
"training.epochs": 100,
"model.type": "transformer",
})

# 添加(顶层与嵌套)
cfg.new_section = {"param1": "value1", "param2": "value2"}
cfg.model.new_param = "new_value"

# 删除(直接删除与安全删除)
del cfg.training.grad_clip
cfg.pop("training.grad_clip", None)

# 保存与导出
OmegaConf.save(cfg, "output_config.yaml")
config_dict = OmegaConf.to_container(cfg, resolve=True)

常见问题(FAQ)

  • 命令行覆盖如何与文件/DeepSpeed 配置交互?
    • 合并顺序:基础配置文件 < DeepSpeed 配置(若存在) < 命令行参数;后者覆盖前者。
  • 如何只修改深层嵌套字段?
    • 使用 cfg.update({"a.b.c": value}) 或直接 cfg.a.b.c = value
  • 为什么修改没有生效?
    • 确认修改在 load_cfg 之后、构建组件之前;并检查是否被后续合并或 CLI 覆盖。
  • 如何检测某个可选字段是否存在?
    • 使用 "field" in cfg.sectioncfg.section.get("field", default)
  • 如何保存最终生效配置以便复现实验?
    • 使用 OmegaConf.save(cfg, path) 保存合并后的最终配置。