数据集
6 分钟 等级: 入门
数据集
本框架的数据集接口基于 HuggingFace Datasets,提供一致的 Map/Iterable 两种风格的数据集封装,并补充了自动划分、批处理、便捷的 map/filter 操作、保存卡片、与配置文件集成等能力。
你通常只需要在配置里指定 dataset.type 与数据来源路径,或在代码中直接实例化对应的数据集类:
BaseMapDataset: 适用于可随机索引、可缓存的小中型数据BaseIterableDataset: 适用于流式/超大规模数据
快速开始
Map 与 Iterable 教程
- Map 风格:见《数据集(Map 风格)》
- Iterable 风格:见《数据集(Iterable 风格)》
以下为快速示例:
从本地文件加载(Map 风格)
from wall_e.dataset import BaseMapDataset
# 自动根据扩展名推断数据格式(csv/json/jsonl/text/tsv)
ds = BaseMapDataset(
data_source="./data/train.jsonl", # 也可传目录(load_from_disk 保存的 HF Dataset)
metadata={"task": "classification", "source": "local"},
split_ratios=(0.96, 0.02, 0.02), # 自动划分 train/valid/test
)
# 构造 Pytorch DataLoader(默认把 HF Dataset 格式化为 torch)
train_loader = ds.get_batch_loader(
split="train",
batch_size=8,
shuffle=True,
num_workers=2,
)
for batch in train_loader:
# 训练循环
pass
从 HuggingFace Hub 加载(Map 风格)
from wall_e.dataset import BaseMapDataset
ds = BaseMapDataset(
data_source="imdb", # 数据集 repo id
metadata={"task": "sentiment"},
split_ratios=None # Hub 上通常自带多个 split,关闭自动划分
)
valid_subset = ds.get_split("test") # 获取已有切分
print(valid_subset.dataset_card)
流式读取(Iterable 风格)
from wall_e.dataset import BaseIterableDataset
stream_ds = BaseIterableDataset(
data_source=["c4", "en"], # [repo_id, config]
metadata={"task": "lm"}
)
loader = stream_ds.get_batch_loader(batch_size=16, shuffle=False)
for batch in loader:
# 流式训练
pass
数据来源与自动推断
data_source 支持:
- 字符串路径:文件(csv/json/jsonl/text/tsv)或 HF
load_from_disk保存的目录 - HuggingFace Hub 名称:如
"imdb"或["c4", "en"] - 现成对象:
datasets.Dataset/DatasetDict - 序列
[path_or_repo, name]:指定 Hub 的 config
文件场景下会自动根据扩展名推断格式:
# infer_data_format: csv/json/jsonl/text/tsv -> 对应 HF 格式
Map vs Iterable 的选择
- Map:支持随机索引,
len(ds)、ds[i],适合大多数训练任务;可以save_to_disk - Iterable:流式逐条读取,内存占用友好;不支持
save_to_disk,可转为 Map 再保存
自动划分与子集抽样(Map)
当传入的是单一 Dataset(或本地原始文件)且 split_ratios 不为空时,将自动划分:
ds = BaseMapDataset("./data.jsonl", split_ratios=(0.9, 0.1)) # train/test
print(ds.dataset_card["splits"]) # ["train", "test"]
mini = ds.get_subset(split="train", n=100) # 取前100条作为子集
sampled = ds.sample(split="train", n=3) # 便捷采样,返回 HF Dataset
若原始就是 DatasetDict(多 split),自动划分会被忽略并给出提示。
map / filter 数据变换
对底层 HF Dataset 进行统一封装,保持 Map/Iterable 的一致接口:
def tokenize(batch):
# batched=True 时 batch 是样本的字典列表
return {"text": [x.lower() for x in batch["text"]]}
ds2 = ds.map(tokenize, batched=True, batch_size=64, remove_columns=["raw"])
def keep_non_empty(x):
return len(x["text"].strip()) > 0
ds3 = ds2.filter(keep_non_empty, batched=False)
注意:Iterable 模式下 filter 不支持 num_proc,且不支持 remove_columns;实现已做兼容。
批处理与 DataLoader
你可以传入 transform_fn 在取 batch 前对样本做动态格式化:
from wall_e.dataset.utils import default_collate, pseudo_collate
def to_features(batch):
# 对 HF Dataset 应用 set_transform 或 with_format("torch") 的替代方案
return {"input_ids": batch["ids"], "labels": batch["label"]}
loader = ds.get_batch_loader(
split="train",
batch_size=8,
transform_fn=to_features,
collate_fn=default_collate, # 或 pseudo_collate
shuffle=True,
)
当不传 transform_fn 时会自动调用 with_format("torch"),使返回的字段为张量。
保存与加载
# 保存(Map/Iterable 都会保存 case 与 card;Iterable 不支持保存底层数据)
ds.save_to_disk("./assert/my_dataset")
# 重新加载(Map)
reloaded = BaseMapDataset("./assert/my_dataset")
print(reloaded.dataset_card)
保存数据卡片与 case(重点)
保存后会额外生成:
dataset_case.txt或train_case.txt/...:首条样本快照,便于排查dataset_card.json:来自dataset_card属性的元信息
简易示例:
from wall_e.dataset import BaseMapDataset, BaseIterableDataset
# Map 风格:从本地 jsonl 加载,保存后生成 dataset_card.json 与 dataset_case.txt
ds_map = BaseMapDataset(
data_source="./data/samples.jsonl",
metadata={"task": "demo", "language": "en"},
)
ds_map = ds_map.map(
fn=lambda b: {"text": [t.lower() for t in b["text"]], "label": b["label"]},
batched=True,
)
ds_map.save_to_disk("./assert/dataset_map")
# Iterable 风格:同样保存 card 与 case(不保存底层数据)
ds_iter = BaseIterableDataset(
data_source="./data/samples.jsonl",
metadata={"task": "demo", "language": "en"},
)
ds_iter.save_to_disk("./assert/dataset_iter")
保存目录示例:
./assert/dataset_map/
- dataset_card.json
- dataset_case.txt
- dataset/ # Map 风格会包含 HF 数据目录
./assert/dataset_iter/
- dataset_card.json
- dataset_case.txt # Iterable 仅包含 card 与 case
dataset_card.json 典型内容:
{
"task": "sentiment analysis",
"language": "English",
"splits": "no split",
"size": 1234,
"streaming": false
}
数据卡片(dataset_card)
两类数据集都暴露 dataset_card:
print(ds.dataset_card)
# Map:{"splits": ["train", "valid", "test"], "size": {...}, "streaming": False, **metadata}
# Iterable:{"streaming": True, **metadata}
你可以在构造时传入 metadata(如任务名、数据来源、许可等),这些信息会被写入卡片并保存到磁盘。
与配置文件集成
在 cfg.yaml 中指定数据集:
dataset:
type: "BaseMapDataset" # 或 "BaseIterableDataset"
path: "./data/train.jsonl" # 或 HF repo id,如 imdb
metadata:
task: "classification"
split_ratios: [0.96, 0.02, 0.02]
在代码中按配置加载:
from wall_e.config.load_config import load_cfg
from wall_e.dataset import BaseMapDataset, BaseIterableDataset
cfg = load_cfg("./cfg.yaml")
DatasetCls = BaseMapDataset if cfg.dataset.type == "BaseMapDataset" else BaseIterableDataset
ds = DatasetCls.from_cfg(cfg.dataset.path, metadata=getattr(cfg.dataset, "metadata", None))
推送到 HuggingFace Hub(Map)
ds.push_to_hub("your-namespace/your-dataset")
常见问题(FAQ)
- 需要指定
split吗?- 当底层是
DatasetDict时(自动或原生多切分),get_batch_loader/get_subset/sample需要提供split。
- 当底层是
- 本地文件没有显式切分怎么办?
- 使用
split_ratios自动划分,或手动构造DatasetDict后传给BaseMapDataset。
- 使用
- 流式数据能保存吗?
- 会保存 case 与 card;但底层数据不支持
save_to_disk,需要先转为 Map 风格再保存。
- 会保存 case 与 card;但底层数据不支持
- 只用本地缓存加载 Hub 数据?
- 传参
only_local=True。
- 传参