大数跨境
0
0

FlashAttention-2:重构算法与并行机制,如何将长文本训练提速 200%?

FlashAttention-2:重构算法与并行机制,如何将长文本训练提速 200%? 汇智灵曦
2025-12-25
0
导读:本文介绍了FlashAttention-2是如何通过重构算法与并行机制将长文本训练提速 200%,并且它的成功也进一步证明了通过深入理解硬件特性等,可以将理论性能转化为实际性能,推动 AI 技术向更长

引言

在大语言模型(LLM)竞相追逐长上下文(Long Context)的今天,如何让模型高效地"吞吐"数万甚至数十万 token,成为了系统优化的圣杯。Transformer 的自注意力机制(Self-Attention)在时间和内存上都随着序列长度呈二次方增长。虽然 FlashAttention V1 利用 GPU 非对称存储结构,通过 IO 感知(IO-Awareness)算法成功打破了内存墙,实现了 2-4 倍的加速,但斯坦福大学 Tri Dao 团队发现,它距离硬件的物理极限仍有不小的差距。


在 2023 年发布的 FlashAttention-2 中,作者指出 V1 在 A100 GPU 上仅达到了理论最大 FLOPs/s 的 25-40%,而优化的矩阵乘法(GEMM)通常能达到 80-90%。为了填补这一效率鸿沟,FlashAttention-2 通过重构算法公式、并行策略和 Warp 工作划分,将计算速度再次提升了约 2 倍,在 A100 上达到了理论最大吞吐量的 73%。


Paper:https://arxiv.org/pdf/2307.08691


GitHub:https://github.com/Dao-AILab/flash-attention


背景:从 Memory Bound 到 Compute Bound


1

标准 Attention 的计算瓶颈


在标准的 Attention 实现中,给定输入序列,其中是序列长度,是头维度,需要计算注意力输出。整个计算过程包括三个步骤:





这个过程的核心问题在于需要将巨大的矩阵频繁写入高带宽内存(HBM)。当序列长度达到几千甚至上万时,这个中间矩阵会占用大量内存。标准实现通常需要三次矩阵乘法调用和多次内存读写:先调用 GEMM 计算并写入 HBM,再从 HBM 读取计算 softmax 得到并写回 HBM,最后再读取计算。由于大部分操作都受内存带宽限制,大量的内存访问导致了显著的性能瓶颈,使得实际执行时间远超理论计算时间。


2

FlashAttention V1 的突破与局限


FlashAttention V1 引入了分块(Tiling)和重计算技术来解决内存瓶颈。其核心思想是将输入切分为小块,在片上 SRAM 中融合计算,避免实例化巨大的中间矩阵。具体而言,算法采用在线 Softmax 技术,通过维护行最大值和指数和这两个统计量,可以逐块处理注意力计算。每处理一个新的块时,算法会更新这些统计量,并对已计算的输出进行重新缩放(rescaling),最终得到正确的注意力结果,整个过程无需保存完整的矩阵。


图1: FlashAttention 前向传播逻辑


通过这种方式,FlashAttention V1 显著减少了内存读写量,实现了 2-4 倍的加速,并将内存占用从降低到。然而,当内存不再是瓶颈时,计算效率问题就暴露出来了。作者通过性能分析(Profiling)发现,V1 存在几个关键的计算效率问题:首先,非矩阵运算(Non-matmul FLOPs)虽然占比较小,但在 GPU 上执行很慢;其次,线程块(Thread Blocks)和线程束(Warps)的划分不够优化,导致了 GPU 占用率(Occupancy)不足或过多的同步开销。FlashAttention-2 正是为了解决这些"计算受限"问题而诞生的。


方法详解:三大维度的系统级重构


FlashAttention-2 的核心设计哲学是"多做矩阵乘法,少做标量运算,减少通信"。这个原则基于一个关键的硬件特性:现代 GPU 的矩阵乘法吞吐量远高于标量运算。



算法层面:重写 Online Softmax,为 Tensor Core 减负


1.硬件性能的巨大差异


在 GPU(如 A100)上,FP16 矩阵乘法吞吐量达到 312 TFLOPs/s,而非矩阵 FP32 运算仅有 19.5 TFLOPs/s。这意味着每个非矩阵 FLOP 的代价是矩阵乘法 FLOP 的 16 倍。这种差异源于 GPU 内部的 Tensor Core 专用硬件单元,它们专门为矩阵乘法运算优化,可以在单个时钟周期内完成大量的乘加运算。因此,为了保持高吞吐量(例如超过理论峰值的 50%),必须尽可能多地将时间花在矩阵乘法上,而不是除法、指数等标量运算上。


2.V1 的性能瓶颈分析


在 FlashAttention V1 使用的 Online Softmax 算法中,为了数值稳定性,每次处理一个新的分块时,都需要对之前的中间结果进行重缩放。V1 的更新公式中包含大量除法操作,特别是这种逐元素除法频繁出现。具体来说,V1 的输出更新公式为:



这个公式中包含两个除法操作:一个是 ,另一个是除以 。这导致 GPU 在高速矩阵运算的执行流程中频繁插入低速的除法和指数运算,打断了 Tensor Core 的流水线,造成计算资源的浪费。


3.FlashAttention-2 的前向传播改进


FlashAttention-2 对数学逻辑进行了巧妙调整。算法不再维护"归一化后"的输出,而是维护一个"未缩放"的版本。更新公式简化为:



关键的改进在于,只在循环结束的最后一步,才统一乘以得到最终结果。这样做的好处是将归一化操作延迟到最后,在循环过程中避免了重复的除法运算,让 GPU 能够连续执行矩阵乘法操作,大幅提高了 Tensor Core 的利用率。


以两个块的简单情况为例,新的在线 Softmax 流程变为:


处理第一个块时,计算行最大值,指数和,以及未归一化的输出


处理第二个块时,更新最大值 ,更新指数和 ,更新输出


最后统一归一化:


4.反向传播的优化


在反向传播中,FlashAttention-2 发现不再需要同时保存行最大值和指数和这两个统计量。算法只需保存 LogSumExp 统计量:



这个改进减少了需要保存的中间状态,进一步降低了内存占用。同时,由于 LogSumExp 本身就包含了归一化所需的全部信息,反向传播时可以直接使用这个值,而无需分别处理最大值和指数和,简化了计算流程。


这些改动看似微小,实际上大幅减少了非矩阵 FLOPs 的比例,让 Tensor Core 能够更满负荷地运转,从而显著提升了整体吞吐量。



并行机制:序列长度维度的并行化


1.V1 的并行策略及其局限


FlashAttention V1 的并行策略是基于 Batch Size 和 Head Number 的。具体来说,每个线程块(Thread Block)负责处理一个 Attention Head,算法在批次维度和头数维度上进行并行。在 A100 GPU 上有 108 个流多处理器(SM),当 Batch Size × Head Number 的值较大(如 ≥ 80)时,这种策略非常有效,可以充分利用几乎所有的计算资源。


然而,在长序列训练(Long Context)场景中,这种策略暴露出严重问题。由于显存限制,当序列长度很长时,Batch Size 通常需要设置得很小(甚至为 1),以避免显存溢出。此时,即使有多个头,总的并行度(Batch Size × Head Number)仍然较小,导致 GPU 上大量的 SM 处于空闲状态,计算资源利用率(Occupancy)很低。这就像有 108 个工人,但只给了 16 个任务,剩下的 92 个工人只能闲置,造成严重的资源浪费。


2.FlashAttention-2 的前向传播并行


FlashAttention-2 引入了序列长度(Sequence Length)维度的并行来解决这个问题。在前向传播中,外层循环(对 Query 的行分块进行遍历)被并行化。不同的线程块可以同时计算同一个 Attention Head 的不同行块,而这些行块之间是完全独立的,不需要任何通信。


具体而言,假设序列长度为 16k,块大小为 128,那么会有 16k/128 = 128 个行块。FlashAttention-2 会启动 128 个线程块(在多个头和批次上还会有更多),每个线程块负责一个行块的计算。这样即使 Batch Size 为 1,也能保证有足够的并行度来填满 GPU 的计算资源。算法仍然在批次维度和头数维度上并行,但额外增加了序列长度维度的并行,显著提升了 GPU 利用率。


图2: 并行策略对比


3.FlashAttention-2 的反向传播并行


反向传播的并行化更具挑战性,因为存在数据依赖。观察反向传播的算法流程,唯一的共享计算是更新(Query 的梯度)。多个列块在计算时都需要累加到同一个上,这就产生了写冲突。


FlashAttention-2 采用了列块并行化策略:每个线程块负责计算一个列块。为了处理多个线程块同时更新的写冲突问题,算法使用原子加(Atomic Adds)操作。原子加是 GPU 提供的一种硬件级同步原语,可以保证多个线程同时写入同一内存位置时不会发生数据竞争。虽然原子操作会带来一定开销,但相比于让大量 SM 空闲,这个代价是值得的。


这种细粒度的并行策略确保了无论 Batch Size 大小如何,都能充分利用 GPU 的计算资源。在长序列训练场景中,这种改进带来的性能提升尤为显著。



Warp 工作划分:从 Split-K 到 Partition-Q


这是微架构层面最硬核的优化。在 GPU 的一个线程块内部,32 个线程被组织成一个线程束(Warp)。一个线程块通常包含 4 到 8 个 Warp。如何在这些 Warp 之间分配工作,对性能有着关键影响。


1.V1 的 Split-K 方案及其问题


FlashAttention V1 采用的是 Split-K 方案:将 矩阵在列维度上切分给不同的 Warp。例如,有 4 个 Warp 时,Warp 1 处理的第 1-32 列,Warp 2 处理第 33-64 列,以此类推。每个 Warp 都能访问完整的矩阵。


这种方案的问题在于:每个 Warp 计算出的结果只是部分和。例如,Warp 1 计算,Warp 2 计算,最终的输出。这意味着所有 Warp 必须将它们的中间结果写入共享内存(Shared Memory),然后进行同步(Synchronization),最后再从共享内存读取并累加这些部分和。


共享内存的读写非常昂贵。虽然共享内存比 HBM 快得多,但相比于寄存器访问仍然慢很多。更严重的是,这种方案需要频繁的 Warp 间同步。GPU 的同步机制会强制所有 Warp 等待最慢的那个完成,造成计算资源的闲置。此外,共享内存的容量有限(A100 上每个 SM 只有 192KB),大量的中间结果写入会导致共享内存成为新的瓶颈。


2.FlashAttention-2 的 Partition-Q 方案


FlashAttention-2 采用了截然不同的策略:Partition-Q 方案。算法将矩阵在行维度上切分给不同的 Warp,而让对所有 Warp 可见(通过共享内存加载一次)。例如,有 4 个 Warp 时:


  • Warp 1 负责计算和对应的输出

  • Warp 2 负责计算和对应的输出

  • Warp 3 负责计算和对应的输出

  • Warp 4 负责计算和对应的输出


这种方案的关键优势在于:每个 Warp 负责输出矩阵的不同行,彼此之间完全独立,不需要任何通信。Warp 1 计算完后,可以直接写入寄存器或最终输出,无需等待其他 Warp,也不需要通过共享内存进行结果累加。


图3: Warp 工作划分对比


具体执行流程如下:


  1. 所有 Warp 协作将当前块的加载到共享内存(这是唯一需要共享内存的地方)

  2. 每个 Warp 从寄存器读取自己负责的的行块

  3. 每个 Warp 独立计算自己的注意力分数、softmax 和输出,全程在寄存器中操作

  4. 计算完成后,每个 Warp 将结果直接写回 HBM,无需同步


这彻底消除了共享内存读写的瓶颈和 Warp 间同步的开销。测试显示,这一改进单独就能带来约 1.3-1.5 倍的性能提升。


3.反向传播的工作划分


反向传播的工作划分类似,但由于涉及更复杂的数据依赖(需要计算三个梯度),仍需要一定的同步。不过,通过避免 Split-K 方案,算法显著减少了共享内存的读写量,同样带来了可观的性能提升。


4.块大小的调优


块大小的选择对性能至关重要。增大块大小通常能减少共享内存的加载次数,提高计算密度。但块大小过大会导致两个问题:一是寄存器使用量增加,超过硬件限制后会发生寄存器溢出(Register Spilling),数据被迫存储到更慢的本地内存,严重影响性能;二是共享内存需求超过 SM 的容量限制,导致内核无法运行。


FlashAttention-2 通常选择 64×64、64×128、128×64 或 128×128 的块大小,具体取决于头维度和设备的共享内存容量。虽然当前是手动调优,但未来可以通过自动调优(Auto-tuning)技术来自动选择最优参数。


实验:逼近硬件极限


01

算子性能评测


作者在 NVIDIA A100 80GB SXM4 GPU 上进行了全面的性能测试,涵盖了不同的配置:是否使用因果掩码(Causal Mask)、不同的头维度(64 或 128)、不同的序列长度(512 到 16k)。


1.性能对比结果


测试结果令人印象深刻。在序列长度为 16k、头维度为 128 的典型配置下:


  • 前向传播:FlashAttention-2 达到了 225 TFLOPs/s,这是 A100 理论峰值 312 TFLOPs/s 的 73%。相比之下,FlashAttention V1 只能达到约 110 TFLOPs/s(35%),标准 PyTorch 实现更是只有约 46 TFLOPs/s(15%)


  • 反向传播:FlashAttention-2 达到了约 196 TFLOPs/s,占理论峰值的 63%。这个数字尤其难得,因为反向传播涉及更复杂的数据依赖和更多的矩阵乘法(5 个而非 2 个)


  • 前向+反向总时间:FlashAttention-2 相比 V1 提速约 2 倍,相比 PyTorch 标准实现提速 5-10 倍


图4: A100 GPU 上前向+后向传播总速度


图5: A100 GPU 上前向传播速度


图6: A100 GPU 上反向传播速度


特别值得注意的是因果掩码场景。在自回归语言建模中,由于未来位置不能看到当前位置的信息,注意力矩阵的上三角部分会被屏蔽。FlashAttention-2 利用这一特性,直接跳过约一半的块计算,实现了额外 1.7-1.8 倍的加速。这意味着在 GPT 类模型训练中,FlashAttention-2 的优势更加明显。


2.性能提升的来源分析


通过详细的性能分析,可以量化各项优化的贡献:


  1. 算法优化(减少非矩阵 FLOPs):约 15-20% 的性能提升

  2. 序列长度并行:在长序列、小批次场景下贡献 30-40% 的提升

  3. Warp 工作重划分(Partition-Q):贡献约 30-35% 的提升


这三项优化叠加,最终实现了 2 倍的整体加速,使 Attention 运算的效率首次接近 GEMM 的水平。


02

端到端训练性能


算子级别的优化最终要体现在实际应用中。作者在 8×A100 80GB SXM 上训练 GPT 风格的模型(1.3B 和 2.7B 参数),测试了不同序列长度(2k 和 8k)下的训练吞吐量。


1.训练速度对比


结果显示,FlashAttention-2 在端到端训练中同样表现出色:



  • GPT-1.3B,2k 上下文:FlashAttention-2 达到 196 TFLOPs/s,相比无 FlashAttention 基线提速 1.4 倍,相比 V1 提速 1.04 倍


  • GPT-1.3B,8k 上下文:FlashAttention-2 达到 220 TFLOPs/s(72% 模型 FLOPs 利用率),相比基线提速 3.1 倍,相比 V1 提速 1.3 倍


  • GPT-2.7B,8k 上下文:FlashAttention-2 达到 225 TFLOPs/s(72% 利用率),相比基线提速 2.8 倍,相比 V1 提速 1.3 倍


需要注意的是,端到端训练的提速比算子级别的 2 倍提速要小。这是因为模型训练还包括其他计算,如前馈网络(FFN)、LayerNorm、优化器更新等。但在长序列场景下(8k),Attention 占比增大,FlashAttention-2 的优势更加突出。


2.实际意义


这些数字意味着什么?最直接的影响是训练成本的大幅降低。原本训练一个 8k 上下文长度的模型需要的时间和成本,现在可以用来训练 16k 上下文的模型,几乎不增加额外开销。对于需要处理长文档、高分辨率图像或长视频的应用,这个提升是革命性的。


03

H100 GPU 上的初步测试


作者还在新一代 H100 GPU 上进行了初步测试。值得注意的是,这些测试使用的是相同的代码,没有针对 H100 的新特性(如 Tensor Memory Accelerator TMA 和第四代 Tensor Core)进行任何优化。


图7: H100 GPU 上的初步测试结果


即便如此,FlashAttention-2 在 H100 上也达到了 335 TFLOPs/s 的惊人速度。考虑到 H100 的理论峰值更高,作者估计通过针对性优化,可以再获得 1.5-2 倍的提升,最终达到 500-670 TFLOPs/s 的吞吐量。这将进一步缩小与理论峰值的差距,让 Attention 运算真正成为计算密集型而非内存密集型操作。


总结与展望


1.核心成就


FlashAttention-2 的成功证明了"系统感知(System-aware)"优化的重要性。Tri Dao 团队没有止步于 IO 优化,而是深入公式推导和底层指令,通过三个层面的协同优化实现了质的飞跃:


  • 算法层面:重构 Online Softmax 公式,将归一化延迟到循环结束,减少了约 20% 的非矩阵运算,让 Tensor Core 能够更连续地工作


  • 并行层面:引入序列长度维度的并行化,解决了长序列、小批次场景下的 GPU 利用率不足问题,在该场景下贡献了 30-40% 的性能提升


  • 微架构层面:从 Split-K 切换到 Partition-Q 的 Warp 工作划分,彻底消除了 Warp 间通信开销,贡献了约 30-35% 的性能提升


这一系列优化将 Attention 的计算效率推向了 GEMM 的水平,在 A100 上达到理论峰值的 73%,为长上下文大模型时代奠定了坚实的技术基石。


2.实际影响


FlashAttention-2 的 2 倍加速具有深远意义。这意味着原本训练 8k 长度模型的时间和成本,现在可以用来训练 16k 长度的模型。对于需要理解长文档(如法律合同、学术论文)、处理高分辨率图像、生成长视频或音频的应用,这个提升是革命性的。同时,它也加速了现有模型的训练、微调和推理,降低了 AI 应用的成本门槛。


3.未来方向


FlashAttention-2 的工作还远未结束。作者提出了几个重要的未来研究方向:


  • 硬件适配:针对 H100 的新特性(TMA、第四代 Tensor Core、FP8 数据类型)进行优化,预期可再获得 1.5-2 倍提升;适配 AMD GPU 等其他硬件平台


  • 算法扩展:将底层优化与高层算法(如局部注意力、稀疏注意力、滑动窗口注意力)结合,进一步突破序列长度限制


  • 编译器支持:与编译器研究者合作,将这些优化技术融入编译框架,使开发者无需手动优化就能获得高性能


  • 自动调优:开发自动调优系统,根据硬件特性和模型配置自动选择最优的块大小和并行策略


FlashAttention-2 的成功不仅是一次技术突破,更是系统优化方法论的胜利。它证明了通过深入理解硬件特性、精心设计算法、细致优化执行流程,可以将理论性能转化为实际性能,推动 AI 技术向更长上下文、更复杂应用的方向发展。

【声明】内容源于网络
0
0
汇智灵曦
汇智灵曦数字科技以“智赋医疗,研以致用”为理念,致力于通过AI技术推动医疗健康数字化转型。公司聚焦医疗场景需求,打造了包含深度问数、汇智查房等医疗AI产品,为医疗机构提供从临床决策到科研创新的全链条解决方案,大幅提升诊疗质量与科研效率。
内容 31
粉丝 0
汇智灵曦 汇智灵曦数字科技以“智赋医疗,研以致用”为理念,致力于通过AI技术推动医疗健康数字化转型。公司聚焦医疗场景需求,打造了包含深度问数、汇智查房等医疗AI产品,为医疗机构提供从临床决策到科研创新的全链条解决方案,大幅提升诊疗质量与科研效率。
总阅读0
粉丝0
内容31