OA0
OA0 是一个探索 AI 的社区
现在注册
已注册用户请  登录
OA0  ›  代码  ›  Dataloader — 为多模态训练与大规模数据处理提供支持

Dataloader — 为多模态训练与大规模数据处理提供支持

 
  but ·  2026-05-21 11:00:22 · 11 次点击  · 0 条评论  

Test
DeepSource

%matplotlib inline
import matplotlib.pyplot as plt
import torch.utils.data
import torch.nn
from random import randrange
import os
os.environ["WDS_VERBOSE_CACHE"] = "1"
os.environ["GOPEN_VERBOSE"] = "0"

WebDataset 格式

WebDataset 格式的文件是 tar 文件,遵循两个约定:

  • 在每个 tar 文件中,属于同一个训练样本的文件共享相同的基本名称(去掉所有文件扩展名后)
  • tar 文件的分片编号形如 something-000000.tarsomething-012345.tar,通常使用大括号表示法指定,如 something-{000000..012345}.tar

你可以在 WebDataset 格式规范 中找到更详细、更完整的 WebDataset 格式说明。

WebDataset 可以从本地磁盘或任何管道读取文件,从而能够使用常见的云对象存储访问文件。WebDataset 还可以读取连接的 MsgPack 和 CBOR 来源。

WebDataset 表示方法允许为大规模深度学习编写纯顺序 I/O 流水线。这对于实现本地存储的高 I/O 速率(相比随机访问,本地驱动可提升 3-10 倍)以及使用对象存储和云存储进行训练非常重要。

WebDataset 格式以图像、视频、音频等的原始文件格式呈现,使得创建 WebDataset 格式的数据就像创建 tar 存档一样简单。由于数据的对齐方式,WebDataset 也能很好地与块去重配合,并将数据对齐到可预测的边界。

可以使用标准工具访问和处理 WebDataset 格式的文件。

bucket = "https://storage.googleapis.com/webdataset/testdata/"
dataset = "publaynet-train-{000000..000009}.tar"

url = bucket + dataset
!curl -s {bucket}publaynet-train-000000.tar | dd count=5000 2> /dev/null | tar tf - 2> /dev/null | sed 10q
PMC4991227_00003.json
PMC4991227_00003.png
PMC4537884_00002.json
PMC4537884_00002.png
PMC4323233_00003.json
PMC4323233_00003.png
PMC5429906_00004.json
PMC5429906_00004.png
PMC5592712_00002.json
PMC5592712_00002.png

注意:在这些 .tar 文件中,我们有一对 .json.png 文件;每一对构成一个训练样本。

WebDataset 库

有多个库支持 WebDataset 格式:

  • Python3 的 webdataset(包含 wids 库),即本仓库
  • Julia 实现的 Webdataset.jl
  • Golang 实现的 tarp 及命令行工具
  • Ray Data 的源和汇

webdataset 库可与 PyTorch、Tensorflow 和 Jax 一起使用。

webdataset

webdataset 库是 PyTorch IterableDataset 的一个实现(如果你不使用 PyTorch,则是其模拟实现)。它实现了流处理形式。其部分特性包括:

  • 通过分片实现大规模并行数据访问
  • 纯顺序读取带来的高性能磁盘 I/O
  • 大管道带来的低延迟敏感性
  • 无需本地存储
  • 训练作业即时启动
  • 只需读取文件描述符/网络流,无需特殊 API
  • 其 API 鼓励高性能 I/O 流水线
  • 从微型桌面数据集到 PB 级数据集均可扩展
  • 可选本地缓存
  • 无需数据集元数据;可以立即读取和使用任何分片集合

用户遇到的主要限制与以下事实相关:IterableDataset 在 PyTorch 中使用较少,某些现有代码可能无法很好地支持它;以及在多个计算节点上为固定 epoch 大小实现精确平衡的训练样本数量比较棘手;对于多节点训练,webdataset 通常与分片重采样一起使用。

有两个接口:简洁的“流体”接口和更长的“流水线”接口。我们将使用流体接口展示示例,这通常是你需要的。

import webdataset as wds
shuffle_buffer = 10  # 通常,选择一个更大的值,比如 1000
pil_dataset = wds.WebDataset(url).shuffle(shuffle_buffer).decode("pil").to_tuple("png", "json")
/Users/tbreuel/proj/webdataset/src/webdataset/compat.py:379: UserWarning: WebDataset(shardshuffle=...) is None; set explicitly to False or a number
  warnings.warn("WebDataset(shardshuffle=...) is None; set explicitly to False or a number")

生成的 datasets 是标准的 PyTorch IterableDataset 实例。

isinstance(pil_dataset, torch.utils.data.IterableDataset)
True
for image, json in pil_dataset:
    break
plt.imshow(image)
<matplotlib.image.AxesImage at 0x12fec5050>

png

我们可以添加到现有流水线中进行增强和数据准备。

import torchvision.transforms as transforms
from PIL import Image

preproc = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    lambda x: 1-x,
])

def preprocess(sample):
    image, json = sample
    try:
        label = json["annotations"][0]["category_id"]
    except Exception:
        label = 0
    return preproc(image), label

dataset = pil_dataset.map(preprocess)

for image, label in dataset:
    break
plt.imshow(image.numpy().transpose(1, 2, 0))
<matplotlib.image.AxesImage at 0x1677a1250>

png

WebDataset 只是标准 IterableDataset 的一个实例。它是一种单线程迭代数据集的方式。由于图像解压和数据增强可能计算密集,PyTorch 通常使用 DataLoader 类来并行化数据加载和预处理。WebDataset 与标准 DataLoader 完全兼容。

以下是多个说明如何使用 WebDataset 进行图像分类和 LLM 训练的笔记本:

wds-notes 笔记本包含一些关于库的附加文档和信息。

webdataset 流水线 API

wds.WebDataset 流体接口只是编写流水线的便捷简写。底层流水线是 wds.DataPipeline 类的一个实例,你可以显式构建数据流水线,类似于在模型内部使用 nn.Sequential 的方式。

dataset = wds.DataPipeline(
    wds.SimpleShardList(url),
    # 此时,我们有一个所有分片的迭代器

    # 这会打乱分片
    wds.shuffle(100),

    # 如果使用多个节点,在此处添加 wds.split_by_node
    wds.split_by_worker,

    # 此时,我们有一个分配给每个工作进程的分片迭代器
    wds.tarfile_to_samples(),

    # 这会打乱内存中的样本
    wds.shuffle(shuffle_buffer),

    # 这会解码图像和 json
    wds.decode("pil"),
    wds.to_tuple("png", "json"),
    wds.map(preprocess),
    wds.batched(16)
)

batch = next(iter(dataset))
batch[0].shape, batch[1].shape
(torch.Size([16, 3, 224, 224]), (16,))

安全模式

你可以运行 WebDataset 以提高安全性。这会禁用 pipe:file: 协议,以及尝试解码 Python pickle。这应该能禁用简单的攻击,目前尚无已知的成功攻击;自行承担使用此模式的风险。

你可以通过在使用库之前设置 webdataset.utils.enforce_security = True 来启用安全模式。你也可以在环境中设置 WDS_SECURE=1

安装和文档

$ pip install webdataset

对于 Github 版本:

$ pip install git+https://github.com/tmbdev/webdataset.git

以下是一些关于 WebDataset 和大规模深度学习的视频:

依赖

WebDataset 库仅需要 PyTorch、NumPy 和一个名为 braceexpand 的小型库。

WebDataset 仅在需要时动态加载一些额外的库,且仅在解码器中:

  • PIL/Pillow 用于图像解码
  • torchvisiontorchvideotorchaudio 用于图像/视频/音频解码
  • msgpack 用于 MessagePack 解码
  • curl 命令行工具用于访问 HTTP 服务器
  • Google/Amazon/Azure 命令行工具用于访问云存储桶

这些库的加载由配置的解码器触发,当尝试解码给定格式的内容并在解码过程中遇到该格式的文件时加载。(最终,torch... 依赖将重构到这些库中。)

11 次点击  ∙  0 人收藏  
登录后收藏  
0 条回复
关于 ·  帮助 ·  PING ·  隐私 ·  条款   
OA0 - Omni AI 0 一个探索 AI 的社区
沪ICP备2024103595号-2
耗时 23 ms
Developed with Cursor