
极市导读
简单回顾了下 pytorch 前期一些量化工具的痛点,并通过新的工具 torch.fx 展示了一种更加自动化的对模型进行量化的思路。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
我在之前的文章中写了很多模型量化相关的入门教程,同时也附带提供了一个用 pytorch 从零搭建的量化 demo。有不少同学还把我这个 demo 用到了工作和学习中,这让我很惶恐。从后来跟他们交流中发现,这些同学果然被我这个简陋的 demo 害惨了。。
pytorch量化工具的痛点
简陋的量化demo
其实这个 demo 只是出于教程的目的,让不懂量化的同学能大概知道模型量化到底在做什么。但要想把它当作一个量化库,那就有点过于天真了。
我们简单看一段代码,看看用这个 demo 跑量化有哪些让人蠢哭的操作。
首先,定义一个很简单的卷积网络
class Net(nn.Module):
def __init__(self, num_channels=1):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
self.fc = nn.Linear(5*5*40, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*40)
x = self.fc(x)
return x
然后,如果要对这个网络做量化的话,就需要把里面的 Conv、ReLU、MaxPool 这些都手动转成量化的模式,并在之后的代码中,调用 quantize_forward 来跑量化训练的前向传播。
class Net(nn.Module):
def quantize(self, num_bits=8):
self.qconv1 = QConv2d(self.conv1, qi=True, qo=True, num_bits=num_bits)
self.qrelu1 = QReLU()
self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
self.qconv2 = QConv2d(self.conv2, qi=False, qo=True, num_bits=num_bits)
self.qrelu2 = QReLU()
self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)
def quantize_forward(self, x):
x = self.qconv1(x)
x = self.qrelu1(x)
x = self.qmaxpool2d_1(x)
x = self.qconv2(x)
x = self.qrelu2(x)
x = self.qmaxpool2d_2(x)
x = x.view(-1, 5*5*40)
x = self.qfc(x)
return x
这里面改动的工作量并不小。这还只是一个小网络,有读者还用它去量化 mobilenet,遇到很多 bug,然后问我怎么解决。我心疼了他三秒钟后,表示爱莫能助,这种量化的方式确实很难定位问题,哪个地方写漏或者写错都有可能出错。
此外,这个工具也不支持 concat、add 这些算子的量化 (这些算子通常需要统计输入输出的数值范围)。因此,我一直强调这只是个入门的 demo,并不适合用来量化复杂点的网络。
其实,不少厂商都有针对 pytorch 开发自己的模型量化工具,例如英特尔的 OpenVino、高通的 AIMET 等。甚至 pytorch 在 1.3 版本开始就加入对量化的支持了。
这些工具都有什么特点呢?
以高通的 AIMET 以及 pytorch 官方的量化库为例。
AIMET
AIMET 的自动化程度当然比我那「破玩意」高很多,不需要自己手动一步步转量化 layer。但用过这个工具的同学应该也会发现,这个工具也存在很多坑。
高通在文档中说明了使用这个工具的一些注意事项。
比如说,在模型定义的时候,不能使用 torch.nn.functional 里面的函数,必须用 torch.nn.Module 定义。也就是 ReLU 这样的函数,要使用 torch.nn.ReLU来定义,不能在 forward 里面使用 torch.nn.functional.relu。
此外,Module 不能重复使用。也就是说,如果有多个 ReLU,那就不能使用下面这种代码:
def __init__(self,...):
...
self.relu = torch.nn.ReLU()
...
def forward(...):
...
x = self.relu(x)
...
x2 = self.relu(x2)
...
而应该这样写:
def __init__(self,...):
...
self.relu = torch.nn.ReLU()
self.relu2 = torch.nn.ReLU()
...
def forward(...):
...
x = self.relu(x)
...
x2 = self.relu2(x2)
...
这个写法和我自己写的那个 demo 其实很类似了,主要就是因为不同 layer 的量化参数是要单独统计的,不能共享。
在一条条遵守高通的规范定义好网络后,就可以用 AIMET 提供的工具一键量化并训练了:
def quantization_aware_training_range_learning(forward_pass):
model = mnist_model.Net().to(device='cuda')
# 这一步一键量化
sim = QuantizationSimModel(model, quant_scheme=QuantScheme.training_range_learning_with_tf_init, default_output_bw=8,
default_param_bw=8, input_shapes=(1, 1, 28, 28))
# Initialize the model with encodings
sim.compute_encodings(forward_pass, forward_pass_callback_args=5)
# 量化训练
sim.model.train()
mnist_model.train(sim.model, epochs=1, num_batches=100, use_cuda=True)
torch.quantization
再来看看 pytorch 官方早期推出的量化工具 (eager quantization)。
pytorch 的 eager quantization 也跟 AIMET 一样,不能使用 torch.nn.functional,而且要用 torch.nn.quantized.FloatFunctional 把 Add、Concat 这些封装起来。
针对后端推理引擎的不同,还有一些额外的操作要处理,具体地可以看看官方的文档,或者参考我之前的文章。这套操作下来总结一句话就是:劝退。
torch.fx:全新的版本
为什么这些量化工具都这么难用呢?这跟 pytorch 动态图的机制有关。
在做量化训练的时候,需要在网络结构的每个节点中插入伪量化节点。这些节点主要有两个作用:一是统计 weight 和 feature map 的 minmax 等量化参数;二是量化训练的时候承担 STE 的作用,可以让反向传播进行下去。
大家都知道,pytorch 中真正定义网络结构的地方是在 forward 函数。因此,要完成自动插入伪量化节点这个工作,就需要一套工具,可以自动对 forward 函数进行解析,得到整个网络的拓扑结构。
不过,遗憾的是,在 torch.fx 之前,这样的工具几乎没有。所以,大部分量化框架的做法,是在 __init__ 函数定义的 Module 中插入伪量化节点。因为 pytorch 会自动注册这部分 Module,所以可以逐个遍历这些 Module 并自动化地插入伪量化节点。但对于非 Module 定义的节点就爱莫能助了。
不过,自从 pytorch1.8 推出 FX 这套工具后,情况有了很大的改观。这套工具终于可以比较好地解析整个 forward 流程,并以 python 代码的形式返回整个网络结构。
有了这套工具加持后,可以做到真正自动化地插入伪量化节点,我在之前的文章中也对比了 FX 和 eager quantization 的差别,真的是全新的体验。
进阶版demo
下面简单展示一下,如何用 FX 来让我之前写的 demo 更加自动化一些。
还是以前面提到的这个简单的网络为例:
class Net(nn.Module):
def __init__(self, num_channels=1):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
self.fc = nn.Linear(5*5*40, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*40)
x = self.fc(x)
return x
现在,我们希望在不做任何修改的情况下,借助 FX 自动完成插入伪量化节点的工作。
首先,我们用 FX 解析一下这个网络:
class Net(nn.Module):
def quantize(self, num_bits=8):
graph_model = torch.fx.symbolic_trace(self)
print(graph_model)
打印出来的结果如下:
Net(
(conv1): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
(fc): Linear(in_features=1000, out_features=10, bias=True)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
relu = torch.nn.functional.relu(conv1, inplace = False); conv1 = None
max_pool2d = torch.nn.functional.max_pool2d(relu, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False); relu = None
conv2 = self.conv2(max_pool2d); max_pool2d = None
relu_1 = torch.nn.functional.relu(conv2, inplace = False); conv2 = None
max_pool2d_1 = torch.nn.functional.max_pool2d(relu_1, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False); relu_1 = None
view = max_pool2d_1.view(-1, 1000); max_pool2d_1 = None
fc = self.fc(view); view = None
return fc
FX 在解析完网络结构后,会动态修改 forward 函数的代码。换句话说,经过torch.fx.symbolic_trace(self) 处理后,我们原先的 forward 代码就被替换成上面这种形式了。
接下来,我们将 __init__ 中定义的 Module 替换成量化的形式:
class Net(nn.Module):
def quantize(self, num_bits=8):
graph_model = torch.fx.symbolic_trace(self)
# module quant
graph_model = self._module_quant(graph_model, num_bits)
print(graph_model)
def _module_quant(self, graph_model: GraphModule, num_bits=8):
device = list(graph_model.parameters())[0].device
reassign = {}
for i, (name, mod) in enumerate(graph_model.named_children()):
qi = False
qo = True
if i == 0:
qi = True
if isinstance(mod, nn.Conv2d):
reassign[name] = QConv2d(mod, qi, qo, num_bits).to(device)
elif isinstance(mod, nn.Linear):
reassign[name] = QLinear(mod, qi, qo, num_bits).to(device)
for key, value in reassign.items():
graph_model._modules[key] = value
return graph_model
_module_quant 这一步主要是遍历原来网络中定义好的 Module,把它们替换成对应的量化模块。AIMET 等工具的自动化基本只做到这一步。
经过这一步替换后,网络变成了:
Net(
(conv1): QConv2d(
(qi): QParam()
(qo): QParam()
(conv_module): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
(qw): QParam()
)
(conv2): QConv2d(
(qo): QParam()
(conv_module): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
(qw): QParam()
)
(fc): QLinear(
(qo): QParam()
(fc_module): Linear(in_features=1000, out_features=10, bias=True)
(qw): QParam()
)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
relu = torch.nn.functional.relu(conv1, inplace = False); conv1 = None
max_pool2d = torch.nn.functional.max_pool2d(relu, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False); relu = None
conv2 = self.conv2(max_pool2d); max_pool2d = None
relu_1 = torch.nn.functional.relu(conv2, inplace = False); conv2 = None
max_pool2d_1 = torch.nn.functional.max_pool2d(relu_1, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False); relu_1 = None
view = max_pool2d_1.view(-1, 1000); max_pool2d_1 = None
fc = self.fc(view); view = None
return fc
可以看出,Module 都是量化的模块了。
接下来才是 FX 发挥作用的地方,我们要把 forward 中使用到的各种 torch.nn.functional 等函数也替换成量化的模块。
class Net(nn.Module):
def quantize(self, num_bits=8):
graph_model = torch.fx.symbolic_trace(self)
# module quant
graph_model = self._module_quant(graph_model, num_bits)
# function quant
self._function_quant(graph_model, num_bits)
print(graph_model)
def _function_quant(self, graph_model: GraphModule, num_bits=8):
device = list(graph_model.parameters())[0].device
nodes = list(graph_model.graph.nodes)
for i, node in enumerate(nodes):
if node.op == "call_function":
if node.target == F.relu:
setattr(graph_model, "qrelu_%d" % i, QReLU().to(device))
with graph_model.graph.inserting_after(node):
new_node = graph_model.graph.call_module("qrelu_%d" % i, node.args, node.kwargs)
node.replace_all_uses_with(new_node)
elif node.target == F.max_pool2d:
setattr(graph_model, "qmaxpool2d_%d" % i, QMaxPooling2d().to(device))
with graph_model.graph.inserting_after(node):
new_node = graph_model.graph.call_module("qmaxpool2d_%d" % i, node.args, node.kwargs)
node.replace_all_uses_with(new_node)
graph_model.graph.erase_node(node)
graph_model.recompile()
return graph_model
这里面的核心是 _function_quant 这个函数。这个函数会遍历整个网络中的节点,然后找出 torch.nn.functional 相关的节点,把它们替换成量化的模块。
在我们这个例子中,使用到了 torch.nn.functional.relu 和 torch.nn.functional.max_pool2d 两个函数。因此,在捕获到这两个节点后,需要分别在原来的网络中加入 QReLU 和 QMaxPooling2d 这两个量化的模块。同时,再通过 FX 中提供的对网络结构进行修改的函数,把原先的 relu 和 max_pool2d 替换成这两个量化模块。
经过这一连串操作后,我们的网络变成了下面这个样子:
Net(
(conv1): QConv2d(
(qi): QParam()
(qo): QParam()
(conv_module): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
(qw): QParam()
)
(conv2): QConv2d(
(qo): QParam()
(conv_module): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
(qw): QParam()
)
(fc): QLinear(
(qo): QParam()
(fc_module): Linear(in_features=1000, out_features=10, bias=True)
(qw): QParam()
)
(qrelu_2): QReLU(
(qo): QParam()
)
(qmaxpool2d_3): QMaxPooling2d(
(qo): QParam()
)
(qrelu_5): QReLU(
(qo): QParam()
)
(qmaxpool2d_6): QMaxPooling2d(
(qo): QParam()
)
)
def forward(self, x):
conv1 = self.conv1(x); x = None
qrelu_2 = self.qrelu_2(conv1, inplace = False); conv1 = None
qmaxpool2d_3 = self.qmaxpool2d_3(qrelu_2, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False); qrelu_2 = None
conv2 = self.conv2(qmaxpool2d_3); qmaxpool2d_3 = None
qrelu_5 = self.qrelu_5(conv2, inplace = False); conv2 = None
qmaxpool2d_6 = self.qmaxpool2d_6(qrelu_5, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False); qrelu_5 = None
view = qmaxpool2d_6.view(-1, 1000); qmaxpool2d_6 = None
fc = self.fc(view); view = None
return fc
看到没,原来 forward 中的 relu 和 max_pool2d 现在都变成了 QReLU 和 QMaxPooling2d。这样,我们在不做任何修改的情况下,完成了插入伪量化节点的工作。之后,就可以像训练普通的网络那样,直接做量化训练了。
完整的代码请看:https://github.com/Jermmy/pytorch-quantization-demo/blob/test_fx/model.py#L27
当然啦,这个新的 demo 仍然非常的简陋,很多功能并不完善 (比如不支持 add 和 concat 的量化,针对每个模型仍然需要单独定制 quantize 函数),因此不推荐大家在工作中使用这些代码。日常工作学习还是推荐大家使用更成熟的工具。
总结
这篇文章算是简单回顾了下 pytorch 前期一些量化工具的痛点,并通过新的工具 torch.fx 展示了一种更加自动化的对模型进行量化的思路。在 FX 的加持下,用 pytorch 做模型量化将变成一件很舒服的事情。

公众号后台回复“速查表”获取
21张速查表(神经网络、线性代数、可视化等)打包下载~
算法竞赛:算法offer直通车、50万总奖池!高通人工智能创新应用大赛等你来战!
技术干货:超简单正则表达式入门教程|22 款神经网络设计和可视化的工具大汇总
极视角动态:芜湖市湾沚区联手极视角打造核酸检测便民服务系统上线!|青岛市委常委、组织部部长于玉一行莅临极视角调研
# CV技术社群邀请函 #
备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)
即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群
极市&深大CV技术交流群已创建,欢迎深大校友加入,在群内自由交流学术心得,分享学术讯息,共建良好的技术交流氛围。
点击阅读原文进入CV社区
获取更多技术干货

