本仓库提供了以下论文中 FlashAttention 和 FlashAttention-2 的官方实现。
FlashAttention: 具有 IO 感知的快速、内存高效精确注意力机制
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
论文:https://arxiv.org/abs/2205.14135
IEEE Spectrum 关于我们使用 FlashAttention 提交 MLPerf 2.0 基准测试的文章。

FlashAttention-2: 具有更好并行性和工作划分的更快注意力机制
Tri Dao
论文:https://tridao.me/publications/flash2/flash2.pdf

我们很高兴看到 FlashAttention 在发布后短时间内被广泛采用。这个页面列出了部分使用 FlashAttention 的项目。
FlashAttention 和 FlashAttention-2 可免费使用和修改(见 LICENSE)。如果您使用了它,请引用并注明来源。
FlashAttention-3 针对 Hopper 架构 GPU(如 H100)进行了优化。
博客文章:https://tridao.me/blog/2024/flash3/
论文:https://tridao.me/publications/flash3/flash3.pdf

这是一个 Beta 版本,用于在我们将其集成到仓库其余部分之前进行测试和基准测试。
目前已发布:
- FP16 / BF16 前向和反向传播,FP8 前向传播
要求:H100 / H800 GPU,CUDA >= 12.3。
我们强烈推荐使用 CUDA 12.8 以获得最佳性能。
安装方法:
cd hopper
python setup.py install
运行测试:
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py
安装完成后,可以按如下方式导入:
import flash_attn_interface
flash_attn_interface.flash_attn_func()
要求:
- CUDA 工具包或 ROCm 工具包
- PyTorch 2.2 及以上版本
- packaging Python 包 (pip install packaging)
- psutil Python 包 (pip install psutil)
- ninja Python 包 (pip install ninja) *
- Linux。从 v2.3.2 开始可能支持 Windows(我们已看到一些积极的报告),但 Windows 编译仍需更多测试。如果您有关于如何为 Windows 设置预构建 CUDA wheel 包的想法,请通过 Github Issue 联系我们。
* 请确保 ninja 已安装且工作正常(例如,运行 ninja --version 然后 echo $? 应返回退出码 0)。如果不行(有时 ninja --version 后 echo $? 返回非零退出码),请卸载后重新安装 ninja (pip uninstall -y ninja && pip install ninja)。没有 ninja,编译会耗时很长(2小时),因为它不使用多 CPU 核心。使用 ninja 后,在 64 核机器上使用 CUDA 工具包编译只需 3-5 分钟。
安装方法:
pip install flash-attn --no-build-isolation
或者,您也可以从源代码编译:
python setup.py install
如果您的机器内存少于 96GB 且 CPU 核心数很多,ninja 可能会启动过多的并行编译任务,导致内存耗尽。要限制并行编译任务数,可以设置环境变量 MAX_JOBS:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
接口文件: src/flash_attention_interface.py
要求:
- CUDA 12.0 及以上版本。
我们推荐使用 Nvidia 的 Pytorch 容器,它包含了安装 FlashAttention 所需的所有工具。
目前 FlashAttention-2 在 CUDA 上支持:
1. Ampere、Ada 或 Hopper 架构 GPU(例如 A100、RTX 3090、RTX 4090、H100)。对 Turing 架构 GPU(T4、RTX 2080)的支持即将推出,目前请使用 FlashAttention 1.x。
2. 数据类型 fp16 和 bf16(bf16 需要 Ampere、Ada 或 Hopper 架构 GPU)。
3. 所有头维度(head dimension)最大支持到 256。~~头维度 > 192 的反向传播需要 A100/A800 或 H100/H800~~。从 flash-attn 2.5.5 开始,头维度 256 的反向传播现在可以在消费级 GPU 上运行(如果没有 dropout)。
ROCm 版本有两个后端。默认后端是 composable_kernel (ck),还有一个 Triton 后端。它们都提供了 FlashAttention-2 的实现。
要求:
- ROCm 6.0 及以上版本。
我们推荐使用 ROCm 的 Pytorch 容器,它包含了安装 FlashAttention 所需的所有工具。
目前 FlashAttention-2 ROCm CK 后端支持:
1. MI200x、MI250x、MI300x 和 MI355x GPU。
2. 数据类型 fp16 和 bf16。
3. 前向和反向传播的头维度最大支持到 256。
基于 Triton 的 Flash Attention 实现支持 AMD 的 CDNA(MI200、MI300)和 RDNA GPU,使用 fp16、bf16 和 fp32 数据类型。它提供前向和反向传播,支持因果掩码、可变序列长度、任意 Q/KV 序列长度和头大小、MQA/GQA、dropout、旋转位置编码、ALiBi、分页注意力(paged attention)以及 FP8(通过 Flash Attention v3 接口)。滑动窗口注意力目前正在开发中。
安装方法:首先从 https://pytorch.org/get-started/locally/ 获取 ROCm 版 PyTorch,然后安装 Triton 和 Flash Attention:
pip install triton==3.5.1
cd flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
运行测试(注意:完整测试套件需要数小时):
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
为获得更好性能,可使用 FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" 启用自动调优。
或者,如果不进行自动调优,可以使用 FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON 来设置单个 triton 配置,覆盖 attn_fwd 的硬编码默认值。例如:
FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON='{"BLOCK_M":128,"BLOCK_N":64,"waves_per_eu":1,"PRE_LOAD_V":false,"num_stages":1,"num_warps":8}'
使用 Docker 快速开始:
FROM rocm/pytorch:latest
WORKDIR /workspace
# 安装 triton
RUN pip install triton==3.5.1
# 使用 triton 后端构建 flash attention
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
cd flash-attention &&\
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
# 设置工作目录
WORKDIR /workspace/flash-attention
# 设置环境变量以使用 triton 后端
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
构建并运行:
docker build -t flash-attn-triton .
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton
主要函数实现了缩放点积注意力(scaled dot product attention): softmax(Q @ K^T * softmax_scale) @ V
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""评估时应将 dropout_p 设置为 0.0
如果 Q、K、V 已经堆叠成一个张量,调用此函数会比在 Q、K、V 上调用 flash_attn_func 更快,因为反向传播避免了显式拼接 Q、K、V 的梯度。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注键在 [i - window_size[0], i + window_size[1]] 区间内(包含边界)。
参数:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float。Dropout 概率。
softmax_scale: float。在应用 softmax 之前对 QK^T 的缩放因子。默认为 1 / sqrt(headdim)。
causal: bool。是否应用因果注意力掩码(例如,用于自回归建模)。
window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。会在查询 i 和键 j 的注意力分数上加上 (-alibi_slope * |i - j|) 的偏置。
deterministic: bool。是否使用确定性的反向传播实现,该实现稍慢且使用更多内存。前向传播始终是确定性的。
返回:
out: (batch_size, seqlen, nheads, headdim)。
"""
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""评估时应将 dropout_p 设置为 0.0
支持多查询注意力(MQA)和分组查询注意力(GQA),方法是传入比 Q 头数少的 KV。注意,Q 的头数必须能被 KV 的头数整除。
例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的头 0、1、2 将关注 K、V 的头 0,Q 的头 3、4、5 将关注 K、V 的头 1。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注键在
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 区间内(包含边界)。
参数:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float。Dropout 概率。
softmax_scale: float。在应用 softmax 之前对 QK^T 的缩放因子。默认为 1 / sqrt(headdim)。
causal: bool。是否应用因果注意力掩码(例如,用于自回归建模)。
window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。会在查询 i 和键 j 的注意力分数上加上
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置。
deterministic: bool。是否使用确定性的反向传播实现,该实现稍慢且使用更多内存。前向传播始终是确定性的。
返回:
out: (batch_size, seqlen, nheads, headdim)。
"""
```python
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 表示无限上下文窗口
rotary_interleaved=True,
alibi_slopes=None,
):
"""
如果 k 和 v 不为 None,k_cache 和 v_cache 将原地更新为来自 k 和 v 的新值。这对于增量解码很有用:您可以传入上一步缓存的键/值,并用当前步的新键/值更新它们,并在 1 个内核中完成与更新后缓存的注意力计算。
如果传入了 k / v,必须确保缓存足够大以容纳新值。例如,KV 缓存可以预分配最大序列长度,然后使用 cache_seqlens 来跟踪批次中每个序列的当前序列长度。
如果传入了 rotary_cos 和 rotary_sin,还会应用旋转位置编码。键 @k 将在 cache_seqlens、cache_seqlens + 1 等索引处被 rotary_cos 和 rotary_sin 旋转。
如果 causal 或 local(即 window_size != (-1, -1)),查询 @q 将在 cache_seqlens、cache_seqlens + 1 等索引处被旋转。
如果不是 causal 且不是 local,查询 @q 将只在 cache_seqlens 索引处被旋转(即我们认为 @q 中的所有 token 都位于 cache_seqlens 位置)。
有关如何使用此函数的示例,请参见 tests/test_flash_attn.py::test_flash_attn_kvcache。
支持多查询注意力(MQA)和分组查询注意力(GQA),方法是传入比 Q 头数少的 KV。注意,Q 的头数必须能被 KV 的头数整除。
例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的头 0、1、2 将关注 K、V 的头 0,Q 的头 3、4、5 将关注 K、V 的头 1。
如果 causal=True,因果掩码与注意力矩阵的右下角对齐。
例如,如果 seqlen_q = 2 且 seqlen_k = 5,因果掩码(1 = 保留,0 = 屏蔽)为:
1 1 1 1 0
1 1 1 1 1
如果 seqlen_q = 5 且 seqlen_k = 2,因果掩码为:
0 0
0 0
0 0
1 0
1 1
如果掩码的某一行全为零,则输出将为零。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注键在
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 区间内(包含边界)。
注意:不支持反向传播。
参数:
q: (batch_size, seqlen, nheads, headdim)
k_cache: 如果没有 block_table,则为 (batch_size_cache, seqlen_cache, nheads_k, headdim);
如果有 block_table(即分页 KV 缓存),则为 (num_blocks, page_block_size, nheads_k, headdim)
page_block_size 必须是 256 的倍数。
v_cache: 同上。
k [可选]: (batch_size, seqlen_new, nheads_k, headdim)。如果不为 None,我们将从 cache_seqlens 指定的索引开始,将 k 与 k_cache 拼接。
v [可选]: (batch_size, seqlen_new, nheads_k, headdim)。类似 k。
rotary_cos [可选]: (seqlen_ro, rotary_dim / 2)。如果不为 None,我们将对 k 和 q 应用旋转位置编码。仅在传入 k 和 v 时适用。rotary_dim 必须能被 16 整除。
rotary_sin [可选]: (seqlen_ro, rotary_dim / 2)。类似 rotary_cos。
cache_seqlens: int, 或 (batch_size,), dtype torch.int32。KV 缓存的序列长度。
block_table [可选]: (batch_size, max_num_blocks_per_seq), dtype torch.int32。
cache_batch_idx: (batch_size,), dtype torch.int32。用于索引 KV 缓存的索引。
如果为 None,我们假设批次索引是 [0, 1, 2, ..., batch_size - 1]。
如果索引不唯一,并且提供了 k 和 v,则缓存中更新的值可能来自任何重复的索引。
softmax_scale: float。在应用 softmax 之前对 QK^T 的缩放因子。默认为 1 / sqrt(headdim)。
causal: bool。是否应用因果注意力掩码(例如,