Skip to main content

BaseModel

BaseModel(wall_e/model/base_model.py) 是一个最小而稳健的模型抽象:既可独立在任意工程中使用,也可与 wall_eRunner/Loop 配合训练。

何时单独使用

  • 只需要 from_cfg 构建、load_checkpoint 重载、参数冻结/可视化等通用能力;
  • 自己控制训练循环或集成到第三方训练框架。

何时配合 Runner 使用

  • 需要分布式/DeepSpeed、回调、日志、评估与调度器等一体化训练编排。

抽象方法

wall_e/model/base_model.py (L15–36)
class BaseModel(nn.Module, ABC):
def train_step(self, data_batch: dict, *args: Any, **kwargs: Any) -> dict: ...
def valid_step(self, data_batch: dict, *args: Any, **kwargs: Any) -> dict: ...
def test_step(self, data_batch: dict, *args: Any, **kwargs: Any) -> dict: ...
def compute_loss(self, *args: Any, **kwargs: Any) -> dict: ...
  • train_step 必须返回包含键 loss 的字典,以便训练循环执行反向传播。
  • valid_step/test_step 返回度量相关的字典(可自由扩展)。

参数与返回约定:

  • train_step(batch: dict, /, **kwargs) -> dict

    • 参数:
      • batch:单个 mini-batch 数据,通常包含张量,如 {"x": Tensor, "y": Tensor}
      • **kwargs:可选的额外上下文(例如模式开关、采样信息)。
    • 返回:
      • {"loss": Tensor, ...},必须包含 loss;可附加 loss_aux、自定义度量等。
  • valid_step(batch: dict, /, **kwargs) -> dicttest_step(...) -> dict

    • 参数:同上。
    • 返回:任意评估所需的标量或张量,可与评估器协同使用。
  • compute_loss(*args, **kwargs) -> dict

    • 参数:由实现者自定义(例如传入 x, y)。
    • 返回:{"loss": Tensor, ...},用于复用训练/验证/测试中的损失计算。

常用方法

load_checkpoint

wall_e/model/base_model.py (L38–56)
def load_checkpoint(self, cached_path: str) -> List[str]:
checkpoint = torch.load(cached_path, map_location="cpu")
state_dict = checkpoint.get("model", checkpoint)
msg = self.load_state_dict(state_dict, strict=False)
return msg.missing_keys
  • 参数:
    • cached_path (str):权重文件路径;支持 state_dict 或包含 "model" 键的state_dict
  • 返回:
    • List[str]:加载时缺失的参数名列表(若为空表示完全匹配)。

from_cfg

wall_e/model/base_model.py (L58–124)
@classmethod
def from_cfg(cls: Type[T], cfg: Any) -> T:
# 解析 cfg,选择子类构造器,预训练加载与 device 放置
return model
  • 参数:
    • cfg (dict | OmegaConf | argparse.Namespace 等)
      • 若包含 type,通过注册表 registry.get_model_class(type) 定位子类并按其构造签名取参;
      • 若由子类调用且不含 type,使用当前子类构建;
      • 支持键:pretrained(权重路径)、allow_missing_keys(是否允许缺参)、device(如 "cuda"/"cpu")。
  • 返回:
    • BaseModel 子类实例,按需迁移到指定 device

示例:

# 方式一:通过注册表指定 type(推荐用于解耦构建)
from wall_e.model.base_model import BaseModel

cfg = {
"type": "YourModelClassName", # 必须先在 registry 中注册
"hidden_size": 256,
"pretrained": "./checkpoints/model.pt",
"device": "cuda"
}
model = BaseModel.from_cfg(cfg)
print(type(model)) # YourModelClassName

# 方式二:由子类直接调用(不需要提供 type)
from your_pkg.models import YourModelClassName as YourModel

model2 = YourModel.from_cfg({
"hidden_size": 256,
# 也可支持:"pretrained": "./checkpoints/model.pt", "device": "cpu"
})

freeze_parameters

wall_e/model/base_model.py (L127–143)
def freeze_parameters(self, freeze: bool = True, regex: Optional[str] = None) -> "BaseModel": ...
  • 参数:
    • freeze (bool):是否冻结匹配参数(冻结后 requires_grad=False)。
    • regex (str | None):正则表达式,匹配需操作的参数名;缺省表示作用于全部参数。
  • 返回:
    • BaseModel:返回自身以便链式调用。

visualize_architecture

wall_e/model/base_model.py (L145–159)
def visualize_architecture(self, input: dict, save_path: str = "model_graph.png") -> str: ...
  • 参数:
    • input (dict):前向调用所需的入参字典,将以 self(**input) 方式前向一次。
    • save_path (str):输出图片路径(无扩展名时保存为 png)。
  • 返回:
    • str:保存后的路径。

plot_parameter_histogram

wall_e/model/base_model.py (L162–173)
def plot_parameter_histogram(self, param_name: str, bins: int = 50) -> Figure: ...
  • 参数:
    • param_name (str):待分析的参数名(如 "encoder.layers.0.self_attn.q_proj.weight")。
    • bins (int):直方图分箱数。
  • 返回:
    • matplotlib.figure.Figure:绘图对象。

属性

  • device / dtype
  • num_parameters / num_trainable_parameters

说明:

  • device: torch.device:第一个参数所在设备(模型整体设备)。
  • dtype: torch.dtype:首个参数的数据类型。
  • num_parameters: int:参数总量。
  • num_trainable_parameters: int:可训练参数总量。

最小使用示例(独立)

from wall_e.model.base_model import BaseModel

class ToyModel(BaseModel):
def __init__(self, dim=10):
super().__init__()
import torch.nn as nn
self.linear = nn.Linear(dim, 1)

def compute_loss(self, x, y):
import torch.nn.functional as F
pred = self.linear(x)
return {"loss": F.mse_loss(pred, y)}

def train_step(self, batch):
return self.compute_loss(**batch)

def valid_step(self, batch):
return self.compute_loss(**batch)

def test_step(self, batch):
return self.compute_loss(**batch)

# 独立使用常见能力
toy_model = ToyModel()
toy_model.freeze_parameters(regex="^linear")

与 Runner 协同(示意)

from wall_e.runner.runner import Runner

# 假设已构建 dataloader 与 cfg
runner = Runner(model=toy_model, epochs=5, train_data_loader=train_loader, cfg=cfg)
runner.fit()