跳到主要内容

概述

12 分钟 等级: 入门

WALL_E 瓦力是一个基于 PyTorch、Datasets 与 OmegaConf 构建的轻量级通用深度学习框架,旨在简化模型的开发、训练与评估。框架采用模块化设计:BaseModelBaseDatasetload_cfg 可独立使用;与 Runner/Loops、分布式与回调/评估组合时,提供完整训练编排。

系统架构

该框架采用模块化设计模式,将关注点分离为独立组件,实现灵活组合和易于维护。其核心是 Runner,通过基于回调的事件管理协调整个训练生命周期。

代码结构(API 链接)

  • Runner(wall_e.runner.runner.Runner: 训练执行引擎,负责设备/分布式初始化、日志、优化器、回调与评估装配,统一调度 TrainLoop/ValidLoop/TestLoop等。(见 API → api/runner
  • Loops(wall_e.runner.loop.*:
    • TrainLoop: 管理 epoch/iter 级训练,支持混合精度、梯度累积、验证/测试的时间调度、学习率调度。(见 API → api/loops
    • ValidLoop/TestLoop: 独立的验证与测试流程。
  • Callbacks(wall_e.callback.*:(见 API → Supporting → 回调)
    • EpochSummaryCallBack: 汇总 loss 与指标并上报(含 WandB,当开启时)。
    • ProcessCallBack: 训练进度与关键状态打印。
    • CheckpointCallback: 保存/恢复训练状态与权重。
    • WandbCallback: Weights & Biases 集成(可选)。
  • Dataset(wall_e.dataset.*: 提供 BaseMapDatasetBaseIterableDataset 基类,抽象统一数据接口与切分策略。(见 API → Dataset)
  • Evaluator & Metric(wall_e.eval.*: 统一注册与汇报评估指标,支持结果转储。(见 API → Eval)
  • Config(wall_e.config.*: 使用 OmegaConf 读取 YAML,集中管理训练、日志、分布式、回调、检查点等参数。(见 API → api/config
  • Logging(wall_e.logging.*: 控制台/文件日志,主从进程日志级别分离。
  • Distributed(wall_e.dist.*: 自动选择单机/多卡 DDP/DeepSpeed 并包装模型与优化器。
  • Scheduler(wall_e.scheduler: 通过注册表配置学习率调度器。
  • Tuner(wall_e.tunner.Tuner: 配合 Ray 等进行参数搜索(可选)。

运行流程概览

  1. 读取 YAML,构建/注册 datasetmodelevaluator 等组件。
  2. 初始化 Runner:完成设备/分布式、模型与优化器、循环(Loop)、回调与评估器装配。
  3. 执行 runner.fit():进入训练主循环;按配置在 epoch/iter 级触发验证与测试;自动记录日志与保存检查点。
  4. 需要时可调用 runner.valid()runner.test() 进行独立评估。

扩展点

  • 自定义训练逻辑:继承 TrainLoop/ValidLoop/TestLoop
  • 自定义回调:实现 BaseCallBack 并通过 Runner.extend_callbacks() 注册。
  • 自定义指标与评估:实现 BaseMetric 并在 Evaluator 中装配。
  • 自定义数据集/模型:继承相应基类并在配置中声明 typeparams

下一步