BaseModel
BaseModel(wall_e/model/base_model.py) 是一个最小而稳健的模型抽象:既可独立在任意工程中使用,也可与 wall_e 的 Runner/Loop 配合训练。
何时单独使用
- 只需要
from_cfg构建、load_checkpoint重载、参数冻结/可视化等通用能力; - 自己控制训练循环或集成到第三方训练框架。
何时配合 Runner 使用
- 需要分布式/DeepSpeed、回调、日志、评估与调度器等一体化训练编排。
抽象方法
wall_e/model/base_model.py (L15–36)
class BaseModel(nn.Module, ABC):
def train_step(self, data_batch: dict, *args: Any, **kwargs: Any) -> dict: ...
def valid_step(self, data_batch: dict, *args: Any, **kwargs: Any) -> dict: ...
def test_step(self, data_batch: dict, *args: Any, **kwargs: Any) -> dict: ...
def compute_loss(self, *args: Any, **kwargs: Any) -> dict: ...
train_step必须返回包含键loss的字典,以便训练循环执行反向传播。valid_step/test_step返回度量相关的字典(可自由扩展)。
参数与返回约定:
-
train_step(batch: dict, /, **kwargs) -> dict- 参数:
batch:单个 mini-batch 数据,通常包含张量,如{"x": Tensor, "y": Tensor}。**kwargs:可选的额外上下文(例如模式开关、采样信息)。
- 返回:
{"loss": Tensor, ...},必须包含loss;可附加loss_aux、自定义度量等。
- 参数:
-
valid_step(batch: dict, /, **kwargs) -> dict与test_step(...) -> dict- 参数:同上。
- 返回:任意评估所需的标量或张量,可与评估器协同使用。
-
compute_loss(*args, **kwargs) -> dict- 参数:由实现者自定义(例如传入
x, y)。 - 返回:
{"loss": Tensor, ...},用于复用训练/验证/测试中的损失计算。
- 参数:由实现者自定义(例如传入
常用方法
load_checkpoint
wall_e/model/base_model.py (L38–56)
def load_checkpoint(self, cached_path: str) -> List[str]:
checkpoint = torch.load(cached_path, map_location="cpu")
state_dict = checkpoint.get("model", checkpoint)
msg = self.load_state_dict(state_dict, strict=False)
return msg.missing_keys
- 参数:
cached_path (str):权重文件路径;支持state_dict或包含"model"键的state_dict。
- 返回:
List[str]:加载时缺失的参数名列表(若为空表示完全匹配)。
from_cfg
wall_e/model/base_model.py (L58–124)
@classmethod
def from_cfg(cls: Type[T], cfg: Any) -> T:
# 解析 cfg,选择子类构造器,预训练加载与 device 放置
return model
- 参数:
cfg (dict | OmegaConf | argparse.Namespace 等):- 若包含
type,通过注册表registry.get_model_class(type)定位子类并按其构造签名取参; - 若由子类调用且不含
type,使用当前子类构建; - 支持键:
pretrained(权重路径)、allow_missing_keys(是否允许缺参)、device(如"cuda"/"cpu")。
- 若包含
- 返回:
BaseModel子类实例,按需迁移到指定device。
示例:
# 方式一:通过注册表指定 type(推荐用于解耦构建)
from wall_e.model.base_model import BaseModel
cfg = {
"type": "YourModelClassName", # 必须先在 registry 中注册
"hidden_size": 256,
"pretrained": "./checkpoints/model.pt",
"device": "cuda"
}
model = BaseModel.from_cfg(cfg)
print(type(model)) # YourModelClassName
# 方式二:由子类直接调用(不需要提供 type)
from your_pkg.models import YourModelClassName as YourModel
model2 = YourModel.from_cfg({
"hidden_size": 256,
# 也可支持:"pretrained": "./checkpoints/model.pt", "device": "cpu"
})
freeze_parameters
wall_e/model/base_model.py (L127–143)
def freeze_parameters(self, freeze: bool = True, regex: Optional[str] = None) -> "BaseModel": ...
- 参数:
freeze (bool):是否冻结匹配参数(冻结后requires_grad=False)。regex (str | None):正则表达式,匹配需操作的参数名;缺省表示作用于全部参数。
- 返回:
BaseModel:返回自身以便链式调用。
visualize_architecture
wall_e/model/base_model.py (L145–159)
def visualize_architecture(self, input: dict, save_path: str = "model_graph.png") -> str: ...
- 参数:
input (dict):前向调用所需的入参字典,将以self(**input)方式前向一次。save_path (str):输出图片路径(无扩展名时保存为png)。
- 返回:
str:保存后的路径。
plot_parameter_histogram
wall_e/model/base_model.py (L162–173)
def plot_parameter_histogram(self, param_name: str, bins: int = 50) -> Figure: ...
- 参数:
param_name (str):待分析的参数名(如"encoder.layers.0.self_attn.q_proj.weight")。bins (int):直方图分箱数。
- 返回:
matplotlib.figure.Figure:绘图对象。
属性
device/dtypenum_parameters/num_trainable_parameters
说明:
device: torch.device:第一个参数所在设备(模型整体设备)。dtype: torch.dtype:首个参数的数据类型。num_parameters: int:参数总量。num_trainable_parameters: int:可训练参数总量。
最小使用示例(独立)
from wall_e.model.base_model import BaseModel
class ToyModel(BaseModel):
def __init__(self, dim=10):
super().__init__()
import torch.nn as nn
self.linear = nn.Linear(dim, 1)
def compute_loss(self, x, y):
import torch.nn.functional as F
pred = self.linear(x)
return {"loss": F.mse_loss(pred, y)}
def train_step(self, batch):
return self.compute_loss(**batch)
def valid_step(self, batch):
return self.compute_loss(**batch)
def test_step(self, batch):
return self.compute_loss(**batch)
# 独立使用常见能力
toy_model = ToyModel()
toy_model.freeze_parameters(regex="^linear")
与 Runner 协同(示意)
from wall_e.runner.runner import Runner
# 假设已构建 dataloader 与 cfg
runner = Runner(model=toy_model, epochs=5, train_data_loader=train_loader, cfg=cfg)
runner.fit()