Skip to main content

回调(Callbacks)

回调由 Runner.setup_callbacks 统一注册,可扩展,覆盖训练的关键生命周期。

wall_e/runner/runner.py (L382–425)
def setup_callbacks(self):
epoch_summary_callback = EpochSummaryCallBack(self)
progress_callback = ProcessCallBack(self)
wandb_callback = WandbCallback(self) if cfg.wandb.enable else None
checkpoint_callback = CheckpointCallback(self) if cfg.pt.enable or resume_from else None
self.register_callbacks([...])

要点:

  • EpochSummaryCallBack:汇总包含 "loss" 的指标与 evaluator 输出(也会注册到 wandb)。
  • ProcessCallBack:训练进度与状态输出。
  • WandbCallback:可选;未安装时会报错提示安装 pip install wandb
  • CheckpointCallback:当 pt.enable=True 或存在恢复点时启用。

扩展:

  • 自定义回调需继承 BaseCallBack,并由 RunnerBase.register_callbacks 校验与注册。

生命周期(BaseCallBack)

wall_e/callback/base_callback.py (L1–80)
class BaseCallBack:
def on_register(self): ...
def on_exception(self, *args, **kwargs): ...
def before_fit(self); def after_fit(self)
def before_train(self); def after_train(self)
def before_running_epoch(self); def after_running_epoch(self)
def before_running_batch(self); def after_running_batch(self)
def before_valid(self); def after_valid(self)
def before_test(self); def after_test(self)

说明:

  • 这些钩子由 RunnerBase 在相应时机广播调用。
  • on_register 在注册时触发,适合做一次性检查或资源准备。

功能型回调

进度与监控(ProcessCallBack)

wall_e/callback/progress_callback.py (L11–61)
class ProcessCallBack(BaseCallBack):
# 从 cfg.training.progress_* 读取显示频率;
# 通过 registry.get("metric.*") 汇总并打印进度/速度/剩余时间与监控指标。

注意:

  • 依赖 TrainLoop.register_model_output 将包含 "loss" 的指标注册至 metric.*
  • 支持平均/最佳值打印与自定义监控项。

轮次总结(EpochSummaryCallBack)

wall_e/callback/epoch_summary_callback.py (L8–38)
class EpochSummaryCallBack(BaseCallBack):
# 记录单轮起止时间,结束时输出耗时与 registry.get('metric') 概览。

检查点(CheckpointCallback)

wall_e/callback/checkpoint_callback.py (L16–176, L179–246)
class CheckpointCallback(BaseCallBack):
# initial/final/按 epoch 或 batch 周期保存;异常时额外保存;
# 支持 TopK 与 best 模型维护;DeepSpeed/标准 PyTorch 两套保存/加载实现。

注意:

  • 仅主进程执行标准保存;DeepSpeed 由 deepspeed.save_checkpoint 管理。
  • pt.begin_*pt.*_interval 控制保存频率,pt.best_monitor/topk 控制 Best/TopK 策略。

Wandb(WandbCallback)

wall_e/callback/wandb_callback.py (L8–31)
class WandbCallback(BaseCallBack):
# 主进程初始化 wandb;每个 batch 将 registry.get("metric") 上报 wandb。

注意:

  • 仅在 cfg.wandb.enable=True 且安装了 wandb 时使用;
  • mode 支持 offline/online;可设置 proj_name/run_name/dir/tags 等。