Skip to main content

BaseMapDataset

BaseMapDataset(wall_e/dataset/dataset.py):基于 BaseDataset 的 map-style 封装,支持自动划分与子集操作。

何时使用

  • 数据可随机访问、适合内存内 map-style 处理(非流式)。
  • 需要自动划分 train/valid/test、快速抽样/切片与便捷的 DataLoader 构建。

提示:继承自 BaseDataset,可直接使用其 get_batch_loadermap/filter、持久化与 Hub 接口。

构造与数据加载

wall_e/dataset/dataset.py (L10–27)
@registry.register_dataset("BaseMapDataset")
class BaseMapDataset(BaseDataset):
def __init__(self, data_source, only_local=False, metadata=None, shuffle=True,
split_ratios: Optional[tuple] = (0.96, 0.02, 0.02), *args, **kwargs):
...
self.dataset = self._set_dataset(data_source, only_local)
if split_ratios:
... # 自动划分
  • 参数:
    • data_source (Sequence | str | Dataset | DatasetDict):本地路径/文件、HF Hub 名称、或已有数据集对象。
    • only_local (bool)True 时仅从本地缓存加载 HF 数据集。
    • metadata (dict | None):元信息,用于 dataset_card
    • shuffle (bool):自动划分时是否打乱。
    • split_ratios (tuple | None):自动划分比例;None 表示不自动划分。
  • 行为:
    • 通过 _set_dataset 加载为 DatasetDatasetDict;若 split_ratios 非空且为 Dataset,会自动调用 auto_split

_set_dataset

wall_e/dataset/dataset.py (L38–71)
def _set_dataset(self, data_source, only_local) -> Union[Dataset, DatasetDict]:
# 支持:本地路径/HF Hub/Sequence[repo, name]/已有 Dataset(Dict)
return dataset
  • 说明:
    • 本地路径:目录使用 load_from_disk;文件根据扩展名推断格式并用 load_dataset 加载。
    • 远程 Hub:支持 path[path, name] 形式,streaming=False(map-style)。
    • 也接受现成的 Dataset/DatasetDict

auto_split

wall_e/dataset/dataset.py (L73–123)
def auto_split(self, shuffle, ratios: tuple) -> Optional[DatasetDict]:
# 2 或 3 段比例划分,输出 DatasetDict(train/valid/test)
  • 参数:
    • shuffle (bool):划分时是否打乱。
    • ratios (tuple):长度 2 或 3;长度 3 输出 train/valid/test,长度 2 输出 train/test
  • 返回:
    • DatasetDict | None:当底层已是 DatasetDict 时不进行划分并返回 None

说明:

  • len(ratios)==3 时,先划分出 train 与临时集,再将临时集按 val/test 比例二次划分。

get_subset/get_split

wall_e/dataset/dataset.py (L124–149)
def get_subset(self, split=None, n=None, start=0) -> "BaseMapDataset": ...
def get_split(self, split: str) -> "BaseMapDataset": ...
  • 参数:
    • get_subset(split=None, n=None, start=0):从指定 split(或单一 Dataset)切片,返回新的 BaseMapDataset
      • split:当底层为 DatasetDict 时必填(如 "train");
      • n:样本数量;start:起始偏移。
    • get_split(split: str):当底层为 DatasetDict 时,获取指定 split 的 BaseMapDataset
  • 返回:
    • BaseMapDataset:新包装实例(不修改原实例)。

len/getitem/sample

wall_e/dataset/dataset.py (L151–165)
def __len__(self) -> int: ...
def __getitem__(self, idx: int): ...
def sample(self, split, n=1, start=0) -> Dataset: ...
  • 说明:
    • __len__/__getitem__ 直接代理到底层 Dataset/DatasetDict[split]
    • sample(split, n=1, start=0):返回底层 Dataset 的一个切片(非包装)。

dataset_card

wall_e/dataset/dataset.py (L167–177)
@property
def dataset_card(self) -> Dict: ...
  • BaseDataset.dataset_card 基础上增加:
    • splits:当为 DatasetDict 时列出键;否则为 'no split'
    • size:各 split 的样本数或整体样本数;
    • streaming=False

示例:

from wall_e.dataset.dataset import BaseMapDataset

ds = BaseMapDataset(data_source="imdb", split_ratios=(0.9, 0.1))
train_loader = ds.get_batch_loader(split="train", batch_size=8)
small_valid = ds.get_subset(split="valid", n=64)
print(list(small_valid.sample("valid", n=3)))

# 也可通过 from_cfg 使用路径或 Hub 名称:
# ds2 = BaseMapDataset.from_cfg("imdb")