大数跨境
0
0

flash attention 讲解

flash attention 讲解 Angela的外贸日常
2025-10-23
3
导读:flash attention讲解引言 transformer 是当前AI大模型时代的最火热的架构,而

flash attention讲解

引言

transformer 是当前AI大模型时代的最火热的架构,而flash attention则是加速transformer在硬件中的快速的算法。本篇文章来分析flash attention的计算细节。

FlashAttention v1 公式推导

https://www.youtube.com/watch?v=gBMO1JZav44[1]





online softmax


flash attention


两个for loop能否变成一个for loop 呢?


项给提出来









举例说明

我们通过一个非常简化的例子(用小矩阵、具体数字)来展示它怎么一步步实现同样的结果却节省显存

一、例子设定

假设我们有:

  • 序列长度 n = 4
  • 每个向量维度 d = 2
    矩阵如下(假设都是已经计算好的 embedding):

二、传统 Attention 的计算方式

1️⃣ 计算 :

除以 得到:

2️⃣ 对每一行做 softmax:

以第一行为例:

softmax 后为:

计算 softmax:

softmax 权重:

3️⃣ 最终标准attention输出:

计算最终输出:

或者可以这样写

三、FlashAttention 的做法(关键)

FlashAttention 不会一次性计算完整的 QK^T,
而是——按 块(tile) 来计算。
假设我们每次只处理 2个token的Key/Value(block size = 2)。


步骤一:初始化状态

对于每个 query token ,我们需要保存三个标量:

  • 当前最大值
  • 归一化因子
  • 累积的加权值
    初始:

步骤二:第一块计算(K₀,V₀)

Key块(0–1):

对所有 Query 做:

得到:

在线 softmax 更新逻辑

对于每个 query i:
设这次块的最大值为

计算新的最大值:

然后更新归一化因子:

并更新输出:

这三个公式是 FlashAttention 的核心。

它们允许我们在“逐块更新”的过程中,

得到与标准 softmax 完全相同的最终结果

我们以 第一个 query 向量 Q₀ = [1,0] 为例来算:

① 当前块的最大值

更新新的最大值:

② 更新归一化因子

由于初始

③ 更新加权输出

此时:

步骤三:第二块计算(K₂,V₂)

Key块(2–3):

再计算:

得到:

然后重复刚才的在线 softmax 更新步骤,更新

对于同一个 query(Q₀ = [1,0]):

① 当前块的最大值

新的总体最大值:


② 更新归一化因子

代入数值:

逐步计算:


③ 更新输出向量

第一项:

第二项(来自当前块):

合并:

最后:

这就是最终输出的

如果我们直接用完整 softmax 算,会得到完全相同的结果(浮点误差内一致)。
最终 即为每个 query 的输出结果,与标准 attention 完全一致。
区别在于——我们从未存储过完整的 注意力矩阵


四、关键总结

概念 传统方法 FlashAttention
是否存完整矩阵 ✅ 是 ❌ 否
计算方式 一次性算完 分块逐步计算
softmax 一次性归一化 在线(动态)归一化
显存开销 O(n^2) O(n)
结果 完全一致 完全一致

online softmax 在线更新

我们把 K,V 按块分成两组(block size = 2):

  • 第 1 块
  • 第 2 块

FlashAttention 的在线 softmax 三个核心变量

每个 query(比如第 i 个 Q 向量)会维护三个标量:

变量 含义
m_i 当前处理过的所有块的最大 logit 值(用于数值稳定)
l_i 当前归一化分母
o_i 当前加权输出向量(相当于累积的 softmax 权重×V 的结果)

过程总结

步骤 操作 含义
1 逐块计算 QKᵀ 避免显存爆炸
2 维护 m_i 防止指数溢出(数值稳定)
3 动态更新 l_i 逐步构造正确的 softmax 归一化因子
4 更新 o_i 累积 softmax 权重×V
最终结果 与标准 attention 完全一致

Triton code impl

@triton.jit
def flashatt_kernel(
    q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_i = tl.program_id(0)
    log2_e = 1.44269504
    # myexp = lambda x: tl.exp2(log2_e * x)
    # Finish me!

    off_i = block_id_i * B0 + tl.arange(0, B0)
    mask_i = off_i < N0
    inf = 1.0e6

    # Need `other`!!!
    q = tl.load(q_ptr + off_i, mask=mask_i)

    # The variable names of Triton's offcial FlashAttention tutorial
    # is attached here for reference.
    # Our variable names are consistent with Puzzle 8.

    # l_i
    exp_sum = tl.zeros((B0,), dtype=tl.float32)
    # m_i
    qk_max = tl.full((B0,), -inf, dtype=tl.float32)
    z = tl.zeros((B0,), dtype=tl.float32)

    for id_j in tl.range(0, T, B1):
        off_j = id_j + tl.arange(0, B1)
        mask_j = off_j < T
        mask_ij = mask_i[:, None] & mask_j[None, :]

        k = tl.load(k_ptr + off_j, mask=mask_j)
        qk = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -1.0e6)
        # print(qk.shape)

        # m_ij
        new_max = tl.maximum(tl.max(qk, axis=1), qk_max)
        qk_exp = tl.exp2(log2_e * (qk - new_max[:, None]))
        # alpha
        factor = tl.exp2(log2_e * (qk_max - new_max))
        # l_ij
        new_exp_sum = exp_sum * factor + tl.sum(qk_exp, axis=1)
        v = tl.load(v_ptr + off_j, mask=mask_j, other=0.0)
        # 这里可以和公式有些不一样    
        # z = z * factor + tl.sum(qk_exp * v[None, :], axis=1)
        z = (
            z * factor * (exp_sum / new_exp_sum)
            + tl.sum(qk_exp * v[None, :], axis=1) / new_exp_sum
        )

        qk_max = new_max
        exp_sum = new_exp_sum

    # z = z / exp_sum
    tl.store(z_ptr + off_i, z, mask=mask_i)
    return

  1. https://www.youtube.com/watch?v=gBMO1JZav44 ↩

【声明】内容源于网络
0
0
Angela的外贸日常
跨境分享间 | 长期积累专业经验
内容 45910
粉丝 1
Angela的外贸日常 跨境分享间 | 长期积累专业经验
总阅读274.8k
粉丝1
内容45.9k