
01
引言
为什么我们需要了解计算深度学习模型中的参数数量?我们一般情况下是不需要这么做的。但是,如果我们需要减小模型的大小,甚至缩短模型推理所需的时间,那么了解模型量化前后的参数数量就会派上用场。
计算深度学习模型中的可训练参数量被认为太琐碎了,因为往往很多代码框架里已经可以帮我们自动做到这一点。但我想把我之前的笔记放在这里,供大家学习参考。
闲话少说,让我们直接开始吧。
02
前置条件
Feed-Forward Neural Network (FFN)
-
Recurrent Neural Network (RNN) -
Convolutional Neural Network (CNN)
from keras.layers import Input, Dense, SimpleRNN, LSTM, GRU, Conv2Dfrom keras.layers import Bidirectionalfrom keras.models import Model
使用上述库函数在建立模型后,通过调用model.count_params()来验证有多少参数用以训练。
03
前馈神经网络--FFN
i: 输入维度
h: 隐藏层大小
o: 网络输出维度
num_params = (connections between layers + biases) in every layer= (i×h + h) + (h×o + o)
我们先来看个图例,如下:

num_params = (3×5+5) + (5×2+2)= 32
我们用代码实现上述过程,如下:
input = Input((None, 3))dense = Dense(5)(input)output= Dense(2)(dense)model = Model(input, output)print(f"train params of the model is {model.count_params()}")
04
循环神经网络--RNN
g: 一个单元中的 FFN 数量(一般来说RNN结构中FFN数量为1,而GRU结构中FFN数量为3个,LSTM结构中FFN数量为4个)
h: 隐藏单元的大小
i: 输入大小
在RNN中对于每个FFN,最开始输入状态和隐藏状态是concat在一起作为输入的,因此每个 FFN 具有(h+i)×h + h 个参数。所以总的参数量的计算公式为:
num_params = g × [(h+i)×h + h]
观察上图,我们将g=4 h=2 i=3带入上式,得到上述LSTM的参数量为:
num_params = g × [(h+i)×h + h]= 4 × [(2+3)×2 + 2]= 48
input = Input((None, 3))lstm = LSTM(2)(input)model = Model(input, lstm)print(f"train params of the model is {model.count_params()}")
结果如下:

05
卷积神经网络--CNN
i: 输入特征图的通道数
f: 滤波器的尺寸
o: 输出的通道数(等于滤波器的个数)
则对应卷积层的参数量计算公式为:
num_params = weights + biases= [i × (f×f) × o] + o

观察上图,我们知道 i=1 f=2 o=3 带入上式,得到结果为:
num_params = [i × (f×f) × o] + o= [1 × (2×2) × 3] + 3= 15
input = Input((None, None, 1))conv2d = Conv2D(kernel_size=2, filters=3)(input)model = Model(input, conv2d)print(f"train params of the model is {model.count_params()}")
得到结果如下:
06
复杂例子
由于卷积神经网络多在计算机视觉领域得到应用,我们再来看个稍微复杂点的例子,针对2个通道输入使用3个2X2的卷积核进行卷积操作,图示如下:

观察上图,我们知道 i=2 f=2 o=3 带入上式,得到结果为:
num_params = [i × (f×f) × o] + o= [2 × (2×2) × 3] + 3= 27
我们用代码进行验证,如下所示:
input = Input((None, None, 2))conv2d = Conv2D(kernel_size=2, filters=3)(input)model = Model(input, conv2d)print(f"train params of the model is {model.count_params()}")
得到结果如下:
07
总结
本文重点介绍了FFN/RNN/CNN等核心组件的参数量的计算方法,并给出了详细的图示和对应的代码实现,学会上述核心组件的计算方法可以加深大家对常见网络结构的深入理解。
您学废了嘛?
点击上方小卡片关注我
新年寄语:
所求皆如愿,
所行皆坦途。
多喜乐,长安宁。

