Skip to main content

注册机制(Registry)

Registry(wall_e/common/registry.py) 提供统一的注册/查询机制,用于:

  • 管理组件解耦:模型/数据集/回调/调度器/Runner 等以字符串名注册,按需获取;
  • 全局上下文:通过 registry.register/get 维护运行期状态(如 metrics、cfg 片段等)。

可注册的组件

wall_e/common/registry.py (L19–187)
class Registry:
def register_model(cls, name): ...
def register_metric(cls, name): ...
def register_dataset(cls, name): ...
def register_callback(cls, name): ...
def register_lr_scheduler(cls, name): ...
def register_runner(cls, name): ...

要点:

  • register_* 返回装饰器,校验基类继承关系并防重复注册;
  • 典型用法:
from wall_e.common.registry import registry

@registry.register_model("MyModel")
class MyModel(BaseModel):
...

# 获取:
model_cls = registry.get_model_class("MyModel")
model = model_cls(...)

全局上下文(state)

wall_e/common/registry.py (L207–233)
def register(name, obj): ...  # 例:registry.register("metric.train/loss", 0.1)
def get(name, default=None, no_warning=False): ... # 例:registry.get("metric.train/loss", 0.1)

要点:

  • name 支持点分层级(如 metric.train/losscfg.training.is_sweep);
  • 训练中 TrainLoop.register_model_output 会将包含 "loss" 的标量注册到 metric.*,供日志与监控使用;
  • 可用于传递运行期配置/对象以实现模块间解耦。

枚举与路径

wall_e/common/registry.py (L255–282)
def list_models(): ...; 
def list_datasets(): ...;
def list_callbacks(): ...;
def list_lr_schedulers(): ...;
def list_runners(): ...
def register_path(name, path): ...;
def get_path(name): ...

与 from_cfg/Runner 的关系

  • BaseModel.from_cfg:当 cfgtype 时通过 registry.get_model_class(type) 定位具体子类;
  • TrainLoop:将 loss 指标注册到 metric.*,与日志/可视化联动;
  • 通过在启动时 registry.register("cfg", cfg) 等方式,实现各模块读取共享配置且保持解耦。