跳到主要内容

快速开始

5 分钟 等级: 入门

安装

git clone https://github.com/alan-tsang/wall_e.git
cd wall_e
pip install .
注意

要求 Python 3.9+ 和 PyTorch 2.3.0+。可根据需要安装 DeepSpeed、WandB 和 Ray 等可选依赖项。

最小可运行示例

这里提供两种路径:

  1. 代码最短路径(10 分钟上手)
  2. 基于 YAML 配置的启动方式

1) 代码最短路径(10 分钟上手)

定义最小模型、数据集并运行 Runner:

from wall_e.model.base_model import BaseModel
import torch
import torch.nn as nn
import torch.nn.functional as F

class ToyModel(BaseModel):
def __init__(self, dim=8):
super().__init__()
self.linear = nn.Linear(dim, 1)
def compute_loss(self, x, y):
pred = self.linear(x)
return {"loss": F.mse_loss(pred, y)}
def train_step(self, batch):
return self.compute_loss(**batch)
def valid_step(self, batch):
return self.compute_loss(**batch)
def test_step(self, batch):
return self.compute_loss(**batch)

构建数据集与 DataLoader:

from wall_e.dataset.dataset import BaseMapDataset
from torch.utils.data import DataLoader, TensorDataset
import torch

def make_tensor_ds(n=64, dim=8):
x = torch.randn(n, dim)
y = torch.randn(n, 1)
return TensorDataset(x, y)

class TensorMapDataset(BaseMapDataset):
def _set_dataset(self, data_source, only_local=False):
ds = make_tensor_ds()
return ds

ds = TensorMapDataset(data_source="dummy", split_ratios=(0.9, 0.1))
train_loader = DataLoader(ds.get_split("train"), batch_size=8)
valid_loader = DataLoader(ds.get_split("test"), batch_size=8)

运行 Runner:

from wall_e.runner.runner import Runner
from omegaconf import OmegaConf

cfg = OmegaConf.create({
"run_name": "quickstart",
"training": {
"fp16": False,
"progress_every_n_batches": 10,
},
"optimizer": {"lr": 1e-3, "weight_decay": 0.0}
})

runner = Runner(
model=ToyModel(),
epochs=2,
train_data_loader=train_loader,
valid_data_loader=valid_loader,
cfg=cfg,
)
runner.fit()

预期:将看到进度打印、轮次总结与检查点输出(见 logs/checkpoints/)。

2) 基于配置的启动方式(YAML)

1) 新建配置文件

在项目根目录创建 demo.yaml

# demo.yaml
# 可选项
dataset:
# 由你实现的数据集,示例见下文“定义你的数据集”
type: 'YourMapDataset'
params:
data_source: './dataset/your_path'
shuffle: true
split_ratios: [0.98, 0.01, 0.01]

model:
# 由你实现或复用的模型,示例见下文“定义你的模型”
type: 'YourModel'
params:
vocab_size: 32000
hidden_size: 512

# 必须项
run_name: 'quickstart-demo'
run_description: "Minimal runnable example"

# null 项可以删除
training:
epochs: 1
ds_config: ''
gradient_accumulation: 1
resume_from: null
load_from: null
activation_checkpoint: []
grad_clip: null
fp16: false

print_model: false
progress_show:
loss: true
valid/acc: true
test/acc: true
valid_begin_epoch: 1
valid_interval_epoch: 1
valid_begin_iter: 0
valid_interval_iter: 100000000 # 仅按 epoch 校验

test_begin_epoch: 1
test_interval_epoch: 1
test_begin_iter: 0
test_interval_iter: 100000000 # 仅按 epoch 测试

progress_every_n_epochs: 1
progress_every_n_batches: 1

log:
to_file: true
folder: "./assert/logs"
level: "INFO"
rank_level: "WARNING"
pt:
enable: true
dir: "./assert/checkpoints"
best_monitor:
loss: true
topk: 3
begin_epoch: 1
epoch_interval: 1
begin_batch: 0
batch_interval: 100000000
wandb:
enable: false
proj_name: "wall_e quickstart"
offline: true
dir: "./assert"
tags: ["wall_e", "quickstart"]

2) Python 启动脚本

# run_demo.py
from wall_e.config.load_config import load_cfg
from wall_e.runner.runner import Runner
from wall_e.model.base_model import BaseModel
from wall_e.dataset.dataset import BaseMapDataset
from torch.utils.data import DataLoader

cfg = load_cfg('demo.yaml')

# 按需构建 dataset / model(或用注册表从 cfg 中创建)
ds = BaseMapDataset.from_cfg(cfg.dataset.path, metadata=getattr(cfg.dataset, 'metadata', None))
train_loader = DataLoader(ds.get_split('train'), batch_size=8)
valid_loader = DataLoader(ds.get_split('test'), batch_size=8)

model = BaseModel.from_cfg(cfg.model.params | {'type': cfg.model.type}) # 需要已在 registry 中注册

runner = Runner(model=model, epochs=cfg.training.epochs,
train_data_loader=train_loader, valid_data_loader=valid_loader, cfg=cfg)
runner.fit()

运行:

python run_demo.py / torchrun --nproc_per_node=* ./run_demo.py

下一步:

  • 《开始你的第一个任务》:docs/tutorials/开始你的第一个任务
  • API 参考:api/runnerapi/loopsapi/base-modelapi/dataset/overviewapi/config