大数跨境
0
0

结合贝叶斯推断的去噪生成模型?详解BFN在连续型数据场景下的实现— Bayesian Flow Networks(二)

结合贝叶斯推断的去噪生成模型?详解BFN在连续型数据场景下的实现— Bayesian Flow Networks(二) 极市平台
2023-11-08
0
导读:本文也展示了对应的伪代码实现,包括:网络的输出与预测、离散时间和连续时间情况下的 loss 计算 以及 采样生成的过程。
↑ 点击蓝字 关注极市平台
作者丨CW不要無聊的風格
编辑丨极市平台

极市导读

 

本文在介绍数学上的具体实现之余,也展示了对应的伪代码实现,包括:网络的输出与预测、离散时间和连续时间情况下的 loss 计算以及采样生成的过程。最后,基于自己的理解谈了谈 BFN 与 diffusion models 的联系,揭示了它在作者的实现方式下,其实也走上了去噪生成的道路。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

前言

在本系列的上一篇文章中,CW 向大家介绍了 BFN 的背景、动机 及其 优势,并对其数学框架进行了解析,相信大家对 BFN 已经有了一定程度的认识。

BTNs是怎么玩转生成即压缩的?详解结合贝叶斯统计和深度学习的生成模型 — Bayesian Flow Networks(一)

这篇文章是本系列的第二篇,主要讲述 BFN 在连续(continuous)数据场景下的具体实现,内容结构与上一篇文章中“BFN 的数学框架”那一章相对应,各位客观可以自行搭配食用~

本文在介绍数学上的具体实现之余,也展示了对应的伪代码实现,包括:网络的输出与预测离散时间和连续时间情况下的 loss 计算 以及 采样生成的过程。最后,CW 基于自己的理解谈了谈 BFN 与 diffusion models 的联系,揭示了它在作者的实现方式下,其实也走上了去噪生成的道路。

同时,切莫忘记本系列的文章内容均以两位主人公 Alice & Bob 的消息传输过程为背景来展开叙述,这个过程实质上是一场“数据压缩游戏”,BFN 训练的底层框架思想也是以尽可能对数据进行压缩为目标的。

另外,有以下两点 CW 提一下:

  • 时间变量的取值范围是
  • 数据的取值范围被归一化至 , 是 维数据

OK,不吹水了,直奔主题吧~!

输入分布:看似无意,实则有心

输入分布建模为协方差矩阵是对角阵的多维正态分布(diagonal normal)

对于正态分布的参数集 , 其包含均值和方差, 但这里有些非主流, 将方差用精度 来表示, 它是方差的倒数, 于是:

其中, 均值 维向量, 而精度 是标量数值。

输入分布则定义为 :

其中, 阶单位阵。以上定义说明输入分布是各向同性(isotropic)的, 即各维分量的方差都一致。进一步, 作者将先验(prior)定义为标准多维正态分布(standard multivariate normal):

注意, 以上的 0 是 维向量, 而非一个数值。

另外,此处给出一个 WARNING: 在 BFN 的方法论里,并不会直接从这个分布去采样来作为最终的生成结果, 而是将该分布的参数作为 BFN 的输入,去辅助其输出包含上下文语义的向量,然后再构建出最终的生成结果(详情待后文揭晓,不准偷看哦~)。

由于输入分布的参数会根据观测样本进行后验更新,因此这些参数就隐含着样本信息,从而 BFN 通过将其作为输入就可以获取数据样本的经验性信息,于是可以习得数据的经验先验,并且生成接近原数据分布的结果。

可谓:看似无意,实则有心~

发送者分布:“噪”家伙

发送者分布就是对原数据加噪, 最常见的做法就是将其设为:以原数据为均值、精度的导数 为方差的正态分布:

由以上表达式可以看出,它和输入分布一样也是各向同性的。

贝叶斯更新函数:抱紧贝叶斯爸爸的大腿

在贝叶斯统计的世界中,若后验分布与先验分布同属一种分布,则两者为共轭分布。

对于这点,正态分布正好契合了一个优秀的性质:当先验与似然都是正态分布时,根据贝叶斯公式,所得的后验分布也将是正态分布,从而先验与后验即为共轭分布。并且,后验的均值与方差可近似根据先验和似然的均值与方差来计算,具体可表述为:

对于一维未知数据 的高斯先验 , 在观测到服从已知为方差 的高斯分布 的噪声样本后, 其贝叶斯后验也是高斯分布, 记为 , 其中 , 具体推导过程请见附录。

在 BFN 的玩法下, 先验 是根据观测样本 并通过贝叶斯推断来实现后验更新的, 在两者都是一维变量的情况下, 就会有形如以下的关系:

而实际上 是输入分布 的参数, 是从发送者分布 中采样的, 这两个分布都是各向同性的正态分布, 即各维分量均不相关, 因此各维度上的贝叶斯更新是独立进行而不相关的, 于是可以直接套用以上一维情况的关系, 只不过将上面三项从左往右依次替换为更新后的输入分布 、发送者分布、更新前的输入分布 , 即:

记更新前的输入分布参数为: , 更新后为: , 则贝叶斯更新函数表示为:

由上图可看出, 随着后验更新的不断进行, 输入分布会逐渐向数据真值靠拢。并且, 结合 式可以知道, 每次更新后, 都会位于 之间, 因为 。于是, 随着精度 的提高, 观测样本 会越来越接近真实数据, 从而 也会趋向于真实数据。

贝叶斯更新分布:对后验更新进行更为全面的考虑

贝叶斯更新函数根据一个观测样本来实现后验更新,对应计算出一个值,而 是多维随机变量,我们更希望得到它所服从的分布,于是就有了贝叶斯更新分布来满足该需求。它通过在贝叶斯函数上边缘化(marginalise over)观测样本来实现,这考虑了观测样本的所有可能性,是贝叶斯更新函数在发送者分布上的期望,就像是一个“拥有大局观”的贝叶斯更新函数。

做人也一样,面对事情时要尽可能全面地进行考虑,避免傲慢与偏见。CW 真没有在教你做人,逃~

我们现在要求出式 所服从的分布, 这时, 关于正态分布的一个有用的性质就派上用场了:

而观测样本所服从的发送者分布恰好是正态分布: , 同时因为这里的边缘分布是对 的概率密度进行积分, 所以将 看作定值。对式 , 令:

同时利用式 所示的性质, 即得:

因为在参数集 中只有 是随机变量(而 是标量), 于是贝叶斯更新分布就是关于 的更新分布:

关于贝叶斯更新分布的“精度可加性”证明

记第 步时对应的精度为 , 根据 式, 有:

记以上分布的均值 , 则通过高斯分布采样的重参数化, 可表示为:

记第 步的精度为 , 由于:

因此, 根据 式所示关于正态分布的性质, 就有:

此时再加上另一个关于正态分布的 "buff":

于是:

这就代表, 可通过以 为条件参数的贝叶斯更新分布来实现更新, 同时其中的精度参数具有可加性:

而以上这个分布实质上考虑了 的所有可能性, 因为在推导过程中 是随机变量, 涉及了它的分布, 考虑了它所有的可能性, 所以就相当于是 进行边缘化的边缘分布, 也就是在 上求期望:

加噪设置:以线性熵减为宗旨

这一章主要来推导 "accuracy schedule": 的表达式,使其能够根据时间变量的值来进行计算,从而控制观测样本所包含噪声的程度,这等价于控制着发送者分布和接收者分布的方差(两者的方差一致)。

作者的推导思路源于一个“宗旨”—— 输入分布的期望熵要线性地随时间减小, 这代表数据信息以恒定速率注入到输入分布中(通过其参数 )。因为拥有的数据信息越多,就越趋向于稳定状态,确定性越高,从而熵就越小,而数据信息以恒定速率注入,则熵就会对应地随时间线性减小。

作者在 paper 中“一声不吭”就甩出如下所示的输入分布的期望熵,它是在贝叶斯流分布上的期望:

不要懵逼(虽然我知道你第一眼看到以上结果肯定懵逼.. 只怪作者太无情),这实质上就是多维正态分布的联合熵:

其中, 是协方差矩阵的行列式。在我们这里, 输入分布的协方差矩阵是对角元相等的对角阵 (对角元是各维分量的方差), 其行列式为所有对角元的乘积。对于 维数据, 这个值就是在 时刻下方差的 次幂。

那么, 输入分布在 时刻的方差是多少呢?

在前面讲输入分布那章, CW 已告诉过大家, 在起始时刻: , 精度: , 于是, 对应的方差就是:

根据贝叶斯更新函数(见 式), 精度是累加式地更新。并且根据定义, 精度从起始累加至 时刻的结果就是:

所以, 时刻的精度就可表示为:

顺理成章, 对应的方差就是:

再将以上结果代入到前面多维正态分布联合熵的表达式中, 就可以得到作者那“毫无道理”的 的表达式了。而关于多维正态分布联合熵的推导, 不嫌弃的话可以瞄瞄 的这篇水文:

https://zhuanlan.zhihu.com/p/663232072

你或许会说:“慢着.. 的定义里不是还有个 吗? 前面只是求出了输入分布的熵而已, 还需进一步在贝叶斯流分布上求期望才对! "

没错! 很仔细哈 但是你可以看到, 的取值是无关的, 它仅由 决定, 所以输入分布的熵在贝叶斯流分布上的期望就是其本身:

, 接下来, 我们就以 的单调递减函数为出发点来推导出 的表达式。

在初值与终值确定的情况下, 可表示为:

其中 时刻的方差(将 代入前面 的表达式中便知)。进一步, 我们还可以推导出:

至于 , 作者说按照经验去设定, 它要能够使得在最终步时, Alice 向 Bob 传输的观测样本与原数据样本尽量接近(也就是噪声程度尽量低), 这样 BFN 重构原数据样本也会比较容易。但同时它也不能太小, 否则重构变得过于简单, 这场游戏的 loss 就耗在了不必要的数据传输成本上了。

总的来说, 以上推导过程就是预设一个最终步的噪声强度 , 然后以每一时刻下输入分布的期望熵 是时间变量的单调递减函数为出发点来最终推导出对应时刻下的 accuracy schedule( ) 和

在实际操作时, 只要预设好 , 那么后续在各时刻下只需代入对应时间变量的值, 即可求出

贝叶斯流分布:全局视角下的后验更新

贝叶斯流分布是贝叶斯更新分布的边缘分布,它考虑了由起始直至当前时刻,中间每一步更新的 的可能性,相当于有一个全局视角,从一开始就对中间每个环节的可能性进行考虑,计算方式是用当前时刻的贝叶斯更新分布在前面各步的贝叶斯更新分布上求期望:

根据 式, 即结合了精度可加性的贝叶斯更新分布表达式, 并且应用 式, 可得:

看到上式,不知道你们是否感觉均值和方差的常系数部分有些“臃肿”,不如干脆定义一个函数来表示它,使其看起来更“清爽”一些:

于是,贝叶斯流分布的最终形式为:

输出分布:这家伙比较懒,选择了单点分布

输出分布的关键在于怎样去拟合原数据分布,对于这个问题,我们可以巧用贝叶斯流分布。

根据采样的重参数化技术,有:

咦! 看到以上这现象, 是个聪明的高级动物都会选择偷懒——直接让 BFN 预测(生成)的结果 也遵循以上形式:

其中, 就是 BFN 的输出, 它对贝叶斯更新时所使用的噪声向量进行估计(看到此处, 了解 diffusion models 的朋友们是否闻到了熟悉的味道? )。

这里有一点需要特别注意下: 根据游戏规则, 当前输入分布使用的 是根据上一轮游戏的贝叶斯更新而来的, 也就是在上一轮游戏中, Bob 收到 Alice 消息后对自己的猜测策略进行纠正, 纠正后的结果被用于当前这轮游戏中继续猜测新收到的消息。

所以,在离散时间的情况下,就有:

即这个 是上一个离散区间的末尾时刻(也是当前区间的起始时刻)。

还没完~ 毕竟我要求的是分布,而不是一个估计值,那么 (viii)式到底服从什么分布呢?

回看 Bob 和 Alice 的游戏过程,其实前者只要求 BFN 能够输出一个对原数据的预估值,然后对它进行加噪,从而构建一个对 Alice 传输的消息的猜测结果即可。同时,BFN 的“升级途径”(训练)也是通过 Bob 和 Alice 的消息匹配与校正(接收者分布去拟合发送者分布,而非直接去拟合原数据分布)来完成的。所以,BFN 能够生成一个预测值也是能满足需求的。

那么,有没有一种分布形式,如单点分布那样,使得我们从其中采样后得到的结果,必定是我们期望的那个结果呢?如果有,那么就可以将输出分布设为这种分布,从而使其采样结果就是 BFN 的预测(生成)结果。

幸运的是,老天爷没有亏待我们,世界上还就有满足这个需求的分布,它就是 狄拉克分布 δ,于是,我们就可以将输出分布设为:

还有一个 bug 要解决: 若 , 则 , 这会导致 (viii) 式没有意义, 对应的实际情况就是游戏一开始还没有传输过消息, 因此也无法进行贝叶斯更新, 那么用 式去预测原数据就没有道理, 于是 BFN 在游戏刚开始时必须以某种先验作为预测结果。

联想到前面我们将输入分布的先验 设为 0 , 也就是假设原数据的先验均值为 0 , 那么就也有一定理由在最开始时将对数据的估计值设为 0 。

由于实际玩起来的时候, 不仅是离散时间的情况, 还可能是连续时间, 因此作者在实际操作时, 是这样设定的:

另外, 作者还对  ( )  式的结果裁剪至  [ 1 , 1 ]  范围内, 保持与原数据一致。

接收者分布:希望变得与发送者分布一样

经过前面的环节,我们已经知道了输出分布和发送者分布,那么接收者分布就“顺势可为”了。它在输出分布上采用与发送者分布同样的加噪方式来加噪,以尽力拟合发送者分布:

接收者分布 to 发送者分布:“我希望变得与你一样~”

离散时间的损失函数

在上一篇文章中,CW 已和大家一起推导过离散时间的 表达式为:

通过前面的内容,我们已经知道发送者和接收者分布均为各向同性的正态分布,且两者的协方差矩阵相等。而同样在上一篇文章的附录中,CW 已为大家推导出两个多维正态分布的 KL 散度的表达式,并且,当两者的协方差矩阵一致时, KL 散度为:

特别地,这两个正态分布的协方差矩阵是对角元均等于精度倒数 的对角阵,于是就有:

为发送者分布的均值 为接收者分布的均值 ,从而可得 的表达式为:

又因:

将上式再代入到前面 的表达式中,最终得到:

连续时间的损失函数

在上一篇文章中,CW 已向大家介绍过作者对于连续时间下发送者和接收者分布的 KL 散度提出了一种泛化形式:

其中 是映射到某向量空间的 embedding 函数, 是在同一向量空间下的某分布, 代表两个概率分布的卷积操作, 代表正态分布, 为常数。

并且, 在上一篇文章中也推导出了 表达式为:

在这里的连续数据场景下, 为恒等映射, 为输出分布: , 于是:

根据上一篇文章中“连续时间的损失函数”那一章的结论,并且由于 各维分量相互独立,因此:

最终:

由此可看出, 实质上与离散时间的 的形式是类似的, 两者的差异主要是由精度 的计算方式不同而引起的, 后者的精度是离散的 , 而前者的精度是连续的

重构损失

对于真实的连续型数据,如果想要完全重构,那么就需要无限的精度,因为真正的连续数据取任何一个具体值的概率都为0,这会导致重构损失变得没有意义。

我们不妨假设数据已经经过一定程度的离散化(discretised), 毕竞它们已被置于数字计算机上; 又或者, 我们可以认为数据本身就含有一定程度的噪声, 且呈现出的数据分布是: 以原数据为均值、某固定值 为方差的各向同性的正态分布 。于是, 重构损失就可以定义为该正态分布与最后时刻含同样噪声强度的输出分布 之间的 KL 散度。考虑到 的取值有概率性, 于是将其进一步在贝叶斯流分布上求期望:

以上结果是两个各向同性且协方差相等的多维正态分布之间的 KL 散度,具体推导过程请见本系列上一篇文章的附录。

需要注意的是, 这个重构损失并未用于训练。因为前面推导出的 已经承担了这个重构损失的角色(请看它们的表达式)。并且, 由于预设的 通常会比较小, 经历多次在 的优化后, 到 时重构损失通常会非常小, 因此这项重构损失对于网络训练的帮助不大。

伪代码

这一章将展示 BFN 在连续数据场景下运作的伪代码,结合前面的理论部分,以便大家能看到实操的可能性。

网络输出与预测

以上这段伪代码就是 BFN 接收输入分布的参数: , 然后估计出对应时刻的噪声, 最终通过 (viii) 式转换为预测(生成)结果的过程。

注意, 就是 , 完全由时间变量决定,根据式 (vii) 计算。

离散时间的训练过程

整个训练过程主要包括:

  • 采样离散的时间步(对应变量 )
  • 根据时间步计算出时间变量 的值
  • 贝叶斯更新(对应于以上 的采样)
  • 生成预测样本
  • 计算

在离散时间的情况下, 有一个关键点是时间变量 的计算, 它是上一个时间步结束的末尾时刻, 因为 当前所用的 是上一次贝叶斯更新的值。

连续时间的训练过程

连续时间的情况与上类似,只不过时间变量的取值在这里就方便多了——可以直接从连续的均匀分布中采样。

另外,无论是在连续时间还是离散时间的情况下,因为有贝叶斯流分布的存在,所以直接从中采样即可完成后验更新,而无需先构造带噪的观测样本,然后再根据贝叶斯公式来进行计算。

采样生成

采样过程以标准正态分布为先验而开始。

你或许有些困惑:为何这里不能够像前面展示的训练过程那样,直接用贝叶斯流分布来进行后验更新呢?

要知道,在训练时,每个数据样本是确定(不变)的;而在推理时,BFN 在每一步生成的数据都是变化的,因此贝叶斯更新的过程“不足以形成一个流”,必须根据每次变化的样本来实施后验更新。 或者,你回顾下贝叶斯流分布的推导过程也能够更清晰地理解这一现象。

论:BFN 实质上是去噪生成模型?

相信大家也看出来了,BFN 也是通过加噪和去噪来完成训练和生成的。其中,发送者分布、接收者分布就是加噪过程的产物,而输出分布则是去噪过程的产物。

你可能会喷:搞了半天,还是 diffusion models 那一套?还是个“真扩散模型”(微笑脸),呵呵!

且慢,虽然 BFN 和 diffusion models 的 “战术” 一致——都是 加噪&去噪,但两者在实现上还是有区别:BFN 是把加噪过程“藏”在了贝叶斯更新里, 在前面的伪代码中, 从贝叶斯流分布中采样这一步就隐含着加噪过程,这个过程作为实现贝叶斯更新的媒介。 另外,它也无需像 diffusion models 那样定义一个严格的马尔可夫扩散过程,在概念上的约束似乎没那么强。

不过,虽然作者说 BFN 在概念上更简单——不像 diffusion models 那样需要定义前向加噪和对应的逆向去噪过程,但 CW 却没有被他骗到——无论是训练还是采样生成,都需要先后经历 贝叶斯更新 和 BFN 的输出预测,两者就分别对应了加噪和去噪的过程, 实质上是一样的。并且,对于如今“大力出奇迹”的暴力深度学习时代,或许大家对“贝叶斯统计”这种东西还更为陌生,而人家 diffusion models 的扩散过程起码还有个对应的热力学现象来做形象的解释。

为何说 BFN 和 diffusion models 是“战术”上一致,而非“战略”上一致呢?

因为 加噪&去噪 只是 BFN 执行层面上的一种方案,其战略思想其实是通过观测样本来对先验进行后验更新,并且再利用神经网络(在高维空间中)强大的(上下文)学习能力来做进一步的校正。 只不过在这里,作者使用了加噪这种方式来构造观测样本,从而再通过估计对应的噪声来作为神经网络的学习途径,这样更简单,也更容易(而且还有执行同样战术的 diffusion models 作为借鉴,比较稳)。它不像 diffusion models 在诞生之初就被约束在非平衡热力学的扩散模型框架之下。

总的来说,BFN 在作者的实现下,你说它是 “隐式加噪的去噪生成模型”,没毛病。但 CW 觉得 BFN 的思想不仅局限于去噪生成,还可以延伸出许多好玩的方案。

附:正态分布的贝叶斯后验的均值与方差推导

在这部分, 与大家来一起推导在“贝叶斯更新函数"那一章中提到的正态分布的贝叶斯后验的均值与方差的计算公式, 这也是式 的由来。

记先验为: 、似然函数为: 、后验为: , 我们的目标是推导出以下结果:

首先, 根据 贝叶斯定理, 有以下关系:

对右边两项代入各自正态分布的概率密度公式,得:

可见,忽略掉一些常系数的影响,后验也是正态分布,且存在 (*)式的对应关系,于是目标得证。

令:

(*)式即为文中 (i)式的结果。

Next

下一篇,CW 将为大家分享 BFN 在离散化(discretised)数据场景下的具体实现,其中许多部分都会“继承”本文所述的方法,所以还请您不要走远,欢迎再次光临哦~!

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货


【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读3.2k
粉丝0
内容8.2k