快速开始 | 预训练 VLA 模型 | 安装 | 通过 LoRA 微调 OpenVLA | 全参数微调 OpenVLA |
从头训练 VLA | 评估 OpenVLA | 项目网站
这是一个用于训练和微调通用机器人操作的视觉-语言-动作模型(VLA)的简单且可扩展的代码库:
基于 Prismatic VLMs 构建。
为了开始加载并运行 OpenVLA 模型进行推理,我们提供了一个轻量级接口,它利用了 HuggingFace transformers 的 AutoClasses,依赖项极少。
例如,要在 BridgeData V2 环境 中加载 openvla-7b 模型,用于 WidowX 机器人的零样本指令跟随:
# 安装最小依赖项 (`torch`, `transformers`, `timm`, `tokenizers`, ...)
# > pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
import torch
# 加载处理器和 VLA
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
"openvla/openvla-7b",
attn_implementation="flash_attention_2", # [可选] 需要 `flash_attn`
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to("cuda:0")
# 获取图像输入并格式化提示
image: Image.Image = get_from_camera(...)
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"
# 预测动作(7自由度;针对 BridgeData V2 进行反归一化)
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
# 执行...
robot.act(action, ...)
我们还提供了一个用于为新任务和机器人形态微调 OpenVLA 模型的示例脚本;该脚本支持不同的微调模式——包括由 HuggingFace 的 PEFT 库 支持的(量化)低秩适配(LoRA)。
对于部署,我们提供了一个轻量级脚本用于通过 REST API 提供 OpenVLA 模型服务,为将 OpenVLA 模型集成到现有机器人控制栈中提供了一种简单的方法,无需强大的设备端计算能力。
我们发布了作为我们工作一部分训练的两个 OpenVLA 模型,检查点、配置和模型卡片可在 我们的 HuggingFace 页面 获取:
- openvla-7b:我们论文中的旗舰模型,基于 Prismatic prism-dinosiglip-224px VLM(融合了 DINOv2 和 SigLIP 视觉主干以及 Llama-2 LLM)训练。在来自 Open X-Embodiment 的包含 97 万条轨迹的大型数据集混合上训练(混合详情 - 参见 "Open-X Magic Soup++")。
- openvla-v01-7b:开发过程中使用的早期模型,基于 Prismatic siglip-224px VLM(单一的 SigLIP 视觉主干和 Vicuña v1.5 LLM)训练。在与 Octo 相同的数据集混合上训练,但 GPU 训练时长远少于我们的最终模型(混合详情 - 参见 "Open-X Magic Soup")。
关于模型许可和商业使用的明确说明:虽然本仓库中的所有代码均以 MIT 许可证发布,但我们的预训练模型可能继承自我们使用的底层基础模型的限制。具体来说,上述两个模型都源自 Llama-2,因此受 Llama 社区许可证 约束。
注意:这些安装说明适用于全规模预训练(和分布式微调);如果只想运行 OpenVLA 模型进行推理(或进行轻量级微调),请参见上面的说明!
本仓库使用 Python 3.10 构建,但应向后兼容任何 Python >= 3.8。我们需要 PyTorch 2.2.* —— 安装说明可在此处找到。本仓库的最新版本使用以下版本开发和全面测试:
- PyTorch 2.2.0, torchvision 0.17.0, transformers 4.40.1, tokenizers 0.19.1, timm 0.9.10, 以及 flash-attn 2.5.5
[2024年5月21日] 注意:根据 transformers、timm 和 tokenizers 后续版本中报告的回归问题和破坏性变更,我们明确固定了上述依赖版本。我们正在努力实现全面的测试,并计划尽快放宽这些限制。
使用以下设置命令开始:
# 创建并激活 conda 环境
conda create -n openvla python=3.10 -y
conda activate openvla
# 安装 PyTorch。下面是一个示例命令,但你应该检查以下链接
# 以找到针对你计算平台的特定安装说明:
# https://pytorch.org/get-started/locally/
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia -y # 请更新此命令!
# 克隆并安装 openvla 仓库
git clone https://github.com/openvla/openvla.git
cd openvla
pip install -e .
# 为训练安装 Flash Attention 2 (https://github.com/Dao-AILab/flash-attention)
# =>> 如果遇到困难,请先尝试 `pip cache remove flash_attn`
pip install packaging ninja
ninja --version; echo $? # 验证 Ninja --> 应返回退出码 "0"
pip install "flash-attn==2.5.5" --no-build-isolation
如果在安装过程中遇到任何问题,请提交 GitHub Issue。
注意: 有关 OpenVLA 模型的完整训练和验证脚本,请参见 vla-scripts/。注意 scripts/ 目录主要是原始(基础)prismatic-vlms 仓库的遗留部分,支持训练和评估视觉条件语言模型;虽然你可以使用此仓库训练 VLM 和 VLA,但请注意,尝试使用现有的 OpenVLA 模型生成语言(通过 scripts/generate.py)将无法工作(因为我们只训练当前的 OpenVLA 模型来生成动作,且仅生成动作)。
(2025-03-03 更新:我们建议尝试新的 OFT 方案来微调 OpenVLA,以产生更快、更成功的策略。查看项目网站此处。)
在本节中,我们将讨论如何使用 Hugging Face transformers 库通过低秩适配(LoRA)微调 OpenVLA。如果你没有足够的计算资源来全参数微调一个 70 亿参数的模型,推荐使用此方法。LoRA 微调的主要脚本是 vla-scripts/finetune.py。(如果你想进行全参数微调,请参见全参数微调 OpenVLA 部分。)
下面我们展示一个示例,说明如何通过 LoRA 微调主要的 OpenVLA 检查点(openvla-7b)。这里我们使用单个 80 GB VRAM 的 A100 GPU 在 BridgeData V2 上进行微调。(你也可以使用更小的 GPU 进行微调,只要它至少有约 27 GB 内存,通过修改批大小即可。)
首先,下载 BridgeData V2 数据集:
# 切换到你的基础数据集文件夹
cd <基础数据集目录路径>
# 下载完整数据集 (124 GB)
wget -r -nH --cut-dirs=4 --reject="index.html*" https://rail.eecs.berkeley.edu/datasets/bridge_release/data/tfds/bridge_dataset/
# 将数据集重命名为 `bridge_orig` (注意:省略此步骤可能导致后续运行时错误)
mv bridge_dataset bridge_orig
现在,启动 LoRA 微调脚本,如下所示。注意,--batch_size==16 且 --grad_accumulation_steps==1 需要约 72 GB GPU 内存。如果你的 GPU 较小,应减小 --batch_size 并增加 --grad_accumulation_steps,以保持足够大的有效批大小以进行稳定训练。如果你有多个 GPU 并希望通过 PyTorch 分布式数据并行(DDP)进行训练,只需将下面 torchrun 命令中的 --nproc-per-node 设置为可用 GPU 的数量。
torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/finetune.py \
--vla_path "openvla/openvla-7b" \
--data_root_dir <基础数据集目录路径> \
--dataset_name bridge_orig \
--run_root_dir <日志/检查点目录路径> \
--adapter_tmp_dir <保存适配器权重的临时目录路径> \
--lora_rank 32 \
--batch_size 16 \
--grad_accumulation_steps 1 \
--learning_rate 5e-4 \
--image_aug <True 或 False> \
--wandb_project <项目名> \
--wandb_entity <实体名> \
--save_steps <每个检查点保存间隔的梯度步数>
注意:如果你在上述命令中设置 --image_aug==False,你将在训练日志中观察到接近 100% 的 action_accuracy,因为 openvla-7b 模型已经在包含 BridgeData V2 的数据集超集上进行了预训练(未使用数据增强)。
要在其他数据集上进行 LoRA 微调,你可以从 Open X-Embodiment (OXE) 混合中下载数据集(参见此自定义脚本了解如何从 OXE 下载数据集的示例)。或者,如果你有一个不属于 OXE 的自定义数据集,你可以 (a) 将数据集转换为与我们的微调脚本兼容的 RLDS 格式(有关说明,请参见此仓库),或 (b) 使用你自己的自定义 PyTorch Dataset 包装器(有关说明,请参见 vla-scripts/finetune.py 中的注释)。对于大多数用户,我们推荐选项 (a);RLDS 数据集和数据加载器经过了更广泛的测试,因为我们在所有的预训练和微调实验中使用了它们。
对于选项 (a),在将数据集转换为 RLDS 后,你需要通过在此处注册数据集配置此处和数据集转换函数此处来将其注册到我们的数据加载器。
一旦你整合了新的数据集,就可以使用上面相同的 vla-scripts/finetune.py 脚本启动 LoRA 微调。如果遇到任何问题,请访问 VLA 故障排除 部分或在 OpenVLA GitHub Issues 页面(包括“已关闭”的问题)中搜索类似问题。如果在那里找不到类似问题,请随时创建一个新 Issue。
(2025-03-03 更新:我们建议尝试新的 OFT 方案来微调 OpenVLA,以产生更快、更成功的策略。查看项目网站此处。)
在本节中,我们将讨论使用 Prismatic VLMs 训练脚本,通过原生 PyTorch 全分片数据并行(FSDP)全参数微调 OpenVLA(所有 75 亿参数)。全参数微调更高级/复杂,仅在你拥有足够计算资源(例如,一个包含 8 个 A100 GPU 的完整节点)且 LoRA 微调不足以满足你的用例(例如,如果微调分布与预训练分布差异巨大)时推荐。否则,我们建议你尝试通过 LoRA 进行参数高效微调,这在 通过 LoRA 微调 OpenVLA 部分有描述。
对于全参数微调,你需要下载与 Prismatic VLMs 代码库兼容的 OpenVLA 模型检查点版本,我们基于该代码库开发了 OpenVLA 模型。你可以使用下面的 git 命令下载这个与 Prismatic 兼容的 OpenVLA 检查点(或者,你也可以通过 Hugging Face CLI 下载):
```bash
cd <基础模型检查点目录路径>
git clone git@hf.co:openvla/openvla-7b-prismatic