大数跨境
0
0

机械学习-深度学习模型之“GAN” 干!

机械学习-深度学习模型之“GAN” 干! XR Engineering Technology
2022-05-21
1
导读:大家是否在抖音上见过一种场景,一张人物照片随着年龄增大的图像变化,预测某人老的时候长什么样子。 是否
    大家是否在抖音上见过一种场景,一张人物照片随着年龄增大的图像变化,预测某人老的时候长什么样子。 是否还见过一种场景,老照片修复与动态视频展示。这些实现离不开今天的主题:生成性对抗网络(深度学习算法中的一种重要模型)。
一、原理解析
    GAN是一种无监督深度神经网络。主要包含两种核心结构:生成器和对判别器。
    为了便于理解,给大家讲个故事:
    当下防疫抗疫是我们每个人必要坚守的事情!但是,每个人的想法不一样,突然有一天有一批人存在侥幸心理,想要在学校管控政策下“悄悄出校”。这份当然是违反纪律的事情!现在出现几种情况:2个人直接在摄像头下翻过栅栏,逃出去了;1个人伪造出校码,也出去了;。。。。。手段是各不相同啊。于是领导人员发现了有学生逃走开始侦察,结果一看摄像头2个学生翻墙,“来来来,都给我抓回来,给处分”,但是,伪造出校码的侥幸逃脱。后续还有学生想出去,发现直接翻墙这种方式太容易被抓,于是学生们就开始“卷”如何在学校不知情下,出去。当然,学校领导也在探索如何更高效的发现这种情况,于是,双向“卷”。我想办法,我有对策。我再有办法,我还有对策。。。。
    上述的违规学生就是生成器,学校管理制度就是判别器。
    大家现在是不是清楚GAN的原理了。给我掌声

二、实施过程
    Step1: 固定「判别器D」,训练「生成器G」
    Step2: 固定「生成器G」,训练「判别器D」
    Step3: 循环step1和step2,使双方能力都越来越强


三、数学原理
    GAN建模本质上需要定义两件事:架构损失函数。我们已经描述了Generative Adversarial Networks的架构。它包含两个网络:
----》架构
  • 生成网络G,其采用密度为pz的随机输入z,并返回输出xg = G(z),该输出应遵循(训练后)目标概率分布。

  • 一个判别网络D,它取一个可以是“真”的输入x(xt其密度用pt表示)或“生成”的一个(xg,其密度pg是由密度pz引起的密度通过G)并将x的概率D(x)返回为“真实”数据。

    在图上进行了标注:


    生成过程可以理解为“进化”过程,但是,必须要遵守原系统的概率分布特点。比如,根据图像来生成图像(图像看成像素数据,均匀分布),原图像是一张自拍照,生成的图像也必须是自拍照的性质,不能生成一个盘子。再直白一些:龙生龙,凤生凤,老鼠的儿子会打洞!!这不就明白了。
    这样的生成过程中重要的一点就是变换方法。生成N个不相关的均匀随机变量,我们可以使用变换方法。为此,我们需要将N维随机变量表示为应用于简单N维随机变量的非常复杂函数的结果!起初变换函数涉及累积分布函数实现逆变换。但是,很多应用对象在这个变换过程中无法公式表达,为此,必须从数据中学习它的变换方式,如何学习:神经网络!!!
    通过一个神经网络对变换函数进行建模,该神经网络将一个简单的N维均匀随机变量作为输入,并作为输出返回另一个N维随机变量,在训练之后,该随机变量应遵循正确的“目标概率分布” 。
    详细见:生成对抗网络(GAN) - 知乎 (zhihu.com)
  ----》损失函数如下:


---》整体算法如下:

四、Python实现代码

    基于Keras框架实现,现在Keras已经被Tensorflow合并,它是一种神经网络高级的API。反正就是高级,直接用就行。但是在pip install 的时候,需要核对版本信息,否则安装完,不能使用。还有很多问题,需要调试,根据终端反馈来回调试就行。为此,调试了一下Erik Linder创建的框架:
import numpy as npfrom matplotlib import pyplot as pltfrom keras.models import Sequential # Keras model modulefrom keras.layers import Dense, Dropout, Activation, Flattenfrom keras.layers import Inputfrom keras.layers import LeakyReLUfrom keras.datasets import mnist# 加载MNIST数据集用于训练(X_train, y_train), (X_test, y_test) = mnist.load_data()
class Adam: def __init__(self, lr=0.001, beta1=0.9, beta2=0.999): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.iter = 0 self.m = None self.v = None
def update(self, params, grads): if self.m is None: self.m, self.v = {}, {} for key, val in params.items(): self.m[key] = np.zeros_like(val) self.v[key] = np.zeros_like(val)
self.iter += 1 lr_t = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)
for key in params.keys(): self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key]) self.v[key] += (1 - self.beta2) * (grads[key ]**2 - self.v[key])
params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)# 定义GAN网络框架class GAN(): def __init__(self): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shap = (self.img_rows, self.img_cols, self.channels) # Adam优化器(SGD等,自己选一个) optimizer = Adam(0.0002, 0.5)
# 构建和编译判别器 self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        # 构建和编译生成器 self.generator = self.build_generator() self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)
# 生成器工作 z = Input(shape=(100,)) img = self.generator(z)
# 仅训练生成器 self.discriminator.trainable = False
#将生成的图像输入到对抗器中,并验证真实性 valid = self.discriminator(img) # 实现过程:input => generates images => determines validity self.combined = Model(z, valid) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self): noise_shape = (100,)
model = Sequential()
model.add(Dense(256, input_shape=noise_shape)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=noise_shape) img = model(noise)
return Model(noise, img)
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Flatten(input_shape=img_shape)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary()
img = Input(shape=img_shape) validity = model(img)
return model(img, validity)
def train(self, epochs, batch_size=128, save_interval=50):
# 加载数据集 (X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3)
half_batch = int(batch_size / 2)
for epoch in range(epochs):            #  训练判别器
            # 随机选择被输入的数据 idx = np.random.randint(0, X_train.shape[0], half_batch) imgs = X_train[idx]
            noise = np.random.normal(01, (half_batch, 100))            gen_imgs = self.generator.predict(noise)             d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1))) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            #   训练生成器
noise = np.random.normal(0, 1, (batch_size, 100))
            # 生成器开始“欺骗”            valid_y = np.array([1] * batch_size) g_loss = self.combined.train_on_batch(noise, valid_y)
# 输出过程 print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
# 保存生成 if epoch % save_interval == 0: self.save_imgs(epoch)
def save_imgs(self, epoch): r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, 100)) gen_imgs = self.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') axs[i, j].axis('off') cnt += 1 fig.savefig("gan/images/mnist_%d.png" % epoch) plt.close()
if __name__ == '__main__': gan = GAN() gan.train(epochs=30000, batch_size=32, save_interval=200)

运行结果如下:忽略None, 哇哈哈。

    这东西就是慢慢探索,GAN模型已经有上百种了,万变不离其宗,原理都是一样的,探索吧!

【声明】内容源于网络
0
0
XR Engineering Technology
专注于工程技术领域的XR理论和技术研发应用
内容 6
粉丝 0
XR Engineering Technology 专注于工程技术领域的XR理论和技术研发应用
总阅读9
粉丝0
内容6