训练循环(Loops)
TrainLoop/ValidLoop/TestLoop 专注迭代逻辑,Runner 负责装配与调度。
BaseLoop(公共能力)
wall_e/runner/loop/base_loop.py (L1–29)
class BaseLoop:
def __init__(self, runner, dataloader, shuffle=True): ...
注意:
- 统一持有
runner与dataloader引用;shuffle由具体 loop 决定是否使用。
TrainLoop(训练)
wall_e/runner/loop/train_loop.py (L24–66)
class TrainLoop(BaseLoop):
def __init__(..., max_epochs: int, valid_begin_epoch, valid_interval_epoch,
valid_begin_iter, valid_interval_iter,
test_begin_epoch, test_interval_epoch,
test_begin_iter, test_interval_iter, shuffle=True): ...
要点:
- 计算
_max_epochs/_max_iters/_iters_per_epoch,维护_epoch/_iter与runner.state。 - 支持 fp16(autocast+GradScaler)、梯度累积、梯度裁剪与调度器步进。
- 迭代过程中按 epoch 与 iter 双重条件触发验证/测试。
核心流程:
wall_e/runner/loop/train_loop.py (L112–171)
def run(self) -> torch.nn.Module: ... # 循环 epoch -> run_epoch(); 期间按配置触发 valid/test
def run_epoch(self): ... # 遍历 dataloader,调用 run_iter
单步训练与反向:
wall_e/runner/loop/train_loop.py (L173–278)
def run_iter(self, idx, data_batch):
output = model.train_step(data_batch); assert "loss" in output; backward(...); register_model_output(...)
def backward(self, scaler, loss):
# deepspeed: model.backward/step; 否则:混合精度缩放、梯度累积、可选 clip_grad、scheduler.step
注意:
register_model_output会将所有包含 "loss" 的标量注册到registry,用于日志与可选 sweeps。- 分布式场景下,NaN 检测结果会在各进程间同步决定是否跳过该 batch。
ValidLoop/TestLoop
wall_e/runner/loop/valid_loop.py (L1–74)
class ValidLoop(BaseLoop):
def run(self) -> dict: ...
```python title="wall_e/runner/loop/test_loop.py (L1–71)"
class TestLoop(BaseLoop):
def run(self) -> dict: ...
要点:
- 在
Runner.setup_loop中由Runner创建并注入到TrainLoop的调度流程中,或由用户显式调用runner.valid()/runner.test()。 evaluator(若提供)会在Runner.setup_evaluator时注入RunnerState。