BaseDataset
BaseDataset(wall_e/dataset/base_dataset.py):通用数据集基类,统一批加载、流水线操作、持久化与 Hub 推送。
何时单独使用
- 需要基于 HuggingFace Datasets 的 map/filter、批加载与持久化/Hub 推送能力;
- 自己控制训练循环或集成到第三方框架。
何时与 Runner 配合
- 需要分布式采样、统一日志与评估、调度器、回调等训练编排(通过
Runner.wrap_dataloader等处理)。
核心接口
wall_e/dataset/base_dataset.py (L15–41)
class BaseDataset(ABC):
def __init__(self, data_source, metadata: Optional[Dict] = None, *args, **kwargs): ...
def _set_dataset(self, *args, **kwargs): ...
def get_batch_loader(self, split=None, transform_fn=None, **kwargs): ...
get_batch_loader
wall_e/dataset/base_dataset.py (L41–59)
def get_batch_loader(self, split=None, transform_fn=None, **kwargs): ...
- 参数:
split (str | None):当dataset为DatasetDict时指定 split。transform_fn (Callable | None):设置set_transform;缺省使用with_format("torch")。**kwargs:透传给torch.utils.data.DataLoader。
- 返回:
DataLoader。
说明:
- 典型返回的 batch 类型为
dict[str, torch.Tensor](当已with_format("torch")且未自定义collate_fn)。 - 自定义
collate_fn时,batch 结构取决于你的实现(list/tuple/dict/自定义对象)。
_get_batch_loader(类方法)
wall_e/dataset/base_dataset.py (L60–87)
@classmethod
def _get_batch_loader(cls, dataset, transform_fn=None, batch_size=1, pin_memory=False,
sampler=None, num_workers=0, collate_fn=None, shuffle=False, **loader_kwargs) -> DataLoader: ...
map / filter / _apply_operation
wall_e/dataset/base_dataset.py (L89–139)
def map(...); def filter(...); def _apply_operation(...): # 统一处理 Dataset/IterableDataset/DatasetDict
- 提示:
IterableDataset对remove_columns/num_proc有限制。
参数(通用):
fn:映射或过滤函数;当batched=True时传入批次,需返回等长结构。batched (bool):是否以批为单位处理。batch_size (int):批尺寸(仅batched=True时生效)。num_proc (int):进程数;注意IterableDataset的限制(filter不支持num_proc)。remove_columns (list | None):map时可移除列(IterableDataset不支持)。
返回:
BaseDataset:返回新的数据集包装实例(不修改原实例)。
持久化与 Hub
参数/返回:
save_to_disk(path: str):创建目录并顺序调用save_case、save_card、dataset.save_to_disk;无返回。save_case(path: str):保存一个样例(或每个 split 的首个样例)到文本;无返回。save_card(path: str):将dataset_card保存为dataset_card.json;无返回。push_to_hub(repo_id: str, **kwargs):将底层dataset推送到 HuggingFace Hub;无返回。
wall_e/dataset/base_dataset.py (L141–190)
def save_to_disk(self, path: str); def save_case(self, path: str); def save_card(self, path: str); def push_to_hub(self, repo_id: str, **kwargs)
dataset_card
说明:
- 默认返回
metadata的浅拷贝;在具体子类(如BaseMapDataset)中会扩充splits/size/streaming等字段。
wall_e/dataset/base_dataset.py (L175–179)
@property
def dataset_card(self) -> Dict: ...
from_cfg
wall_e/dataset/base_dataset.py (L182–186)
@classmethod
def from_cfg(cls, path: str, **kwargs) -> "BaseDataset": ...
示例:
from wall_e.dataset.base_dataset import BaseDataset
class MyDataset(BaseDataset):
def _set_dataset(self):
from datasets import load_dataset
self.dataset = load_dataset(self.data_source)
ds = MyDataset.from_cfg("imdb")
loader = ds.get_batch_loader(split="train", batch_size=8)
print(type(next(iter(loader)))) # 常见为 dict[str, torch.Tensor]