flash attention讲解
FlashAttention v1 公式推导
https://www.youtube.com/watch?v=gBMO1JZav44[1]
online softmax
flash attention
两个for loop能否变成一个for loop 呢?
把 项给提出来
举例说明
我们通过一个非常简化的例子(用小矩阵、具体数字)来展示它怎么一步步实现同样的结果却节省显存。
一、例子设定
假设我们有:
- 序列长度 n = 4
- 每个向量维度 d = 2
矩阵如下(假设都是已经计算好的 embedding):
二、传统 Attention 的计算方式
1️⃣ 计算 :
除以 得到:
2️⃣ 对每一行做 softmax:
以第一行为例:
softmax 后为:
计算 softmax:
softmax 权重:
3️⃣ 最终标准attention输出:
计算最终输出:
或者可以这样写
三、FlashAttention 的做法(关键)
FlashAttention 不会一次性计算完整的 QK^T,
而是——按 块(tile) 来计算。
假设我们每次只处理 2个token的Key/Value(block size = 2)。
步骤一:初始化状态
对于每个 query token ,我们需要保存三个标量:
- 当前最大值
- 归一化因子
- 累积的加权值
初始:
步骤二:第一块计算(K₀,V₀)
Key块(0–1):
对所有 Query 做:
得到:
在线 softmax 更新逻辑
对于每个 query i:
设这次块的最大值为
计算新的最大值:
然后更新归一化因子:
并更新输出:
这三个公式是 FlashAttention 的核心。
它们允许我们在“逐块更新”的过程中,
得到与标准 softmax 完全相同的最终结果。
我们以 第一个 query 向量 Q₀ = [1,0] 为例来算:
① 当前块的最大值
更新新的最大值:
② 更新归一化因子
由于初始
③ 更新加权输出
此时:
步骤三:第二块计算(K₂,V₂)
Key块(2–3):
再计算:
得到:
然后重复刚才的在线 softmax 更新步骤,更新 。
对于同一个 query(Q₀ = [1,0]):
① 当前块的最大值
新的总体最大值:
② 更新归一化因子
代入数值:
逐步计算:
③ 更新输出向量
第一项:
第二项(来自当前块):
合并:
最后:
这就是最终输出的 。
如果我们直接用完整 softmax 算,会得到完全相同的结果(浮点误差内一致)。
最终
即为每个 query 的输出结果,与标准 attention 完全一致。
区别在于——我们从未存储过完整的
注意力矩阵!
四、关键总结
| 概念 | 传统方法 | FlashAttention |
|---|---|---|
| 是否存完整矩阵 | ✅ 是 | ❌ 否 |
| 计算方式 | 一次性算完 | 分块逐步计算 |
| softmax | 一次性归一化 | 在线(动态)归一化 |
| 显存开销 | O(n^2) | O(n) |
| 结果 | 完全一致 | 完全一致 |
online softmax 在线更新
我们把 K,V 按块分成两组(block size = 2):
- 第 1 块:
- 第 2 块:
FlashAttention 的在线 softmax 三个核心变量
每个 query(比如第 i 个 Q 向量)会维护三个标量:
| 变量 | 含义 |
|---|---|
| m_i | 当前处理过的所有块的最大 logit 值(用于数值稳定) |
| l_i | 当前归一化分母 |
| o_i | 当前加权输出向量(相当于累积的 softmax 权重×V 的结果) |
过程总结
| 步骤 | 操作 | 含义 |
|---|---|---|
| 1 | 逐块计算 QKᵀ | 避免显存爆炸 |
| 2 | 维护 m_i | 防止指数溢出(数值稳定) |
| 3 | 动态更新 l_i | 逐步构造正确的 softmax 归一化因子 |
| 4 | 更新 o_i | 累积 softmax 权重×V |
| ✅ | 最终结果 | 与标准 attention 完全一致 |
Triton code impl
@triton.jit
def flashatt_kernel(
q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
):
block_id_i = tl.program_id(0)
log2_e = 1.44269504
# myexp = lambda x: tl.exp2(log2_e * x)
# Finish me!
off_i = block_id_i * B0 + tl.arange(0, B0)
mask_i = off_i < N0
inf = 1.0e6
# Need `other`!!!
q = tl.load(q_ptr + off_i, mask=mask_i)
# The variable names of Triton's offcial FlashAttention tutorial
# is attached here for reference.
# Our variable names are consistent with Puzzle 8.
# l_i
exp_sum = tl.zeros((B0,), dtype=tl.float32)
# m_i
qk_max = tl.full((B0,), -inf, dtype=tl.float32)
z = tl.zeros((B0,), dtype=tl.float32)
for id_j in tl.range(0, T, B1):
off_j = id_j + tl.arange(0, B1)
mask_j = off_j < T
mask_ij = mask_i[:, None] & mask_j[None, :]
k = tl.load(k_ptr + off_j, mask=mask_j)
qk = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -1.0e6)
# print(qk.shape)
# m_ij
new_max = tl.maximum(tl.max(qk, axis=1), qk_max)
qk_exp = tl.exp2(log2_e * (qk - new_max[:, None]))
# alpha
factor = tl.exp2(log2_e * (qk_max - new_max))
# l_ij
new_exp_sum = exp_sum * factor + tl.sum(qk_exp, axis=1)
v = tl.load(v_ptr + off_j, mask=mask_j, other=0.0)
# 这里可以和公式有些不一样
# z = z * factor + tl.sum(qk_exp * v[None, :], axis=1)
z = (
z * factor * (exp_sum / new_exp_sum)
+ tl.sum(qk_exp * v[None, :], axis=1) / new_exp_sum
)
qk_max = new_max
exp_sum = new_exp_sum
# z = z / exp_sum
tl.store(z_ptr + off_i, z, mask=mask_i)
return
- https://www.youtube.com/watch?v=gBMO1JZav44 ↩

