极市导读
本文通过一个叫“reflow”的方法,实现梦想中的“一步生成”:只需一步计算就直接产生高质量的结果,而不需要调用计算量大的数值求解器来迭代式地模拟整个扩散过程。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
Diffusion Generative Models(扩散式生成模型)已经在各种生成式建模任务中大放异彩,但是,其复杂的数学推导却常常让大家望而却步,缓慢的生成速度也极大地阻碍了研究的快速迭代和高效部署。研究过DDPM的同学可能见到过这种画风的变分法(Variational Inference)推导(截取自What are Diffusion Modelshttps://lilianweng.github.io/posts/2021-07-11-diffusion-models/):
总体上推导的难度和对数学的要求还是比较高的。在连续时间的形式下,还需要随机微分方程(Stochastic Differential Equation (SDE))的知识,有不低的入门门槛。除此以外,扩散式生成模型的一个众所周知的老大难问题就是_生成速度慢_:生成一张图需要模拟一整个基于复杂的深度模型的扩散过程。缓慢的生成速度是阻碍这些模型更广泛的普及的一个主要瓶颈。
Rectified Flow,一个"简简单单走直线“生成模型,是我们对这些挑战的一个回答:极度简单,一步生成。我们的方法有以下要点:
(1)我们无需一般扩散模型复杂的推导,代之以一个简单的 “沿直线生成” 的思想。算法理解上不需要变分法或随机微分方程等基础知识。我们的方法是基于一个简单的常微分方程(ODE),通过构造一个“尽量走直线”的连续运动系统来产生想要的数据分布。
(2)“尽量走直线”的目的是让我们模型实现快速生成。通过一个叫“reflow”的方法,我们可以实现梦想中的“一步生成”:只需一步计算就直接产生高质量的结果,而不需要调用计算量大的数值求解器来迭代式地模拟整个扩散过程。
(3)通常的扩散模型是把高斯白噪声转换成想要的数据(比如图片)。我们的方法可以把任何一种数据或噪声(比如猫脸照片)转换成另外一种数据(比如人脸照片)。所以我们的方法不仅可以做生成模型,还可以应用于很多更广泛的迁移学习 (比如domain transfer)任务上。
有兴趣的同学可以参见我们的论文 (Arxiv 或 OpenReview,以及和最优传输(optimal transport)相关的深入理论Arxiv)。代码,示例Colab Notebook和预训练模型已经开源在github。一个英文版简介在这里。欢迎大家使用和交流!
arxiv:https://arxiv.org/abs/2209.03003
OpenReview:https://openreview.net/forum?id=XVjTT1nw5z
深入理论Arxiv:https://arxiv.org/abs/2209.14577
github:https://github.com/gnobitab/RectifiedFlow
英文版简介:https://www.cs.utexas.edu/~lqiang/rectflow/html/intro.html
问题-传输映射(将一个分布搬运到另一个分布)
我们先定义好要解决的问题。无论是从噪声生成图片(generative modeling),还是将人脸转化为猫脸 (domain transfer),都可以这样概括成将一个分布转化成另一个分布的问题:
给定从两个分布 和 中的采样, 我们希望找到一个传输映射 使得, 当 时,
比如, 在生成模型里, 是高斯噪声分布, 是数据的分布(比如图片), 我们想找到一个方法, 把噪声 映射成一个服从 的数据 。在数据迁移 (domain transfer) 里, 分别是人脸和猫脸的图片。所以这个问题是生成模型和数据迁移的统一表述。
在我们的框架下, 映射 是通过以下连续运动系统, 也就是一个常微分方程(ordinary differential equation (ODE)), 或者叫流模型(flow), 来隐式定义的:

我们可以想象从 里采样出来的 是一个粒子。它从 时刻开始连续运动, 在 时刻以 为速度。直到 时刻得到 。我们希望 服从分布 。这里我们假设 是一个神经网络。我们的任务是从数据里学习出 来达到 的目的。
走直线,走得快
除了希望 ,我们还希望这个连续运动系统能够在计算机里快速地模拟出来。注意到, 在实际计算过程中, 上面的连续系统通常是用Euler法(或其变种)在离散化的时间上近似:
这里 是一个步长参数。我们需要适当的选择 来平衡速度和精度: 需要足够小来保证近似的精 度, 但同时小的 意味着我们从 到 要跑很多步, 速度就慢。
那么问题来了, 什么样的系统能最快地用Euler法来模拟呢? 也就是说, 什么样的体系能允许我们在用较大的步长 的同时还能得到很好的精度呢?
答案是“走直线"。如下图所示, 如果粒子的运动轨迹是弯曲的, 我们需要很细的离散化来得到很好的结果。如果粒子的轨迹是直线, 那么即使我们取最大的步长 , 只用一步走到 时刻,还是能得到正确的结果! 所以, 我们希望我们学习出来的速度模型 既能保证 , 又能给出尽量直的轨迹。怎么同时实现这两个目的在数学上是一个非常不简单(non-trivial)的问题, 涉及最优传输 (optimal transport) 的一些深刻理论。但是我们发现其实可以用一个非常简单的方法来解决这个问题。
Rectified Flow-基于直线ODE学习生成模型
假设我们有从两个分布中的采样 (比如 是从 里出来的随机噪声, 是一个随机的数据(服从 )。我们把 和 用一个线性插值连接起来, 得到
这里 和 是随机, 或者说, 以任意方式配对的。你也许觉得 和 应该用一种有意义的方式配对好, 这样能够得到更好的效果。我们先忽略这个问题, 待会回来解决它。
现在, 如果我们拿 对时间 求导, 我们其实已经可以得到一个能够将数据从 传输到 的"ODE"了,
但是, 这个"ODE"并不实用而且很奇怪, 所以要打个引号: 它不是一个“因果"(causal), 或者“可前向模拟"(forward simulatable)的系统, 因为要计算 在 时刻的速度 需要提前(在 时)知道ODE轨迹的终点 。如果我们都已经知道 了, 那其实也就没有必要模拟ODE了。
那么我们能不能学习 , 使得我们想要的“可前向模拟”的ODE 能尽可能逼近刚才这个"不可前向模拟"的过程呢? 最简单的方法就是优化 来最小化这两个系统的速度函数(分别是 和 之间的平方误差:

这是一个标准的优化任务。我们可以将 设置成一个神经网络, 并用随机梯度下降或者Adam来优化,进而得到我们的可模拟ODE模型。
这就是我们的基本方法。数学上, 我们可以证明这样学出来的 确实可以保证生成想要的分布 。对数学感兴趣的同学可以看一看论文里的理论推导。下面我们只用这个图来给一些直观的解释。

图(a): 在我们用直线连接 和 时, 有些线会在中间的地方相交, 这是导致 非因果的原因(在交叉点, 既可以沿蓝线走, 也可以沿绿线走, 因此粒子不知该向岔路的哪边走)。
图(b) :我们学习出的ODE因为必须是因果的,所以不能出现道路相交的情况,它会在原来相交的地方把道路交换成不交叉的形式。这样,我们学习出来的ODE仍然保留了原来的基本路径,但是做了一个重组来避免相交的情况。这样的结果是,图(a)和图(b)里的系统在每个时刻 ttt的边际分布是一样的,即使总体的路径不一样。
我们的方法起名为Rectified Flow。这里rectified是“拉直”,“规整”的意思。我们这个框架其实也可以用来推导和解释其他的扩散模型(如DDPM)。我们论文里有详细说明,这里就不赘述了。我们现在的算法版本应该是在已知的算法空间里最简单的选项了。我们提供了Colab Notebook(https://colab.research.google.com/drive/1CyUP5xbA3pjH55HDWOA8vRgk2EEyEl_P?usp=sharing)来帮助大家通过实践来理解这个过程。
Reflow-拉直轨迹,一步生成
因为Rectified Flow要在直线轨迹的交叉点做路径重组,所以上面的ODE模型(或者说flow)的轨迹仍然可能是弯曲的 (如上面的图(b)),不能达到一步生成。我们提出一个“_Reflow”方法_,将ODE的轨迹进一步变直。
具体的做法非常简单: 假设我们从 里采样出一批 。然后, 从 出发, 我们模拟上面学出的flow(叫它1-Rectified Flow), 得到 。我们用这样得到的 对来学一个新的"2-Rectified Flow":

这里, 2-Rectified Flow和1-Rectified Flow在训练过程中唯一的区别就是数据配对不同: 在1Rectified Flow中, 与 是随机或者任意配对的; 在2-Rectified Flow中, 与 是通过1Rectified Flow配对的。上面的动图中, 图(c) 展示了Reflow的效果。因为从1-Rectified Flow里出来的 已经有很好的配对, 他们的直线插值交叉数减少, 所以2-Rectified Flow的轨迹也就(比起1-Rectified Flow)变得很直了(虽然仔细看还不完美)。理论上, 我们可以重复Reflow多次, 从而得到3-Rectified Flow, 4-Rectified Flow... 我们可以证明这个过程其实是在单调地减小最优传输理论中的传输代价(transport cost), 而且最终收玫到完全直的状态。当然, 实际中, 因为每次 优化得不完美, 多次Reflow会积累误差, 所以我们不建议做太多次的Reflow。幸运的是, 在我们的实验中, 我们发现对生成图片和很多我们感兴趣的问题而言, 像上面的图(c)一样, 1次Reflow已经可以得到非常直的轨迹了, 配合蒸馏足够达到一步生成的效果了。
Reflow与Distillation
给定一个配对 , 要想实现一步生成, 也就是 , 我们好像也可以通过优化下面的平方误差来直接"蒸馏(distillation)"出一个一步模型:

这个目标函数和上面的Reflow的目标函数很像, 只是把所有的时间 都设成 了。
尽管如此, Distillation和Reflow是有本质的区别的。Distillation试图一五一十地复现 配对的关系。但是, 如果 的配对是随机的, Distillation最多只能得到 在给定 时的条件平均, 也就是 , 并不能成功地完全匹配 。即使 有确定的一一对应关系, 他们的配对关系也可能很复杂, 导致直接蒸馏很困难。
Reflow解决了Distillation的这些困难。它的意义在于 :
-
给定任何 配对, 就算是随机的配对, 他都能学出一个给出正确边际分布(marginal distribution)的flow。Reflow不会去试图完全复现 的配对关系, 而只注重于得到正确的边际分布。 -
从Reflow出的ODE里采样, 我们还可以得到一个更好的配对 , 从而给出更好的 flow。重复这个过程可以最终得到保证一步生成的直线ODE。
形象地来讲, 如果 太复杂, Reflow会“拒绝”完全复现 , 转而给出一个新的, 更简单的, 但仍然满足 的配对 。所以, Distillation更像“模仿者”, 只会机械地模仿, 就算问题无解也要“硬做”。Reflow更像"创造者", 懂得变通, 发现新方法来解决问题。
当然,Reflow和Distillation也可以组合使用:先用Reflow得到比较好的配对,最后再用已经很好的配对进行Distillation 。我们在论文里发现,这个结合的策略确实有用。
下面, 我们进一步基于具体例子解释一下Reflow对配对的提高效果。如果一个配对 是好的, 那么从这个配对里随机产生的两条直线 就不会相交。在我们的论文里,这种直线不相交的配对我们叫做"Straight Coupling"。我们的Reflow过程就是在不停地降低这个相交概率的过程。下图我们展示随着Reflow的不断进行,配对的直线交叉数确实逐渐降低。在图中,对每种配对方法,我们随机选择两个配对,分别用直线段连接它们,然后若它们相交,就用红色点标出这两条直线段的交点。对于这种交叉的配对,Reflow就有可能改善它们。我们重复10000次并统计交叉的概率。我们发现:(1)每次Reflow都降低了交叉的概率和L2传输代价(2)即使2-Rectified Flow在肉眼观察时已经很直,但它的交叉概率仍不为0,更多的Reflow次数就可能进一步使它变直并降低传输代价。相比之下,单纯的蒸馏是不能改善配对的,这是Reflow与蒸馏的本质区别。
理论保证
Rectified Flow不仅简洁,而且在理论上也有很好的性质。我们在此给出一些理论保证的非正式表述,如果大家对理论部分感兴趣,欢迎大家阅读我们文章的细节。
-
边际分布不变: 当 取得最优值时, 对任意时间 , 我们有 和 的分布相等。因为 , 因此 确实可以将 转移到 。 -
降低传输损失: 每次Reflow都可以降低两个分布之间的传输代价。特别的, Reflow并不优化一个特定的损失函数, 而是同时优化所有的凸损失函数。 -
拉直ODE轨迹: 通过不停重复Reflow, ODE轨迹的直线性(Straightness)以 的速率下降, 这里, 是reflow的次数。
实验结果-Rectified Flow能做到什么?
-
使用Runge Kutta-45求解器,1-Rectified Flow在CIFAR10上得到IS=9.6, FID=2.58,recall=0.57,基本与之前的VP SDE/sub-VP SDE[2]相同,但是平均只需要127步进行模拟。 -
Reflow可以使ODE 轨迹变直,因此2-Rectified Flow和3-Rectified Flow在仅用一步(N=1) 时也可以有效的生成图片(FID=12.21/8.15)。 -
Reflow可以降低传输损失,因此在进行蒸馏时会得到更好的表现。用2-Rectified Flow+蒸馏,我们在仅用一步生成时得到了FID=4.85,远超之前最好的仅基于蒸馏/基于GAN loss的快速扩散式生成模型(当用一步采样时FID=8.91) 。同时,比起GAN,Rectified Flow+蒸馏有更好的多样性(recall>0.5)。
我们的方法也可以用于高清图片生成或无监督图像转换。
同期相关工作
有意思的是,今年ICLR在openreview上出现了好几篇投稿论文提出了类似的想法。
(1) Flow Matching for Generative Modeling
(2) Building Normalizing Flows with Stochastic Interpolants
(3) Iterative -alpha (de)Blending: Learning a Deterministic Mapping Between Arbitrary Densities
(4) Action Matching: A Variational Method for Learning Stochastic Dynamics from Samples
这些工作都或多或少地提出了用拟合插值过程来构建生成式ODE模型的方法。除此之外,我们的工作还阐明了这个路径相交重组的直观解释和最优传输的内在联系,提出了Reflow算法,实现了一步生成,建立了比较完善的理论基础。大家不约而同地在一个地方发力,说明这个方法的出现是有很大的必然性的。因为它的简单形式和很好的效果,相信以后有很大的潜力。
如有任何问题,欢迎留言或者发邮件!
主要论文:
X. Liu, C. Gong, Q. Liu. Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. ICLR2023, arXiv:2209.03003
Q. Liu. Rectified flow: A marginal preserving approach to optimal transport. arXiv preprint arXiv:2209.14577, 2022.
参考文献:
[1] Song Y, Sohl-Dickstein J, Kingma D P, et al. Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations.
[2] Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 2020, 33: 6840-6851.
[3] Song J, Meng C, Ermon S. Denoising Diffusion Implicit Models. International Conference on Learning Representations.
[4] Lu C, Zhou Y, Bao F, et al. DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. Advances in Neural Information Processing Systems.
[5] Bansal A, Borgnia E, Chu H M, et al. Cold diffusion: Inverting arbitrary image transforms without noise. arXiv preprint arXiv:2208.09392, 2022.
[6] Liu X, Wu L, Ye M. Learning Diffusion Bridges on Constrained Domains//International Conference on Learning Representations.
[7] Liu Q. Rectified flow: A marginal preserving approach to optimal transport. arXiv preprint arXiv:2209.14577, 2022.
公众号后台回复“极市直播”获取100+期极市技术直播回放+PPT


