模型
8 分钟 等级: 入门
概览(更多细节见 API → BaseModel)
WALL_E 的模型需要继承 BaseModel,并实现训练/验证/测试三个步骤接口,以便与 TrainLoop/ValidLoop/TestLoop 无缝对接。本文档介绍:
- 基类
BaseModel的核心接口与工具方法 - 与 Loop 的对接规范(必须的返回与行为)
- 通过配置
from_cfg构建模型的推荐方式
必须实现的接口
模型需继承 BaseModel 并实现以下抽象方法:
@abstractmethod
def train_step(self, *args: Any, **kwargs: Any) -> dict:
pass
@abstractmethod
def valid_step(self, *args: Any, **kwargs: Any) -> dict:
pass
@abstractmethod
def test_step(self, *args: Any, **kwargs: Any) -> dict:
pass
@abstractmethod
def compute_loss(self, *args: Any, **kwargs: Any) -> dict:
pass
- train_step 返回要求:必须返回包含键
loss的字典,且loss为torch.Tensor。- 训练循环会从该字典中读取所有包含 "loss" 的键,做指标登记与上报。
- valid_step/test_step 返回约定:返回可被
Evaluator.process(data_samples, data_batch)接受的结构(例如张量、字典或自定义对象列表),具体结构由你的Evaluator决定。
与 Loop 的对接规范
训练、验证、测试循环的关键行为如下:
def run_iter(self, idx, data_batch: dict[str, Sequence]) -> None:
self.runner.before_running_batch()
if hasattr(self.runner.model, "module"):
model_output = self.runner.model.module.train_step(data_batch)
else:
model_output = self.runner.model.train_step(data_batch)
assert "loss" in model_output, "模型输出必须返回包含loss的字典"
skip_batch = torch.isnan(model_output["loss"]).any().item()
# 分布式场景会同步跳过决策
if skip_batch:
self.runner.logger.error("该批次检测到NaN,所有进程跳过该batch")
else:
self.backward(self.scaler, model_output["loss"]) # 梯度缩放/累积/裁剪/step
self.register_model_output(model_output) # 登记含"loss"键的指标
self.runner.after_running_batch()
- DDP/DP 兼容:如使用并行封装,Loop 会通过
model.module.*_step调用。 - 混合精度:当
training.fp16=true时自动启用 autocast 与 GradScaler。 - 梯度累积:受
training.gradient_accumulation控制,loss会按步数平均。 - 梯度裁剪:若
training.grad_clip非空则在未缩放梯度空间执行裁剪。 - 学习率调度:若非 DeepSpeed,Loop 会按
scheduler配置步进。 - NaN 处理:若
loss含 NaN,该 batch 会被全部进程跳过并记录日志。
验证/测试循环会将 valid_step/test_step 的输出喂入评估器:
with autocast(enabled = self.fp16):
if hasattr(self.runner.model, "module"):
outputs = self.runner.model.module.valid_step(data_batch)
else:
outputs = self.runner.model.valid_step(data_batch)
if self.evaluator is not None:
self.evaluator.process(data_samples = outputs, data_batch = data_batch)
with autocast(enabled=self.fp16):
if hasattr(self.runner.model, "module"):
outputs = self.runner.model.module.test_step(data_batch)
else:
outputs = self.runner.model.test_step(data_batch)
if self.evaluator is not None:
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
触发时机(可按 epoch 与 iter 两种粒度):参考 training.valid_* 与 training.test_* 配置字段(详见《配置文件》)。
通过配置构建模型(from_cfg)
BaseModel.from_cfg 支持从 OmegaConf/argparse/dict 等配置构建实例,并自动:
- 从注册表解析
type找到模型类 - 过滤仅匹配构造函数签名的参数
- 可选加载预训练权重
pretrained - 选择设备
device并将模型迁移至该设备
@classmethod
def from_cfg(cls: Type[T], cfg: Any) -> T:
# 解析配置、解析 type -> model_cls、过滤构造参数
# 可选加载 pretrained、分配 device 并 to(device)
return model
示例:
from wall_e.model.base_model import BaseModel
cfg = {
"type": "YourRegisteredModel", # 在 registry 中注册的模型名
"hidden_size": 512,
"num_layers": 6,
"pretrained": "./checkpoints/epoch-1.pt", # 可选
"device": "cuda", # 可选
}
model = BaseModel.from_cfg(cfg)
print(model.device, model.num_parameters)
关于注册机制,请确保你的模型类已在 registry 中注册为可发现的 model_class。
预训练权重与状态恢复
def load_checkpoint(self, cached_path: str) -> List[str]:
checkpoint = torch.load(cached_path, map_location="cpu")
state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
msg = self.load_state_dict(state_dict, strict=False)
return msg.missing_keys
- 返回缺失键列表,默认为非严格加载。若你不允许缺失键,可在
from_cfg后自行断言。 - 训练过程中使用恢复训练请参考
TrainLoop.resume()与检查点回调配置。
实用工具方法
- 参数冻结/解冻:
def freeze_parameters(self, freeze: bool = True, regex: Optional[str] = None) -> "BaseModel":
for name, param in self.named_parameters():
if re.fullmatch(regex, name) if regex else True:
param.requires_grad = not freeze
return self
- 计算图可视化(需安装
torchviz):
def visualize_architecture(self, input: dict, save_path: str = "model_graph.png") -> str:
from torchviz import make_dot
output = self(**input)
graph = make_dot(output, params = dict(self.named_parameters()))
graph.render(save_path, format = 'png', cleanup = True)
return save_path
- 参数分布与异常检测(基于
matplotlib与 3σ 原则):
def plot_parameter_histogram(self, param_name: str, bins: int = 50) -> Figure:
# 直方图可视化
def detect_parameter_outliers(self, sigma: float = 3) -> List[Dict[str, Any]]:
# 基于均值/方差的异常值统计
- 便捷属性:
device、dtype、num_parameters、num_trainable_parameters。
最小可用样例
import torch
from wall_e.model.base_model import BaseModel
class ToyClassifier(BaseModel):
def __init__(self, in_dim: int = 10, num_classes: int = 2):
super().__init__()
self.net = torch.nn.Linear(in_dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x):
return self.net(x)
def compute_loss(self, logits, labels):
return {"loss": self.criterion(logits, labels)}
def train_step(self, batch):
logits = self.forward(batch["x"]) # batch: {"x": Tensor, "y": LongTensor}
return self.compute_loss(logits, batch["y"]) # 必须含 "loss"
def valid_step(self, batch):
with torch.no_grad():
return {"logits": self.forward(batch["x"]), "labels": batch["y"]}
def test_step(self, batch):
with torch.no_grad():
return {"logits": self.forward(batch["x"]) }
# from_cfg 构建(如已在 registry 注册可用 type 指定)
model = ToyClassifier()
与配置的关系与常见问题
- Loop 触发频率:由
training.valid_*、training.test_*字段控制;详见《配置文件》文档中的“训练参数”。 - 混合精度/梯度裁剪/累积:分别由
training.fp16、training.grad_clip、training.gradient_accumulation控制。 - DeepSpeed:启用后调度器需在 DeepSpeed 配置中指定;非 DS 模式下由框架根据
scheduler配置自动步进。 - Evaluator 对接:自定义
Evaluator.process/evaluate以匹配valid_step/test_step输出结构。
进一步阅读
- 《配置文件》:详见
docs/tutorials/配置文件.md - 训练循环实现:
wall_e/runner/loop/train_loop.py - 验证/测试循环实现:
wall_e/runner/loop/valid_loop.py、wall_e/runner/loop/test_loop.py - 模型基类:
wall_e/model/base_model.py