概述 | 安装 | 快速开始 | 文档 | 社区 | 引用 torchtune | 许可证
torchtune 是一个用于轻松编写、后训练和实验大语言模型(LLM)的 PyTorch 库。它提供:
torchtune 支持 完整的后训练生命周期。一个成功的后训练模型可能会用到以下多种方法。
| 权重更新类型 | 单设备 | 多设备 | 多节点 |
|---|---|---|---|
| 全参数 | ✅ | ✅ | ✅ |
| LoRA/QLoRA | ✅ | ✅ | ✅ |
示例:tune run lora_finetune_single_device --config llama3_2/3B_lora_single_device
您也可以运行例如 tune ls lora_finetune_single_device 来查看所有可用配置的完整列表。
| 权重更新类型 | 单设备 | 多设备 | 多节点 |
|---|---|---|---|
| 全参数 | ❌ | ❌ | ❌ |
| LoRA/QLoRA | ✅ | ✅ | ❌ |
示例:tune run knowledge_distillation_distributed --config qwen2/1.5B_to_0.5B_KD_lora_distributed
您也可以运行例如 tune ls knowledge_distillation_distributed 来查看所有可用配置的完整列表。
| 方法 | 权重更新类型 | 单设备 | 多设备 | 多节点 |
|---|---|---|---|---|
| DPO | 全参数 | ❌ | ✅ | ❌ |
| LoRA/QLoRA | ✅ | ✅ | ❌ | |
| PPO | 全参数 | ✅ | ❌ | ❌ |
| LoRA/QLoRA | ❌ | ❌ | ❌ | |
| GRPO | 全参数 | 🚧 | ✅ | ✅ |
| LoRA/QLoRA | ❌ | ❌ | ❌ |
示例:tune run lora_dpo_single_device --config llama3_1/8B_dpo_single_device
您也可以运行例如 tune ls full_dpo_distributed 来查看所有可用配置的完整列表。
| 权重更新类型 | 单设备 | 多设备 | 多节点 |
|---|---|---|---|
| 全参数 | ✅ | ✅ | ❌ |
| LoRA/QLoRA | ❌ | ✅ | ❌ |
示例:tune run qat_distributed --config llama3_1/8B_qat_lora
您也可以运行例如 tune ls qat_distributed 或 tune ls qat_single_device 来查看所有可用配置的完整列表。
以上配置仅是帮助您入门的示例。完整的配方列表可以在 recipes/ 找到。如果您想填补上述表格中的空白,请提交 PR!如果您希望在 torchtune 中实现一个全新的后训练方法,请随时创建一个 Issue。
对于上述配方,torchtune 支持许多在 Hugging Face Hub 或 Kaggle Hub 上可用的先进模型。我们支持的部分模型:
| 模型 | 尺寸 |
|---|---|
| Llama4 | Scout (17B x 16E) [模型, 配置] |
| Llama3.3 | 70B [模型, 配置] |
| Llama3.2-Vision | 11B, 90B [模型, 配置] |
| Llama3.2 | 1B, 3B [模型, 配置] |
| Llama3.1 | 8B, 70B, 405B [模型, 配置] |
| Mistral | 7B [模型, 配置] |
| Gemma2 | 2B, 9B, 27B [模型, 配置] |
| Microsoft Phi4 | 14B [模型, 配置] |
| Microsoft Phi3 | Mini [模型, 配置] |
| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B [模型, 配置] |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B [模型, 配置] |
| Qwen2 | 0.5B, 1.5B, 7B [模型, 配置] |
我们一直在添加新模型,但如果您希望看到某个新模型出现在 torchtune 中,请随时 提交 Issue。
以下是不同 Llama 3.1 模型的内存需求和训练速度示例。
[!NOTE]
为便于比较,以下所有数据均基于批大小 2(无梯度累积)、数据集打包至序列长度 2048 且启用 torch compile 的条件提供。
如果您有兴趣在不同的硬件上或使用不同的模型运行,请查看我们关于内存优化的文档 此处,以找到适合您的设置。
| 模型 | 微调方法 | 可运行于 | 每 GPU 峰值内存 | 每秒处理 Token 数 * |
|---|---|---|---|---|
| Llama 3.1 8B | 全参数微调 | 1x 4090 | 18.9 GiB | 1650 |
| Llama 3.1 8B | 全参数微调 | 1x A6000 | 37.4 GiB | 2579 |
| Llama 3.1 8B | LoRA | 1x 4090 | 16.2 GiB | 3083 |
| Llama 3.1 8B | LoRA | 1x A6000 | 30.3 GiB | 4699 |
| Llama 3.1 8B | QLoRA | 1x 4090 | 7.4 GiB | 2413 |
| Llama 3.1 70B | 全参数微调 | 8x A100 | 13.9 GiB ** | 1568 |
| Llama 3.1 70B | LoRA | 8x A100 | 27.6 GiB | 3497 |
| Llama 3.1 405B | QLoRA | 8x A100 | 44.8 GB | 653 |
= 测量于一个完整的训练周期
*= 使用 CPU 卸载和融合优化器
torchtune 提供了许多用于内存效率和性能的调节选项。下表展示了将其中一些技术依次应用于 Llama 3.2 3B 模型的效果。每种技术都是在前一种基础上添加的,除了 LoRA 和 QLoRA,它们不使用 optimizer_in_bwd 或 AdamW8bit 优化器。
基线使用 配方=full_finetune_single_device, 模型=Llama 3.2 3B, 批大小=2, 最大序列长度=4096, 精度=bf16, 硬件=A100
| 技术 | 峰值活跃内存 (GiB) | 内存变化 vs 前一项 | 每秒处理 Token 数 | Token/秒 变化 vs 前一项 |
|---|---|---|---|---|
| 基线 | 25.5 | - | 2091 | - |
| + 打包数据集 | 60.0 | +135.16% | 7075 | +238.40% |
| + 编译 | 51.0 | -14.93% | 8998 | +27.18% |
| + 分块交叉熵 | 42.9 | -15.83% | 9174 | +1.96% |
| + 激活检查点 | 24.9 | -41.93% | 7210 | -21.41% |
| + 将优化器步骤融合到反向传播中 | 23.1 | -7.29% | 7309 | +1.38% |
| + 激活卸载 | 21.8 | -5.48% | 7301 | -0.11% |
| + 8-bit AdamW | 17.6 | -19.63% | 6960 | -4.67% |
| LoRA | 8.5 | -51.61% | 8210 | +17.96% |
| QLoRA | 4.6 | -45.71% | 8035 | -2.13% |
表格最后一行与基线 + 打包数据集相比,内存使用减少了 81.9%,每秒处理 Token 数增加了 284.3%。
tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device \
dataset.packed=True \
compile=True \
loss=torchtune.modules.loss.CEWithChunkedOutputLoss \
enable_activation_checkpointing=True \
optimizer_in_bwd=False \
enable_activation_offloading=True \
optimizer=torch.optim.AdamW \
tokenizer.max_seq_len=4096 \
gradient_accumulation_steps=1 \
epochs=1 \
batch_size=2
torchtune 仅与最新的稳定版 PyTorch(当前为 2.6.0)以及预览版 nightly 版本进行过测试,并利用 torchvision 进行多模态 LL