极市导读
来看看S-DSB到底是个什么东西! >>加入极市CV技术交流群,走在计算机视觉的最前沿
邂逅
前段时间的某个早上,在“晨例”喝手冲(咖啡)的时候不小心看了场介绍 “简化薛定谔扩散桥(Simplified-Diffusion Schrödinger Bridge)” 的直播,是 paper(https//ar5iv.labs.arxiv.org/html/2403.14623) 原作的技术分享,感觉挺有意思的,而且近来这种基于扩散薛定谔桥(DSB)的框架貌似蠢蠢欲动想接替扩散模型(Score-based Generative Model)成为视觉生成界的新宠。于是在这种契机之下,CW 就没忍住把 S-DSB 的 paper 和 源码实现都看完了..
你们也知道 CW 的个性,看完了不无聊的 paper 通常都忍不住要吹水一番;而对于开源的 paper 也忍不住要亲手撸一遍,总忍不住把好玩的东西都内化到自己身上留下烙印。唉~坏毛病——又得费精力码字了。
简化扩散薛定谔桥(S-DSB)是针对扩散薛定谔桥(DSB)(https//arxiv.org/abs/2303.16852)收敛慢和训练难的问题进行改进而提出的,主要的改进措施包括:
-
简化了 DSB 的 loss 函数 -
将 SGM 适配到 DSB 框架 -
提出了两种重参数化方式
话又说来,先不论 S-DSB 是怎么回事,那所谓的什么 DSB 和薛定谔的猫到底是什么关系呀?它和 SGM 相比又是怎样的一种玩法呢?
嗯,好问题!DSB(扩散薛定谔桥) 中的“薛定谔”,确实是薛定谔的猫的主人,大名鼎鼎的量(浪)子物理学家——薛定谔(https//baike.baidu.com/item/%25e5%259f%2583%25e5%25b0%2594%25e6%25b8%25a9%25c2%25b7%25e8%2596%259b%25e5%25ae%259a%25e8%25b0%2594/2124805)。大家也知道,他是搞量子力学的,这么一来 DSB 就像是个“跨界er”—— 由量子力学跨界来搞 CV 生成。这..是要来一波降维打击和砸场子?
同时,在当下,基于 score 的扩散生成模型 —— SGM 可谓中流砥柱,但这在一定程度上也得益于大家都很给面子。
毕竟嘛.. 它的毛病显而易见:
-
先验被限制为简单的有解析形式的分布(如高斯分布)而不能是任意分布 -
只能单向生成(由先验向目标数据分布),而不能在任意两个分布之间建立双向映射
然而,DSB 在先天上就不存在这些毛病,这很降维打击!并且它的目标是计算任意两个可采样分布之间的最优传输(Optimal Transport,OT)(https//www.damtp.cam.ac.uk/research/cia/files/teaching/Optimal_Transport_Notes.pdf),有 OT 理论的加持。注意是“任意”分布!并没有将分布限制为某种形式或要求具有解析式。一旦完成训练,它就可以完成双向生成——将 A 分布的数据喂给模型,就能采样生成 B 分布的数据,反之亦然,并且 A, B 可以是任意分布,这天花板可高多了!
由此也可以发现,DSB 天生就是条件生成的优秀选手,比如在训练时给的两个分布是猫和狗的图片,那么训练完毕后它自然就可以根据一张猫的图片,生成一张“与这只猫很像的狗”的图片(反之亦然),不需要额外为“根据猫来生成狗”这个条件而进行设计。或者说,在 DSB 这就不存在 conditional generation,因为它天生就是 conditional 基因。
WOW!! 这简直就是不无聊的风格,很符合 CW 的口味!
不多说了,接下来就切换到稍微正经的模式来进入到本文的世界中吧~
【注】
本文会反复使用一些简称,主要有:
SB —— 薛定谔桥
DSB —— 扩散薛定谔桥
S-DSB —— 简化扩散薛定谔桥
SGM —— 泛指扩散模型
S-DSB 诞生的源动力
一个新事物的诞生通常源于当下时代背景的一些因素,但因素众多,不可能逐一分析,于是作者选择抓“主要矛盾”——精准定位到 SGM 和 DSB 这俩家伙身上,然后仔细挑遍了它俩的毛病,经过筛选后,最终列出的“重大症状”如下:
-
SGM
-
针对不同任务需要很费心思去设计加噪方案(noise schedule) -
先验(prior)被限制为一种预定义的简单分布(e.g. Gaussian)
-
DSB
-
模型收敛慢、训练难、拟合能力不足 -
对空间和计算资源的要求高,在训练过程中,网络学习的 target 需要经历两次独立的 forward()才能获得
另外,DSB 没有利用好 SGM 当前的发展成果,相当于有很好的资源却浪费了。
是病就得治,为了让打造健康和谐的生成模型生态,作者提出了一种 DSB 的改进版本,它就是本文的主角—— S-DSB(Simplified-DSB),这个新玩法对原来 DSB 的 loss 函数进行了简化,并且将 SGM 拿到 DSB 框架中去玩(bridge the gap between DSB & SGM),同时还可以利用预训练的 SGM 做初始化,从而加快收敛。这种方法不仅解决了以上 DSB 的问题,还进一步提升了原有 SGM 的性能。
此外,作者还提出了两种重参数化技巧(reparameterization trick)应用到 DSB 的输出空间,从而对应到不同的训练目标,进一步提升了模型的收敛速度,并且肉眼可见比 DSB 生成的效果要好。
这还不止,作者设计实验进一步证明,即使在没有使用预训练 SGM 的情况下,S-DSB 也能取得与 DSB 不相上下的性能。
其实作者大大你是不是想说:我 S-DSB 完爆你 DSB!
【配图】
前情回顾
在主角 S-DSB 正式登场之前,先来回顾下历史情节—— SGM 和 SB&DSB 的玩法,以更好地承接剧本。
SGM 的玩法
SGM 通过两个互逆的过程(dual process)将两个分布 “连接” 了起来:前向过程(forward process) 将数据分布 转换为先验分布 ,而逆向过程则相反。这两个过程都建模为马尔科夫链,在离散时间步的情况下,联合分布可表示为:
但是,直接根据上面的贝叶斯定理去计算逆向过程的 通常是 intractable 的,毕竟 都不知道。因此 将前向过程建模为不断施加高斯噪声的马尔科夫链,使得:
其中, 起到了 noise scale 的作用。然后,再通过贝叶斯定理,就能得到逆向过程的 也近似等同于高斯分布。不 show 明白你们肯定不服,好吧,具体推导如下:
然后对 在 处进行1阶泰勒展开:
将上式代入到前面的式子中,得到:
同理,这种玩法可以拓展至连续时间的情况,从而前向过程就对应到以下 SDE:
其中 代表布朗运动, 是漂移(drift)项, 是扩散(diffusion)项。这样,如 (i) 式的 transition kernel 实质上就是对上式使用欧拉-丸山(Euler-Maruyama)法进行数值离散的结果。
类似地,逆向过程则对应到以下 reverse-time SDE:
注意,上式中 也是正的,而在通常所见的 reverse-time SDE 中, 是一个负值的微元。
SB 和 DSB 的玩法
既然 DSB 与 SGM “同病相怜”,那么也同样先来回顾下它的玩法。然而考虑到 DSB 是在 SB 基础上发展而来的,所以干脆就先从 SB 开始下手叭~
-
SB
薛定谔桥(Schrödinger Bridge) 本是数学和物理学中的一个概念,起源于量子力学中的薛定谔方程,同时也多应用在最优运输问题中。它本质上是一种条件随机过程,通常指在给定起点和终点分布的条件下,对某种随机过程(e.g. 布朗运动)的变换。
恰好 SGM 玩的就是典型的起点和终点分布问题——前向过程是复杂的数据分布到简单的高斯分布,逆向过程则相反。于是,把 SB 引入 SGM 的游戏中就成了很顺手的事情。
薛定谔桥问题是在一定约束条件(i.e. 给定起点和终点)下的最优传输问题,其寻求在两点所有可能的轨迹中,使得 cost 最小的那一条。若将上一章介绍的 作为目标轨迹,将其记为 “参考分布" 则 的目标就是要找到一个分布 , 满足:
待找到这样的 之后,就可以使用祖先采样(ancestral sampling)由 开始,不断根据 来采样得到 ;同理,也可以反过来由 开始,不断根据 来采样得出 。
SB 是求解一个可逆的双向映射,相比起 SGM,它没有将其中一个分布预设为已知解析形式的简单分布(SGM 的做法通常是设置 )。也就是说,SB 的目标是实现任意两个分布之间的转换。
在大多数情况下,SB 的目标都不好求(且通常没有解析形式),于是为了降低求解难度,经常会用到一种叫作 "Iterative Proportional Fitting(IPF)"(https//www.sciencedirect.com/science/article/abs/pii/016771529390257J%3Fvia%253Dihub) 的招数,它本质上属于一种弱化了边界条件的迭代优化方法:
其中 代表迭代的轮数、 。可以看出,在每一轮优化中,仅有 或 这其中一个边界条件,而不像原始 做法那样需要同时满足两者。以这种方式,当每一轮都取得最优解并且迭代轮数足够大时,最终将得到 的目标解 。
然而,在每轮优化中,这种方法仍需要计算和优化联合概率分布 ,在面对高维数据时对计算和内存资源就会有很高的要求。
-
DSB
联想到扩散模型 经常 “玩弄” 和 ,也就是在相邻时间步之间的条件分布,于是 DSB 就汲取了这个灵感来近似 IPF 的做法, 它将原本在每轮中优化联合概率分布拆解为优化一系列相邻时间步之间的条件分布,从而将 在前向过程(forward process)中解构为一系列 ;在逆向过程(backward process)中则解构为一系列 ,然后在训练过程中交替地去优化 forward 和 backward 两条轨迹:
如上,在奇数轮(epoch) 优化的是 ,也就是 所对应的 backward 轨迹,得到 ;而在偶数轮 优化的则是 ,对应 的 forward 轨迹,得到 。每轮在优化时都最小化与上一轮所得结果的 KL 散度。
既然都向 SGM 汲取了灵感,不如将它的建模方法也借(白)鉴(嫖)过来好了。
于是,DSB 就将 和 都建模为高斯分布,这样两者都拥有解析形式,方便求解和计算。同时,分别使用两个模型去建模 forward 和 backward 轨迹。经过一系列花哨的数学推导 (细节请参考 DSB 的 paper),最终得到 loss 函数如下:
在实现时,如果当前处于第 个 epoch,且 为奇数,则优化的是 backward 轨迹,建模 backward 过程的模型就会被训练来去估计 ,也就是 backward 高斯分布 的均值;类似地,若 为偶数,则优化的是 forward 轨迹,建模 forward 过程的模型会被训练为去估计 ,对应 forward 高斯分布 的均值。
师出同门:各生成模型实质是扩散薛定谔桥的特例
在 S-DSB 正式亮相之前还有个彩蛋——作者揭出目前市面上流行的一些生成模型如 SGM, FM(Flow Matching)(https//arxiv.org/abs/2210.02747) 等都可看作扩散薛定谔桥的特例。这是因为,之前已经有大佬的 paper(https//arxiv.org/abs/1308.0215) 指出薛定谔桥的最优解实际上遵循以下形式的 SDE:
更为重要的是, 在 时刻的分布 就是:
通过与前面所介绍的 SGM 的 SDE 形式相对比,容易看出 SGM 实际上就是式(11)在 是常数(从而 ) 且 时的特例。根据 在 时刻的分布(即上式),这同时有:
将其代入到上图式(11)中的第二个式子则正好就是我们所熟悉的 SGM 的 reverse-time SDE(这里的 reverse-time SDE 中 为负,而前面章节所展示的是 为正的形式)!
进一步来说,只要设置不同形式的 和 ,就能够对应到不同生成模型的玩法。比如在 SDE 统一扩散模型的那篇 paper 中所提出的 Variance Preserving (VP) 和 Variance Exploding(VE) 的加噪方案(noise schedule),就是将漂移项 设成 ,其中 是非负实数。
主角登场:S-DSB
主角 S-DSB 登场了!请大家倾耳拭目,一起来看看 S-DSB 是怎么为它前辈 DSB 做手术治病的。
手术一:简化 loss 函数
S-DSB 的命名由来主要是源于它对 DSB 的 loss 函数进行了简化,因此拿了 "simplified" 的头衔。具体地,它将 DSB 的 loss 函数简化为以下形式:
在 和 的假设条件下,S-DSB 的这个简化版 loss 函数可近似等于原来 DSB 的 loss 函数,比如对于 backward model 的 target,容易证明当 时 , 即两个 loss 函数的 target 近似相等:
那么,这个简化版的 loss 有何优势或者说能够带来什么好处呢?当然是有的,不然作者也不敢写这篇 paper 来讲故事。
概括起来,S-DSB 的简化版 loss 主要有以下两个好处:
-
直观来看,这种 loss 形式就是要网络去学会预测 “下一个状态” —对于 backward model 来说,它要学会基于 去预测 ,其轨迹方向是 ,在当前是 的情况下,下一个状态就是 ;而对于 forward model 来说,它则要学会基于 去预测 。这种 loss 形式既直观又易于理解,并且与 SGM 实现了 “对齐” (SGM 所用的 loss 通常也是这种形式)。 -
更为重要的是,相比于原来 DSB 的 loss,这种简化版 loss 在计算网络学习的 target 时不必运行两次模型的 forward()过程,从而节省了计算和内存资源。比如对于 backward model 来说, DSB 的 loss 需要运行 ,这是2次模型的forward(),而在这里网络的 target 就直接是 。另外,模型运行forward()的次数也被成为 NFE-Number of Forward Evaluations.
手术二:让预训练 SGM 助攻一波
虽然 S-DSB 对 loss 函数进行了简化,但它的训练方式还是处于 IPF 框架之下的。也就是说,它需要交替训练两个 model,在训练其中一个时,其学习的 target 轨迹实际上由另一个 model(固定住权重) 给出,因此每一轮的训练效果实质上会依赖于上一轮的收敛性!作者在 paper 的 3.2 节也给出了一小段数学解释,感兴趣可以去瞄瞄,挺直观易懂的~
另外,通过对原有 DSB 的 loss 函数进行简化,S-DSB 的 loss 已经与 SGM 实现了对齐,这就代表,S-DSB 的训练目标与 SGM 其实是一致的。
炼过丹的都知道,当遇到训练比较难收敛时,如果有预训练的权重,那么想都不用想当然是利(白)用(嫖)起来啊!
目前我们的 S-DSB 正是面临这种处境,刚好市面上又大把已经训好的 SGM,而 S-DSB 的训练 target 又与 SGM 一致,这..八字也太合了叭,你要不加载预训练的 SGM 权重就简直是违反天时地利人和!
正是基于这种先天条件与后天境遇的结合,作者决定让 pre-trained SGM 来助攻一波,将参考分布 根据 pre-trained SGM 的 noise schedule 进行设置就可以使用它来初始化 backward model; 同理,forward model 也是可以用 pre-trained SGM 来进行初始化的。两个模型可以选择不同的预训练权重,也可以选择一样的,至于效果嘛..还得看具体场景,毕竟是炼丹=玄学
手术三:重参数化(Reparameterization)
除了简化 loss + 白嫖预训练 SGM 以外,S-DSB 还有招,那就是 SGM 的网红招数——重参数化。又白嫖..?
这招通常能够降低模型的训练难度,因为模型学习的 target 通常是随时间动态变化的,这又源于整个过程本质上服从 SDE。于是,让模型直接去预测这种动态变化的 target,通常比较困难,从而需要更长时间才能收敛。经过重参数化后,通常能够使模型在不同 timestep 都拥有一致的 target。
作者大大当然也晓得这种网红技巧带来的好处,而且如果能用上这么 in 的 trick,多少也能让 paper “漂亮”几分。
于是,其果断决定在 S-DSB 中引入重参数化技巧,而且还给重参数化后的 S-DSB 起了个新名字 Neparameterized DSB -DSB)”。OK,决心是下了,但实际该如何行动呢?
要使用哪种重参数化方式,得考虑决定因素有哪些。如果在 backward 或 forward 轨迹中,始终有一致性的状态(比如它们的终点 或 ) 决定着整条轨迹的变化,那么可以考虑让模型学会预测终点就好了,这样在每个 timestep,模型的 target 都是一致的。
顺着这种 “幻想",作者还真捣鼓出了不得了的东西:
(证明过程请参考论文附录 B.1~B.2,写得非常详细,而且没有过份跳跃,顺着看下来能够明明白白,赞一个!)
以上提出的这个命题说明:在学习 forward 轨迹时,模型需要预测的下一状态 其实是由当前状态 和 forward 轨迹的终点 所决定的; 类似地,在学习 backward 轨迹时,模型需要预测的 是由前一个状态 和 backward 轨迹的终点 所决定的。
嗯哼? 这说明什么,是不是有点 feel 了
-
Terminal Reparameterized DSB(TR-DSB)
不卖关子了(相信聪明的你们也想到了),上面那个命题其实已经指明了道路——让模型学会去预测终点,因为无论在哪条轨迹上(forward or backward),模型所要预测的下一个状态都是由当前状态和终点所共同决定,而当前状态是已经拿到手的,所以只要模型能够预测出终点值,那么就可以由当前状态和模型所估计的终点值来计算出下一个状态。不断迭代这个过程最终就能抵达轨迹的终点,从而完成整个采样过程,这本质上属于祖先采样的方式。
OK,既然重参数化后,模型预测的目标政变了,那么就得修改原来 S-DSB 的 loss 函数——将 backward model 的 target 改为 ;同时将 forward model 的 target 改为 :
这种重参数化的方式是基于终点(terminal)的重参数化,于是在这种模式下的 S-DSB 就被称作 "Terminal Reparameterized DSB(TR-DSB)"。
待训练完毕后,对于 backward model,我们就可以将模型的输出看作是 ,代入到命题 3 的 (19)式中的第一式从而完成 的采样过程; 同理,对于 forward model,则将其输出看作是 代入到命题 3 的(19)式中的第二式来完成 的采样生成。
-
Flow Reparameterized DSB(FR-DSB)
然而,基于命题 3 的重参数化方式也并非仅有预测终点这一种,仔细观察(19)式,不难发现其实还可以令 backward model 去预测 ; 同时 forward model 则去预测 ,对应到以下 loss 函数:
这种预测目标并非是强行瞎搞的,而是有直观意义的:对于 backward 轨迹 来说, 代表由当前状态指向终点的向量,而 则代表在 forward 的加噪过程中, 的噪声尺度总和,相当于变化了多少,我们可以把它当作 与 的 “距离” ;同理,对于 forward 轨迹来说, 代表由当前状态指向终点的向量, 则同样看作

