TinyLlama 项目旨在使用 3 万亿 token 预训练一个 1.1B 参数的 Llama 模型。通过适当的优化,我们可以在 90 天内,使用 16 张 A100-40G GPU 完成这一目标 🚀🚀。训练已于 2023-09-01 开始。
我们采用了与 Llama 2 完全相同的架构和分词器。这意味着 TinyLlama 可以即插即用地应用于许多基于 Llama 的开源项目。此外,TinyLlama 非常紧凑,仅有 1.1B 参数。这种紧凑性使其能够满足众多对计算和内存资源有严格限制的应用场景。
您可以在 EVAL.md 中找到 TinyLlama 的评估结果。
我们将按照以下时间表发布中间检查点。
基础模型:
| 日期 | Hugging Face 检查点 | Token 数量 | 训练步数 | 常识推理平均分 |
|---|---|---|---|---|
| 2023-09-01 | Pythia-1.0B | 300B | 143k | 48.30 |
| 2023-09-04 | TinyLlama-1.1B-intermediate-step-50k-105b | 105B | 50k | 46.11 |
| 2023-09-16 | TinyLlama-1.1B-intermediate-step-240k-503b | 503B | 240K | 48.28 |
| 2023-10-01 | TinyLlama-1.1B-intermediate-step-480k-1T | 1T | 480k | 50.22 |
| 2023-11-04 | TinyLlama-1.1B-intermediate-step-715k-1.5T | 1.5T | 715k | 51.28 |
| 2023-11-20 | TinyLlama-1.1B-intermediate-step-955k-2T | 2T | 955k | 51.64 |
| 2023-12-11 | TinyLlama-1.1B-intermediate-step-1195k-2.5T | 2.5T | 1195k | 53.86 |
| 2023-12-28 | TinyLlama-1.1B-intermediate-step-1431k-3T | 3T | 1431k | 52.99 |
我们正在撰写一篇说明,解释为什么从 2T 到 2.5T 检查点有显著提升(这与 bos_id 问题 有关)。
聊天模型:
| 日期 | Hugging Face 检查点 | Token 数量 | 训练步数 | 常识推理平均分 |
|---|---|---|---|---|
| 2023-09-16 | TinyLlama-1.1B-Chat-V0.1 | 503B | 240K | 49.57 |
| 2023-10-1 | TinyLlama-1.1B-Chat-V0.3 | 1T | 480K | 51.36 |
| 2023-11-04 | TinyLlama-1.1B-Chat-V0.4 | 1.5T | 715K | 52.30 |
请注意,基础模型的学习率尚未冷却,因此我们建议您也使用微调后的聊天模型。
同时,您可以在此处实时跟踪交叉熵损失:这里。
小巧而强大的语言模型在许多应用中都非常有用。以下是一些潜在的应用场景:
- 辅助更大模型的推测解码。(参见 Andrej Karpathy 的教程)
- 部署在内存和计算能力受限的边缘设备上,实现无需互联网连接的实时机器翻译等功能(4位量化的 TinyLlama-1.1B 权重仅占用 637 MB)。
- 在视频游戏中实现实时对话生成。
此外,我们的代码可以作为对预训练 50 亿参数以下语言模型感兴趣的爱好者的参考,而无需过早深入研究 Megatron-LM。
以下是我们训练设置的一些细节:
| 设置项 | 描述 |
|---|---|
| 参数量 | 1.1B |
| 注意力机制变体 | 分组查询注意力 |
| 模型尺寸 | 层数: 22, 注意力头数: 32, 查询组数: 4, 嵌入维度: 2048, 中间层维度 (Swiglu): 5632 |
| 序列长度 | 2048 |
| 批次大小 | 200 万 token (2048 * 1024) |
| 学习率 | 4e-4 |
| 学习率调度 | 余弦退火,包含 2000 步预热。关于一个小错误,请参见 Issue 27 |
| 训练数据 | Slimpajama & Starcoderdata |
| 数据预处理 | 排除了 Slimpajama 中的 GitHub 子集;从 Starcoderdata 中采样了所有代码数据 |
| 合并数据集大小 | 约 950B token |
| 训练期间总 Token 数 | 3 万亿 (略多于 3 个周期/1430k 步) |
| 自然语言与代码比例 | 7:3 |
| 硬件 | 16 张 A100-40G GPU |
我们的代码库支持以下特性:
- 使用 FSDP 进行多 GPU 和多节点分布式训练。
- Flash Attention 2。
- 融合层归一化。
- 融合 Swiglu。
- 融合交叉熵损失。
- 融合旋转位置嵌入。
致谢:Flash Attention 2、融合层归一化、融合交叉熵损失和融合旋转位置嵌入来自 FlashAttention 仓库。融合 Swiglu 来自 xformers。
得益于这些优化,我们在每张 A100-40G GPU 上实现了 每秒 24k token 的吞吐量,这相当于 56% 的模型浮点运算利用率(未使用激活检查点)(我们预计在 A100-80G 上 MFU 会更高)。这意味着您可以在 32 小时内,使用 8 张 A100 训练一个 chinchilla 最优的 TinyLlama(1.1B 参数,220 亿 token)。这些优化也极大地减少了内存占用,使我们能够将 1.1B 模型塞进 40GB 的 GPU 显存中,并以每 GPU 16k token 的批次大小进行训练。您也可以在 3090/4090 GPU 上以更小的每 GPU 批次大小预训练 TinyLlama。
以下是我们代码库与 Pythia 和 MPT 训练速度的对比。
| 模型 | 在 300B token 上消耗的 A100 GPU 小时数 |
|---|---|
| TinyLlama-1.1B | 3456 |
| Pythia-1.0B | 4830 |
| MPT-1.3B | 7920 |
Pythia 的数据来自他们的论文。MPT 的数据来自此处,其中提到 MPT-1.3B "在 200B token 上,使用 440 张 A100-40GB 训练了大约半天"。
TinyLlama 是一个相对较小且采用分组查询注意力的模型,这意味着它在推理时也很快。以下是我们测量的一些吞吐量:
| 框架 | 设备 | 设置 | 吞吐量 (token/秒) |
|---|---|---|---|
| Llama.cpp | Mac M2 16GB RAM | batch_size=1; 4-bit 推理 | 71.8 |
| vLLM | A40 GPU | batch_size=100, n=10 | 7094.5 |
有关如何预训练 TinyLlama 的说明,请参阅 PRETRAIN.md。
我们在 sft 中包含了一个简单的全参数微调和推理脚本。我们的 V0.1 聊天模型就是使用此脚本微调的。我们使用的微调数据集是 openassistant-guanaco。
对于使用少于 4GB 内存的微调,我们推荐您参考 Qlora 和 bitsandbytes 仓库。
我们没有进行广泛的超参数调优,也没有选择性能更好的微调数据集。我们希望社区能够探索对 TinyLlama 的微调,并开发出更好的聊天模型。我将把社区微调的模型包含在此仓库中。
该项目仍在积极开发中。我们是一个很小的团队。社区的反馈和贡献将非常受欢迎。以下是我们计划开展的工作:
- [ ] 添加在其他数据集上进行预训练的脚本。
- [ ] 序列长度外推。
- [ ] 测试 Llama-2-7B 的推测解码。
- [ ] 测试在 RTX 3090/4090 上的吞吐量。
- [ ] 添加微调脚本。
- [ ] 在下游任务上正确评估模型。
- [ ] 在手机上运行的演示。
- [ ] 探索检索增强。
此仓库基于 lit-gpt 和 flash-attention 构建。如果您对这个出色的开源项目还不熟悉,请务必探索一下!
@online{lit-gpt,
author = {Lightning AI},
title = {Lit-GPT},
url = {https://github.com/Lightning-AI/lit-gpt},
year = {2023},
}
@article{dao2023flashattention2,
title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author ={Dao, Tri},
year ={2023}
}
该项目目前由新加坡科技设计大学 StatNLP 研究组的 Peiyuan Zhang 、Guangtao Zeng 、Tianduo Wang 和 Wei Lu 贡献。
如果您认为我们的工作有价值,请引用:
@misc{zhang2024tinyllama,
title={TinyLlama: An Open-Source Small Language Model},
author={Peiyuan Zhang and Guangtao Zeng and Tianduo Wang and Wei Lu},
year={2024},
eprint={2401.02385},
archivePrefix={arXiv},
primaryClass={cs.CL}
}

上图是 Llama 2 论文中的训练损失曲线。我在此引用该论文:"我们观察到,在 2T Token 上预训练后,模型仍未显示出任何饱和迹象"。这就是为什么我们认为在 3T token 上预训练一个 1.1B 模型是合理的。即使损失曲线最终不再下降,我们仍然可以研究饱和现象并从中学习。

Pythia 论文中的图表显示了 LAMBADA 准确率与总训练 token 数(300B)的关系。"饱和"一词特指 70M 和 160M 模型。值得注意的是,即使是 410M 模型在 300B token 上也没有饱和,因为它仍然显示出与更大模型相似的增长趋势。