大数跨境
0
0

预训练无条件扩散生成模型的 training-free 条件生成食谱(一)

预训练无条件扩散生成模型的 training-free 条件生成食谱(一) 极市平台
2024-08-22
0
↑ 点击蓝字 关注极市平台
作者丨CW不要無聊的風格
编辑丨极市平台

极市导读

 

如何使用预训练的无条件扩散生成模型来进行条件生成的指南 >>加入极市CV技术交流群,走在计算机视觉的最前沿

前言

对于扩散模型来说,可控(条件)生成和采样加速是两大热门方向,其中 CW 认为前者的可玩性和观赏性会更胜一筹,特别是研究如何利用已经训练好的无条件扩散模型来生成自己想要的图像(而不是随机生成),这点个人觉得非常好玩。于是,在调研了一些广为流传(maybe 好用)的方法之后,CW 决定把它们制成一份食谱,以便日后做(炼)菜(丹)时能够从中借鉴和调配。

既然要为日后重新调配做打算,那么就会涉及代码实现,但食谱中仅收录每种方法里关键操作的代码。另外,本食谱对于每种方法提供偏向直觉性的解释,这是考虑到厨师做菜通常不会纠结分子和原子的化学变化。对于严谨的推导和证明过程,还请各位食客自行参考论文。

注意,这份食谱所收录的方法都是 training-free 的,也就是不会对模型进行二次训练,也可称作 "sampling-based",因为它们仅对模型的采样过程进行改造和调整。并且所用的扩散模型是无条件的,也就是没有不会有条件项作为模型输入。只要手头上有预训练的无条件扩散模型,在给出条件信息(比如参考图像、文本 prompt 等)的情况下,就能够实现可控生成。由于这种方法并没有将条件信息作为输入来进行训练,因此通常也被叫作是 "zero-shot"无监督的。

题外话:本文的标题有点 "Engnese" 的味道,因为其实它来自于我脑子里的 "The training-free recipe for conditional generation with pretrained unconditional diffusion models",但纠结一番后还是想接地气一些,于是成了如今这个样子。

keywords: conditional generation, training-free, zero-shot, unsupervised, diffusion models, inverse problems, image inpainting, image restoration, image editing

图生图

这一章介绍的方法都是在给定一张参考图像的条件下进行采样生成。 在这种情况下,通常是希望模型能够生成与参考图像相似的图片,或者对参考图像进行修复(填充)和编辑等。

SDEdit

SDEdit: Guided Image Synthesis and Editing with Stochastic Differential Equations 可谓是出身名门,不信你可以去看看这篇 paper 的作者们,每一位都是不得了的大佬!

但这篇论文的思路却出奇简单,或许这就是所谓的大道至简吧~

SDEdit 巧妙地利用了扩散模型采样过程所对应的 reverse SDE 的性质——采样起点可以从中间任意时刻开始

the reverse SDE can be solved not only from , but also from any intermediate time

具体地, 先对参考图像(论文中称作 "guide")进行加噪, 噪声强度根据扩散模型的 noise schedule 选择某个时刻 所对应的噪声强度; 然后从该时刻开始根据 reverse SDE 对加噪后的参考图像进行迭代去噪。当采样至最后时刻 , 若感觉效果不满意, 还可以将生成的结果重新加噪回 时刻, 然后再次实施同样的迭代去噪过程, 这样的玩法可以反复多次直至得到满意的效果为止。

对于一些图像编辑和修复的场景,可能我们只想要改变参考图像的某部分,让这部分就用模型采样生成的结果来代替,而另一些部分则希望保持原封不动。对于这种需求,可以在前面的基础上额外利用 mask 来指示参考图像的哪些位置是需要保持不变的,而哪些是希望用模型的生成结果来代替的(比如在 mask 中值为1的部分代表将要用模型的生成结果来覆盖)。

  • SDEdit 做法的背后隐藏着什么

CW 觉得上面这幅图很好地阐释了 SDEdit 的原理:真实的图像(扩散模型学到的数据分布)和输入的参考图(上图是油画)本来隶属于不同的数据流形,经过噪声扰动后,两个流形产生了交集(扰动使得流形扩大了范围)。然后从交点开始经由扩散模型(reverse SDE)迭代去噪,从而使数据点不断向真实图像所在的流形区域移动(reverse SDE 决定了这个方向),最终将油画 project 为真实图像。

由于加噪主要破坏的是高频信息,因此在去噪后能保留住参考图的大体结构和语义(低频信息),同时又去掉了纹理等细节,从而生成的图像既能与参考图有一定相似性又显得真实(服从图像数据分布)。

  • 真实度与相似性之间的 trade-off

一个需要重点考虑的问题是生成图像的真实度(realistic)与参考图像的相似性(论文中称为 "faithful")之间的 trade-off。由上面那幅图也可以直观感受到,如果将数据点被过份地移动到真实图像的流形区域(从而远离参考图像的流形),那么虽然生成的结果显得很真实,但与参考图的相似性就难免会降低;相反,若数据点太靠近参考图所在的流形(于是离真实图像所在的流形较远),则生成的结果虽然与参考图的相似度很高,但整幅图像看起来就显得不那么真实了。

而影响这个问题的 keypoint 就是采样起始时刻 的选择。极端地考虑, 对应采样过程结束, 于是生成的结果就是参考图本身(完全处于参考图流形中); 相反, 代表经历整个 reverse SDE 过程, 由于这是一个随机过程, 因此最终生成的结果就"不那么受控"、随机性比较大, 可能会严重偏离参考图像(被过份地移动至真实图像的流形, 于是远离参考图像所在的流形)。

于是, 越靠近1,生成图像的真实度就会越高,同时由于采样方差更大,因此多样性也越高;而 越靠近0,生成结果与参考图的相似度就越高,同时因为采样方差较小,所以多样性就会变低。

在作者的实验中, 是比较合适的选择, 而实际在面对不同任务时, 建议通过实验去手动调整看效果如何来确定 的选择。

  • 局限性

SDEdit 的局限性首先在于其适用场景有限,若参考图像与扩散模型学到的数据域相差较远,可能难以生成既与参考图像相似又看起来真实的图片(考虑参考图像流形与真实图像流形相距甚远而得不到交集的情况)。

其次,为了兼顾生成图像的真实度与参考图像的相似度,需要手动调整采样起始时刻

还有,为了获得满意的生成效果,可能要重复多次“加噪-去噪”的采样过程。

  • 核心代码

以下是 SDEdit 在面临带 mask 的图像编辑场景时的核心代码实现,CW 在此仅截取关键部分,完整源码请自行参考官方repo(https://github.com/ermongroup/SDEdit/blob/main/runners/image_editing.py%23L119).

# `sample_step` 代表重复 “加噪-去噪” 的采样次数
for it in range(self.args.sample_step):
    # 采样起始时刻
    total_noise_levels = self.args.t
    a = (1 - self.betas).cumprod(dim=0)

    e = torch.randn_like(x0)
    # 将参考图像加噪至采样起始时刻对应的噪声强度
    x = x0 * a[total_noise_levels - 1].sqrt() + e * (1.0 - a[total_noise_levels - 1]).sqrt()
    
    ...  # 省略

    with tqdm(total=total_noise_levels, desc="Iteration {}".format(it)) as progress_bar:
        # 迭代去噪采样
        for i in reversed(range(total_noise_levels)):
            t = (torch.ones(n) * i).to(self.device)
            # 扩散模型的去噪采样: p(x_{t-1} | x_t),这里使用的是 VP-SDE
            x_ = image_editing_denoising_step_flexible_mask(
                x, t=t, model=model,
                logvar=self.logvar,
                betas=self.betas
            )
            
            # 将参考图加噪至当前时刻对应的噪声强度
            x = x0 * a[i].sqrt() + e * (1.0 - a[i]).sqrt()
            # 把扩散模型的采样结果放到参考图像的对应位置
            x[:, (mask != 1.)] = x_[:, (mask != 1.)]

            ...  # 省略

    # 将整个采样过程获得的结果放到参考图像对应位置
    x0[:, (mask != 1.)] = x[:, (mask != 1.)]

扩散模型 VP-SDE 的去噪采样:

def image_editing_denoising_step_flexible_mask(x, t, *,
                                               model,
                                               logvar,
                                               betas)
:

    """
    Sample from p(x_{t-1} | x_t)
    """


    alphas = 1.0 - betas
    alphas_cumprod = alphas.cumprod(dim=0)

    # 此处模型的输出是估计的噪声
    model_output = model(x, t)
    # 估计的噪声乘以该项等于 \beta(t) 乘以 score
    weighted_score = betas / torch.sqrt(1 - alphas_cumprod)
    # 参考 Algorithm 5 
    mean = extract(1 / torch.sqrt(alphas), t, x.shape) * (x - extract(weighted_score, t, x.shape) * model_output)

    logvar = extract(logvar, t, x.shape)
    noise = torch.randn_like(x)
    mask = 1 - (t == 0).float()
    mask = mask.reshape((x.shape[0],) + (1,) * (len(x.shape) - 1))
    # 若是采样的最后一步,则无需加上随机噪声,这是根据 Tweedie's formula 计算采样均值去噪的结果
    # 为避免最终生成的图像模糊
    sample = mean + mask * torch.exp(0.5 * logvar) * noise
    sample = sample.float()

    return sample

ILVR

ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models 的核心思想与 SDEdit 在一定程度上有相同之处——引入参考图像的低频信息(图像的大体结构和语义)

ILVR 使用了线性低通滤波器,在采样过程中将每步采样图像的低通部分替换为参考图像的低通部分,对参考图像进行低通滤波前会先对其按当前时间步的噪声强度进行加噪( ) 。另外,也可以不必在采样的全程都使用替换操作,即可以设置采样到某个时间步之后就不再使用参考图像,而是进行无条件采样生成,这样做会促进生成结果的多样性,但同时也会降低与参考图像的相似性(后续统称条件一致性),这是一种 trade-off.

由上展示的算法流程可知,所谓的 "latent variable" 就是指图像经过低通滤波后剩下的低频信息,ILVR 的做法可看作是“带约束的采样过程”—— 要求每步采样结果的低频信息都要与参考图像的低频信息进行对齐

ILVR 有一个需要重点考虑的因素就是低通滤波器的下采样因子(对应上面的 scale N)。下采样因子的值越小(下采样倍率越小),所保留的高频细节就越多,从而生成的图像就会与参考图像更加相似,但多样性和真实度就会相对降低。

另外,作者的实验结果表明,ILVR 对于低通滤波器的选择还是比较鲁棒(robust)的。

  • 局限性与麻烦点

ILVR 的最大限制之一是要求低通滤波器必须是线性的,因而无法使用神经网络来提取参考图像中更为丰富的语义。其次就是为了在多样性和条件一致性之间做 trade-off,需要手工调整下采样因子和采样过程中参考图像注入的时间范围(conditioning range)。

  • 核心代码

以下代码展示了 ILVR 的关键操作,截取自官方 repo(https://github.com/jychoi118/ilvr_adm/blob/main/guided_diffusion/gaussian_diffusion.py%23L550).

    def p_sample_loop_progressive(...):
        ...  # 省略
    
        # 下采样和上采样操作
        if resizers is not None:
            down, up = resizers

        # T -> 0 迭代采样
        for i in indices:
            t = th.tensor([i] * shape[0], device=device)
            with th.no_grad():
                # p(x_{t-1} | x_t)
                out = self.p_sample(
                    model,
                    img,
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    cond_fn=cond_fn,
                    model_kwargs=model_kwargs,
                )

                #### ILVR ####
                if resizers is not None:
                    # 若当前时间步超过 `range_t` 则不使用参考图像
                    # 即回归到正常的无条件生成
                    if i > range_t:
                        # 由于要适配采样结果的 shape,因此下采样后要再上采样回来
                        out["sample"] = out["sample"] - up(down(out["sample"])) + \
                            up(
                                down(
                                    self.q_sample(
                                        model_kwargs["ref_img"], t, 
                                        th.randn(*shape, device=device)
                                    )
                                )
                            )

                yield out

                ...  # 省略

以上用到的 resizers 包含了下采样(即 ILVR 的低通滤波器)和上采样操作,具体实现可参考这里(https://github.com/jychoi118/ilvr_adm/blob/main/resizer.py%23L8),源头是来自这个库(https://github.com/assafshocher/ResizeRight),其专门针对图像 resize 操作的诸多问题(issue)进行解决,并无缝支持 Numpy & Pytorch(因而完全可微分).

RePaint

RePaint: Inpainting using Denoising Diffusion Probabilistic Models 主要是针对图像修复(image inpainting)任务而提出的,它的做法其实与 SDEdit 的 editing with mask(如前面的 Algorithm 3)相似,输入除了参考图像以外还有一个 mask 用于指示待修复的部分,这部分将用模型采样生成的结果来填充,其余部分则维持与参考图像一致。

(ps: 以上对于 的计算有误, 右边应该对应 , 但结合后文所展示的代码实现后, 或许能够知道为何上面会是 )

由于 mask 内由模型生成的部分与 mask 外参考图像的部分是没有相互进行考虑的,相当于将模型的采样结果“硬塞”到参考图像中被 mask 掉的区域,因此这样难免会造成整幅图像语义上存在不和谐的现象。虽然每次在生成下一步采样结果时,模型都试图对这两部分进行协调,但随着采样过程的进行,采样方差逐渐减小,于是模型对这种不和谐现象进行改造的力度则愈发变小,最终很可能就无法生成一幅语义和谐的图像。

对此, 作者引入 "resampling" 技术, 即对每步采样后的结果重新加噪 (e.g. ), 回到噪声强度更高的阶段, 然后再实施去噪采样, 也可以指定对采样结果加噪多步(e.g. )。因为加噪后采样方差变大了, 所以留给模型纠正这种不和谐现象的空间也就更广了。

以上展示的算法流程是每步采样后仅加噪1步的示例,这个加噪步数在论文中被称为 "jump length"; 代表每个时间步要实施 resampling 多少次, 对应论文中的超参

在实际进行采样时,可以预先计算出时间步的变化情况,作者称其为 "time schedule"。这样,每步迭代时,先与下一个时间步进行比较,若比当前时间步大,则代表要实施加噪;否则,就代表要进行去噪采样。

RePaint 的主要局限显而易见 —— resampling 可能导致生成满意的结果需要较长时间,并且也无法保证模型一定就能纠正语义不和谐的问题。另外,适用场景也十分有限,仅仅是 image inpainting with mask 这种。

  • 核心代码

以下展示了 RePaint 的核心代码,截取自官方 repo(https://github.com/andreas128/RePaint/blob/main/guided_diffusion/gaussian_diffusion.py%23L510),涉及 time schedule 的计算、结合 mask 的去噪采样 以及 resampling.

            # 预先计算 time schedule
            times = get_schedule_jump(**conf.schedule_jump_params)
            # 把相邻时间步配对
            time_pairs = list(zip(times[:-1], times[1:]))
            
            ...  # 省略

            for t_last, t_cur in time_pairs:
                ...  # 省略

                t_last_t = th.tensor([t_last] * shape[0],  # pylint: disable=not-callable
                                     device=device)

                # 去噪采样
                if t_cur < t_last:  # reverse
                    with th.no_grad():
                        image_before_step = image_after_step.clone()
                        # p(x_{t-1} | x_t)
                        out = self.p_sample(
                            model,
                            image_after_step,
                            t_last_t,
                            clip_denoised=clip_denoised,
                            denoised_fn=denoised_fn,
                            cond_fn=cond_fn,
                            model_kwargs=model_kwargs,
                            conf=conf,
                            pred_xstart=pred_xstart
                        )
                        image_after_step = out["sample"]
                        pred_xstart = out["pred_xstart"]

                        sample_idxs[t_cur] += 1

                        yield out
                # 加噪 resampling
                else:
                    # 代表加噪几步,默认是1步
                    t_shift = conf.get('inpa_inj_time_shift'1)
                    image_before_step = image_after_step.clone()
                    # 回到更噪的水平 p(x_{t+t_shift} | x_t)
                    image_after_step = self.undo(
                        image_before_step, image_after_step,
                        est_x_0=out['pred_xstart'], t=t_last_t+t_shift, debug=False
                    )
                    pred_xstart = out["pred_xstart"]

去噪采样:

def p_sample(...):
    ...  # 省略

    # 第一步采样之后 `pred_xstart` 才不是 `None`
    # 这说明在采样了一步之后(从而当前不是纯噪声)才会将参考图像与采样结果相结合
    if pred_xstart is not None:
        ''' gt 指参考图像,将参考图像加噪至当前时间步的噪声水平 '''

        gt_weight = th.sqrt(alpha_cumprod)
        gt_part = gt_weight * gt

        noise_weight = th.sqrt((1 - alpha_cumprod))
        noise_part = noise_weight * th.randn_like(x)

        # 加噪后的参考图像
        weighed_gt = gt_part + noise_part
        # 根据 mask 将参考图像与采样结果相结合
        x = gt_keep_mask * weighed_gt + (1 - gt_keep_mask) * x
    
    # p(x_{t-1} | x_t)
    out = self.p_mean_variance(
        model,
        x,
        t,
        clip_denoised=clip_denoised,
        denoised_fn=denoised_fn,
        model_kwargs=model_kwargs,
    )

    ...  # 省略    

由上可知,RePaint 在实现时,是先将参考图像与上一步采样结果相融合,然后再给扩散模型去采样下一步结果的,并非像上面 Algorithm 1 所示是先采样出结果再融合参考图像。并且,第一步采样是不考虑参考图像的,只有待当前采样图像变为非纯噪声之后才会考虑融合参考图像。

resampling 加噪就是正常的扩散模型的加噪过程:

def undo(self, image_before_step, img_after_model, est_x_0, t, debug=False):
return self._undo(img_after_model, t)

def _undo(self, img_out, t):
beta = _extract_into_tensor(self.betas, t, img_out.shape)
img_in_est = th.sqrt(1 - beta) * img_out + th.sqrt(beta) * th.randn_like(img_out)

return img_in_est

总结

SDEdit 应该是最早成功应用预训练无条件扩散模型来实现图生图的 training-free 方法之一,可谓是这个领域的开创性先锋,其后出现的 ILVR 和 RePaint 估计或多或少都有从它那里受到启发。比如 ILVR 的 conditioning range 其实就与 SDEdit 从中间时刻开始采样是相同的原理,限制了条件(即参考图)的影响时间范围;而 RePaint 的 resampling 其实几乎就是 SDEdit 反复 “加噪-去噪”的那一套。

最后就为这三兄弟来个总结叭~


线性逆问题

接下来介绍的这批方法主要是针对图像处理的线性逆问题(linear inverse problem) 而提出的,在数学形式上通常可以将逆问题表示为:

其中 是观测量(measurements), 代表观测噪声, 是线性退化(degradation)算子, 通常以矩阵形式来表示, 则是期望恢复的原图。

构造不同形式的 则可以对应到不同类型的图像修复(restoration)任务, 比如:去噪 (denoising)、编辑和填充(inpainting)、超分(super resolution)、上色(colorization)、去模糊 (deblurring) 等。其实, 上一章所介绍的一些场景都可以被归纳至线性逆问题中。

在这种逆问题的设定下, 生成模型的目标就是要建模后验分布 。为便于求解, 通常会将观测噪声假设为高斯噪声。再进一步结合已知的观测量和退化算子, 使得后验分布的建模变得 tractable.

DDRM

DDRM(Denoising Diffusion Restoration Models) 从变分推断(variational inference) 的角度出发,构造了一种后验采样的方法,从而可以利用预训练的(即无需二次训练)无条件扩散模型来解决任意线性逆问题。

DDRM 将求解逆问题的后验分布定义为以观测量为条件的马尔科夫链,这是模型建模的过程;同时将采样过程构造为以观测量为条件的变分分布,这是模型建模所要对齐的目标。

与正常的扩散模型一样,对齐以上两个过程实质上等价于最大化 ELBO(Evidence Lower BOund),并且作者在论文附录 C 中证明了,在一定条件下这个 ELBO 与无条件训练的 DDPM/DDIM 的 ELBO 是等价的,这是预训练无条件扩散模型能够直接用于 DDRM 求解逆问题的关键之一。

另外, 作者构造的变分分布满足边缘分布 , 即扩散模型 VE-SDE 的形式; 同时在附录 B 中, 作者证明了通过适当转换, 也可使用 VP-SDE 的形式, 只需要在每步采样时将 除以 即可转换到 VP-SDE 的状态。

DDRM 最骚的操作在于对退化算子 进行奇异值分解(SVD), 从而在谱空间(spectral space) 中看待扩散过程。之所以这么做, 是因为奇异值分解能够根据奇异值大小来决定观测量 对于原图 的影响程度, 进一步将观测噪声与扩散过程中 的噪声进行"绑定"(通过构造变分分布的方差来维持原扩散过程的噪声强度), 从而确保无条件扩散模型能够正确去噪。

并且,通过 SVD,我们可以知道观测量中缺失的但原始信号 应该拥有的部分,它们对应为奇异值等于0的那些部分,然后利用扩散模型来生成它们。同时,由于扩散模型的采样过程是去噪过程,因此观测噪声也会被去噪,有利于恢复原图质量。

作者构造的变分分布的具体形式如下:

的伪逆,在后者奇异值为 的位置,前者则对应为 。对于以上(4),(5),咋一看容易懵逼,但是冷静下来分析就会发现这波构造还是有理可依的。

从顶层设计来说,分为奇异值大于 0 和等于 0 的情况。等于 0 代表观测量没有对原始信号提供有用信息,因此直接使用扩散模型的无条件采样即可;大于 0 时则取决于谱空间中的观测噪声强度( )与原扩散过程的噪声强度 的情况。但无论两者大小如何, 都保证方差与原来扩散模型的一样是 ,从而满足边缘分布

为何说以上构造出来的变分分布维持了方差是 ? 这看上去不像啊..

先让我们把目光聚焦至谱空间:

然后再把上面 (4),(5) 中含有 的地方用上式代进去,就会发现分布的方差是 ,均值是 。至于 (5) 中的第一个分布, 由于使用的扩散过程是 VE-SDE, 因此 , 于是 就是标准高斯噪声, 从而整个分布的方差就是 , 均值则是

所以说, 作者所构造出来的变分分布妥妥地满足了边缘分布 , 和扩散过程是一致的, 这也是无条件扩散模型能够直接用于 DDRM 求解逆问题的关键!

OK,变分分布和无条件扩散模型“对齐”了,现在就可以考虑生成模型所建模的分布了。

虽然上面写着 "trainable parameters ",但前面说过,作者证明了将 (7), (8) 和 (4), (5) 进行对齐所对应的 ELBO 等价于无条件扩散模型在训练时所对应的 ELBO。于是,非常幸运地,相当于预训练无条件扩散模型就足以 cover 住 (7), (8) 了!

在论文附录 C 中,作者证明了当 的情况下,DDRM 的 ELBO 能够完美对齐无条件扩散模型的 ELBO,即:

选择其它值时, DDRM 的 ELBO 也正比于 , 所以预训练的无条件扩散模型也是一个不错的最优解平替。

顺便吹一下, 作者还在论文附录 H 中证明了 ILVR 可以被看作是 DDRM 在没有观测噪声情况下的特例, 其中退化算子 是下采样矩阵。

  • 局限性

DDRM 虽然思路新奇,足够秀,但确实不太“新手友好”,首先难免会吐槽这点。

然后,SVD 虽然对于矩阵分析是很有帮助的,但在高维空间下计算量会比较大

其次,作者 hand-craft 的变分分布和采样分布难免让人“感觉不稳”,面临实际场景时的效果还有待考量。

最后,DDRM 的只能解决线性的逆问题,并且需要预先知道退化算子的数学形式(当然也可以通过构造来近似)。

  • 核心代码

DDRM 涉及的数学公式不少,代码实现当然也比较复杂,有点那种一眼看去还没入门就想放弃的赶脚..

但其实只要耐下心来,对照着论文公式去读代码,就会发现还是“有理可依、有迹可循”的(不对照着论文公式去读的话容易晕,还请三思而后行)。 CW 将 DDRM 的核心代码从官方库(https://github.com/bahjat-kawar/ddrm/blob/master/functions/denoising.py%23L11)截取下来并附上了关键注释,各位友友们要是实在太闲不妨可以看看。

由于 DDRM 在理论上使用的是 VE-SDE,但在代码实现中,扩散模型的采样用的却是 VP-SDE 的形式,因此需要来回在 VE-SDE 和 VP-SDE 之间的状态( )进行切换,这点需要注意一下。

def efficient_generalized_steps(x, seq, model, b, H_funcs, y_0, sigma_0, etaB, etaA, etaC, cls_fn=None, classes=None):
    with torch.no_grad():
        ''' setup vectors used in the algorithm '''

        # 退化矩阵 H 的奇异值
        singulars = H_funcs.singulars()
        # 奇异矩阵
        Sigma = torch.zeros(x.shape[1]*x.shape[2]*x.shape[3], device=x.device)
        Sigma[:singulars.shape[0]] = singulars

        # $U^Ty$
        U_t_y = H_funcs.Ut(y_0)
        # $\bar{y}$ 观测量在谱空间中的表示
        Sig_inv_U_t_y = U_t_y / singulars[:U_t_y.shape[-1]]

        ''' initialize x_T as given in the paper '''

        # VP-SDE 的 \alpha_T
        largest_alphas = compute_alpha(b, (torch.ones(x.size(0)) * seq[-1]).to(x.device).long())
        # VE-SDE 的 sigma_T,根据论文附录 B 中 \alpha_t 和 \sigma_t 的转换关系计算
        largest_sigmas = (1 - largest_alphas).sqrt() / largest_alphas.sqrt()

        # 判断奇异值是否足够大
        large_singulars_index = torch.where(singulars * largest_sigmas[0000] > sigma_0)
        inv_singulars_and_zero = torch.zeros(x.shape[1] * x.shape[2] * x.shape[3]).to(singulars.device)
        # 奇异值比较大的位置则设为 $\frac{\sigma_y}{s_i}$,否则为0
        inv_singulars_and_zero[large_singulars_index] = sigma_0 / singulars[large_singulars_index]
        inv_singulars_and_zero = inv_singulars_and_zero.view(1-1)     

        ''' implement p(x_T | x_0, y) as given in the paper '''

        # if eigenvalue is too small, we just treat it as zero (only for init) 
        init_y = torch.zeros(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]).to(x.device)
        init_y[:, large_singulars_index[0]] = U_t_y[:, large_singulars_index[0]] / singulars[large_singulars_index].view(1-1)
        init_y = init_y.view(*x.size())
        # 根据论文(7)的分布计算
        remaining_s = largest_sigmas.view(-11) ** 2 - inv_singulars_and_zero ** 2
        remaining_s = remaining_s.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).clamp_min(0.0).sqrt()
        # \bar{x}_T
        init_y = init_y + remaining_s * x

        # 从 VE-SDE 转换为 VP-SDE 的状态,根据论文附录 B 计算
        # 由于 $\sigma_T$ 较大,因此这里将 $\sqrt{1 + \sigma_T^2}$ 近似为 $\sigma_T$
        init_y = init_y / largest_sigmas
   
        ''' setup iteration variables '''

        # 由谱空间转换回来:$V^Tx$ -> $x$
        x = H_funcs.V(init_y.view(x.size(0), -1)).view(*x.size())
        n = x.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]

        ''' iterate over the timesteps '''

        for i, j in tqdm(zip(reversed(seq), reversed(seq_next))):
            t = (torch.ones(n) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            at = compute_alpha(b, t.long())
            at_next = compute_alpha(b, next_t.long())
            xt = xs[-1].to('cuda')

            ''' 扩散模型估计噪声 '''

            if cls_fn == None:
                et = model(xt, t)
            else:
                et = model(xt, t, classes)
                et = et[:, :3]
                et = et - (1 - at).sqrt()[0,0,0,0] * cls_fn(x,t,classes)
            
            if et.size(1) == 6:
                et = et[:, :3]
            
            # prediction of x0
            # 根据当前状态 $x_t$ 和 估计的噪声 `et` 预测原图 `x0_t`
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()

            ''' variational inference conditioned on y '''

            sigma = (1 - at).sqrt()[0000] / at.sqrt()[0000]
            sigma_next = (1 - at_next).sqrt()[0000] / at_next.sqrt()[0000]

            # 由 VP-SDE 的状态转换至 VE-SDE 的状态以便执行 DDRM 在谱空间的采样过程
            # 根据论文附录 B 计算:$x_t = x_t \sqrt{1+sigma_t} = \frac{x}{\alpha_t}$
            xt_mod = xt / at.sqrt()[0000]
            # $\bar{x}_t$
            V_t_x = H_funcs.Vt(xt_mod)
            SVt_x = (V_t_x * Sigma)[:, :U_t_y.shape[1]]
            # $\bar{x}_{\theta, t}$
            V_t_x0 = H_funcs.Vt(x0_t)
            SVt_x0 = (V_t_x0 * Sigma)[:, :U_t_y.shape[1]]

            falses = torch.zeros(V_t_x0.shape[1] - singulars.shape[0], dtype=torch.bool, device=xt.device)
            cond_before_lite = singulars * sigma_next > sigma_0
            cond_after_lite = singulars * sigma_next < sigma_0
            cond_before = torch.hstack((cond_before_lite, falses))
            cond_after = torch.hstack((cond_after_lite, falses))

            ''' 用于 p^{(t)}_{\theta}(\bar{x}^{(i)}_{t-1} | x_t, y) '''

            # 用于 $s_i = 0$ 的情况
            # $\eta \sigma_{t-1}$
            std_nextC = sigma_next * etaC
            # \sqrt{1-\eta^2} \sigma_{t-1}$
            sigma_tilde_nextC = torch.sqrt(sigma_next ** 2 - std_nextC ** 2)

            # 用于 $\sigma_t < \frac{\sigma_y}{s_i}$ 的情况
            std_nextA = sigma_next * etaA
            sigma_tilde_nextA = torch.sqrt(sigma_next**2 - std_nextA**2)
            
            # 用于 $\sigma_t >= \frac{\sigma_y}{s_i}$ 的情况
            diff_sigma_t_nextB = torch.sqrt(sigma_next ** 2 - sigma_0 ** 2 / singulars[cond_before_lite] ** 2 * (etaB ** 2))

            # missing pixels, $s_i = 0$
            Vt_xt_mod_next = V_t_x0 + sigma_tilde_nextC * H_funcs.Vt(et) + std_nextC * torch.randn_like(V_t_x0)

            # less noisy than y (after) $\sigma_t < \frac{\sigma_y}{s_i}$
            Vt_xt_mod_next[:, cond_after] = \
                V_t_x0[:, cond_after] + sigma_tilde_nextA * ((U_t_y - SVt_x0) / sigma_0)[:, cond_after_lite] + std_nextA * torch.randn_like(V_t_x0[:, cond_after])
            
            #noisier than y (before) $\sigma_t >= \frac{\sigma_y}{s_i}$
            Vt_xt_mod_next[:, cond_before] = \
                (Sig_inv_U_t_y[:, cond_before_lite] * etaB + (1 - etaB) * V_t_x0[:, cond_before] + diff_sigma_t_nextB * torch.randn_like(U_t_y)[:, cond_before_lite])

            # aggregate all 3 cases and give next prediction
            # 由谱空间转换回来 \bar{x}_{t-1}=V^Tx_{t-1} -> x_{t-1}
            xt_mod_next = H_funcs.V(Vt_xt_mod_next)
            # 从 VE-SDE 状态转换至 VP-SDE 状态
            # 根据附录 B 计算:x_t = $\frac{x_t}{\sqrt{1+\sigma_{t-1}^2}} = x_t \sqrt{\alpha_t}$
            xt_next = (at_next.sqrt()[0000] * xt_mod_next).view(*x.shape)

            x0_preds.append(x0_t.to('cpu'))
            xs.append(xt_next.to('cpu'))

至于退化矩阵 H_funcs 的计算,可以根据不同场景来构造,具体见这部分代码,CW 就不再这里赘述了。

DDNM

如果说前面的 DDRM 让你感到思路新奇,那么即将出场的 DDNM(Zero-Shot Image Restoration Using Denoising Diffusion Null-Space Model) 比起它来可谓是有过之而无不及~

DDNM 可以解决任意形式的线性逆问题,它的核心思想是利用了向量空间的 range-null space decomposition(零值域分解),通过在采样过程中持续校正(refine)零空间的部分(null-space contents),就可以使得预训练无条件扩散模型生成真实且满足条件一致性(data-consistent)的结果,可谓是“novel 感”满满~

OK,CW 知道以上这段话会让大部分友友们感觉说了等于没说,DDNM 究竟是怎么 work 的?什么是 range-null space decomposition?校正 null-space contents 具体又是什么鬼?...

稍安勿躁,接下来 CW 为您一一揭晓~

  • Range-Null Space Decomposition

假设有矩阵 , 其伪逆(pseudo-inverse) 为 , 满足 . 那么, 对于向量 就可以把其投影到 值空间(range-space), 因为: ;另一方面, 则可以将其投影至零空间(null-space), 因为:. 有意思的是, 任意向量 都可以分解为两部分,其中一部分位于矩阵 的 range-space, 而另一部分则位于 null-space:

  • Refine Null-Space

先简单考虑不含噪声的线性逆问题 , 其中 代表观测量, 是原图(i.e. groundtruth), 是线性退化 (degradation)矩阵。解决该问题的目标就是要找到(生成)一张图 从而满足:

其中 代表原图所服从的分布。以上目标中的前一项代表了(条件)一致性(consistency),后一项则代表真实性(realness)。如果对原图进行 range-null space decomposition,根据前面的式子,可以得到:

于是不妨遵循以上形式来构造我们想要生成的目标:

这样,我们求解的目标就变为 。并且,无论其取值如何,都不影响 consistency,因为:

于是, consistency 的问题我们不用操心, 只需重点关注  ( 𝐼 𝐴 𝐴 ) 𝑥 ¯ , 也就是 null-space contents, 以便尽可能满足  𝑥 ^ 𝑞 ( 𝑥 ) , 即 realness.

说到生成, 那可是扩散模型的拿手好戏! 由于 是已知量, 因此扩散模型只负责生成 这部分就好。但这部分并非是用每步采样的结果(i.e. ), 因为刚刚在构造目标形式解时是依据原图的 range-null space 启发而来的, 所以 这部分要用扩散模型每步采样时所估计的原图 来代替。若使用 DDPM, 则为:

用上式来代替 并代入到前面所构造的目标形式解中:

这起到了校正无条件扩散模型所估计的 的效果,从而约束其满足 consistency.

DDPM 在采样时所依据的分布 是高斯分布,其中均值和方差为:

将以上 用前面校正后的结果 替代, 便可以得到下一步的采样结果:

整个采样流程如下:

  • 同时满足 consistency & realness

通过前面部分可知, 我们对于逆问题所构造的目标形式解 是天然满足 consistency 的, 这点不用担心。并且, 最后一步的采样结果就是 , 于是 consistency 依然满足。而 realness 所对应的 null-space contents, 我们交给了扩散模型来 handle, 即 , 但它与 range-space 所对应的 天然是"割裂"的, 论文中称之为 "disharmony" 现象。

不过,由于 DDNM 在每步采样时都将 range-space contents 和 null-space contents 代入到采样分布中得到下一步采样结果 , 它相当于是 的 noised version. 于是,这其中的 noise 无意间便起到了消除 disharmony 的作用——想象下,将原本是“隔开”的、没有交集的 range-sapce 和 null-space 通过添加噪声来扰动,使它们的覆盖范围都变大,于是便产生了交集,类似于在 SDEdit 那一章所展示的原理图。

这么一来,就会使得 range-space in harmony with null-space,于是最终生成的结果就可以同时满足 consistency & realness.

以上针对的线性逆问题均以不含噪声(noise-free)为前提,如今我们一起来把目光转向到带有观测噪声的情况。

对于含有观测噪声的情况, 由于在 中夹带的噪声 (源于 ) 会被引入到下一步的采样结果 中, 因此作者选择对 这部分进行 scale(于是引入 ), 这部分可看作是 range-space contents 对于 的校正, 是保证 consistency 的关键。然后进一步构造采样方差( )与原本无条件扩散模型的采样方差进行绑定, 这与前面 DDRM 的做法类似, 为的是不改变无条件扩散模型在采样时的噪声强度, 这样模型基于原来每个时间步的 SNR(信噪比) 所估计的噪声就能够适用, 从而可以成功去噪。

那么,对 range-space contents 的 correction 进行 scale 的道理在哪呢?

可以这么想:由于夹带了噪声,就相当于 range-space contents 的校正作用受到了影响(被噪声扰动),从数学表达式来看,以上 (16) 式中的蓝色部分是可以被合并到其前一项的(从而得到 ),这就直接体现了 correction 的程度受到了影响,因此不妨直接引入一个 scale 系数来表达这种影响,同时还能够控制影响程度。

在之前 noise-free 的情况下,原本的下一步采样结果为:

结合 (17),(18) 式, 如今 那部分就会多出一个强度为 的噪声项(想象下 (17) 式括号内还有一项 , 并且 变为 。暂且考虑最简单的情况, 即 , 并假设 具有单位尺度, 同时记上式 的系数为 , 要使得 的方差不改变 (等于 ), 则需满足:

之所以在 的情况下令 是希望在观测噪声不大(对于 range-space contents 的影响较小)的情况下尽量削弱其影响,以尽量与前面 noise-free 的情况接近,从而最大化 range-space contents 的校正作用,以便保证 consistency.

对于一般化的情况,作者在论文附录 I 中给出了详细的解析,原理类似,还请各位靓仔靓女们自行去细细品味,CW 要偷懒一下~

另外,作者不小心发现 DDNM 在面临一些场景时生成结果的真实性(realness)不太能看,比如:大尺度均值池化下的 SR(超分) 任务、低采样率的 CS(压缩感知) 任务以及大部分内容被 mask 掉的 inpainting 任务等。作者对此归因为 range-space contents 过于局部(从而模型看到的内容太少)以至于无法引导模型生成一个全局 harmony(同时满足 consistency & realness) 的结果。

In these cases, the range-space contents A†y is too local to guide the reverse diffusion process toward yielding a global harmony result.

既然生成的结果不够真实,那就“回流再造”,直至 KPI 达标(生成满意的结果)为止,模型还不是如打工人一样是牛马!在这种念头的驱动下,作者发现前辈 RePaint 就这么玩过(RePaint 中提出的 resampling 方法),于是“借鉴”过来小改一下随即对外宣称为 "time-travel" 技术。

具体地,在采样的每个时间步,可以选择重新加噪回前面更 noised 的阶段,然后再继续实施去噪采样。至于具体加噪几步,则可以用一个超参 来控制。直观来看,这是一种“以更好的过去来产生更好的未来”的思路。

Intuitively, the time-travel trick produces a better “past”, which in turn produces a better “future”.

DDNM 在时间步 的采样依赖于 , 在直觉上, 它应该比之前在 时间步所估计的 来得好。于是, 从 加噪回去, 相当于利用了更好的 所携带的信息, 这一步相当于是 , 然后重新由 时间步进行去噪采样直至 时, 就会产生更好的结果, 因为中间每步迭代生成都继承了"更好的产生更好的", 最终生成的结果也就更好。

结合 time-travel 技术,在解决含噪声的逆问题时,DDNM 进化为 ,整个采样流程如下:

  • 局限性

DDNM 主要局限在于需要明确知道退化矩阵的数学形式并计算其伪逆(或者人为构造),并且只能解决线性的逆问题。其次,time-travel 会导致生成满意的效果需要较长时间。另外,对于有观测噪声的情况,需要将噪声强度绑定到扩散模型的采样方差,引入了手工设计的参数

CW 顺便提一嘴,在 DDNM 论文的附录 H 中,作者证明了 RePaint 和 ILVR 都是 DDNM 在 noise-free(不含观测噪声)下的特例;同样,在这种情况(i.e. noise-free)下,DDRM 也可视作是 DDNM 的特例。

  • 核心代码

以下代码截取自官方 repo(https://github.com/wyhuai/DDNM/blob/main/guided_diffusion/diffusion.py%23L368),对应以上 Algorithm 2.

            # x_T
            x = torch.randn(...)
            n = x.size(0)

            with torch.no_grad():   
                x0_preds = []
                xs = [x]
                
                times = get_schedule_jump(config.time_travel.T_sampling, 
                                          config.time_travel.travel_length, 
                                          config.time_travel.travel_repeat)
                time_pairs = list(zip(times[:-1], times[1:]))          
                skip = config.diffusion.num_diffusion_timesteps // config.time_travel.T_sampling

                for i, j in tqdm.tqdm(time_pairs):
                    i, j = i * skip, j * skip
                    if j < 0:
                        j=-1 

                    if j < i: # normal sampling 
                        t = (torch.ones(n) * i).to(x.device)
                        next_t = (torch.ones(n) * j).to(x.device)

                        at = compute_alpha(self.betas, t.long())
                        at_next = compute_alpha(self.betas, next_t.long())
                        sigma_t = (1 - at_next**2).sqrt()

                        xt = xs[-1].to('cuda')
                        # 扩散模型估计噪声
                        et = model(xt, t)
                        if et.size(1) == 6:
                            et = et[:, :3]

                        # Eq. 12 $x_{0|t}$
                        x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()

                        # Eq. 19
                        if sigma_t >= at_next * sigma_y:
                            lambda_t = 1.
                            gamma_t = (sigma_t**2 - (at_next*sigma_y)**2).sqrt()
                        else:
                            lambda_t = (sigma_t) / (at_next*sigma_y)
                            gamma_t = 0.

                        # Eq. 17 $\hat{x}_{0|t}$
                        # range-space contents correction
                        x0_t_hat = x0_t - lambda_t * Ap(A(x0_t) - y)

                        ''' DDIM 采样 '''
                        eta = self.args.eta
                        c1 = (1 - at_next).sqrt() * eta
                        c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)

                        # different from the paper, we use DDIM here instead of DDPM
                        xt_next = at_next.sqrt() * x0_t_hat + \
                            gamma_t * (c1 * torch.randn_like(x0_t) + c2 * et)

                        x0_preds.append(x0_t.to('cpu'))
                        xs.append(xt_next.to('cpu'))    
                    else# time-travel back
                        next_t = (torch.ones(n) * j).to(x.device)
                        at_next = compute_alpha(self.betas, next_t.long())
                        x0_t = x0_preds[-1].to('cuda')

                        xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt()
                        xs.append(xt_next.to('cpu'))
                
                ...  # 省略

至于线性退化矩阵的实现,作者给出了一些简单场景的参考:

def color2gray(x):
    coef=1/3
    x = x[:,0,:,:] * coef + x[:,1,:,:]*coef +  x[:,2,:,:]*coef
    return x.repeat(1,3,1,1)

def gray2color(x):
    x = x[:,0,:,:]
    coef=1/3
    base = coef**2 + coef**2 + coef**2
    return th.stack((x*coef/base, x*coef/base, x*coef/base), 1)    
    
def PatchUpsample(x, scale):
    n, c, h, w = x.shape
    x = torch.zeros(n,c,h,scale,w,scale) + x.view(n,c,h,1,w,1)
    return x.view(n,c,scale*h,scale*w)

# Implementation of A and its pseudo-inverse Ap    
    
if IR_mode=="colorization":
    A = color2gray
    Ap = gray2color
elif IR_mode=="inpainting":
    A = lambda z: z * mask
    Ap = A
elif args.deg =='denoising':
    A = lambda z: z
    Ap = A
elif IR_mode=="super resolution":
    A = torch.nn.AdaptiveAvgPool2d((256 // scale, 256 // scale))
    Ap = lambda z: PatchUpsample(z, scale)
elif IR_mode=="old photo restoration":
    A1 = lambda z: z * mask
    A1p = A1
    
    A2 = color2gray
    A2p = gray2color
    
    A3 = torch.nn.AdaptiveAvgPool2d((256 // scale, 256 // scale))
    A3p = lambda z: PatchUpsample(z, scale)
    
    A = lambda z: A3(A2(A1(z)))
    Ap = lambda z: A1p(A2p(A3p(z)))

y = A(x_orig)

此外,作者还给出了一个应用 SVD 来实现线性退化矩阵 AA 的版本,具体可参考https://github.com/wyhuai/DDNM/blob/main/functions/svd_operators.py%23L211,该版本的采样过程对应于论文附录 I 的内容,其代码实现可参考https://github.com/wyhuai/DDNM/blob/main/guided_diffusion/diffusion.py%23L419

总结

不难发现,DDRM 和 DDNM 在思想上是有些相似的——都是利用了向量/矩阵空间的分解技术,在那个空间去看待扩散过程,以便将观测信息注入至无条件扩散模型(的采样过程)中。为了能够成功利用扩散模型进行去噪,二者都通过 hand-craft 的方式来与原来无条件的情况(无条件扩散模型预训练时)对齐每个时间步的噪声强度。另外,两篇论文的作者们都偏好将以往的方法视作它们的特例,就有一种喜欢强调自己方法更 general 的调性(不带褒贬地说出事实)~


Guidance 的影子

在本文介绍的所有方法中,除了 SDEdit 以外,其余或多或少都瞥见了 guidance 的影子。ILVR 和 RePaint 是对每步的采样结果 进行校正,而 DDRM 和 DDNM 则是对每步所估计的原图 进行校正。校正是为满足一致性(consistency),这本质上同属 guidance 的效果。

具体地,对于 ILVR:

对于 RePaint:

对于 DDRM 来说, guidance term 则更为明显一把采样分布均值中的 单独提取出来, 其余部分便可视作 guidance term, 只不过这是在 spectral space 中对 做引导(即对 进行校正)。

最后,对于 DDNM,前面的式 (17) 更是将“guidance 风范”展露无遗,连 guidance scale 都用上了:

嗯,相信各位友友们都猜到了 —— CW 将会在下一篇文章里大(吹)聊(水) guidance 方法,从而继续为这份食谱增添一些诱人的配方~


公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列ICCV2023论文解读极市直播
极视角动态欢迎高校师生申报极视角2023年教育部产学合作协同育人项目新视野+智慧脑,「无人机+AI」成为道路智能巡检好帮手!
技术综述:四万字详解Neural ODE:用神经网络去刻画非离散的状态变化transformer的细节到底是怎么样的?Transformer 连环18问!

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读5.7k
粉丝0
内容8.2k