跳到主要内容

训练循环(Loops)

TrainLoop/ValidLoop/TestLoop 专注迭代逻辑,Runner 负责装配与调度。

BaseLoop(公共能力)

wall_e/runner/loop/base_loop.py (L1–29)
class BaseLoop:
def __init__(self, runner, dataloader, shuffle=True): ...

注意:

  • 统一持有 runnerdataloader 引用;shuffle 由具体 loop 决定是否使用。

TrainLoop(训练)

wall_e/runner/loop/train_loop.py (L24–66)
class TrainLoop(BaseLoop):
def __init__(..., max_epochs: int, valid_begin_epoch, valid_interval_epoch,
valid_begin_iter, valid_interval_iter,
test_begin_epoch, test_interval_epoch,
test_begin_iter, test_interval_iter, shuffle=True): ...

要点:

  • 计算 _max_epochs/_max_iters/_iters_per_epoch,维护 _epoch/_iterrunner.state
  • 支持 fp16(autocast+GradScaler)、梯度累积、梯度裁剪与调度器步进。
  • 迭代过程中按 epoch 与 iter 双重条件触发验证/测试。

核心流程:

wall_e/runner/loop/train_loop.py (L112–171)
def run(self) -> torch.nn.Module: ...  # 循环 epoch -> run_epoch(); 期间按配置触发 valid/test
def run_epoch(self): ... # 遍历 dataloader,调用 run_iter

单步训练与反向:

wall_e/runner/loop/train_loop.py (L173–278)
def run_iter(self, idx, data_batch):
output = model.train_step(data_batch); assert "loss" in output; backward(...); register_model_output(...)

def backward(self, scaler, loss):
# deepspeed: model.backward/step; 否则:混合精度缩放、梯度累积、可选 clip_grad、scheduler.step

注意:

  • register_model_output 会将所有包含 "loss" 的标量注册到 registry,用于日志与可选 sweeps。
  • 分布式场景下,NaN 检测结果会在各进程间同步决定是否跳过该 batch。

ValidLoop/TestLoop

wall_e/runner/loop/valid_loop.py (L1–74)
class ValidLoop(BaseLoop):
def run(self) -> dict: ...

```python title="wall_e/runner/loop/test_loop.py (L1–71)"
class TestLoop(BaseLoop):
def run(self) -> dict: ...

要点:

  • Runner.setup_loop 中由 Runner 创建并注入到 TrainLoop 的调度流程中,或由用户显式调用 runner.valid()/runner.test()
  • evaluator(若提供)会在 Runner.setup_evaluator 时注入 RunnerState