这是 Marlin(Mixed Auto-Regressive Linear kernel 的缩写,也是一种地球上速度最快的鱼的名字),一个极度优化的 FP16xINT4 矩阵乘法内核,专为 LLM 推理设计。它能够在批次大小达到 16-32 个 token 时提供接近理想(4倍)的加速效果(相比之下,先前具有类似加速效果的工作仅支持 1-2 个 token)。这使得 Marlin 非常适合大规模服务、推测解码或高级多推理方案,如 CoT-Majority。
大多数现代 GPU 的 FLOP 与字节比约为 100-200。因此,只要我们对每 4 位量化权重执行少于 25-50 次(张量核心)乘加运算,理论上就有可能维持相对于 FP16 权重接近理想的 4 倍加速。这意味着,权重量化带来的全部性能优势,原则上应能扩展到比现有内核当前支持的批次大小大 4-8 倍的范围。
然而,在实践中实现这一点极具挑战性,因为我们本质上需要同时充分利用所有可用的 GPU 资源(全局内存、L2 缓存、共享内存、张量核心、向量核心)。Marlin 通过多种技术和优化实现了这一点,简要概述如下:
我们首先在一个可以在 NVIDIA A10 GPU 上理想分区的大型矩阵上,比较 Marlin 与其他流行的 4 位推理内核的性能。这允许所有内核达到其最佳性能。所有内核均在分组大小 128 下执行(但请注意,缩放因子的格式并非 100% 相同)。
现有内核在批次大小为 1 时实现了相对接近最优的 3.87 倍加速(注意分组缩放因子有 0.125 位的存储开销),但随着输入数量的增加,其性能迅速下降。相比之下,Marlin 在所有批次大小下都提供了基本理想的加速,在批次大小约为 16-32 时仍能实现最大可能的 3.87 倍加速。
得益于其条带化分区方案,Marlin 在真实(较小)的矩阵和各种 GPU 上也表现出强劲的性能。下面的结果证明了这一点,我们在批次大小为 16 的情况下,对流行开源模型的 Transformer 块中所有线性层的总运行时间进行了基准测试。
最后,我们还研究了在锁定 GPU 基础时钟频率的情况下,长时间内可以维持的性能。有趣的是,我们发现降低时钟频率会显著损害先前内核的相对加速比,但对 Marlin 几乎最优的性能(相对于较低的时钟设置)没有影响。
nvcc 编译器的版本应与 torch 匹配)torch>=2.0.0numpytransformersdatasetssentencepiece如果满足所有要求,可以通过在此仓库的根目录下调用以下命令来安装 Marlin:
pip install .
安装后,使用 Marlin 内核最简单的方式是通过 marlin.Layer,这是一个代表 Marlin 量化层的 torch 模块。它允许通过 marlin.Layer.pack(linear, scales) 将一个“伪量化”(反量化值以 FP16 存储)的 torch.Linear 层转换为压缩的 Marlin 格式。或者,如果权重和缩放因子已经过适当的预处理(参见 marlin.Layer.pack(...)),也可以通过 marlin.mul(..) 直接调用内核。内核本身位于独立的 marlin/marlin_cuda_kernel.cu 文件中,该文件除了基础 CUDA 外不包含任何依赖项,因此应该很容易集成到其他低级框架中。
可以通过 python test.py 运行正确性测试,通过 python bench.py 运行基准测试。请注意,为了复现我们的“可持续性能”基准测试,需要使用以下命令将 GPU 时钟锁定在各自的基础值:
sudo nvidia-smi --lock-gpu-clocks=BASE_GPU_CLOCK --lock-memory-clocks=BASE_MEM_CLOCK
此外,如果启用了 ECC(例如在 A10 上),则最大可实现的内存带宽将比官方规格表低 10-15%,因为每个内存请求都会包含校验和开销。可以通过以下命令禁用 ECC:
sudo nvidia-smi -e 0
我们在 A10 基准测试中就是这样做的。
在 gptq 子文件夹中,我们还提供了 GPTQ 算法的一个略微改进的版本,具有更好的分组网格裁剪和非均匀校准样本长度,可以生成与 Marlin 兼容的 Llama2 模型的 4 位版本。此外,还有一个脚本可以在流行的 LLM 评估工具 中使用 Marlin 内核评估此类压缩模型。下面的脚本使用 lm-eval-harness==0.4.0 进行了测试,可能不适用于更新或更旧的版本。以下是相应的示例命令(必须安装 marlin、transformers 和 datasets 包):
% 压缩 Llama2 模型并以 Marlin 格式导出模型。
python llama2.py LLAMA2_CHECKPOINT --wbits 4 --save checkpoint.pt
% 对未压缩模型执行困惑度评估。
python llama2.py LLAMA2_CHECKPOINT
% 在评估工具中使用 Marlin 内核评估压缩模型。
python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu \
--marlin_checkpoint checkpoint.marlin.g128
% 评估全精度基线模型。
python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu
我们测量了 4 位(分组大小=128)Marlin 模型的以下 WikiText 和 Red-Pajama 困惑度,以及 MMLU 零样本准确率:
| Llama2 | Wiki2 (FP16) | Wiki2 (INT4) | RedPaj (FP16) | RedPaj (INT4) | MMLU (FP16) | MMLU (INT4) |
|---|---|---|---|---|---|---|
| 7B | 5.12 | 5.27 | 6.14 | 6.30 | 41.80 | 40.07 |
| 13B | 4.57 | 4.67 | 5.67 | 5.79 | 52.10 | 51.13 |
| 70B | 3.12 | 3.21 | 4.74 | 4.81 | 65.43 | 64.81 |
请注意,此 GPTQ 示例目前主要旨在演示如何生成准确的 Marlin 模型,并作为内核正确性的端到端验证(而不是作为一个灵活的压缩工具)。
如果您觉得这项工作有用,请考虑引用:
@article{frantar2024marlin,
title={MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models},
author={Frantar, Elias and Castro, Roberto L and Chen, Jiale and Hoefler, Torsten and Alistarh, Dan},
journal={arXiv preprint arXiv:2408.11743},
year={2024}
}