大数跨境
0
0

Attention的分块计算: 从flash-attention到ring-attention

Attention的分块计算: 从flash-attention到ring-attention 极市平台
2024-03-14
2
导读:↑ 点击蓝字 关注极市平台作者丨宫酱手艺人@知乎(已授权)来源丨https://zhuanlan.zhihu.
↑ 点击蓝字 关注极市平台
者丨宫酱手艺人@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/686240618
编辑丨极市平台

极市导读

 

flash attention是LLM训练的标配。它是一个加速attention的cuda算子;ring attention则是利用分布式计算扩展attention长度的一个工作。然而它们背后的核心则都是softmax局部和全局关系的一个巧妙公式。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

TL; DR

观察到局部softmax和全局softmax的关系,充分利用容量小但速度快的cache计算局部的attention,再推导出全局的attention,最终达到加速attention计算,或扩展attention长度的目的。

概要

局部的softmax和全局的softmax可以推出一个公式关系。利用这一点,flash-attention使用SRAM来计算局部的attention,再规约到全局的attention,并将attention包装为一个CUDA kernel,大大加速attention计算速度,并减小现存占用。而ring attention则反着利用这个公式关系,让一个GPU计算attention的一个局部,整个GPU多卡集群就可以计算出全局的attention,这样就大大扩展了Transformer序列长度。

方法

局部softmax和全局softmax之间的关系

Attention机制里面使用Softmax函数将Attention权重归一化。考虑向量 ,那么向量 作为Softmax的输出,有

注意到softmax函数有平移不变的性质

一般的softmax实现都是利用这个性质,给分子分母同时减去最大值,这样取指数的时候就不容易越界了:

flash-attention则利用这个性质,来大幅提高softmax算子的局部性。具体上,将 向量分成 个块,每个块 个元素,对每个分块先进行计算,这样SRAM里面需要处理的元素就不多了,处理完以后,再将每个分块里面计算的结果进行组合,计算出来最终的softmax结果。

这个事情的关键就是要把局部块的值和全局的值得关系找出来。为了符号的简单,考虑两个块之间的关系。 。令 为局部的最大值, 为全局最大值:

对于第一个块里面的输出值 ,有

第二个块也类似。那么我们就可以看出来,全局softmax值通过局部块的softmax值和这些 这些局部最大值因子可以算出来。所以呢,通过计算局部块的指数 ,累积和 ,以及最值 并保留下来,就可以算出全局的softmax了。

这样的好处是什么呢?

attention里面的softmax,一般值还挺多的,特别是对于长序列,所以整体计算全局softmax,可能cache里面就放不下,就得放到容量大但是比较慢的地方去计算这些数了。而局部的值,数量少,就可以放在cache里面算,算完以后再根据上面的公式,把总体的softmax算出来。

flash attention:利用SRAM作为cache

flash attention是一个attention的算子,主要目的是加速attention的计算。

GPU里面的存储有个层次结构。HBM (high bandwidth memory,可以认为就是cuda编程里面的global memory)就是显卡上边的memory,容量大,但是速度慢; SRAM (Static Random-Access Memory,可以认为就是cuda编程里面的shared memory),容量小,但是速度快。

flash-attention的核心思想就是,把attention的计算分成一小块一小块的,放在SRAM里面算,算完以后再通过前面介绍的关系,把全局的attention值算出来。大大提升了attention的计算速度。

flash-attention还把整个attention的计算做成一个算子,这样就可以把中间的结果给它省掉,大大减小了显存占用。

CPU/GPU计算时候的存储层次结构 from flash-attention

ring attention:利用单GPU卡作为cache

ring attention的主要目的是扩展Transformer的序列长度。计算Transformer序列长度的一个核心困难是算attention的时候,序列太长会OOM。

ring attention的核心想法是,每一个GPU只计算一个局部的attention,然后全局的attention再利用前面的公式给计算出来。这样,因为每个GPU的算的attention长度就没那么长了,就可以计算了,但整体的attention长度就可以大大扩展了。这个attention长度的扩展还是根据GPU数量线性增加的,有多少GPU就能扩多长,所以ring attention的论文题目里说"Near-Infinite Context"。

小结与想法

flash attention已经是LLM训练的标配了。它是一个加速attention的cuda算子;ring attention则是利用分布式计算扩展attention长度的一个工作。然而它们背后的核心则都是softmax局部和全局关系的一个巧妙公式。真的是非常漂亮。

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货


【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读3.2k
粉丝0
内容8.2k