Skip to main content

模型

8 分钟 等级: 入门

概览(更多细节见 API → BaseModel

WALL_E 的模型需要继承 BaseModel,并实现训练/验证/测试三个步骤接口,以便与 TrainLoop/ValidLoop/TestLoop 无缝对接。本文档介绍:

  • 基类 BaseModel 的核心接口与工具方法
  • 与 Loop 的对接规范(必须的返回与行为)
  • 通过配置 from_cfg 构建模型的推荐方式

必须实现的接口

模型需继承 BaseModel 并实现以下抽象方法:

@abstractmethod
def train_step(self, *args: Any, **kwargs: Any) -> dict:
pass

@abstractmethod
def valid_step(self, *args: Any, **kwargs: Any) -> dict:
pass

@abstractmethod
def test_step(self, *args: Any, **kwargs: Any) -> dict:
pass

@abstractmethod
def compute_loss(self, *args: Any, **kwargs: Any) -> dict:
pass
  • train_step 返回要求:必须返回包含键 loss 的字典,且 losstorch.Tensor
    • 训练循环会从该字典中读取所有包含 "loss" 的键,做指标登记与上报。
  • valid_step/test_step 返回约定:返回可被 Evaluator.process(data_samples, data_batch) 接受的结构(例如张量、字典或自定义对象列表),具体结构由你的 Evaluator 决定。

与 Loop 的对接规范

训练、验证、测试循环的关键行为如下:

def run_iter(self, idx, data_batch: dict[str, Sequence]) -> None:
self.runner.before_running_batch()
if hasattr(self.runner.model, "module"):
model_output = self.runner.model.module.train_step(data_batch)
else:
model_output = self.runner.model.train_step(data_batch)
assert "loss" in model_output, "模型输出必须返回包含loss的字典"

skip_batch = torch.isnan(model_output["loss"]).any().item()
# 分布式场景会同步跳过决策
if skip_batch:
self.runner.logger.error("该批次检测到NaN,所有进程跳过该batch")
else:
self.backward(self.scaler, model_output["loss"]) # 梯度缩放/累积/裁剪/step
self.register_model_output(model_output) # 登记含"loss"键的指标
self.runner.after_running_batch()
  • DDP/DP 兼容:如使用并行封装,Loop 会通过 model.module.*_step 调用。
  • 混合精度:当 training.fp16=true 时自动启用 autocast 与 GradScaler。
  • 梯度累积:受 training.gradient_accumulation 控制,loss 会按步数平均。
  • 梯度裁剪:若 training.grad_clip 非空则在未缩放梯度空间执行裁剪。
  • 学习率调度:若非 DeepSpeed,Loop 会按 scheduler 配置步进。
  • NaN 处理:若 loss 含 NaN,该 batch 会被全部进程跳过并记录日志。

验证/测试循环会将 valid_step/test_step 的输出喂入评估器:

with autocast(enabled = self.fp16):
if hasattr(self.runner.model, "module"):
outputs = self.runner.model.module.valid_step(data_batch)
else:
outputs = self.runner.model.valid_step(data_batch)
if self.evaluator is not None:
self.evaluator.process(data_samples = outputs, data_batch = data_batch)
with autocast(enabled=self.fp16):
if hasattr(self.runner.model, "module"):
outputs = self.runner.model.module.test_step(data_batch)
else:
outputs = self.runner.model.test_step(data_batch)
if self.evaluator is not None:
self.evaluator.process(data_samples=outputs, data_batch=data_batch)

触发时机(可按 epoch 与 iter 两种粒度):参考 training.valid_*training.test_* 配置字段(详见《配置文件》)。

通过配置构建模型(from_cfg)

BaseModel.from_cfg 支持从 OmegaConf/argparse/dict 等配置构建实例,并自动:

  • 从注册表解析 type 找到模型类
  • 过滤仅匹配构造函数签名的参数
  • 可选加载预训练权重 pretrained
  • 选择设备 device 并将模型迁移至该设备
@classmethod
def from_cfg(cls: Type[T], cfg: Any) -> T:
# 解析配置、解析 type -> model_cls、过滤构造参数
# 可选加载 pretrained、分配 device 并 to(device)
return model

示例:

from wall_e.model.base_model import BaseModel

cfg = {
"type": "YourRegisteredModel", # 在 registry 中注册的模型名
"hidden_size": 512,
"num_layers": 6,
"pretrained": "./checkpoints/epoch-1.pt", # 可选
"device": "cuda", # 可选
}

model = BaseModel.from_cfg(cfg)
print(model.device, model.num_parameters)

关于注册机制,请确保你的模型类已在 registry 中注册为可发现的 model_class

预训练权重与状态恢复

def load_checkpoint(self, cached_path: str) -> List[str]:
checkpoint = torch.load(cached_path, map_location="cpu")
state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
msg = self.load_state_dict(state_dict, strict=False)
return msg.missing_keys
  • 返回缺失键列表,默认为非严格加载。若你不允许缺失键,可在 from_cfg 后自行断言。
  • 训练过程中使用恢复训练请参考 TrainLoop.resume() 与检查点回调配置。

实用工具方法

  • 参数冻结/解冻
def freeze_parameters(self, freeze: bool = True, regex: Optional[str] = None) -> "BaseModel":
for name, param in self.named_parameters():
if re.fullmatch(regex, name) if regex else True:
param.requires_grad = not freeze
return self
  • 计算图可视化(需安装 torchviz):
def visualize_architecture(self, input: dict, save_path: str = "model_graph.png") -> str:
from torchviz import make_dot
output = self(**input)
graph = make_dot(output, params = dict(self.named_parameters()))
graph.render(save_path, format = 'png', cleanup = True)
return save_path
  • 参数分布与异常检测(基于 matplotlib 与 3σ 原则):
def plot_parameter_histogram(self, param_name: str, bins: int = 50) -> Figure:
# 直方图可视化

def detect_parameter_outliers(self, sigma: float = 3) -> List[Dict[str, Any]]:
# 基于均值/方差的异常值统计
  • 便捷属性devicedtypenum_parametersnum_trainable_parameters

最小可用样例

import torch
from wall_e.model.base_model import BaseModel

class ToyClassifier(BaseModel):
def __init__(self, in_dim: int = 10, num_classes: int = 2):
super().__init__()
self.net = torch.nn.Linear(in_dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()

def forward(self, x):
return self.net(x)

def compute_loss(self, logits, labels):
return {"loss": self.criterion(logits, labels)}

def train_step(self, batch):
logits = self.forward(batch["x"]) # batch: {"x": Tensor, "y": LongTensor}
return self.compute_loss(logits, batch["y"]) # 必须含 "loss"

def valid_step(self, batch):
with torch.no_grad():
return {"logits": self.forward(batch["x"]), "labels": batch["y"]}

def test_step(self, batch):
with torch.no_grad():
return {"logits": self.forward(batch["x"]) }

# from_cfg 构建(如已在 registry 注册可用 type 指定)
model = ToyClassifier()

与配置的关系与常见问题

  • Loop 触发频率:由 training.valid_*training.test_* 字段控制;详见《配置文件》文档中的“训练参数”。
  • 混合精度/梯度裁剪/累积:分别由 training.fp16training.grad_cliptraining.gradient_accumulation 控制。
  • DeepSpeed:启用后调度器需在 DeepSpeed 配置中指定;非 DS 模式下由框架根据 scheduler 配置自动步进。
  • Evaluator 对接:自定义 Evaluator.process/evaluate 以匹配 valid_step/test_step 输出结构。

进一步阅读

  • 《配置文件》:详见 docs/tutorials/配置文件.md
  • 训练循环实现:wall_e/runner/loop/train_loop.py
  • 验证/测试循环实现:wall_e/runner/loop/valid_loop.pywall_e/runner/loop/test_loop.py
  • 模型基类:wall_e/model/base_model.py