BaseIterableDataset
BaseIterableDataset(wall_e/dataset/dataset.py):基于 BaseDataset 的 iterable-style 封装,面向流式大数据场景。
何时使用
- 数据规模较大或来自流式数据源(无需随机索引,顺序遍历即可)。
- 需要边读边处理,减少内存占用。
提示:继承自 BaseDataset,可复用其 get_batch_loader、map/filter、持久化接口,但在 iterable 模式下存在若干限制(见下)。
构造与数据加载
wall_e/dataset/dataset.py (L180–195)
@registry.register_dataset("BaseIterableDataset")
class BaseIterableDataset(BaseDataset):
def __init__(self, data_source, only_local=False, metadata=None, *args, **kwargs):
...
self.dataset = self._set_dataset(data_source, only_local)
- 参数:
data_source (str | Sequence | datasets.Dataset | DatasetDict | IterableDataset):本地文件/HF Hub 名称/已有数据集。only_local (bool):True时仅从本地缓存加载 HF 数据集。metadata (dict | None):用于生成dataset_card的元信息。
- 行为:
- 通过
_set_dataset统一转换为IterableDataset或其字典。
- 通过
_set_dataset(流式)
wall_e/dataset/dataset.py (L197–232)
def _set_dataset(self, data_source, only_local=False):
# 支持:本地文件/HF Hub/Dataset -> to_iterable_dataset()/DatasetDict -> 每个 split to_iterable
return dataset
- 说明:
- 本地“目录形式”的 Dataset 不支持直接作为流式加载;应使用原始数据文件。
- 当传入
Dataset/DatasetDict时,会调用to_iterable_dataset()转换为流式。
迭代与卡片
wall_e/dataset/dataset.py (L234–251)
def __iter__(self):
yield from self.dataset
@property
def dataset_card(self) -> Dict: ...
def save_to_disk(self, path: str): # 仅保存case与card,提示不可直接保存流式数据
...
__iter__:直接产出底层流式数据。dataset_card:在metadata基础上标记streaming=True。save_to_disk(path):仅保存 case 与 card,并提示“流式数据集不支持保存到磁盘”。
与 BaseDataset 的关系
get_batch_loader(...):- 在未提供
transform_fn时,默认对底层IterableDataset调用with_format("torch"),返回的 batch 常为dict[str, torch.Tensor]; - 如需自定义批处理,提供合适的
collate_fn与batch_size。
- 在未提供
map/filter:IterableDataset不支持remove_columns;filter不支持num_proc;- 其余参数语义与
BaseDataset相同。
示例
示例:
from wall_e.dataset.dataset import BaseIterableDataset
stream_ds = BaseIterableDataset(data_source=["cc_news", None])
loader = stream_ds.get_batch_loader(batch_size=4) # with_format("torch") 后按需 collate
first_batch = next(iter(loader))
print(type(first_batch))
# 也可使用继承自 BaseDataset 的 from_cfg(路径或 Hub 名称):
# stream_ds2 = BaseIterableDataset.from_cfg(["cc_news", None])