大数跨境
0
0

详解贝叶斯流网络在离散化数据场景下的实现—Bayesian Flow Networks(三)

详解贝叶斯流网络在离散化数据场景下的实现—Bayesian Flow Networks(三) 极市平台
2023-11-18
0
↑ 点击蓝字 关注极市平台
作者丨CW不要無聊的風格
编辑丨极市平台

极市导读

 

本文接着来与大家探探 BFN 在面对离散化数据时是怎么玩的,详解离散化数据概念以及实现离散化过程。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

前文回顾:
BTNs是怎么玩转生成即压缩的?详解结合贝叶斯统计和深度学习的生成模型 — Bayesian Flow Networks(一)
结合贝叶斯推断的去噪生成模型?详解BFN在连续型数据场景下的实现— Bayesian Flow Networks(二)

前言

上一篇文章解析了 BFN 在连续(continuous)数据场景下的具体实现后,本文紧接着来与大家探探 BFN 在面对离散化数据时是怎么玩的。离散化数据是比较有个性的,虽然它是个名词,但其实蕴含着一个过程,具有动词的属性。具体如何,接下来会在正文中详细说到。

本文会首先介绍离散化数据的概念,并进一步列出其与最易混淆的离散数据之间的区别;接着,CW 会介绍作者是如何对数据实现离散化的,同时一并介绍其在离散化数据场景下的一些框架设置;然后,就开始步入正题—— BFN 在离散化数据场景下的数学实现;最后谈了谈离散化可能产生的利弊BFN 在此场景下实现的伪代码。另外,和前面的文章一样,关于文中所引用到的数学结论的证明,CW 都放在附录中了,感兴趣的可以扫一扫~

什么是离散化(discretised)数据

何谓离散化数据?这名字可能会有误导性,让你下意识以为它是离散型(discrete)数据。然而,离散化数据的意思是将连续型(continuous)数据转换为离散型数据的过程或结果,其中隐含着一个转换过程。通常可以将连续数值划分为多个等长的数值区间,然后将位于同一区间内的连续数值都使用同一个离散数值(或类别)来表示。

离散化数据 vs 离散数据

为了让大家更清楚地了解离散化数据,CW 觉得有必要在此列举一下它与离散化数据的区别:

  • 离散数据是数据的一种类型,而离散化数据更多地代表着数据的一种处理方式
  • 离散数据它本身就长那样,不是由某某转换而来;而离散化数据的“根”是连续数据、是可还原的;
  • 离散数据的数值只能是整数,而离散化数据的数值可以是任意的

如何实现离散化

在该系列的前几篇文章中 已经说过, 作者在实现时会将数据都归一化至 区间内, 于是, 我们这里做离散化就是要将此区间分为多个等长的小区间。

假设分为 个区间, 那么每个区间的长度就是:

是第 个区间的左端点、中点和右端点, 并让 表示从 1 到 的整数集, 则对于 , 就有:

你问: 为何 是这么算的?

CW 答:让你上中学的弟弟妹妹教你。

开个玩笑(别打我)~

第1个区间的中点是 , 而因为相邻区间的中点之间的间距等于区间长度, 所以各区间中点就构成了公差为 的等差数列。根据等差数列的通项公式, 第 个区间的中点就是:

(你看, 是不是真的可以考虑让你们上中学的弟弟妹妹教你。你也不用怕丢脸, 待他/她答出来后你就说这是你故意考他/她的)

对于 维数据 , 记 维向量 表示数据各维分量经过离散化后所处的区间索引、 分别表示数据经过离散化后所处区间的左端点、中点和右端点 (同样是 维向量), 各维度上的计算方法见 式。

作者说, 对于原本是连续的 维数据, 在经过离散化后, 就以 来表示它。比如, 对于 8bit 的 RGB 图像, 它是3维数据, 若某像素值在 这个维度(channel)上的区间索引为 100 , 那么在该维度上对应的数值就是:

其它维度上的计算方法也遵循同样原理。

有一点需要注意, 由于是以 来表示原来连续的数值, 因此经过离散化后数据值就会位于 内, 而非原来的

在离散化场景下的一些设置

由于离散化数据就是在原本连续数值的基础上划分区间,因此 BFN 在离散化场景下的实现大多沿用了在连续数据场景下的实现,主要的不同在于对输出分布进行了离散化,而输入分布、发送者分布仍保持为连续的。

作者说之所以这样选择,一方面是因为数学上更易实现(贝叶斯更新面对离散化的分布会很复杂),另一方面是他觉得 BFN 以连续的输入分布参数作为输入会更好地“消化与吸收”,相比于以离散的概率值作为输入会更易学到数据的规律。

通过本系列的上一篇文章, 我们知道数据本身所带的噪声 (我们假设连续数据本身自带有一定程度的噪声)会为 ( 时输入分布的标准差)的设置提供参考(见上一篇文章“重构损失”那一章), 从而影响精度 )的设置。

而在这里的离散化场景中,区间的宽(长)度也起着类似的作用,它 的设置提供了参考,因为区间的宽度本身就代表着经过离散化后数值的精度(区间宽度约小则越精确,越接近于原来连续的情况)。

作者还举了个例子:对于 8-bit 的数据, 经过离散化后被分为 256 个宽度为 的区间,那么就将 设为 1e-3 一大约是区间宽度的 。他认为这样设置, 在最后时刻的精度就足够了, 从而足以令 BFN 准确地识别出数据属于哪个区间。

输出分布

如果把 BFN 在离散化数据场景下的实现看作是一首歌,那么其中的副歌部分(通常是一首歌的高潮部分)就是输出分布,因为其余部分的实现都直接“白嫖”自连续型数据场景。

在上一篇文章中,你们已经了解到:在连续型数据场景下,输出分布被建模为单点分布。而在本文的离散化数据场景中,数据的表示被分割为有限个数值区间,所有的数值可能性都被“囊括”在这些区间里,每个连续值都唯一地属于一个区间,属于同一个区间的连续值都会被无差别对待,相当于对连续值进行了分类,每个区间就代表一个类别。所以输出分布理应为每个区间都赋予一定概率,于是就建模为 分类分布(categorical distribution)

OK,了解大体思想后,现在来讲下具体做法。

在连续型数据场景中, BFN 的输出是其估计的高斯噪声向量; 而在这里, 作者将这个输出变为两个 维向量: , 前者代表估计的高斯噪声分布的均值, 后者是标准差的对数

之所以要对标准差取对数, 是因为如果直接将网络输出规定为标准差的话, 那么就要求是非负值,于是就要对网络的输出做截断, 或者利用一些函数将网络原本输出的负值映射为非负值等, 这样实现起来就“不太友好”,毕竟“强扭的瓜,不甜"。

回顾上一篇文章的式 (viii) :

就对应到以上所估计的高斯噪声 的分布的均值和标准差。

于是, 根据高斯分布的性质, 也服从高斯分布, 且其均值 和标准差 分别是:

注意, 这里是将   视作定值, 因为它已经是经过贝叶斯更新后的结果(是后验更新值)。至于 < min 的情况, 请回顾上一篇文章中"输出分布"那一章。

以上结果是网络对于原来连续(离散化之前)数据分布的估计,它依然是连续的高斯分布,而我们的目标是能够输出各区间概率的离散型分布。那么兄弟们,该如何完成 KPI 呢?

可以这么考虑:每个区间都有左右两个端点,对于上面估计出来的连续型分布,将数据小于等于右端点的概率减去小于等于左端点的概率,就得到了位于左右端点之间的概率,也就是位于对应区间内的概率。

而在分布中,小于等于某点的概率可以用 累积分布函数(CDF) 来表示,于是输出分布赋予某个区间的概率就是:网络估计出来的(连续型)分布的累积分布函数在对应区间右端点的函数值减去在左端点的函数值。

同时,考虑到原数据被归一化至 内,于是要对 CDF 进行“截断”: 将其在小于等于 -1 和 大于等于 1 时的函数值分别设为 0 和 1 。记数据的每一维分量的 CDF 为 , 则:

其中 就是原 CDF。

于是, 以第 个区间为例, 在输出分布中, 数据的第 维分量位于该区间的概率就是:

从而, 对于已知的 维数据 , 输出分布就是综合所有维度分量的联合概率:

注意, 在以上的表示中, 的离散化结果是明确的, 其中每一维分量 被确定地分配到对应区间 。实际上, 输出分布在每一维都会输出 个区间对应的概率, 对应于 维向量。于是, 在推理阶段, 若要生成 维数据, 则输出分布会返回 维向量, 然后再根据某种策略 (比如选概率最高的)看究竟要选取 个区间中的哪个, 最终得到 维的结果。

另外,正态分布的 CDF 没有解析式,通常要借助于 误差函数(erf) 来计算,并且一维标准正态分布的 CDF 与误差函数之间存在以下关系:

(具体推导过程请见附录)

于是, 在这里我们可以先把 这个一般形式的正态分布通过“减均值、除标准差”这个经典招式转换为标准正态分布, 然后再利用 erf 进行计算, 即:

以下展示的是将网络估计出来的连续型分布离散化至各区间概率的示意图。可以看到,由于前面提到的对 CDF 的截断操作,原本连续型分布的长尾部分都被集中在首尾区间,因此首尾区间的概率会有不服从于原来连续曲线的现象。

接收者分布

接收者分布中的数据由采样自输出分布的数据加噪而来,加噪方式如同发送者分布,于是接收者分布就是发送者分布在输出分布上的期望(有疑惑的请复习本系列前面两篇文章)。

输出分布在每个区间都有概率,而每个区间都以其中点 来表示。于是,结合接收者分布的定义,就有:

由此可知,接收者分布依然是连续型分布,而且它在每一维实质上是 混合高斯分布(GMM)。

重构损失

与上一篇文章的连续数据不同,在这里,数据是离散化的,重构它不会存在“无限精度”的问题(具体可回顾上一篇文章中“重构损失”那章)。于是,重构损失就如本系列首篇文章中所说的,是:最后时刻的负对数似然在贝叶斯流分布上的期望

同样,与前面几篇文章中所说的一样,重构损失是不用于训练的

离散时间的损失函数

在本系列前面的文章已展示过,离散时间的损失函数 为:

在本文的离散化数据场景下,发送者分布和接收者分布之间的 KL 散度是:

然而,以上并没有 解析形式(closed form),于是我们转向为采用 蒙特卡洛采样 来近似估计

也就是说,在计算发送者分布和接收者分布之间的 KL 散度时,先从发送者分布中采样,然后再利用采样结果去计算两者的对数概率密度差(根据定义,发送者分布和接收者分布的 KL 散度就是它们的对数概率密度差在发送者分布上的期望)。

连续时间的损失函数

在本系列的首篇文章中已说过,在连续时间的情况下,作者提出发送者分布和接收者分布之间的 KL 散度拥有以下泛化形式:

与上一篇文章的做法一样, 此处也是令 为恒等映射和 , 于是:

不同的是, 在各维度上都是离散化区间中点 在输出分布上的期望, 同时各维度相互独立, 于是:

根据前面文章中对应章节的结论(见本系列首篇文章中“连续时间的损失函数”那章),并且由于以下分布中各维分量相互独立,因此:

以上第2个等式是为了方便展现出 是各向同性的高斯分布。进一步,结合我们一起在前面文章中推导出来的 的表达式, 就有:

再代入 (详见上一篇文章式 (vii) 的推导), 最终得到:

“福利”与“麻烦”

关于 BFN 在离散化数据场景的主要实现 CW 已经基本讲完了,现在我们不妨来思考下离散化有什么好处或不便,想一想离散化操作有没有为网络的训练和性能带来福利或麻烦。

上图的红色曲线代表发送者分布、绿色曲线代表网络估计的连续分布、绿色柱状条代表其经过离散化后的输出分布、蓝色曲线是对应的接收者分布,而橙色虚曲线是在原来网络估计出的连续分布上加噪而形成的接收者分布。

注意,以上图示的数据是已经过离散化的。从中可以看出,发送者分布和在离散化的输出分布基础上形成的接收者分布之间的 KL 散度,比起连续的接收者分布之间的会更小。 这是因为,离散化会将一段范围内的连续值都“集中”至一点,相对没有那么“发散”,在网络训练得当的情况下,会使得接收者分布更“靠近”发送者分布,因为在这里,数据已经过离散化。

由此可以大致推测,与连续数据场景下相比,在离散化的情况下,接收者分布通常更易拟合发送者分布,因为离散化的数据更为“集中”,而连续数据实际上是不可能“定位”至具体某点的。

另一方面,回顾前两章所展示的损失函数形式,我们知道,无论是在连续时间还是离散时间的情况下,在计算 loss 时,每一维都有 的计算复杂度。 并且,在计算输出分布时也是如此。

随着 KKK 的增加,计算复杂度会随之变高,但同时数据的表示精度又会更加精确(区间的长度越小,越接近于连续数据),从这个角度看,又有点博弈的味道~

伪代码

前面都是在吹水,是时候给点真材实料了!这样才能让大家看到实操的可能性。虽然所谓的真材实料在这里其实也不是实打实的,毕竟是“伪”代码,但起码也是能给到你们一些实际感,相对不会那么空洞。

离散化的累积分布函数

CDF 没啥好说的,对照前文解释即可秒懂,就酱~

网络输出

如同前面交待的那样,输出分布在每一维都有 的计算复杂度,最终返回 维向量。

离散时间的训练过程

此处要注意和上一篇文章里连续数据场景的区别,这里需要显式地从发送者分布中采样噪声样本 ,这是由离散化数据场景的发送者分布和接收者分布之间的 KL 散度没有解析形式而引起的,于是需要在发送者分布上进行蒙特卡洛采样来近似估计这个 KL 散度(详见上文中的解释)。

连续时间的训练过程

在连续时间的情况下,这个 loss 是有解析形式的,于是就不需要在发送者分布上采用蒙特卡洛采样来近似计算了(即无需显式地从发送者分布中采样噪声样本 )。

采样生成

注意从离散化输出分布中采样出 这一步, 它是 维向量, 每一维是数据各维分量对应的离散化区间的索引, 通常根据以下方式进行采样:

首先根据输出分布计算出从第一个区间至最后一个的累积概率,然後使用随机数生成算法生成一個 0~1 的随机数,看这个数落在哪个累积概率区间,最终就以对应区间的索引作为采样结果。

另外,如上一篇文章中所述:在推理过程中,由于每一步网络生成的样本都是变化的,因此需要显式地根据当前的生成结果来构造观测样本,从而实现贝叶斯(后验)更新(而不能像训练时那样直接通过贝叶斯流分布实现)。

附:标准正态分布的累积分布函数与误差函数之间的关系推导

在这部分,CW 就和大家一起齐心协力(配合下嘛~)地推导下如何通过误差函数 erf 来计算标准正态分布的 CDF。

首先,明确下误差函数的数学定义:

最后一步是因为被积函数 是偶函数,所以可将积分区间缩半同时加倍积分值。

然后,再展示下标准正态分布的 概率密度函数(pdf),虽然我知道你们都会,但为了故事的连贯性(顺带凑凑字数),请允许我啰嗦一下下:

记 CDF 为 ,现在,我们就直奔终点:

即得到正文中的形式。也就是说, 如果想要知道标准正态分布的累积分布函数在某点 的函数值, 那么就要将 代入误差函数 erf 中, 并利用上式进行计算, 最终就能得到相应结果。

或许你们会好奇到:误差函数 erf 的值又该如何计算呢?毕竟它也没有解析形式。

OK,那么 CW 就顺带将这个问题也展开讲一下(又能愉快地吹水了,耶~!)。

误差函数 erf 的计算方法有多种,常见的有采用 数值方法 来近似计算。在这里,CW 介绍下另一种——使用 泰勒展开式 来近似计算。

的值可以通过在 处的泰勒展开来近似计算:

其中 代表误差函数在 处的 阶导。

先来手动计算下一些低阶导数,看看有没有什么共性:

如果照这么玩下去,你会发现:

于是,我们只需要关注奇数阶导数即可,因为偶数阶导数均为0。现在来

进一步看看泰勒展开式中含奇数阶导数的某一项:

其中第2步是将分子部分的奇数连乘项与分母中阶乘的部分约掉, 第3步是对分子、分母同时除 ( 个 2 相乘), 于是分母中的偶数项连乘项均变为原来的一半, 从而整体就成为 的阶乘。

另外, 可以验证, 当 时, 上式也正好是泰勒展开式中含有一阶导那一项的结果, 因此以上结论对所有奇数阶导数项均满足。

现在, 将上述结果代入到泰勒展开式中, 最终得到:

至于上式具体要展开到多少阶,就取决于你的心情了~

Next

下一篇文章 CW 将会为大家解析 BFN 在离散(discrete)数据场景下的具体实现。注意,是天生的离散数据,而非本文的离散化数据(从连续数据转换而来),期待您再次光临哦!

公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货


【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读7.6k
粉丝0
内容8.2k