Skip to main content

BaseIterableDataset

BaseIterableDataset(wall_e/dataset/dataset.py):基于 BaseDataset 的 iterable-style 封装,面向流式大数据场景。

何时使用

  • 数据规模较大或来自流式数据源(无需随机索引,顺序遍历即可)。
  • 需要边读边处理,减少内存占用。

提示:继承自 BaseDataset,可复用其 get_batch_loadermap/filter、持久化接口,但在 iterable 模式下存在若干限制(见下)。

构造与数据加载

wall_e/dataset/dataset.py (L180–195)
@registry.register_dataset("BaseIterableDataset")
class BaseIterableDataset(BaseDataset):
def __init__(self, data_source, only_local=False, metadata=None, *args, **kwargs):
...
self.dataset = self._set_dataset(data_source, only_local)
  • 参数:
    • data_source (str | Sequence | datasets.Dataset | DatasetDict | IterableDataset):本地文件/HF Hub 名称/已有数据集。
    • only_local (bool)True 时仅从本地缓存加载 HF 数据集。
    • metadata (dict | None):用于生成 dataset_card 的元信息。
  • 行为:
    • 通过 _set_dataset 统一转换为 IterableDataset 或其字典。

_set_dataset(流式)

wall_e/dataset/dataset.py (L197–232)
def _set_dataset(self, data_source, only_local=False):
# 支持:本地文件/HF Hub/Dataset -> to_iterable_dataset()/DatasetDict -> 每个 split to_iterable
return dataset
  • 说明:
    • 本地“目录形式”的 Dataset 不支持直接作为流式加载;应使用原始数据文件。
    • 当传入 Dataset/DatasetDict 时,会调用 to_iterable_dataset() 转换为流式。

迭代与卡片

wall_e/dataset/dataset.py (L234–251)
def __iter__(self):
yield from self.dataset

@property
def dataset_card(self) -> Dict: ...

def save_to_disk(self, path: str): # 仅保存case与card,提示不可直接保存流式数据
...
  • __iter__:直接产出底层流式数据。
  • dataset_card:在 metadata 基础上标记 streaming=True
  • save_to_disk(path):仅保存 case 与 card,并提示“流式数据集不支持保存到磁盘”。

与 BaseDataset 的关系

  • get_batch_loader(...)
    • 在未提供 transform_fn 时,默认对底层 IterableDataset 调用 with_format("torch"),返回的 batch 常为 dict[str, torch.Tensor]
    • 如需自定义批处理,提供合适的 collate_fnbatch_size
  • map/filter
    • IterableDataset 不支持 remove_columns
    • filter 不支持 num_proc
    • 其余参数语义与 BaseDataset 相同。

示例

示例:

from wall_e.dataset.dataset import BaseIterableDataset

stream_ds = BaseIterableDataset(data_source=["cc_news", None])
loader = stream_ds.get_batch_loader(batch_size=4) # with_format("torch") 后按需 collate
first_batch = next(iter(loader))
print(type(first_batch))

# 也可使用继承自 BaseDataset 的 from_cfg(路径或 Hub 名称):
# stream_ds2 = BaseIterableDataset.from_cfg(["cc_news", None])