配置文件
5 分钟 等级: 入门
快速开始
创建一个配置文件 cfg.yaml:
# 实验信息
run_name: 'your task' # 实验名称
run_description: "your custom tasks" # 实验描述
# 核心组件(需要你实现具体 type)
dataset:
type: "" # 数据集类型
model:
type: "" # 模型类型
metric:
type: "" # 评估指标类型
# 优化器,默认AdamW
optimizer:
type: 'AdamW'
lr: 3e-4 # 学习率
weight_decay: 0.01 # 权重衰减
# 学习率调度
scheduler:
type: "LinearWarmupCosineLRScheduler" # 调度器类型
min_lr: 3e-5 # 最小学习率
max_lr: 3e-4 # 最大学习率
warmup_rate: 0.05 # 预热比例(0-1)
warmup_start_lr: 1e-5 # 预热起始学习率
# 训练参数
training:
epochs: 20 # 训练轮数
ds_config: 'path/to/deepspeed.json' # DeepSpeed 配置(可选),若提供将并入配置
gradient_accumulation: 1 # 梯度累积步数
resume_from: null # 从某检查点继续训练
load_from: null # 预加载权重路径
grad_clip: null # 梯度裁剪阈值(null 表示不启用)
fp16: false # 是否使用半精度
print_model: false # 是否打印模型结构
# 训练/验证/测试进度展示与频率
progress_show:
loss: false # 训练进度中是否展示 loss 曲线/进度(false表示loss不是越大越好)
# 校验/测试触发频率(按 epoch 与 iter 两种粒度)
valid_begin_epoch: 2
valid_interval_epoch: 1
valid_begin_iter: 10
valid_interval_iter: 5
test_begin_epoch: 3
test_interval_epoch: 1
test_begin_iter: 60
test_interval_iter: 5
progress_every_n_epochs: 1 # 进度汇报间隔(按 epoch)
progress_every_n_batches: 1 # 进度汇报间隔(按 iter)
# 日志
log:
to_file: true # 是否写入文件
folder: "./assert/logs" # 日志目录
level: "INFO" # 控制台日志级别
rank_level: "WARNING" # 分布式时的非主进程日志级别
# 检查点
pt:
enable: true # 是否启用保存
dir: "./assert/checkpoints" # 保存目录
best_monitor:
loss: false # 以最低 loss 作为选择标准
topk: 5 # 保留前 K 个最佳模型
begin_epoch: 2 # 从第几轮开始按 epoch 保存
epoch_interval: 2 # 每隔多少个 epoch 保存一次
begin_batch: 10 # 从第几个 iter 开始按 batch 保存
batch_interval: 10 # 每隔多少个 iter 保存一次
# Wandb
wandb:
enable: true # 是否启用 Wandb
proj_name: "" # 项目名称
offline: true # 是否离线模式
dir: "./assert" # 日志/工件目录
tags: ["", ""] # 标签
在train.py代码中加载配置:
from wall_e.config.load_config import load_cfg
# 加载配置
cfg = load_cfg('./cfg.yaml')
# 使用配置
lr = cfg.optimizer.lr
epochs = cfg.training.epochs
print(f"学习率: {lr}, 训练轮数: {epochs}")
命令行覆盖
你可以用命令行参数覆盖任何配置:
# 修改学习率
python train.py optimizer.lr=1e-4
# 修改训练轮数
python train.py training.epochs=100
# 同时修改多个参数
python train.py training.epochs=50 optimizer.lr=2e-4 training.fp16=true
# 修改嵌套配置
python train.py wandb.enable=true wandb.proj_name=my_experiment
配置优先级
配置按以下顺序合并(后面的会覆盖前面的):
- 基础配置文件
- DeepSpeed 配置
- 命令行参数
配置操作(增删改查与保存)
配置的操作与 OmegaConf 的 DictConfig API 保持一致。以下示例涵盖常用的读取、修改、添加、删除和保存,更多能力请查阅OmegaConf文档。
from omegaconf import OmegaConf
# 读取(含默认值与存在性判断)
lr = cfg.optimizer.lr
epochs = cfg.training.epochs
batch_size = cfg.training.get("batch_size", 32)
if "wandb" in cfg:
print("Wandb 已启用")
# 修改(单项与批量)
cfg.optimizer.lr = 1e-3
cfg.training.epochs = 100
cfg.update({
"optimizer.lr": 1e-3,
"training.epochs": 100,
"model.type": "transformer",
})
# 添加(顶层与嵌套)
cfg.new_section = {"param1": "value1", "param2": "value2"}
cfg.model.new_param = "new_value"
# 删除(直接删除与安全删除)
del cfg.training.grad_clip
cfg.pop("training.grad_clip", None)
# 保存与导出
OmegaConf.save(cfg, "output_config.yaml")
config_dict = OmegaConf.to_container(cfg, resolve=True)
常见问题(FAQ)
- 命令行覆盖如何与文件/DeepSpeed 配置交互?
- 合并顺序:基础配置文件 < DeepSpeed 配置(若存在) < 命令行参数;后者覆盖前者。
- 如何只修改深层嵌套字段?
- 使用
cfg.update({"a.b.c": value})或直接cfg.a.b.c = value。
- 使用
- 为什么修改没有生效?
- 确认修改在
load_cfg之后、构建组件之前;并检查是否被后续合并或 CLI 覆盖。
- 确认修改在
- 如何检测某个可选字段是否存在?
- 使用
"field" in cfg.section或cfg.section.get("field", default)。
- 使用
- 如何保存最终生效配置以便复现实验?
- 使用
OmegaConf.save(cfg, path)保存合并后的最终配置。
- 使用