大数跨境
0
0

复数神经网络及其 PyTorch 实现

复数神经网络及其 PyTorch 实现 极市平台
2022-01-17
0
导读:几种复数操作进行介绍,并给出简单的 Pytorch 实现方法。
↑ 点击蓝字 关注极市平台

作者丨科技猛兽
编辑丨极市平台

极市导读

 

实数网络在图像领域取得极大成功,但在音频中,信号特征大多数是复数,如频谱等。简单分离实部虚部,或者考虑幅度和相位角都丢失了复数原本的关系。论文按照复数计算的定义,设计了深度复数网络,能对复数的输入数据进行卷积、激活、批规范化等操作。这里对论文提出的几种复数操作进行介绍,并给出简单的 Pytorch 实现方法。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

虽然叫深度复数网络,但里面的操作实际上还是在实数空间进行的。但通过实数的层实现类似于复数计算的操作。

目录

1 PyTorch 中的复数张量形式

2 复数神经网络背景

3 复数卷积操作
3.1 复数卷积原理
3.2 复数卷积 PyTorch 实现
3.3 复数的反向传播
3.4 柯西-黎曼方程 (Cauchy–Riemann Equation)

4 复数 LSTM 操作
4.1 复数 LSTM 原理
4.2 复数 LSTM PyTorch 实现

5 复数激活函数
5.1 复数激活函数原理
5.2 复数激活函数 PyTorch 实现

6 复数 Dropout
6.1 复数 Dropout原理
6.2 复数 Dropout PyTorch 实现

7 复数权重初始化
7.1 复数权重初始化原理

8 复数 Batch Normalization
8.1 复数 BN 原理
8.2 复数 BN PyTorch 实现

9 完整模型搭建

1 PyTorch 中的复数张量形式

PyTorch 1.8 及之后都支持2种复数形式的 Tensor,它们分别是:

意味着 torch 中有表示 complex 的张量形式,即:

torch.complex(real, imag, *, out=None) → Tensor

构造一个复数张量,其实部等于 real,虚部等于 imag。

Parameters

  • real (Tensor): 复数张量的实数部分。必须为 float 或 double。
  • imag (Tensor): 复数张量的虚部。dtype 必须与实部 real 相同。

关键字参数:

out (Tensor): 如果输入为 torch.float32 ,则必须为 torch.complex64 。如果输入为 torch.float64 ,则必须为 torch.complex128

torch.is_complex(input)

返回 input 是不是复数形式,也就是torch.complex64, 和torch.complex128中的一种。

2 复数神经网络背景

众所周知, 从计算、生物和信号处理的角度来看,使用复数有许多优点。所以,复数相对于实数具有更强的表达能力。若能够借助复数设计神经网络,则非常具有吸引力。但是一个难题是如何设计配套的各种网络的 building block,比如说 complex BN,complex weight initialization 等等。

复数神经网络也有一些生物学上的优势,即:若网络中的数据都是实数,则只能代表某个中间输出的具体的值的大小;反之,若网络中的数据都是复数,则不仅能代表某个中间输出的具体的值的大小 (复数的模长),还可以代表时间的概念 (复数的相位)。具有相似相位的输入神经元是同步的 (synchronous),因为它们在复数运算中是相加的,而异步神经元相加则具有破坏性 (asynchronous),因此相互干扰。

复数神经网络也有一些信号处理方面的优势,即:复数蕴含着相位信息,而语音信号中的相位信息影响其可懂度。奥本海姆的研究表明,在图像的相位中存在的信息量足以恢复以其幅值编码的大部分信息。事实上,相位信息在对物体的形状,边缘和方向进行编码时,提供了对物体的详细描述。

本文开发了适当的工具和一个通用的框架来训练具有复杂参数的深层神经网络。

3 复数卷积操作

3.1 复数卷积原理

任意的一个复数  ,其实部为  ,虚部为  。作者将复数的实部和虚部表示为逻辑上不同的实值实体,并在内部使用实值算术模拟复数运算。假设一个卷积核,权重是  ,则它可以表示成  个复数权重。

复数域上执行传统的实值二维卷积:

复数卷积核: 

复数输入张量: 

复数卷积过程: 

在具体实现中,可以使用下图1所示的简单结构实现。

图1:复数域上执行传统的实值二维卷积的过程

如下图1所示,把上式写成矩阵的形式,就有:

3.2 复数卷积 PyTorch 实现

PyTorch 实现复数的操作基于 apply_complex 这个方法。

def apply_complex(fr, fi, input, dtype = torch.complex64):
return (fr(input.real)-fi(input.imag)).type(dtype) \
+ 1j*(fr(input.imag)+fi(input.real)).type(dtype)

这个函数需要传入2个操作 (nn.Conv2d, nn.Linear 等等)torch.complex64 类型的 input
fr(input.real): 卷积核的实部 * (输入的实部)。
fi(input.imag): 卷积核的虚部 * (输入的虚部)
fr(input.imag): 卷积核的实部 * (输入的虚部)
fi(input.real): 卷积核的虚部 * (输入的实部)
input 类型: torch.complex64
返回值类型: torch.complex64

因此,利用 Pytorch 的 nn.Conv2D 实现,严格遵守上面复数卷积的定义式:

class ComplexConv2d(Module):

def __init__(self,in_channels, out_channels, kernel_size=3, stride=1, padding = 0,
dilation=1, groups=1, bias=True):
super(ComplexConv2d, self).__init__()
self.conv_r = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv_i = Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

def forward(self,input):
return apply_complex(self.conv_r, self.conv_i, input)

同理还可以实现 Pytorch 的 nn.Linear和 Pytorch 的 nn.ConvTranspose2d:

class ComplexLinear(Module):

def __init__(self, in_features, out_features):
super(ComplexLinear, self).__init__()
self.fc_r = Linear(in_features, out_features)
self.fc_i = Linear(in_features, out_features)

def forward(self, input):
return apply_complex(self.fc_r, self.fc_i, input)

class ComplexConvTranspose2d(Module):

def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0,
output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'):

super(ComplexConvTranspose2d, self).__init__()

self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding,
output_padding, groups, bias, dilation, padding_mode)
self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding,
output_padding, groups, bias, dilation, padding_mode)

具体实现的思路相似,都是借助了 apply_complex 函数,传入2个操作 (nn.Conv2d, nn.Linear 等等)torch.complex64 类型的 input,然后在 ComplexLinear (或 ComplexConvTranspose2d) 中分别计算。

3.3 复数的反向传播

为了在复数神经网络中进行反向传播,一个充分条件是网络训练的目标函数和激活函数对网络中每个 complex parameter 的实部和虚部都是可微的。通常损失函数都是实数,则复数 chain rule 如下:

如果  是实数损失函数,  为复变量,满足  ,则有:

如果现在有另一个复数  ,且  ,则根据偏导数的链式法则:

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读8.7k
粉丝0
内容8.2k