引言
在大语言模型(LLM)竞相追逐长上下文(Long Context)的今天,如何让模型高效地"吞吐"数万甚至数十万 token,成为了系统优化的圣杯。Transformer 的自注意力机制(Self-Attention)在时间和内存上都随着序列长度呈二次方增长。虽然 FlashAttention V1 利用 GPU 非对称存储结构,通过 IO 感知(IO-Awareness)算法成功打破了内存墙,实现了 2-4 倍的加速,但斯坦福大学 Tri Dao 团队发现,它距离硬件的物理极限仍有不小的差距。
在 2023 年发布的 FlashAttention-2 中,作者指出 V1 在 A100 GPU 上仅达到了理论最大 FLOPs/s 的 25-40%,而优化的矩阵乘法(GEMM)通常能达到 80-90%。为了填补这一效率鸿沟,FlashAttention-2 通过重构算法公式、并行策略和 Warp 工作划分,将计算速度再次提升了约 2 倍,在 A100 上达到了理论最大吞吐量的 73%。
Paper:https://arxiv.org/pdf/2307.08691
GitHub:https://github.com/Dao-AILab/flash-attention
背景:从 Memory Bound 到 Compute Bound
标准 Attention 的计算瓶颈
在标准的 Attention 实现中,给定输入序列
,其中
是序列长度,
是头维度,需要计算注意力输出
。整个计算过程包括三个步骤:
这个过程的核心问题在于需要将巨大的
矩阵
和
频繁写入高带宽内存(HBM)。当序列长度
达到几千甚至上万时,这个中间矩阵会占用大量内存。标准实现通常需要三次矩阵乘法调用和多次内存读写:先调用 GEMM 计算
并写入 HBM,再从 HBM 读取
计算 softmax 得到
并写回 HBM,最后再读取
计算
。由于大部分操作都受内存带宽限制,大量的内存访问导致了显著的性能瓶颈,使得实际执行时间远超理论计算时间。
FlashAttention V1 的突破与局限
FlashAttention V1 引入了分块(Tiling)和重计算技术来解决内存瓶颈。其核心思想是将输入切分为小块,在片上 SRAM 中融合计算,避免实例化巨大的中间矩阵。具体而言,算法采用在线 Softmax 技术,通过维护行最大值
和指数和
这两个统计量,可以逐块处理注意力计算。每处理一个新的块时,算法会更新这些统计量,并对已计算的输出进行重新缩放(rescaling),最终得到正确的注意力结果,整个过程无需保存完整的
和
矩阵。
图1: FlashAttention 前向传播逻辑
通过这种方式,FlashAttention V1 显著减少了内存读写量,实现了 2-4 倍的加速,并将内存占用从
降低到
。然而,当内存不再是瓶颈时,计算效率问题就暴露出来了。作者通过性能分析(Profiling)发现,V1 存在几个关键的计算效率问题:首先,非矩阵运算(Non-matmul FLOPs)虽然占比较小,但在 GPU 上执行很慢;其次,线程块(Thread Blocks)和线程束(Warps)的划分不够优化,导致了 GPU 占用率(Occupancy)不足或过多的同步开销。FlashAttention-2 正是为了解决这些"计算受限"问题而诞生的。
方法详解:三大维度的系统级重构
FlashAttention-2 的核心设计哲学是"多做矩阵乘法,少做标量运算,减少通信"。这个原则基于一个关键的硬件特性:现代 GPU 的矩阵乘法吞吐量远高于标量运算。
算法层面:重写 Online Softmax,为 Tensor Core 减负
1.硬件性能的巨大差异
在 GPU(如 A100)上,FP16 矩阵乘法吞吐量达到 312 TFLOPs/s,而非矩阵 FP32 运算仅有 19.5 TFLOPs/s。这意味着每个非矩阵 FLOP 的代价是矩阵乘法 FLOP 的 16 倍。这种差异源于 GPU 内部的 Tensor Core 专用硬件单元,它们专门为矩阵乘法运算优化,可以在单个时钟周期内完成大量的乘加运算。因此,为了保持高吞吐量(例如超过理论峰值的 50%),必须尽可能多地将时间花在矩阵乘法上,而不是除法、指数等标量运算上。
2.V1 的性能瓶颈分析
在 FlashAttention V1 使用的 Online Softmax 算法中,为了数值稳定性,每次处理一个新的分块时,都需要对之前的中间结果
进行重缩放。V1 的更新公式中包含大量除法操作,特别是
这种逐元素除法频繁出现。具体来说,V1 的输出更新公式为:
这个公式中包含两个除法操作:一个是
,另一个是除以
。这导致 GPU 在高速矩阵运算的执行流程中频繁插入低速的除法和指数运算,打断了 Tensor Core 的流水线,造成计算资源的浪费。
3.FlashAttention-2 的前向传播改进
FlashAttention-2 对数学逻辑进行了巧妙调整。算法不再维护"归一化后"的输出
,而是维护一个"未缩放"的版本
。更新公式简化为:
关键的改进在于,只在循环结束的最后一步,才统一乘以
得到最终结果。这样做的好处是将归一化操作延迟到最后,在循环过程中避免了重复的除法运算,让 GPU 能够连续执行矩阵乘法操作,大幅提高了 Tensor Core 的利用率。
以两个块的简单情况为例,新的在线 Softmax 流程变为:
处理第一个块时,计算行最大值
,指数和
,以及未归一化的输出![]()
处理第二个块时,更新最大值
,更新指数和
,更新输出![]()
最后统一归一化:![]()
4.反向传播的优化
在反向传播中,FlashAttention-2 发现不再需要同时保存行最大值
和指数和
这两个统计量。算法只需保存 LogSumExp 统计量:
这个改进减少了需要保存的中间状态,进一步降低了内存占用。同时,由于 LogSumExp 本身就包含了归一化所需的全部信息,反向传播时可以直接使用这个值,而无需分别处理最大值和指数和,简化了计算流程。
这些改动看似微小,实际上大幅减少了非矩阵 FLOPs 的比例,让 Tensor Core 能够更满负荷地运转,从而显著提升了整体吞吐量。
并行机制:序列长度维度的并行化
1.V1 的并行策略及其局限
FlashAttention V1 的并行策略是基于 Batch Size 和 Head Number 的。具体来说,每个线程块(Thread Block)负责处理一个 Attention Head,算法在批次维度和头数维度上进行并行。在 A100 GPU 上有 108 个流多处理器(SM),当 Batch Size × Head Number 的值较大(如 ≥ 80)时,这种策略非常有效,可以充分利用几乎所有的计算资源。
然而,在长序列训练(Long Context)场景中,这种策略暴露出严重问题。由于显存限制,当序列长度很长时,Batch Size 通常需要设置得很小(甚至为 1),以避免显存溢出。此时,即使有多个头,总的并行度(Batch Size × Head Number)仍然较小,导致 GPU 上大量的 SM 处于空闲状态,计算资源利用率(Occupancy)很低。这就像有 108 个工人,但只给了 16 个任务,剩下的 92 个工人只能闲置,造成严重的资源浪费。
2.FlashAttention-2 的前向传播并行
FlashAttention-2 引入了序列长度(Sequence Length)维度的并行来解决这个问题。在前向传播中,外层循环(对 Query 的行分块进行遍历)被并行化。不同的线程块可以同时计算同一个 Attention Head 的不同行块,而这些行块之间是完全独立的,不需要任何通信。
具体而言,假设序列长度为 16k,块大小为 128,那么会有 16k/128 = 128 个行块。FlashAttention-2 会启动 128 个线程块(在多个头和批次上还会有更多),每个线程块负责一个行块的计算。这样即使 Batch Size 为 1,也能保证有足够的并行度来填满 GPU 的计算资源。算法仍然在批次维度和头数维度上并行,但额外增加了序列长度维度的并行,显著提升了 GPU 利用率。
图2: 并行策略对比
3.FlashAttention-2 的反向传播并行
反向传播的并行化更具挑战性,因为存在数据依赖。观察反向传播的算法流程,唯一的共享计算是更新
(Query 的梯度)。多个列块在计算时都需要累加到同一个
上,这就产生了写冲突。
FlashAttention-2 采用了列块并行化策略:每个线程块负责计算一个列块。为了处理多个线程块同时更新
的写冲突问题,算法使用原子加(Atomic Adds)操作。原子加是 GPU 提供的一种硬件级同步原语,可以保证多个线程同时写入同一内存位置时不会发生数据竞争。虽然原子操作会带来一定开销,但相比于让大量 SM 空闲,这个代价是值得的。
这种细粒度的并行策略确保了无论 Batch Size 大小如何,都能充分利用 GPU 的计算资源。在长序列训练场景中,这种改进带来的性能提升尤为显著。
Warp 工作划分:从 Split-K 到 Partition-Q
这是微架构层面最硬核的优化。在 GPU 的一个线程块内部,32 个线程被组织成一个线程束(Warp)。一个线程块通常包含 4 到 8 个 Warp。如何在这些 Warp 之间分配工作,对性能有着关键影响。
1.V1 的 Split-K 方案及其问题
FlashAttention V1 采用的是 Split-K 方案:将
和
矩阵在列维度上切分给不同的 Warp。例如,有 4 个 Warp 时,Warp 1 处理
和
的第 1-32 列,Warp 2 处理第 33-64 列,以此类推。每个 Warp 都能访问完整的
矩阵。
这种方案的问题在于:每个 Warp 计算出的结果只是部分和。例如,Warp 1 计算
,Warp 2 计算
,最终的输出
。这意味着所有 Warp 必须将它们的中间结果写入共享内存(Shared Memory),然后进行同步(Synchronization),最后再从共享内存读取并累加这些部分和。
共享内存的读写非常昂贵。虽然共享内存比 HBM 快得多,但相比于寄存器访问仍然慢很多。更严重的是,这种方案需要频繁的 Warp 间同步。GPU 的同步机制会强制所有 Warp 等待最慢的那个完成,造成计算资源的闲置。此外,共享内存的容量有限(A100 上每个 SM 只有 192KB),大量的中间结果写入会导致共享内存成为新的瓶颈。
2.FlashAttention-2 的 Partition-Q 方案
FlashAttention-2 采用了截然不同的策略:Partition-Q 方案。算法将
矩阵在行维度上切分给不同的 Warp,而让
和
对所有 Warp 可见(通过共享内存加载一次)。例如,有 4 个 Warp 时:
Warp 1 负责计算
和对应的输出
Warp 2 负责计算
和对应的输出
Warp 3 负责计算
和对应的输出
Warp 4 负责计算
和对应的输出
这种方案的关键优势在于:每个 Warp 负责输出矩阵
的不同行,彼此之间完全独立,不需要任何通信。Warp 1 计算完
后,可以直接写入寄存器或最终输出,无需等待其他 Warp,也不需要通过共享内存进行结果累加。
图3: Warp 工作划分对比
具体执行流程如下:
所有 Warp 协作将当前块的
和
加载到共享内存(这是唯一需要共享内存的地方)每个 Warp 从寄存器读取自己负责的
的行块每个 Warp 独立计算自己的注意力分数、softmax 和输出,全程在寄存器中操作
计算完成后,每个 Warp 将结果直接写回 HBM,无需同步
这彻底消除了共享内存读写的瓶颈和 Warp 间同步的开销。测试显示,这一改进单独就能带来约 1.3-1.5 倍的性能提升。
3.反向传播的工作划分
反向传播的工作划分类似,但由于涉及更复杂的数据依赖(需要计算
三个梯度),仍需要一定的同步。不过,通过避免 Split-K 方案,算法显著减少了共享内存的读写量,同样带来了可观的性能提升。
4.块大小的调优
块大小的选择对性能至关重要。增大块大小通常能减少共享内存的加载次数,提高计算密度。但块大小过大会导致两个问题:一是寄存器使用量增加,超过硬件限制后会发生寄存器溢出(Register Spilling),数据被迫存储到更慢的本地内存,严重影响性能;二是共享内存需求超过 SM 的容量限制,导致内核无法运行。
FlashAttention-2 通常选择 64×64、64×128、128×64 或 128×128 的块大小,具体取决于头维度
和设备的共享内存容量。虽然当前是手动调优,但未来可以通过自动调优(Auto-tuning)技术来自动选择最优参数。
实验:逼近硬件极限
算子性能评测
作者在 NVIDIA A100 80GB SXM4 GPU 上进行了全面的性能测试,涵盖了不同的配置:是否使用因果掩码(Causal Mask)、不同的头维度(64 或 128)、不同的序列长度(512 到 16k)。
1.性能对比结果
测试结果令人印象深刻。在序列长度为 16k、头维度为 128 的典型配置下:
前向传播:FlashAttention-2 达到了 225 TFLOPs/s,这是 A100 理论峰值 312 TFLOPs/s 的 73%。相比之下,FlashAttention V1 只能达到约 110 TFLOPs/s(35%),标准 PyTorch 实现更是只有约 46 TFLOPs/s(15%)
反向传播:FlashAttention-2 达到了约 196 TFLOPs/s,占理论峰值的 63%。这个数字尤其难得,因为反向传播涉及更复杂的数据依赖和更多的矩阵乘法(5 个而非 2 个)
前向+反向总时间:FlashAttention-2 相比 V1 提速约 2 倍,相比 PyTorch 标准实现提速 5-10 倍
图4: A100 GPU 上前向+后向传播总速度
图5: A100 GPU 上前向传播速度
图6: A100 GPU 上反向传播速度
特别值得注意的是因果掩码场景。在自回归语言建模中,由于未来位置不能看到当前位置的信息,注意力矩阵的上三角部分会被屏蔽。FlashAttention-2 利用这一特性,直接跳过约一半的块计算,实现了额外 1.7-1.8 倍的加速。这意味着在 GPT 类模型训练中,FlashAttention-2 的优势更加明显。
2.性能提升的来源分析
通过详细的性能分析,可以量化各项优化的贡献:
算法优化(减少非矩阵 FLOPs):约 15-20% 的性能提升
序列长度并行:在长序列、小批次场景下贡献 30-40% 的提升
Warp 工作重划分(Partition-Q):贡献约 30-35% 的提升
这三项优化叠加,最终实现了 2 倍的整体加速,使 Attention 运算的效率首次接近 GEMM 的水平。
端到端训练性能
算子级别的优化最终要体现在实际应用中。作者在 8×A100 80GB SXM 上训练 GPT 风格的模型(1.3B 和 2.7B 参数),测试了不同序列长度(2k 和 8k)下的训练吞吐量。
1.训练速度对比
结果显示,FlashAttention-2 在端到端训练中同样表现出色:
GPT-1.3B,2k 上下文:FlashAttention-2 达到 196 TFLOPs/s,相比无 FlashAttention 基线提速 1.4 倍,相比 V1 提速 1.04 倍
GPT-1.3B,8k 上下文:FlashAttention-2 达到 220 TFLOPs/s(72% 模型 FLOPs 利用率),相比基线提速 3.1 倍,相比 V1 提速 1.3 倍
GPT-2.7B,8k 上下文:FlashAttention-2 达到 225 TFLOPs/s(72% 利用率),相比基线提速 2.8 倍,相比 V1 提速 1.3 倍
需要注意的是,端到端训练的提速比算子级别的 2 倍提速要小。这是因为模型训练还包括其他计算,如前馈网络(FFN)、LayerNorm、优化器更新等。但在长序列场景下(8k),Attention 占比增大,FlashAttention-2 的优势更加突出。
2.实际意义
这些数字意味着什么?最直接的影响是训练成本的大幅降低。原本训练一个 8k 上下文长度的模型需要的时间和成本,现在可以用来训练 16k 上下文的模型,几乎不增加额外开销。对于需要处理长文档、高分辨率图像或长视频的应用,这个提升是革命性的。
H100 GPU 上的初步测试
作者还在新一代 H100 GPU 上进行了初步测试。值得注意的是,这些测试使用的是相同的代码,没有针对 H100 的新特性(如 Tensor Memory Accelerator TMA 和第四代 Tensor Core)进行任何优化。
图7: H100 GPU 上的初步测试结果
即便如此,FlashAttention-2 在 H100 上也达到了 335 TFLOPs/s 的惊人速度。考虑到 H100 的理论峰值更高,作者估计通过针对性优化,可以再获得 1.5-2 倍的提升,最终达到 500-670 TFLOPs/s 的吞吐量。这将进一步缩小与理论峰值的差距,让 Attention 运算真正成为计算密集型而非内存密集型操作。
总结与展望
1.核心成就
FlashAttention-2 的成功证明了"系统感知(System-aware)"优化的重要性。Tri Dao 团队没有止步于 IO 优化,而是深入公式推导和底层指令,通过三个层面的协同优化实现了质的飞跃:
算法层面:重构 Online Softmax 公式,将归一化延迟到循环结束,减少了约 20% 的非矩阵运算,让 Tensor Core 能够更连续地工作
并行层面:引入序列长度维度的并行化,解决了长序列、小批次场景下的 GPU 利用率不足问题,在该场景下贡献了 30-40% 的性能提升
微架构层面:从 Split-K 切换到 Partition-Q 的 Warp 工作划分,彻底消除了 Warp 间通信开销,贡献了约 30-35% 的性能提升
这一系列优化将 Attention 的计算效率推向了 GEMM 的水平,在 A100 上达到理论峰值的 73%,为长上下文大模型时代奠定了坚实的技术基石。
2.实际影响
FlashAttention-2 的 2 倍加速具有深远意义。这意味着原本训练 8k 长度模型的时间和成本,现在可以用来训练 16k 长度的模型。对于需要理解长文档(如法律合同、学术论文)、处理高分辨率图像、生成长视频或音频的应用,这个提升是革命性的。同时,它也加速了现有模型的训练、微调和推理,降低了 AI 应用的成本门槛。
3.未来方向
FlashAttention-2 的工作还远未结束。作者提出了几个重要的未来研究方向:
硬件适配:针对 H100 的新特性(TMA、第四代 Tensor Core、FP8 数据类型)进行优化,预期可再获得 1.5-2 倍提升;适配 AMD GPU 等其他硬件平台
算法扩展:将底层优化与高层算法(如局部注意力、稀疏注意力、滑动窗口注意力)结合,进一步突破序列长度限制
编译器支持:与编译器研究者合作,将这些优化技术融入编译框架,使开发者无需手动优化就能获得高性能
自动调优:开发自动调优系统,根据硬件特性和模型配置自动选择最优的块大小和并行策略
FlashAttention-2 的成功不仅是一次技术突破,更是系统优化方法论的胜利。它证明了通过深入理解硬件特性、精心设计算法、细致优化执行流程,可以将理论性能转化为实际性能,推动 AI 技术向更长上下文、更复杂应用的方向发展。

