Skip to main content

BaseDataset

BaseDataset(wall_e/dataset/base_dataset.py):通用数据集基类,统一批加载、流水线操作、持久化与 Hub 推送。

何时单独使用

  • 需要基于 HuggingFace Datasets 的 map/filter、批加载与持久化/Hub 推送能力;
  • 自己控制训练循环或集成到第三方框架。

何时与 Runner 配合

  • 需要分布式采样、统一日志与评估、调度器、回调等训练编排(通过 Runner.wrap_dataloader 等处理)。

核心接口

wall_e/dataset/base_dataset.py (L15–41)
class BaseDataset(ABC):
def __init__(self, data_source, metadata: Optional[Dict] = None, *args, **kwargs): ...
def _set_dataset(self, *args, **kwargs): ...
def get_batch_loader(self, split=None, transform_fn=None, **kwargs): ...

get_batch_loader

wall_e/dataset/base_dataset.py (L41–59)
def get_batch_loader(self, split=None, transform_fn=None, **kwargs): ...
  • 参数:
    • split (str | None):当 datasetDatasetDict 时指定 split。
    • transform_fn (Callable | None):设置 set_transform;缺省使用 with_format("torch")
    • **kwargs:透传给 torch.utils.data.DataLoader
  • 返回:DataLoader

说明:

  • 典型返回的 batch 类型为 dict[str, torch.Tensor](当已 with_format("torch") 且未自定义 collate_fn)。
  • 自定义 collate_fn 时,batch 结构取决于你的实现(list/tuple/dict/自定义对象)。

_get_batch_loader(类方法)

wall_e/dataset/base_dataset.py (L60–87)
@classmethod
def _get_batch_loader(cls, dataset, transform_fn=None, batch_size=1, pin_memory=False,
sampler=None, num_workers=0, collate_fn=None, shuffle=False, **loader_kwargs) -> DataLoader: ...

map / filter / _apply_operation

wall_e/dataset/base_dataset.py (L89–139)
def map(...); def filter(...); def _apply_operation(...):  # 统一处理 Dataset/IterableDataset/DatasetDict
  • 提示:IterableDatasetremove_columns/num_proc 有限制。

参数(通用):

  • fn:映射或过滤函数;当 batched=True 时传入批次,需返回等长结构。
  • batched (bool):是否以批为单位处理。
  • batch_size (int):批尺寸(仅 batched=True 时生效)。
  • num_proc (int):进程数;注意 IterableDataset 的限制(filter 不支持 num_proc)。
  • remove_columns (list | None)map 时可移除列(IterableDataset 不支持)。

返回:

  • BaseDataset:返回新的数据集包装实例(不修改原实例)。

持久化与 Hub

参数/返回:

  • save_to_disk(path: str):创建目录并顺序调用 save_casesave_carddataset.save_to_disk;无返回。
  • save_case(path: str):保存一个样例(或每个 split 的首个样例)到文本;无返回。
  • save_card(path: str):将 dataset_card 保存为 dataset_card.json;无返回。
  • push_to_hub(repo_id: str, **kwargs):将底层 dataset 推送到 HuggingFace Hub;无返回。
wall_e/dataset/base_dataset.py (L141–190)
def save_to_disk(self, path: str); def save_case(self, path: str); def save_card(self, path: str); def push_to_hub(self, repo_id: str, **kwargs)

dataset_card

说明:

  • 默认返回 metadata 的浅拷贝;在具体子类(如 BaseMapDataset)中会扩充 splits/size/streaming 等字段。
wall_e/dataset/base_dataset.py (L175–179)
@property
def dataset_card(self) -> Dict: ...

from_cfg

wall_e/dataset/base_dataset.py (L182–186)
@classmethod
def from_cfg(cls, path: str, **kwargs) -> "BaseDataset": ...

示例:

from wall_e.dataset.base_dataset import BaseDataset

class MyDataset(BaseDataset):
def _set_dataset(self):
from datasets import load_dataset
self.dataset = load_dataset(self.data_source)

ds = MyDataset.from_cfg("imdb")
loader = ds.get_batch_loader(split="train", batch_size=8)
print(type(next(iter(loader)))) # 常见为 dict[str, torch.Tensor]