大数跨境
0
0

Self Attention 固定激活值显存分析与优化及PyTorch实现

Self Attention 固定激活值显存分析与优化及PyTorch实现 极市平台
2023-01-07
0
↑ 点击蓝字 关注极市平台
作者丨Connolly@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/445016136
编辑丨极市平台

极市导读

 

通过修改SelfAttention的执行逻辑,就可以节省大量的激活值显存开销。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

通过修改SelfAttention的执行逻辑,可以节省大量的激活值显存开销。

这篇文章的消除方法来自于2021年12月10日谷歌放到arxiv上的文章self attention does not need O(n^2) memory. 该方法巧妙地使用了小学学到的加法分配率,将self attention中的固定激活值降到了O(1)的程度。[1]

Self Attention 固定激活值显存分析

Hugging face Transformers中,SelfAttention 内核实现

表格中只列举了会实测中产生激活值的操作,其中B为Batch_size,L为sequence_length,H为hidden_size,m为SelfAttention中head的数量。

则总和

观察:

  1. 固定时, 即模型结构是固定的时候, 我们发现激活值是和 线性相关的。
  2. 变化时, 我们发现会存在一个常数项 , 我称这个常数激活值开销为固定激活值。这个主要是在Query和Key矩阵做乘法, 以及后续的一些操作中生成的。即在 等操作中出现。

SelfAttention 固定激活值显存优化

1. Prerequisites

1.1 Softmax 计算过程

对于向量 表示 中的第 个元素, 那么这个元素的softmax值为:

1.2 SelfAttention计算过程

为了简化计算,我们先忽略掉Scale和Dropout,因为它们都是单操作数的op,这个忽略不会给我们的分析带来影响。考虑最后输出矩阵第i行,第j列的结果,在原始的实现中,他的计算过程为:

, QK的矩阵乘法, 产生Tensor , shape为

维度的Softmax, 产生Tensor , shape为

. Softmax和Value的矩阵乘, 产生最终输出结果, shape为 .

写成伪代码则为:

"""
inputs: Q[L][H/m], K[L][H/m], V[L][H/m]
outputs: O[L][H/m]

matrix A[L][L]=0, S[L][L]=0, O[L][H/m]=0 # 初始化为0矩阵, A,S为中间激活值矩阵
"""


# QK Matmul
for i in range(L):
    for j in range(L):
        for l in range(H/m):
            A[i][j] += Q[i][l]*Q[l][j]

# Softmax, dim=-1
for i in range(L):
    temp = 0
    for j in range(L):
        S[i][j] = math.exp(A[i][j])
        temp += S[i][j]
    S[i]/=temp

# OV Matmul
for i in range(L):
    for j in range(H/m):
        for l in range(L):
            O[i][j] += S[i][l]*Q[l][j]

return O

2. 显存优化

Google采用了一个非常简单的方法来节省Attention核中的大量的显存开销,具体计算过程为:

, QK的矩阵乘法, 但是不单独执行, 直接代入下一个式子。

, 这里没有除以求和值, 而是把除法挪到了下面。

可以发现, 和原来的算法的差别在于把 的计算放到了后面。采用这种方法的好处是, 我 们可以分开计算 了。

我们用临时变量 来存储这两个值的和, 即

来避开原始的实现中所产生的A和S矩阵。

写成伪代码:

"""
Inputs: Q[L][H/m], K[L][H/m], V[L][H/m]
outputs: O[L][H/m]

matrix O[L][H/m]=0 # 初始化为0矩阵
"""


for i in range(L): # O row, Q row
        sum_s = 0
        for j in range(L): # O column, K^T column, V row
            a_ij = 0
            for k in range(H/m): # Q column, K^T row
                a_ij += Q[i][k]*K[k][j] # Q_i K_j matmul
            a_ij = a_ij / math.sqrt(H) # scale
            s_ij_prime = math.exp(a_ij) # softmax numerator
            sum_s_i += s_prime_ij # softmax denominator of i-th row
            for oj in range(H/m): # broacast along V column axis
                if random.uniform(0,1) > 0.1: # dropout
                    O[i][oj] += s_ij_prime * V[j][oj] # attention weight, V matmul
        O[i][:] = O[i][:] / sum_s # attention weight, V matmul 
return O

一个可行的PyTorch api实现,但是效率很低很低,不可能用的。效率想要高估计还是需要用CUDA去写个算子...按照文章的说法,实现的好的话,推断的时候是可以比原始方法要快的,但是就训练而言,这里在后向过程中肯定需要进行丢失信息的重计算,论文里可以预见的会被原始方法慢两倍。

key_layer = key_layer.transpose(-1, -2)

outputs = torch.zeros([1, self.num_attention_heads, 512, 64])
for i in range(512):  # sequence length
    Qi = torch.narrow(query_layer, 2, i, 1)  # (1, 16, 1, 64)
    sum_s = torch.zeros([1, self.num_attention_heads, 1, 1])
    outputs_i = torch.narrow(outputs, 2, i, 1)  # (1, 16, 1, 64)

    for j in range(512): 
        Kj = torch.narrow(key_layer, 3, j, 1)  # (1, 16, 64, 1)
        A_ij = torch.matmul(Qi, Kj) / math.sqrt(self.attention_head_size)  # (1, 16, 1, 1)
        s_ij_prime = torch.exp(A_ij)
        sum_s.add(s_ij_prime)
        V_j = torch.narrow(value_layer, 2, j, 1)  # (1, 16, 1, 64) jth_row
        if random.uniform(0,1) > 0.1:
            outputs_i.add(s_ij_prime.mul(V_j))  # (1, 16, 1, 64)
     outputs_i.div(sum_s)

outputs = outputs.permute(0, 2, 1, 3).contiguous()
outputs_shape = outputs.size()[
                        :-2] + (self.all_head_size,)
outputs = outputs.view(*outputs_shape)

这个实现增加的显存约为 , 相比 来说已经减少了很多了,拿Bert-Large举例,他的L=512, H=1024,在B等于1的时候,原始实现中每个selfattention的matmul等操作核会产生52MB的显存,改良后则会产生2MB的显存,太顶了。考虑到Bert-Large有24层,这一下就去掉了1.2GB/sample的显存,真的是舒服哇。

不想写CUDA又想要提升性能的话,可以考虑narrow的时候多取几行或者几列,跟GPU的核数对应上应该比较合适(文章里是4096,也忒大了),然后换成einsum的张量乘法实现可调整遍历窗口大小的优化方法。

总结

  • 这个方法跟原始方法在逻辑上是等价的,而且计算复杂度也是一致的。
  • 显存开销极大降低,根据实现的方法,最低是可以到O(1)的,但是为了速度考虑可以适当调整每次narrow出来的size来提高GPU利用率。文章中显存开销是 .
  • 使用的时候需要注意在计算指数的时候可能会存在的溢出问题(这个原始实现里也有),因此文章里面的实现在做指数运算前减去了最大的A_ij值。
  • 收敛性相同,且在训练小Transformer时有4个百分点的速度提升。
  • 需要在Backward的时候重计算丢失掉的信息,这里可能会影响到dropout,所以dropout的结果我猜肯定在前向的时候是不能被丢弃的。
  • 推理系统的福音,可以调整并降低中间产生的激活值峰值,同时保证一定的推理速度。

参考

  1. ^self attention does not need O(n^2) memory https://arxiv.org/abs/2112.05682

公众号后台回复“CNN综述”获取67页综述深度卷积神经网络架构

极市干货

技术干货损失函数技术总结及Pytorch使用示例深度学习有哪些trick?目标检测正负样本区分策略和平衡策略总结

实操教程GPU多卡并行训练总结(以pytorch为例)CUDA WarpReduce 学习笔记卷积神经网络压缩方法总结

极市原创作者激励计划 #


极市平台深耕CV开发者领域近5年,拥有一大批优质CV开发者受众,覆盖微信、知乎、B站、微博等多个渠道。通过极市平台,您的文章的观点和看法能分享至更多CV开发者,既能体现文章的价值,又能让文章在视觉圈内得到更大程度上的推广,并且极市还将给予优质的作者可观的稿酬!

我们欢迎领域内的各位来进行投稿或者是宣传自己/团队的工作,让知识成为最为流通的干货!

对于优质内容开发者,极市可推荐至国内优秀出版社合作出书,同时为开发者引荐行业大牛,组织个人分享交流会,推荐名企就业机会等。


投稿须知:
1.作者保证投稿作品为自己的原创作品。
2.极市平台尊重原作者署名权,并支付相应稿费。文章发布后,版权仍属于原作者。
3.原作者可以将文章发在其他平台的个人账号,但需要在文章顶部标明首发于极市平台

投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读5.7k
粉丝0
内容8.2k