大数跨境
0
0

Python深度学习之生成对抗神经网络(GAN)简介及实现

Python深度学习之生成对抗神经网络(GAN)简介及实现 数据皮皮侠
2020-03-28
1
导读:What I cannot create, I do not understand.

引言

What I cannot create, I do not understand.  -- Richard Feynman


自2014年Ian Goodfellow提出了GAN(Generative Adversarial Network)以来,对GAN的研究可谓如火如荼,各种各样的变体层出不穷。有位名叫Avinash Hindupur的国际友人建立了一个GAN Zoo,他的“动物园”里目前已经收集了近500种有名有姓的GAN.,下图是GAN相关论文的发表情况。

图1 近年GAN相关论文发表情况

大牛Yann LeCun甚至评价GAN为 “adversarial training is the coolest thing since sliced bread”。

图2 大牛评价

生成对抗网络是现在人工智能领域的当红技术之一,网络是近两年深度学习领域的新秀,火的不行,本文旨在浅显理解GAN,分享学习心得。

GAN原理

生成式对抗网络GAN (Generative adversarial networks) 是 Goodfellow 等在 2014 年提出的一种生成式模型(原文arxiv:https://arxiv.org/abs/1406.2661)。GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习神经网络上来说,就是通过生成网络G(Generator)和判别网络D(Discriminator)双方的博弈学习,相互提高,最终达到一个平衡点,最终的均衡点为纳什均衡点,此时生成器 G 生成的样本非常逼真,使得鉴别器 D 真假难分。两者的关系就像学生和老师,学生在不断犯错中取得进步,老师也在这个过程逐渐学会了评价学生的能力,成为了一个优秀的老师。或就像台大李宏毅老师所说,生成器和判别器的关系就像鸣佐一样,“写作敌人,念作朋友”。

图3 鸣佐

了解了GAN的基本思想后,让我们来具体看看生成器和判别器怎么实现。以生成图片为例进行说明,假设我们有两个网络,G(Generator)和D(Discriminator),它们的功能分别是:

l  G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)。

l  D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

图4生成网络

图5生成网络和判别网络

 了解了生成器和判别器后,还有一个问题,整个网络是如何进行训练的呢?实际上,GAN 博弈学习的思想体现在在它的训练方式上,由于生成器 G 和判别器 D 的优化目标不一样,不能和之前的网络训练一样只采用一个损失函数。因此那么训练这样的两个模型的大方法就是:单独交替迭代训练。

  假设现在生成网络模型已经有了,那么给一堆随机数组,就会得到一堆假的样本集,现在我们人为的定义真假样本集的标签,因为我们希望真样本集的输出尽可能为1,假样本集为0,很明显这里我们就已经默认真样本集所有的类标签都为1,而假样本集的所有类标签都为0。这样单就判别网络来说,此时问题就变成了一个再简单不过的有监督的二分类问题了,直接送到神经网络模型中训练就完事了。判别器就训练完了,下面我们来看生成网络。对于生成网络,我们的目的是生成尽可能逼真的样本。那么原始的生成网络生成的样本你怎么知道它真不真呢?因此在训练生成网络的时候,我们需要联合判别网络一起才能达到训练的目的。

  那么现在来分析一下样本,原始的噪声数组Z我们有,也就是生成了假样本我们有,此时很关键的一点来了,我们要把这些假样本的标签都设置为1,也就是认为这些假样本在生成网络训练的时候是真样本。回过头想想判别器的目的,是为了能够生成迷惑判别器的样本。这样我们就能使得生成的假样本逐渐逼近为正样本。不过要注意的是在训练这个串接的网络的时候,一个很重要的操作就是不要判别网络的参数发生变化,也就是不让它参数发生更新,只是把误差一直传,传到生成网络那块后更新生成网络的参数。这样就完成了生成网络的训练了。因此我们的这个过程就是,得到初始的D0、G0后,先训练D0,然后固定D0 开始训练 G0, 训练的过程都可以使用 gradient descent,以此类推,训练 D1,G1,D2,G2,...

  了解了GAN的训练过程后,让我们来看看训练过程的数学表达。对于判别网络 D,它的目标是能够很好地分辨出真样本𝒙𝑟与假样本𝒙𝑓。以图片生成为例,它的目标是最小化图片的预测值和真实值之间的交叉熵损失函数:

图6  D交叉损失熵

对于生成网络G(𝒛),我们希望𝒙𝑓 = 𝐺(𝒛)能够很好地骗过判别网络 D,假样本𝒙𝑓在判别 网络的输出越接近真实的标签越好。也就是说,在训练生成网络时,希望判别网络的输出 𝐷(𝐺(𝒛))越逼近 1 越好,此时的交叉熵损失函数为:

图7  G的交叉损失熵

 然后我们得到到原始论文里的目标公式了

图8 训练目标公式

DAGAN

  前面提到,其实生成网络和判别网络都是一个巨大的神经网络。而我们知道深度学习中对图像处理应用最好的模型是CNN,那么如果在GAN框架中生成网络和判别网络均用CNN实现,将CNN与GAN结合会取得什么样的表现呢?  

历史上已经有人尝试将GAN和CNN相结合,但是失败了。DCGAN的作者在使用传统的监督学习CNN架构扩展GAN的过程中,也遇到了困难。在反复实验和尝试之后,作者提出了一系列的架构,可以让GAN+CNN更加稳定,可以deeper,并且产生更高分辨率的图像。核心的工作是对现有的CNN架构做了如下修改:

ü  取消所有pooling层。G网络中使用转置卷积(transposed convolutional layer)进行上采样,D网络中用加入stride的卷积代替pooling。主要就是使用了strided convolution 替代确定性的pooling 操作,从而可以让网络自己学习downsampling(下采样)。作者对generator和discriminator都采用了这种方法,让它们可以学习自己的空间下采样。

ü  在D和G中均使用batch normalization 。BN可以加速学习和收敛,其将每一层的输入变换到0均值和单位标准差(其实还需要shift 和 scale),这被证明是深度学习中非常重要的加速收敛和减缓过拟合的手段。它可以帮助由于初始化不当而导致的训练困难,可以让梯度flow to deeper layers。实践表明,这对于deep generator的有效learning是至关重要的,可以防止generator将所有的samples变成一个single point,这是GAN训练经常会碰到的问题。实践表明,如果直接将BN应用到all layers,会导致sample震荡和不稳定,所以我们只对生成器的输出层和鉴别器的输入层使用BN。

ü  去掉FC层,使网络变为全卷积网络。最近的做法比如使用global average pooling 去替代fully connected layer。global average pooling可以提高模型的稳定性,但是却降低了收敛速度。GAN的输入采用均匀分布初始化,可能会使用全连接层(矩阵相乘),然后得到的结果可以reshape成一个4 dimension的tensor,然后后面堆叠卷积层即可;对于鉴别器,最后的卷积层可以先flatten,然后送入一个sigmoid分类器。

ü  G网络中使用ReLU作为激活函数,最后一层使用tanh。D网络中使用LeakyReLU作为激活函数。对于generator,其输出层使用tanh 激活函数,其余层使用relu 激活函数。我们发现如果使用bounded activation可以加速模型的学习,覆盖训练样本的color space。对于discriminator,我们发现使用leaky Relu 更好一点,特别是对于生成高分辨率的图片。

图9  DCGAN 中的G结构

DCGAN  in TensorFlow

说了一通原理,下面让我们一起动手实现一个DCGAN项目吧。DCGAN原文作者是生成了卧室图片,在这里我们来做更有趣的事情,来生成动漫人物头像。

DCGAN在Tensorflow中已经有人造好了轮子,我们直接使用这个代码就可以了。配置好环境后(tensorflow0.12及以上,scipy,pillow等常用的包),在Github上下载好源码解压;创建data文件夹,并将下载好的人物头像数据文件夹faces放置进来。随后打开命令行窗口,进入项目所在文件夹,输入:

python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset anime --crop --train --epoch 300 --input_fname_pattern "*.jpg"


生成的结果如下:

图10 运行结果(10个epochs)

因为只运行了10个epochs,效果不是很好,让我们一起来看一起大神的运行结果。

图11 300个epoch结果

已经可以看到足以以假乱真的结果了。

总结

GAN作为一个具有 “无限” 生成能力的模型, GAN 的直接应用就是建模, 生成与真实数据分布一致的数据样本, 例如可以生成图像、视频等。GAN 也可以用于解决标注数据不足时的学习问题, 例如无监督学习、半监督学习等。GAN 还可以用于语音和语言处理, 例如生成对话、由文本生成图像等。下面让我们来看看GAN的有趣的工作吧。

图12字迹识别

图13  Q版人物头像生成

图14  “莫奈在春天醒来”、“马变斑马”、“四季更迭”

最后再顺便说一句,前不久火爆全网的南开大学校长论文造假案中,论文图片被爆使用PS(最简单的方式)产生,其实在这里用GAN生成或许就能以假乱真了。

图15

最后,大佬的照片镇楼,愿大家早日成为调参大师

图16  Ian Goodfellow

参考文献

原理讲解:https://www.imooc.com/article/28569

https://zhuanlan.zhihu.com/p/24767059

http://blog.sciencenet.cn/blog-2374-1072502.html

源码分析:https://blog.csdn.net/cs123951/article/details/72870510

数据集:https://pan.baidu.com/s/1eSifHcA 提取码:g5qa

Github项目源码下载:https://github.com/carpedm20/DCGAN-tensorflow


本期作者:杨阳

本期编辑:刘昊昂




【声明】内容源于网络
0
0
数据皮皮侠
社科数据综合服务中心,立志服务百千万社科学者
内容 2137
粉丝 0
数据皮皮侠 社科数据综合服务中心,立志服务百千万社科学者
总阅读615
粉丝0
内容2.1k