概述
12 分钟 等级: 入门
WALL_E 瓦力是一个基于 PyTorch、Datasets 与 OmegaConf 构建的轻量级通用深度学习框架,旨在简化模型的开发、训练与评估。框架采用模块化设计:BaseModel、BaseDataset、load_cfg 可独立使用;与 Runner/Loops、分布式与回调/评估组合时,提供完整训练编排。
系统架构
该框架采用模块化设计模式,将关注点分离为独立组件,实现灵活组合和易于维护。其核心是 Runner,通过基于回调的事件管理协调整个训练生命周期。
代码结构(API 链接)
- Runner(
wall_e.runner.runner.Runner): 训练执行引擎,负责设备/分布式初始化、日志、优化器、回调与评估装配,统一调度TrainLoop/ValidLoop/TestLoop等。(见 API →api/runner) - Loops(
wall_e.runner.loop.*):TrainLoop: 管理 epoch/iter 级训练,支持混合精度、梯度累积、验证/测试的时间调度、学习率调度。(见 API →api/loops)ValidLoop/TestLoop: 独立的验证与测试流程。
- Callbacks(
wall_e.callback.*):(见 API → Supporting → 回调)EpochSummaryCallBack: 汇总 loss 与指标并上报(含 WandB,当开启时)。ProcessCallBack: 训练进度与关键状态打印。CheckpointCallback: 保存/恢复训练状态与权重。WandbCallback: Weights & Biases 集成(可选)。
- Dataset(
wall_e.dataset.*): 提供BaseMapDataset、BaseIterableDataset基类,抽象统一数据接口与切分策略。(见 API → Dataset) - Evaluator & Metric(
wall_e.eval.*): 统一注册与汇报评估指标,支持结果转储。(见 API → Eval) - Config(
wall_e.config.*): 使用 OmegaConf 读取 YAML,集中管理训练、日志、分布式、回调、检查点等参数。(见 API →api/config) - Logging(
wall_e.logging.*): 控制台/文件日志,主从进程日志级别分离。 - Distributed(
wall_e.dist.*): 自动选择单机/多卡 DDP/DeepSpeed 并包装模型与优化器。 - Scheduler(
wall_e.scheduler): 通过注册表配置学习率调度器。 - Tuner(
wall_e.tunner.Tuner): 配合 Ray 等进行参数搜索(可选)。
运行流程概览
- 读取 YAML,构建/注册
dataset、model、evaluator等组件。 - 初始化
Runner:完成设备/分布式、模型与优化器、循环(Loop)、回调与评估器装配。 - 执行
runner.fit():进入训练主循环;按配置在 epoch/iter 级触发验证与测试;自动记录日志与保存检查点。 - 需要时可调用
runner.valid()与runner.test()进行独立评估。
扩展点
- 自定义训练逻辑:继承
TrainLoop/ValidLoop/TestLoop。 - 自定义回调:实现
BaseCallBack并通过Runner.extend_callbacks()注册。 - 自定义指标与评估:实现
BaseMetric并在Evaluator中装配。 - 自定义数据集/模型:继承相应基类并在配置中声明
type与params。