BaseMapDataset
BaseMapDataset(wall_e/dataset/dataset.py):基于 BaseDataset 的 map-style 封装,支持自动划分与子集操作。
何时使用
- 数据可随机访问、适合内存内 map-style 处理(非流式)。
- 需要自动划分 train/valid/test、快速抽样/切片与便捷的 DataLoader 构建。
提示:继承自 BaseDataset,可直接使用其 get_batch_loader、map/filter、持久化与 Hub 接口。
构造与数据加载
wall_e/dataset/dataset.py (L10–27)
@registry.register_dataset("BaseMapDataset")
class BaseMapDataset(BaseDataset):
def __init__(self, data_source, only_local=False, metadata=None, shuffle=True,
split_ratios: Optional[tuple] = (0.96, 0.02, 0.02), *args, **kwargs):
...
self.dataset = self._set_dataset(data_source, only_local)
if split_ratios:
... # 自动划分
- 参数:
data_source (Sequence | str | Dataset | DatasetDict):本地路径/文件、HF Hub 名称、或已有数据集对象。only_local (bool):True时仅从本地缓存加载 HF 数据集。metadata (dict | None):元信息,用于dataset_card。shuffle (bool):自动划分时是否打乱。split_ratios (tuple | None):自动划分比例;None表示不自动划分。
- 行为:
- 通过
_set_dataset加载为Dataset或DatasetDict;若split_ratios非空且为Dataset,会自动调用auto_split。
- 通过
_set_dataset
wall_e/dataset/dataset.py (L38–71)
def _set_dataset(self, data_source, only_local) -> Union[Dataset, DatasetDict]:
# 支持:本地路径/HF Hub/Sequence[repo, name]/已有 Dataset(Dict)
return dataset
- 说明:
- 本地路径:目录使用
load_from_disk;文件根据扩展名推断格式并用load_dataset加载。 - 远程 Hub:支持
path或[path, name]形式,streaming=False(map-style)。 - 也接受现成的
Dataset/DatasetDict。
- 本地路径:目录使用
auto_split
wall_e/dataset/dataset.py (L73–123)
def auto_split(self, shuffle, ratios: tuple) -> Optional[DatasetDict]:
# 2 或 3 段比例划分,输出 DatasetDict(train/valid/test)
- 参数:
shuffle (bool):划分时是否打乱。ratios (tuple):长度 2 或 3;长度 3 输出train/valid/test,长度 2 输出train/test。
- 返回:
DatasetDict | None:当底层已是DatasetDict时不进行划分并返回None。
说明:
len(ratios)==3时,先划分出 train 与临时集,再将临时集按val/test比例二次划分。
get_subset/get_split
wall_e/dataset/dataset.py (L124–149)
def get_subset(self, split=None, n=None, start=0) -> "BaseMapDataset": ...
def get_split(self, split: str) -> "BaseMapDataset": ...
- 参数:
get_subset(split=None, n=None, start=0):从指定 split(或单一 Dataset)切片,返回新的BaseMapDataset;split:当底层为DatasetDict时必填(如"train");n:样本数量;start:起始偏移。
get_split(split: str):当底层为DatasetDict时,获取指定 split 的BaseMapDataset。
- 返回:
BaseMapDataset:新包装实例(不修改原实例)。
len/getitem/sample
wall_e/dataset/dataset.py (L151–165)
def __len__(self) -> int: ...
def __getitem__(self, idx: int): ...
def sample(self, split, n=1, start=0) -> Dataset: ...
- 说明:
__len__/__getitem__直接代理到底层Dataset/DatasetDict[split]。sample(split, n=1, start=0):返回底层Dataset的一个切片(非包装)。
dataset_card
wall_e/dataset/dataset.py (L167–177)
@property
def dataset_card(self) -> Dict: ...
- 在
BaseDataset.dataset_card基础上增加:splits:当为DatasetDict时列出键;否则为'no split';size:各 split 的样本数或整体样本数;streaming=False。
示例:
from wall_e.dataset.dataset import BaseMapDataset
ds = BaseMapDataset(data_source="imdb", split_ratios=(0.9, 0.1))
train_loader = ds.get_batch_loader(split="train", batch_size=8)
small_valid = ds.get_subset(split="valid", n=64)
print(list(small_valid.sample("valid", n=3)))
# 也可通过 from_cfg 使用路径或 Hub 名称:
# ds2 = BaseMapDataset.from_cfg("imdb")