注册机制(Registry)
Registry(wall_e/common/registry.py) 提供统一的注册/查询机制,用于:
- 管理组件解耦:模型/数据集/回调/调度器/Runner 等以字符串名注册,按需获取;
- 全局上下文:通过
registry.register/get维护运行期状态(如 metrics、cfg 片段等)。
可注册的组件
wall_e/common/registry.py (L19–187)
class Registry:
def register_model(cls, name): ...
def register_metric(cls, name): ...
def register_dataset(cls, name): ...
def register_callback(cls, name): ...
def register_lr_scheduler(cls, name): ...
def register_runner(cls, name): ...
要点:
- 各
register_*返回装饰器,校验基类继承关系并防重复注册; - 典型用法:
from wall_e.common.registry import registry
@registry.register_model("MyModel")
class MyModel(BaseModel):
...
# 获取:
model_cls = registry.get_model_class("MyModel")
model = model_cls(...)
全局上下文(state)
wall_e/common/registry.py (L207–233)
def register(name, obj): ... # 例:registry.register("metric.train/loss", 0.1)
def get(name, default=None, no_warning=False): ... # 例:registry.get("metric.train/loss", 0.1)
要点:
name支持点分层级(如metric.train/loss或cfg.training.is_sweep);- 训练中
TrainLoop.register_model_output会将包含 "loss" 的标量注册到metric.*,供日志与监控使用; - 可用于传递运行期配置/对象以实现模块间解耦。
枚举与路径
wall_e/common/registry.py (L255–282)
def list_models(): ...;
def list_datasets(): ...;
def list_callbacks(): ...;
def list_lr_schedulers(): ...;
def list_runners(): ...
def register_path(name, path): ...;
def get_path(name): ...
与 from_cfg/Runner 的关系
BaseModel.from_cfg:当cfg含type时通过registry.get_model_class(type)定位具体子类;TrainLoop:将 loss 指标注册到metric.*,与日志/可视化联动;- 通过在启动时
registry.register("cfg", cfg)等方式,实现各模块读取共享配置且保持解耦。