回调(Callbacks)
回调由 Runner.setup_callbacks 统一注册,可扩展,覆盖训练的关键生命周期。
wall_e/runner/runner.py (L382–425)
def setup_callbacks(self):
epoch_summary_callback = EpochSummaryCallBack(self)
progress_callback = ProcessCallBack(self)
wandb_callback = WandbCallback(self) if cfg.wandb.enable else None
checkpoint_callback = CheckpointCallback(self) if cfg.pt.enable or resume_from else None
self.register_callbacks([...])
要点:
EpochSummaryCallBack:汇总包含 "loss" 的指标与 evaluator 输出(也会注册到 wandb)。ProcessCallBack:训练进度与状态输出。WandbCallback:可选;未安装时会报错提示安装pip install wandb。CheckpointCallback:当pt.enable=True或存在恢复点时启用。
扩展:
- 自定义回调需继承
BaseCallBack,并由RunnerBase.register_callbacks校验与注册。
生命周期(BaseCallBack)
wall_e/callback/base_callback.py (L1–80)
class BaseCallBack:
def on_register(self): ...
def on_exception(self, *args, **kwargs): ...
def before_fit(self); def after_fit(self)
def before_train(self); def after_train(self)
def before_running_epoch(self); def after_running_epoch(self)
def before_running_batch(self); def after_running_batch(self)
def before_valid(self); def after_valid(self)
def before_test(self); def after_test(self)
说明:
- 这些钩子由
RunnerBase在相应时机广播调用。 on_register在注册时触发,适合做一次性检查或资源准备。
功能型回调
进度与监控(ProcessCallBack)
wall_e/callback/progress_callback.py (L11–61)
class ProcessCallBack(BaseCallBack):
# 从 cfg.training.progress_* 读取显示频率;
# 通过 registry.get("metric.*") 汇总并打印进度/速度/剩余时间与监控指标。
注意:
- 依赖
TrainLoop.register_model_output将包含 "loss" 的指标注册至metric.*。 - 支持平均/最佳值打印与自定义监控项。
轮次总结(EpochSummaryCallBack)
wall_e/callback/epoch_summary_callback.py (L8–38)
class EpochSummaryCallBack(BaseCallBack):
# 记录单轮起止时间,结束时输出耗时与 registry.get('metric') 概览。
检查点(CheckpointCallback)
wall_e/callback/checkpoint_callback.py (L16–176, L179–246)
class CheckpointCallback(BaseCallBack):
# initial/final/按 epoch 或 batch 周期保存;异常时额外保存;
# 支持 TopK 与 best 模型维护;DeepSpeed/标准 PyTorch 两套保存/加载实现。
注意:
- 仅主进程执行标准保存;DeepSpeed 由
deepspeed.save_checkpoint管理。 pt.begin_*与pt.*_interval控制保存频率,pt.best_monitor/topk控制 Best/TopK 策略。
Wandb(WandbCallback)
wall_e/callback/wandb_callback.py (L8–31)
class WandbCallback(BaseCallBack):
# 主进程初始化 wandb;每个 batch 将 registry.get("metric") 上报 wandb。
注意:
- 仅在
cfg.wandb.enable=True且安装了wandb时使用; mode支持offline/online;可设置proj_name/run_name/dir/tags等。