大数跨境
0
0

模型核心组件——判别器

模型核心组件——判别器 海博瑞数据
2026-01-05
2

击蓝字 / 关注我们

Beginning of Winter

DATGAN模型中的判别器是一个全连接神经网络,它的核心任务是区分输入的样本是来自真实数据集还是由生成器伪造的数据。为了提升训练的稳定性和生成数据的多样性,该判别器在传统的全连接网络基础上,集成了批次多样性(Batch Diversity)机制,并根据不同的损失函数策略(如WGGP或SGAN,在下一篇文章中具体介绍)自适应调整归一化层。

初始化与配置

在实例化判别器时,模型会根据传入的参数进行灵活配置,为后续的网络构建打下基础。

核心参数配置:

  • 网络规模:由层数和隐藏层大小控制网络的深度和宽度

  • 损失策略:通过loss_function参数决定输出层的激活方式及中间层的归一化方式(支持 WGGP或SGAN)

  • 正则化:可选的L2正则化,防止过拟合

  • 多样性参数:预设了Minibatch Discrimination,也就是批次多样性的参数,包括:n_kernels(特征图的核,取值为5)和kernel_dim(核的维度,取值为3)

网络构建

判别器的网络结构每一层都包含了一套复杂的“特征提取-增强-处理”组件。

单层内部结构

对于隐藏层中的每一层,包含以下组件:

  1. 主特征提取:一个标准的全连接层,用于提取输入数据的非线性特征

  2. 多样性特征提取:另一个全连接层,专门用于为批次多样性机制生成特征(维度为n_kernels * kernel_dim)

  3. 归一化处理:

   a. WGGP模式:使用LayerNormalization。这是为了适应 Wasserstein GAN with Gradient Penalty的要求,不使用Batch Norm以避免破坏样本间的独立性假设(除了 Minibatch Discrimination 部分)

   b. 其他模式:使用标准的BatchNormalization

  1. 正则化与激活:包括Dropout(0.5)层以增强鲁棒性,以及LeakyReLU激活函数以保持梯度流动

输出层设计:

  • 如果不使用SGAN(如WGAN),输出层是线性的(无激活函数)

  • 如果使用SGAN,输出层使用sigmoid激活函数,输出 [0, 1] 的概率值

核心机制:批次多样性

这是该判别器最显著的特征,旨在解决GAN训练中常见的模式崩溃(Mode Collapse)问题。通过让判别器观察这一批次数据的内在差异性,迫使生成器生成多样化的样本,而不仅仅是重复生成判别器认为真实的某一种样本。

算法流程:

  1. 张量重塑:将输入特征重塑为 (-1, n_kernels, kernel_dim) 的3D张量

  2. 差异计算:计算批次内所有样本两两之间的L1距离(绝对值差)

  3. 特征聚合:对距离取负指数 (exp(-abs_diffs))并求和。这意味着样本间越相似,该值越大;样本越分散,该值越小

  4. 特征拼接:在前向传播函数中,将计算出的群体多样性特征直接拼接到个体主特征之后,作为下一层的输入

前向传播

数据在网络中的流动过程如下:

  1. 输入接收:接收维度为 (N, n_features) 的张量。

  2. 层级迭代:遍历所有的隐藏层:

    a. 通过主全连接层计算个体特征

    b. 通过辅助全连接层并计算批次多样性特征

    c. 将个体特征与多样性特征在通道维度拼接

    d. 依次通过归一化,Dropout,LeakyReLU层

  1. 最终输出:通过输出层得到判别评分(概率或数值)

总结

DATGAN的判别器不仅仅是一个简单的二分类器,它是一个融合了现代GAN训练技巧(如批次多样性和Gradient Penalty兼容性)的复杂特征提取器。它通过同时关注个体真实性和群体多样性,为生成器提供了高质量的反馈信号。

下期预告

模型核心组件——损失函数


【声明】内容源于网络
0
0
海博瑞数据
海博瑞(北京)数据科技有限公司主要为国内外制药企业、科研院校提供全方位的研究设计、数据管理、统计分析和第三方稽查和质控服务。管理团队由行业内专家组成,为您提供专业、高效的全程解决方案。
内容 25
粉丝 0
海博瑞数据 海博瑞(北京)数据科技有限公司主要为国内外制药企业、科研院校提供全方位的研究设计、数据管理、统计分析和第三方稽查和质控服务。管理团队由行业内专家组成,为您提供专业、高效的全程解决方案。
总阅读13
粉丝0
内容25