快速开始
5 分钟 等级: 入门
安装
git clone https://github.com/alan-tsang/wall_e.git
cd wall_e
pip install .
注意
要求 Python 3.9+ 和 PyTorch 2.3.0+。可根据需要安装 DeepSpeed、WandB 和 Ray 等可选依赖项。
最小可运行示例
这里提供两种路径:
- 代码最短路径(10 分钟上手)
- 基于 YAML 配置的启动方式
1) 代码最短路径(10 分钟上手)
定义最小模型、数据集并运行 Runner:
from wall_e.model.base_model import BaseModel
import torch
import torch.nn as nn
import torch.nn.functional as F
class ToyModel(BaseModel):
def __init__(self, dim=8):
super().__init__()
self.linear = nn.Linear(dim, 1)
def compute_loss(self, x, y):
pred = self.linear(x)
return {"loss": F.mse_loss(pred, y)}
def train_step(self, batch):
return self.compute_loss(**batch)
def valid_step(self, batch):
return self.compute_loss(**batch)
def test_step(self, batch):
return self.compute_loss(**batch)
构建数据集与 DataLoader:
from wall_e.dataset.dataset import BaseMapDataset
from torch.utils.data import DataLoader, TensorDataset
import torch
def make_tensor_ds(n=64, dim=8):
x = torch.randn(n, dim)
y = torch.randn(n, 1)
return TensorDataset(x, y)
class TensorMapDataset(BaseMapDataset):
def _set_dataset(self, data_source, only_local=False):
ds = make_tensor_ds()
return ds
ds = TensorMapDataset(data_source="dummy", split_ratios=(0.9, 0.1))
train_loader = DataLoader(ds.get_split("train"), batch_size=8)
valid_loader = DataLoader(ds.get_split("test"), batch_size=8)
运行 Runner:
from wall_e.runner.runner import Runner
from omegaconf import OmegaConf
cfg = OmegaConf.create({
"run_name": "quickstart",
"training": {
"fp16": False,
"progress_every_n_batches": 10,
},
"optimizer": {"lr": 1e-3, "weight_decay": 0.0}
})
runner = Runner(
model=ToyModel(),
epochs=2,
train_data_loader=train_loader,
valid_data_loader=valid_loader,
cfg=cfg,
)
runner.fit()
预期:将看到进度打印、轮次总结与检查点输出(见 logs/ 与 checkpoints/)。
2) 基于配置的启动方式(YAML)
1) 新建配置文件
在项目根目录创建 demo.yaml:
# demo.yaml
# 可选项
dataset:
# 由你实现的数据集,示例见下文“定义你的数据集”
type: 'YourMapDataset'
params:
data_source: './dataset/your_path'
shuffle: true
split_ratios: [0.98, 0.01, 0.01]
model:
# 由你实现或复用的模型,示例见下文“定义你的模型”
type: 'YourModel'
params:
vocab_size: 32000
hidden_size: 512
# 必须项
run_name: 'quickstart-demo'
run_description: "Minimal runnable example"
# null 项可以删除
training:
epochs: 1
ds_config: ''
gradient_accumulation: 1
resume_from: null
load_from: null
activation_checkpoint: []
grad_clip: null
fp16: false
print_model: false
progress_show:
loss: true
valid/acc: true
test/acc: true
valid_begin_epoch: 1
valid_interval_epoch: 1
valid_begin_iter: 0
valid_interval_iter: 100000000 # 仅按 epoch 校验
test_begin_epoch: 1
test_interval_epoch: 1
test_begin_iter: 0
test_interval_iter: 100000000 # 仅按 epoch 测试
progress_every_n_epochs: 1
progress_every_n_batches: 1
log:
to_file: true
folder: "./assert/logs"
level: "INFO"
rank_level: "WARNING"
pt:
enable: true
dir: "./assert/checkpoints"
best_monitor:
loss: true
topk: 3
begin_epoch: 1
epoch_interval: 1
begin_batch: 0
batch_interval: 100000000
wandb:
enable: false
proj_name: "wall_e quickstart"
offline: true
dir: "./assert"
tags: ["wall_e", "quickstart"]
2) Python 启动脚本
# run_demo.py
from wall_e.config.load_config import load_cfg
from wall_e.runner.runner import Runner
from wall_e.model.base_model import BaseModel
from wall_e.dataset.dataset import BaseMapDataset
from torch.utils.data import DataLoader
cfg = load_cfg('demo.yaml')
# 按需构建 dataset / model(或用注册表从 cfg 中创建)
ds = BaseMapDataset.from_cfg(cfg.dataset.path, metadata=getattr(cfg.dataset, 'metadata', None))
train_loader = DataLoader(ds.get_split('train'), batch_size=8)
valid_loader = DataLoader(ds.get_split('test'), batch_size=8)
model = BaseModel.from_cfg(cfg.model.params | {'type': cfg.model.type}) # 需要已在 registry 中注册
runner = Runner(model=model, epochs=cfg.training.epochs,
train_data_loader=train_loader, valid_data_loader=valid_loader, cfg=cfg)
runner.fit()
运行:
python run_demo.py / torchrun --nproc_per_node=* ./run_demo.py
下一步:
- 《开始你的第一个任务》:docs/tutorials/开始你的第一个任务
- API 参考:
api/runner、api/loops、api/base-model、api/dataset/overview、api/config