跳到主要内容

Runner

Runner 负责训练编排与系统集成;与 BaseModel/Dataset 通过最小接口耦合。

基类(回调生命周期)

wall_e/runner/base_runner.py (L4–86)
class RunnerBase:
def register_callbacks(self, callbacks): ...
def on_exception(self, *args, **kwargs): ...
def before_fit(self): ...; def after_fit(self): ...
def before_train(self): ...; def after_train(self): ...
def before_running_batch(self): ...; def after_running_batch(self): ...
def before_running_epoch(self): ...; def after_running_epoch(self): ...
def before_valid(self): ...; def after_valid(self): ...
def before_test(self): ...; def after_test(self): ...

要点:

  • 所有回调均继承 BaseCallBack;注册时会校验类型并触发 on_register()
  • 训练/验证/测试与批次/轮次的生命周期钩子均由基类统一分发。

Runner 构造

wall_e/runner/runner.py (L55–116)
class Runner(RunnerBase):
def __init__(self, model, epochs=None,
train_data_loader=None, valid_data_loader=None, test_data_loader=None,
train_loop=None, valid_loop=None, test_loop=None,
train_evaluator=None, valid_evaluator=None, test_evaluator=None,
optimizer=None, cfg=None, *args, **kwargs):
...

要点:

  • model:实现 train_step/valid_step/test_stepBaseModel 子类实例。
  • *_data_loader:可选的三种阶段数据加载器;未提供则对应阶段不可用。
  • *_loop:可注入自定义循环;默认自动装配。
  • *_evaluator:可选评估器,装配时注入 RunnerState
  • optimizer 可为 None(DeepSpeed 模式下初始化)。
  • cfg 采用 OmegaConf 合并后的配置。

关键方法

wall_e/runner/runner.py (L156–199)
def fit(self): ...  # before_fit(); train(); 异常捕获 on_exception(); finally after_fit()
def train(self) -> nn.Module: ...
def valid(self) -> dict: ...
def test(self) -> dict: ...

注意:

  • train() 要求存在 train_loopvalid()/test() 同理。
  • fit() 包含统一异常记录与回调触发,异常后会再次抛出。

循环装配

wall_e/runner/runner.py (L228–269)
def setup_loop(self):
# 未显式提供时:根据 cfg.training 的频率参数创建默认 Train/Valid/TestLoop
...

注意:

  • 训练频率既支持按 epoch,也支持按 iter;两套触发条件可同时工作。
  • TrainLoop 会创建并回填 self.scheduler

日志与分布式/DeepSpeed

wall_e/runner/runner.py (L270–371)
def setup_logger(self): ...
def setup_launch_strategy(self): ... # 单机/分布式/CPU 设备与 is_main_process
def setup_model_optimizer(self): ... # 检查点重载、激活检查点、DDP/DeepSpeed 包装、默认 AdamW

注意:

  • 分布式时使用 DistributedDataParallel;DeepSpeed 时通过 deepspeed.initialize 管理模型与优化器。
  • wrap_dataloader 在分布式时会注入 DistributedSampler

回调与评估器

wall_e/runner/runner.py (L382–447)
def setup_callbacks(self): ...  # 进度、W&B(可选)、检查点、轮次总结;可扩展
def setup_evaluator(self): ... # 将 RunnerState 下发给各阶段 evaluator

注意:

  • wandb 可选;未安装时会报错提示安装。
  • pt.enable 或存在恢复点时启用 CheckpointCallback