大数跨境
0
0

实践教程|2022年了,该用pytorch跑量化训练了

实践教程|2022年了,该用pytorch跑量化训练了 极市平台
2022-11-23
1
↑ 点击蓝字 关注极市平台

作者丨jermmy
来源丨AI小男孩
编辑丨极市平台

极市导读

 

简单回顾了下 pytorch 前期一些量化工具的痛点,并通过新的工具 torch.fx 展示了一种更加自动化的对模型进行量化的思路。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

我在之前的文章中写了很多模型量化相关的入门教程,同时也附带提供了一个用 pytorch 从零搭建的量化 demo。有不少同学还把我这个 demo 用到了工作和学习中,这让我很惶恐。从后来跟他们交流中发现,这些同学果然被我这个简陋的 demo 害惨了。。

pytorch量化工具的痛点

简陋的量化demo

其实这个 demo 只是出于教程的目的,让不懂量化的同学能大概知道模型量化到底在做什么。但要想把它当作一个量化库,那就有点过于天真了。

我们简单看一段代码,看看用这个 demo 跑量化有哪些让人蠢哭的操作。

首先,定义一个很简单的卷积网络

class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
        self.fc = nn.Linear(5*5*40, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x

然后,如果要对这个网络做量化的话,就需要把里面的 Conv、ReLU、MaxPool 这些都手动转成量化的模式,并在之后的代码中,调用 quantize_forward 来跑量化训练的前向传播。

class Net(nn.Module):

    def quantize(self, num_bits=8):
        self.qconv1 = QConv2d(self.conv1, qi=True, qo=True, num_bits=num_bits)
        self.qrelu1 = QReLU()
        self.qmaxpool2d_1 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qconv2 = QConv2d(self.conv2, qi=False, qo=True, num_bits=num_bits)
        self.qrelu2 = QReLU()
        self.qmaxpool2d_2 = QMaxPooling2d(kernel_size=2, stride=2, padding=0)
        self.qfc = QLinear(self.fc, qi=False, qo=True, num_bits=num_bits)

    def quantize_forward(self, x):
        x = self.qconv1(x)
        x = self.qrelu1(x)
        x = self.qmaxpool2d_1(x)
        x = self.qconv2(x)
        x = self.qrelu2(x)
        x = self.qmaxpool2d_2(x)
        x = x.view(-1, 5*5*40)
        x = self.qfc(x)
        return x

这里面改动的工作量并不小。这还只是一个小网络,有读者还用它去量化 mobilenet,遇到很多 bug,然后问我怎么解决。我心疼了他三秒钟后,表示爱莫能助,这种量化的方式确实很难定位问题,哪个地方写漏或者写错都有可能出错。

此外,这个工具也不支持 concat、add 这些算子的量化 (这些算子通常需要统计输入输出的数值范围)。因此,我一直强调这只是个入门的 demo,并不适合用来量化复杂点的网络。

其实,不少厂商都有针对 pytorch 开发自己的模型量化工具,例如英特尔的 OpenVino、高通的 AIMET 等。甚至 pytorch 在 1.3 版本开始就加入对量化的支持了。

这些工具都有什么特点呢?

以高通的 AIMET 以及 pytorch 官方的量化库为例。

AIMET

AIMET 的自动化程度当然比我那「破玩意」高很多,不需要自己手动一步步转量化 layer。但用过这个工具的同学应该也会发现,这个工具也存在很多坑。

高通在文档中说明了使用这个工具的一些注意事项。

比如说,在模型定义的时候,不能使用 torch.nn.functional 里面的函数,必须用 torch.nn.Module 定义。也就是 ReLU 这样的函数,要使用 torch.nn.ReLU来定义,不能在 forward 里面使用 torch.nn.functional.relu

此外,Module 不能重复使用。也就是说,如果有多个 ReLU,那就不能使用下面这种代码:

def __init__(self,...):
    ...
    self.relu = torch.nn.ReLU()
    ...

def forward(...):
    ...
    x = self.relu(x)
    ...
    x2 = self.relu(x2)
    ...

而应该这样写:

def __init__(self,...):
    ...
    self.relu = torch.nn.ReLU()
    self.relu2 = torch.nn.ReLU()
    ...

def forward(...):
    ...
    x = self.relu(x)
    ...
    x2 = self.relu2(x2)
    ...

这个写法和我自己写的那个 demo 其实很类似了,主要就是因为不同 layer 的量化参数是要单独统计的,不能共享。

在一条条遵守高通的规范定义好网络后,就可以用 AIMET 提供的工具一键量化并训练了:

def quantization_aware_training_range_learning(forward_pass):
    model = mnist_model.Net().to(device='cuda')
    # 这一步一键量化
    sim = QuantizationSimModel(model, quant_scheme=QuantScheme.training_range_learning_with_tf_init, default_output_bw=8,
                               default_param_bw=8, input_shapes=(1, 1, 28, 28))

    # Initialize the model with encodings
    sim.compute_encodings(forward_pass, forward_pass_callback_args=5)

    # 量化训练
    sim.model.train()
    mnist_model.train(sim.model, epochs=1, num_batches=100, use_cuda=True)

torch.quantization

再来看看 pytorch 官方早期推出的量化工具 (eager quantization)。

pytorch 的 eager quantization 也跟 AIMET 一样,不能使用 torch.nn.functional,而且要用 torch.nn.quantized.FloatFunctional 把 Add、Concat 这些封装起来。

针对后端推理引擎的不同,还有一些额外的操作要处理,具体地可以看看官方的文档,或者参考我之前的文章。这套操作下来总结一句话就是:劝退。

torch.fx:全新的版本

为什么这些量化工具都这么难用呢?这跟 pytorch 动态图的机制有关。

在做量化训练的时候,需要在网络结构的每个节点中插入伪量化节点。这些节点主要有两个作用:一是统计 weight 和 feature map 的 minmax 等量化参数;二是量化训练的时候承担 STE 的作用,可以让反向传播进行下去。

大家都知道,pytorch 中真正定义网络结构的地方是在 forward 函数。因此,要完成自动插入伪量化节点这个工作,就需要一套工具,可以自动对 forward 函数进行解析,得到整个网络的拓扑结构。

不过,遗憾的是,在 torch.fx 之前,这样的工具几乎没有。所以,大部分量化框架的做法,是在 __init__ 函数定义的 Module 中插入伪量化节点。因为 pytorch 会自动注册这部分 Module,所以可以逐个遍历这些 Module 并自动化地插入伪量化节点。但对于非 Module 定义的节点就爱莫能助了。

不过,自从 pytorch1.8 推出 FX 这套工具后,情况有了很大的改观。这套工具终于可以比较好地解析整个 forward 流程,并以 python 代码的形式返回整个网络结构。

有了这套工具加持后,可以做到真正自动化地插入伪量化节点,我在之前的文章中也对比了 FX 和 eager quantization 的差别,真的是全新的体验。

进阶版demo

下面简单展示一下,如何用 FX 来让我之前写的 demo 更加自动化一些。

还是以前面提到的这个简单的网络为例:

class Net(nn.Module):

    def __init__(self, num_channels=1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 40, 3, 1)
        self.conv2 = nn.Conv2d(40, 40, 3, 1, groups=20)
        self.fc = nn.Linear(5*5*40, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 5*5*40)
        x = self.fc(x)
        return x

现在,我们希望在不做任何修改的情况下,借助 FX 自动完成插入伪量化节点的工作。

首先,我们用 FX 解析一下这个网络

class Net(nn.Module):
  
    def quantize(self, num_bits=8):
        graph_model = torch.fx.symbolic_trace(self)
        print(graph_model)

打印出来的结果如下:

Net(
  (conv1): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
  (fc): Linear(in_features=1000, out_features=10, bias=True)
)

def forward(self, x):
    conv1 = self.conv1(x);  x = None
    relu = torch.nn.functional.relu(conv1, inplace = False);  conv1 = None
    max_pool2d = torch.nn.functional.max_pool2d(relu, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False);  relu = None
    conv2 = self.conv2(max_pool2d);  max_pool2d = None
    relu_1 = torch.nn.functional.relu(conv2, inplace = False);  conv2 = None
    max_pool2d_1 = torch.nn.functional.max_pool2d(relu_1, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False);  relu_1 = None
    view = max_pool2d_1.view(-1, 1000);  max_pool2d_1 = None
    fc = self.fc(view);  view = None
    return fc

FX 在解析完网络结构后,会动态修改 forward 函数的代码。换句话说,经过torch.fx.symbolic_trace(self) 处理后,我们原先的 forward 代码就被替换成上面这种形式了。

接下来,我们将 __init__ 中定义的 Module 替换成量化的形式:

class Net(nn.Module):
    def quantize(self, num_bits=8):
        graph_model = torch.fx.symbolic_trace(self)
        # module quant
        graph_model = self._module_quant(graph_model, num_bits)
        print(graph_model)

    def _module_quant(self, graph_model: GraphModule, num_bits=8):
        device = list(graph_model.parameters())[0].device
        reassign = {}
        for i, (name, mod) in enumerate(graph_model.named_children()):
            qi = False
            qo = True
            if i == 0:
                qi = True
            if isinstance(mod, nn.Conv2d):
                reassign[name] = QConv2d(mod, qi, qo, num_bits).to(device)
            elif isinstance(mod, nn.Linear):
                reassign[name] = QLinear(mod, qi, qo, num_bits).to(device)
        
        for key, value in reassign.items():
            graph_model._modules[key] = value
        
        return graph_model

_module_quant 这一步主要是遍历原来网络中定义好的 Module,把它们替换成对应的量化模块。AIMET 等工具的自动化基本只做到这一步。

经过这一步替换后,网络变成了:

Net(
  (conv1): QConv2d(
    (qi): QParam()
    (qo): QParam()
    (conv_module): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
    (qw): QParam()
  )
  (conv2): QConv2d(
    (qo): QParam()
    (conv_module): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
    (qw): QParam()
  )
  (fc): QLinear(
    (qo): QParam()
    (fc_module): Linear(in_features=1000, out_features=10, bias=True)
    (qw): QParam()
  )
)

def forward(self, x):
    conv1 = self.conv1(x);  x = None
    relu = torch.nn.functional.relu(conv1, inplace = False);  conv1 = None
    max_pool2d = torch.nn.functional.max_pool2d(relu, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False);  relu = None
    conv2 = self.conv2(max_pool2d);  max_pool2d = None
    relu_1 = torch.nn.functional.relu(conv2, inplace = False);  conv2 = None
    max_pool2d_1 = torch.nn.functional.max_pool2d(relu_1, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False);  relu_1 = None
    view = max_pool2d_1.view(-1, 1000);  max_pool2d_1 = None
    fc = self.fc(view);  view = None
    return fc

可以看出,Module 都是量化的模块了。

接下来才是 FX 发挥作用的地方,我们要把 forward 中使用到的各种 torch.nn.functional 等函数也替换成量化的模块。

class Net(nn.Module):
    def quantize(self, num_bits=8):
        graph_model = torch.fx.symbolic_trace(self)
        # module quant
        graph_model = self._module_quant(graph_model, num_bits)
        # function quant
        self._function_quant(graph_model, num_bits)
        print(graph_model)
 
    def _function_quant(self, graph_model: GraphModule, num_bits=8):
        device = list(graph_model.parameters())[0].device
        nodes = list(graph_model.graph.nodes)
        for i, node in enumerate(nodes):
            if node.op == "call_function":
                if node.target == F.relu:
                    setattr(graph_model, "qrelu_%d" % i, QReLU().to(device))
                    with graph_model.graph.inserting_after(node):
                        new_node = graph_model.graph.call_module("qrelu_%d" % i, node.args, node.kwargs)
                        node.replace_all_uses_with(new_node)
                elif node.target == F.max_pool2d:
                    setattr(graph_model, "qmaxpool2d_%d" % i, QMaxPooling2d().to(device))
                    with graph_model.graph.inserting_after(node):
                        new_node = graph_model.graph.call_module("qmaxpool2d_%d" % i, node.args, node.kwargs)
                        node.replace_all_uses_with(new_node)
                graph_model.graph.erase_node(node)
        
        graph_model.recompile()
        return graph_model

这里面的核心是 _function_quant 这个函数。这个函数会遍历整个网络中的节点,然后找出 torch.nn.functional 相关的节点,把它们替换成量化的模块。

在我们这个例子中,使用到了 torch.nn.functional.relutorch.nn.functional.max_pool2d 两个函数。因此,在捕获到这两个节点后,需要分别在原来的网络中加入 QReLUQMaxPooling2d 这两个量化的模块。同时,再通过 FX 中提供的对网络结构进行修改的函数,把原先的 relumax_pool2d 替换成这两个量化模块。

经过这一连串操作后,我们的网络变成了下面这个样子:

Net(
  (conv1): QConv2d(
    (qi): QParam()
    (qo): QParam()
    (conv_module): Conv2d(1, 40, kernel_size=(3, 3), stride=(1, 1))
    (qw): QParam()
  )
  (conv2): QConv2d(
    (qo): QParam()
    (conv_module): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), groups=20)
    (qw): QParam()
  )
  (fc): QLinear(
    (qo): QParam()
    (fc_module): Linear(in_features=1000, out_features=10, bias=True)
    (qw): QParam()
  )
  (qrelu_2): QReLU(
    (qo): QParam()
  )
  (qmaxpool2d_3): QMaxPooling2d(
    (qo): QParam()
  )
  (qrelu_5): QReLU(
    (qo): QParam()
  )
  (qmaxpool2d_6): QMaxPooling2d(
    (qo): QParam()
  )
)

def forward(self, x):
    conv1 = self.conv1(x);  x = None
    qrelu_2 = self.qrelu_2(conv1, inplace = False);  conv1 = None
    qmaxpool2d_3 = self.qmaxpool2d_3(qrelu_2, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False);  qrelu_2 = None
    conv2 = self.conv2(qmaxpool2d_3);  qmaxpool2d_3 = None
    qrelu_5 = self.qrelu_5(conv2, inplace = False);  conv2 = None
    qmaxpool2d_6 = self.qmaxpool2d_6(qrelu_5, 2, stride = 2, padding = 0, dilation = 1, ceil_mode = False, return_indices = False);  qrelu_5 = None
    view = qmaxpool2d_6.view(-1, 1000);  qmaxpool2d_6 = None
    fc = self.fc(view);  view = None
    return fc

看到没,原来 forward 中的 relumax_pool2d 现在都变成了 QReLUQMaxPooling2d。这样,我们在不做任何修改的情况下,完成了插入伪量化节点的工作。之后,就可以像训练普通的网络那样,直接做量化训练了。

完整的代码请看:https://github.com/Jermmy/pytorch-quantization-demo/blob/test_fx/model.py#L27

当然啦,这个新的 demo 仍然非常的简陋,很多功能并不完善 (比如不支持 add 和 concat 的量化,针对每个模型仍然需要单独定制 quantize 函数),因此不推荐大家在工作中使用这些代码。日常工作学习还是推荐大家使用更成熟的工具。

总结

这篇文章算是简单回顾了下 pytorch 前期一些量化工具的痛点,并通过新的工具 torch.fx 展示了一种更加自动化的对模型进行量化的思路。在 FX 的加持下,用 pytorch 做模型量化将变成一件很舒服的事情。


公众号后台回复“速查表”获取

21张速查表(神经网络、线性代数、可视化等)打包下载~

极市干货

算法竞赛:算法offer直通车、50万总奖池!高通人工智能创新应用大赛等你来战!

技术干货超简单正则表达式入门教程22 款神经网络设计和可视化的工具大汇总

极视角动态:芜湖市湾沚区联手极视角打造核酸检测便民服务系统上线!青岛市委常委、组织部部长于玉一行莅临极视角调研

CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart2)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


极市&深大CV技术交流群已创建,欢迎深大校友加入,在群内自由交流学术心得,分享学术讯息,共建良好的技术交流氛围。

点击阅读原文进入CV社区

获取更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读3.2k
粉丝0
内容8.2k