极市导读
文章主要内容分为两部分。第一部分介绍如何利用 probability flow ODE 对真实图片样本计算 likelihood(似然),同时也详细解析了关于评估指标 bpd 的计算细节;第二部分介绍条件生成,本文选取了图像补全和上色两个场景来讲故事。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
写在出师前
本文作为本系列的出师篇(终篇),在内容上比起前两篇会更偏应用层面,也更好玩。
Score-based SDE 生成模型从入门到出师系列(一):用随机微分方程建模图像生成任务并统一分数和扩散模型
SDE 扩散生成模型从入门到出师系列(二):揭秘随机微分方程如何应用于采样生成
主要内容分为两部分。第一部分介绍如何利用 probability flow ODE 对真实图片样本计算 likelihood(似然),同时也详细解析了关于评估指标 bpd 的计算细节;第二部分介绍条件生成,本文选取了图像补全和上色两个场景来讲故事。每部分都贯彻不无聊的风格,在理论解析的同时结合代码实现(当然还有不正经的文风)。
文风虽然不太正经,但 CW 的确是尽全力和大家一起攻陷所有知识的防御门,这点中国人不骗中国人!
在进入正题之前,先来看一个小故事~
前传:关于 Likelihood 的故事
Likelihood 是个很有意思的家伙,它中文名叫 “似然”,不由让人感觉它有种 “似乎 然而.." 的意味在; 它的样子长得像概率,却又不是常规意义上的概率,通常长 这种样子,表示对于已知事实的可能性(概率) ,其中 代表已经发生(出现)的事情(数据),而 嘛.. 就是 “想象力和事实所孕育的孩子" 了。
由于人类对于未知充满了恐惧,因此会假设自己看到的这批数据服从某种分布(就是这么不讲道理),同时老天爷也可怜人类,于是让人类所见的这批数据恰好服从其假设的那种分布。但甜头也是有额度的,所以决定这种分布的参数就不能泄露给人类(但老天爷偷偷告诉了 CW,它叫 ),电视剧都有演啦一一天机不可泄露嘛!
这下没招了,可怜的人类只能㮫猜,于是乎 就这么诞生了一一它就是人类对于真正的分布参数 ( ) 瞎猜的结果。不过, 也不是一成不变的,毕竟人类还是很上进的,不会一直想要躺平。于是,随着自身阅历增长、在见到越来越多社会事实(数据)后,人类就根据当下事实来调整 ,至于调整策略嘛就见仁见智了。像 CW,就选择使用大力出奇迹的深(玄)度(学)学(炼)习(丹)。
随着经验积累, 就会变得越来越像 。最终,在经历了足够多且惨烈的社会毒打之后, 与 的样子就会变得超级像以至于几乎一模一样。总结 的一生,先是由于人类想象力(瞎猜)而诞生,后期成长靠的是社会毒打(事实&数据),因此可谓是想象力与事实的孩子了~
故事结束,以大白话来说, likelihood 的意思就是:对于已经发生(出现)的事情(数据) ,在当前所拥有的参数 这个条件下, 这事(数据)发生(出现)的概率。可以认为,在 likelihood 的概念范畴里, 是常量, 是变量.
计算 Likelihood 的神器:Probability Flow ODE
通过上一篇文章,我们知道 probability flow ODE 的主要功能是采样;但除此之外,它还有一个非常香的功能,那就是能够对真实的图片样本计算 likelihood,是妥妥的神器。
整体计算方法使用了在 Neural ODE(https//arxiv.org/abs/1806.07366) 这篇 paper 中提出的“瞬时变量变换公式(the instantaneous change of variables formula)”,并且利用 "Skilling-Hutchinson trace estimator(https//blog.shakirm.com/2015/09/machine-learning-trick-of-the-day-3-hutchinsons-trick/)"(后文简称迹估计器) 对雅克比矩阵(Jocabian)的迹(trace)做高效的无偏估计,以减轻计算负担。
接下来,CW 会进一步阐述 likelihood 的计算原理,同时浅浅推导下 Skilling-Hutchinson trace estimator 为何能够对大型矩阵的迹进行无偏估计。至于 the instantaneous change of variables formula 的证明,在 Neural ODE 论文的附录中陈天琪(https//rtqichen.github.io/)大佬已经给出非常详细的证明过程,其中还用到了 Jacobi's formula(https//handwiki.org/wiki/Jacobi%2527s_formula)。如果对这部分证明搞不明白但却有非常强烈欲望想搞明白的可以偷偷私信 CW,我保证不告诉别人!
计算原理
为简化表示,首先将 probability flow ODE 的漂移系数记为 ,即:
probability flow ODE 根据瞬时变量变换公式,样本的对数似然为:
其中 代表漂移系数对 求导所得的雅克比矩阵, 则代表对矩阵求迹。哦,对了, 这个东西也叫 “散度(divergence)”。另外, 是先验,也是我们知道的,比如 VE SDE 的是 ;而 VP SDE 和 sub-VP SDE 的都是 , 因此以上式子要计算出结果在理论上是 OK 的。
然而,与我们打交道的 多为包含很多像素的图像数据,于是直接去计算雅克比矩阵的迹会很难搞(需要强大计算资源,而这又需要强大的金钱..),所以我们必须寻求更为聪明的方式去搞定这个计算难题。
你永远不用担心这个世界上缺乏牛逼的人,这不,一不小心就发现这篇 pape(https//link.springer.com/chapter/10.1007/978-94-015-7860-8_48)的作者们提出了一种估计大型矩阵的迹的方法——Skilling-Hutchinson trace estimator(很明显,这个方法就是以作者们的名字来命名的)。有了这个 buff 之后,上面那个雅克比矩阵的迹就可以转换为:
其中 是服从某种 0 均值、单位协方差分布的随机变量,比如标准高斯变量就是一个很好的选择,毕竟是老熟人嘛 利用自动微分系统, 这个 vector-Jacobian product 是能够被高效计算的,计算消耗与仅计算其中的雅克比矩阵 差不多。
有了以上这个 trick 的神助,如今难搞的雅克比矩阵的迹也变得轻松了,那么剩下对它的积分又该如何计算呢?还用问嘛!CW 在前面那么多篇幅的内容就已经告诉你——当然是白嫖一些网红数值求解器最为轻松,比如作者在实验中就使用了 RK45 ODE solver(https://www.sciencedirect.com/science/article/pii/0771050X80900133?ref=pdf_download&fr=RR-2&rr=8817fa866db4cf49n)。
关于似然的计算流程已经阐述完毕,接下来推导一下矩阵迹估计方法 Skilling-Hutchinson trace estimator,走起~!
Skilling-Hutchinson Trace Estimator
为了简化表示, CW 在这里就不使用上面的雅克比矩阵来做证明了,因为这个估计方法是普适(对于任意大型矩阵)的,所以干脆就以 来表示这个目标矩阵。于是,我们所要求证的目标就是:
Step by step,我们就由起点(以上等式左边)出发:
最后一步可以将迹去掉是因为 的结果是标量,求迹也就是其本身。Q.E.D.
品鉴专家:bits per dimension(bpd)
不过,计算出 likelihood 并非终点,真正品鉴生成模型好坏的是一个叫 "bits per dimension(bpd)" 的家伙,也就是要将 likelihood 进一步化为 bpd 来评估模型的效果。从 bpd 这个名字也可以看出,这个指标计算的是每个 dimension 上所需的编码位数(以2为底的 bit)。对于图像来说,它计算的就是编码每个像素所需的位数(bits) 。嗯,这个意思大家是 get 到了,但是对于 bpd 的具体计算方式可能还有点懵,关键在于编码像素所需的位数到底是啥?
这里需要抱一下香浓.. 哦是香农巨佬的大腿
根据他提出的信息论,事件的信息量通常以 表示;同时,这个信息量也代表计算机编码这件事所用到的最小位数(bits)。在这种表示方式下,概率越小的事件所包含的信息量就越大,反之亦然。这也是符合我们凡人的认知,因为对于一件事情,如果你都知道它是必然发生的,那么对你来说它根本就没有什么惊喜可言,而这种惊喜程度的大小就是所谓的“信息量”。
现在,可以将 bpd 和 likelihood 关联起来了。首先根据 probability flow ODE 计算出样本的对数似然: ,然后在前面加个负号就得到了编码这幅生成图像所需的位数(bits)。
且慢! 由于前面计算出来的对数似然是以自然底数 为底的(炼丹的大多喜欢这么玩),因此还需要先转换为以2为底来表示,转换公式是: 。不知道这公式的去复习高数一一高中数学,不是高等数学,想什么呢..
于是,编码这幅图像所需的 bits 就是: 。然而,这还不是 bpd,因为它是对应到每个像素的,而我们现在算出来的是对应到整幅图像的。很好办,将我们前面算出来的除以图像的像素总数就 OK了。比如这幅生成图像的像素个数为 ,那么 bpd 就是: 。
在评估时,计算出来的 bpd 还会在图像分布上求期望,从而变为: ,在具体计算时通常就是对所有图像样本计算出来的 bpd 求平均即可。
劳动生产:代码实现
根据瞬时变量变换公式得出的 (i) 式表明,想要求得生成样本的对数似然,就得去解 ODE,因为其中含有积分项,被积函数( 也就是散度)可看成是某个状态对时间的导数,即: ,那么整个积分 就可看成是初值 的 ODE 问题,从而积分所得的结果对应就是 。
既然是 ODE 的初值问题(initial value problem),那么自然容易想到利用 scipy 的接口:scipy.integrate.solve_ivp(https//docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html)。
另外,这个ODE 依赖于 ,而这又依赖于 probability flow ODE 来获取每个时刻的状态 ,同时初值 又是已经拿到手的真实图像样本,因此这又是另一个 初值问题。于是,这里就需要使用 scipy 接口来同时求解两个 ODE 初值问题。
总结起来,代码实现的核心部分包括:
-
计算 probability flow ODE 的漂移系数 -
利用 Skilling-Hutchinson trace estimator 计算散度 -
将以上两者以及对应的两个ODE 初值塞进 scipy 接口进行求解 -
解出 ODE 后分离出目标状态 并根据 (i) 式计算对数似然 -
根据 bpd 定义计算结果
from scipy import integrate
def ts_to_flattened_np(ts: Tensor) -> np.ndarray:
return ts.detach().cpu().numpy().reshape((-1,))
def flattened_np_to_ts(arr: np.ndarray, shape) -> Tensor:
torch.from_numpy(arr.reshape(shape))
@torch.no_grad()
def likelihood(
sde: SDE, score_fn: Union[Module, Callable], data: Tensor,
hutchinson_type: str = 'Gaussian', method: str = 'RK45',
rtol:float = 1e-5, atol:float = 1e-5, eps: float = 1e-5,
scale_factor: float = 256.
) -> Tuple[Tensor, Tensor, int]:
def get_ode_drift(x: Tensor, t: Tensor) -> Tensor:
""" 计算 probability flow ODE 的漂移系数, 对应论文公式(38). """
# ODE 作为逆向 SDE 的特例
rsde = sde.reverse(score_fn, probability_flow=True)
# 取索引[0]代表漂移系数(索引[1]对应扩散系数)
f = rsde.sde(x, t)[0]
return f
def get_div(x: Tensor, t: Tensor, random_vec: Tensor) -> Tensor:
""" 利用 Hutchinson-Skilling trace estimator 计算散度,
即 $\frac{dlogp}{dt}$, 对应论文公式(39)的被积函数. """
# 注意这里要开启计算图才能计算梯度,
# 因为外围整个 likelihood 函数被 torch.no_grad() 装饰了.
with torch.enable_grad():
x.requires_grad_(True)
# 注意漂移系数 f 要在 x 启用梯度后再计算,
# 这样 f 对 x 才能计算出梯度.
# (不能直接在函数外面把 f 计算好传进来)
f = get_ode_drift(x, t)
# $\epsilon^T \nabla f$
grad = torch.autograd.grad((random_vec * f).sum(), x)[0]
# 记得重新将梯度计算图取消
x.requires_grad_(False)
div = (grad * random_vec).sum(dim=tuple(range(1, x.ndim)))
return div
if hutchinson_type not in ("Gaussian", "Rademacher"):
raise ValueError(f"hutchinson_type must be 'Gaussian' or 'Rademacher', got: {hutchinson_type}")
# Hutchinson-Skilling trace estimator 中使用到的随机向量, 其服从0均值、单位方差的分布
epsilon_vec = torch.randn_like(data) if hutchinson_type == "Gaussian" \
else torch.randint_like(data, 0, 2, dtype=data.dtype) * 2 - 1.
shape = data.shape
# $\Delta logp$ 的初值
delta_log_p_init = np.zeros((shape[0],))
# x 的初值 x(0), 在这里是模型生成的图像
data_init = ts_to_flattened_np(data)
# 以下同时解两个 ODE, 因此这里将两个初值拼接起来
ode_init = np.concatenate(
[data_init, delta_log_p_init],
axis=0
)
# 用于标记拼接后的向量中哪部分对应图像数据
data_range = -shape[0]
def ode_fn(t: float, x: np.ndarray) -> np.ndarray:
x_ts = flattened_np_to_ts(x[:data_range], shape).to(data)
t_ts = torch.ones(shape[0], device=x_ts.device) * t
f = get_ode_drift(x_ts, t_ts)
div = get_div(x_ts, t_ts, epsilon_vec)
# 因为要同时解两个 ODE, 所以将两个 ODE 方程的状态(state)变化拼接在一起
return np.concatenate(
[ts_to_flattened_np(f), ts_to_flattened_np(div)],
axis=0
)
T = sde.T
solution = integrate.solve_ivp(
ode_fn, (eps, T), ode_init,
rtol=rtol, atol=atol, method=method
)
# 最后时刻 T 对应的终止状态
z_delta_log_p = solution.y[:, -1]
# x(T)
z = flattened_np_to_ts(z_delta_log_p[:data_range], shape).to(data)
# 先验的对数概率密度 $log p(x(T))$
prior_log_p = sde.prior_log_p(z)
# 论文公式(39)的积分项
delta_log_p = flattened_np_to_ts(z_delta_log_p[data_range:], (shape[0],)).to(data)
# 根据瞬时变量变换公式计算生产样本的对数似然: $log p(x(0))$
log_p = prior_log_p + delta_log_p
# 一张图片的像素总数, 对应 bpd 的 number of dims
N = np.prod(shape[1:])
# 一张图片的负对数似然除以其像素总数后转换为以2为底的对数表示
# 即得到 bpd
bpd = -log_p / np.log(2) / N
# 通常在数据预处理时会对图像的数值进行缩放,
# 比如 [0,255]->[0,1] 或 [0,255]->[-1,1]
# 前者对应缩小256倍;后者对应缩小128倍,
# 因此要每个像素的对数似然要减去 log_2(256) 或 log_2(128),
# 从而负对数似然就等同于加上这个值
offset = np.log(scale_factor) / np.log(2)
bpd = bpd + offset
# 数值求解 ODE 所经历的迭代次数
nfe = solution.nfev
return bpd, z, nfe
整体代码实现有不少需要注意的地方,CW 重点展开讲讲以下两点:
i) Skilling-Hutchinson trace estimator 所对应的 get_div() 方法
根据 Skilling-Hutchinson trace estimator,应该是
,而代码中确是 torch.autograd.grad((random_vec * f).sum(), x)[0] ,咋一看容易憛逼,但其实你盯住它,用点力..
就会发现代码中的形式和公式实质上是等价的。因为(random_vec * f).sum() 其实等于
,所以代码实现的其实是先将随机向量与漂移系数点乘再求导,公式则是先拿漂移系数计算导数再与随机向量点乘,而随机向量与
无关,因此这就如标量求导的情况一样,将常数项拿到导数外面:
在实现时还有一点需要注意的是,漂移系数 f 不能从外部传参进来,必须在内部待 x 开启梯度(requires_grad_())后重新计算,这样 f 才能成功对 x 求导计算梯度,而在 get_div() 方法外部都是没有梯度计算图的, 因为 likelihood() 方法整体都被 @torch.no_grad() 修饰了起来。
ii) bpd 最后要根据图像预处理的数值缩放加上 offset
众所周知,8-bits 数字图像像素值的取值在 {0~255} 的整数集合中。为了方便模型训练,通常在数据预处理阶段对其进行去量化(dequantize)(https//arxiv.org/abs/1306.0186)操作,使其成为连续的浮点值同时缩放至 [0, 1] 取值范围,常见的方式有均匀去量化: ,也就是对原图(每个像素)加上 均匀分布的噪声再除以 256 。同理,若要缩放至 区间,则对应为: ,相当于缩放了 。
于是,模型实际接触到的是缩放后的连续浮点值图像,其估计的 likelihood(bpd) 也对应处理后的图像而非原始 8-bits 图像。那么,我们得想办法将 likelihood(bpd) 转换为原图空间。
记处理后的像素为 ,那么逆变换 的反函数)则为 。根据 “随机变量的变换法则(change of variables)" (https://mathworld.wolfram.com/ChangeofVariablesTheorem.html),有: 。也就是说,原图每个像素的概率密度是如今处理后图像每个像素的概率密度的 。将整幅图的所有像素看作是独立同分布的话,则总共含有 个像素的原图分布的概率密度将会是去量化后图像分布的概率密度的 。
记 8-bits 原图和去量化后的图像分别为 ,那么两者在对数似然上的关系为: ,从而在 bpd 上的转换就是:
这就是以上代码中 bpd + offset 的原因。当原图缩放至 [0,1] 也就是缩放了 \frac{1}{256}\frac{1}{256} 时,scale_factor = 256 ,对应 offset = 8;同理,当原图缩放至 [-1,1] 也就是缩放了 \frac{1}{128}\frac{1}{128} 时,scale_fator = 128,对应 offset = 7。
在大多数开源代码计算 bpd 时,都会涉及这种转换关系,但 CW 发现全网(包括大部分国际朋友)对这个问题还是挺 confused 的,而且也没有看到哪位大侠把这点讲明白。或许大家更倾向于默默知道的同时低调打怪升级,好吧,我承认只有我比较爱吹水~
最后,再写个弱鸡脚本来跑跑以上实现的计算似然的方法 likelihood()。
device = "cuda"
bpd_num_repeats = 5
bpd_bs = 128
bpd_dataset = MNIST('.', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(
bpd_dataset, batch_size=bpd_bs,
num_workers=2, pin_memory=True
)
sde = subVPSDE()
model = ScoreModel(sde.p_0t, in_channel=1).to(device)
# 加载训练好的权重
ckpt = f"{sde.__class__.__name__}-{n_epochs}_epochs.pth"
state_dict = torch.load(ckpt, map_location=device)
model.load_state_dict(state_dict)
model.eval()
# 当数据集规模不大时,通常会对整个数据集重复计算多次求平均结果,以尽力模拟在数据分布上求期望的效果
avg_mean_bpd = 0.
for i in range(1, bpd_num_repeats + 1):
mean_bpd = 0.
num_items = 0
for data, _ in data_loader:
data = data.to(device)
bpd = likelihood(sde, model, data)[0]
num_items += bpd.size(0)
mean_bpd = mean_bpd + bpd.sum().item()
mean_bpd /= num_items
print(f"repeat: [{i}]/[{bpd_num_repeats}]; mean bpd: {mean_bpd:6f}; num items: {num_items}")
avg_mean_bpd += mean_bpd
avg_mean_bpd /= bpd_num_repeats
print(f"mean bpd of {bpd_num_repeats} repeats: {avg_mean_bpd:6f}")
结果:
repeat: [1]/[5]; mean bpd: 8.684804; num items: 10000
repeat: [2]/[5]; mean bpd: 8.684535; num items: 10000
repeat: [3]/[5]; mean bpd: 8.684938; num items: 10000
repeat: [4]/[5]; mean bpd: 8.684896; num items: 10000
repeat: [5]/[5]; mean bpd: 8.683588; num items: 10000
mean bpd of 5 repeats: 8.684552
嗯,还是能输出东西的,行,就酱吧~
条件生成的内功心法:Conditional Reverse-Time SDE
使用 score models 的一大好处就是可以很方便地拓展至条件生成场景,所谓的 “条件" 可以是指定的物体类别、文字描述、部分画面缺失的图像、黑白图像(目标是将其变成彩图)等。记条件向量为 ,那么在条件生成的场景下, score 就变为 ,从而采样所用到的 reversetime SDE 就对应为:
根据贝叶斯定理一 于是 。因为 与 无关,所以前者对后者的导数为 0 。等式右边第一项是在无条件情况下的 score,而第二项通常也是 tractable 的。比如 代表物体类别时(也就是当下以物体类别为条件进行图像生成,例如要生成一张狗的图片),我只需要在模型尾部额外接一个分类头训练一个 classifier 让其输出 即可。不过需要注意的是,这个分类器接收的输入不是原始图像(特征),而是加噪后的图像(特征) ,这和普通的图像分类模型不一样。
你吐槽到: “还要额外训一个分类器那么麻烦,哪里方便拓展了! ” 哎呦 不错哦 !
确实存在无需训练分类器的方法,它就是大名鼎鼎的 classifier-free guidance,CW 在之前解析 NCSN 的那篇文章(https://zhuanlan.zhihu.com/p/597490389)中也有介绍过,这里就不再啰嗦了。
class-conditional 这个场景比较老套,下面 CW 来讲两个稍微好玩一点的场景。
Imputation: 构建完好无缺的画面
图像领域中的 imputation 是指将损坏/部分内容缺失的不完整图像恢复(补全)为完整的图像,比如:
若要模型有能力完成这一任务,那么就势必要求它对图像的内容和结构有一定程度的理解,因为它只有理解了图像中各部分内容之间的相互关系(上下文关系),才能够顺利恢复出缺失的图像内容。
作者对 score-based SDE model 可是信心满满,认为它理解上下文的能力肯定是有的,关键在于如何建模这项任务以让模型能在不重新训练的情况下就能够顺利完成 imputation 任务。其实现在我们面临的处境有点像在玩预训练 LLM 时,需要将模型未见过的任务形式通过设计合适的 prompt 转换为模型训练时熟悉的输入形式,从而令其能够顺利完成这个在训练集中未接触过的任务。
既然我们手上拥有的是 score models,是专业吐 score 的,那么无论是面对哪种形式的条件,其完成条件生成的内心功法都始终是前面提到的 conditional reverse-time SDE,所以关键点就在于 "条件 score" 了。如今在 imputation 这个场景中, 就对应图像中完好无损的那些部分,如果直接从这个形式出发,那么就代表们需要将图像中完好无损的部分喂给模型,然后让其在对应时刻输出对于整幅图像(包含损坏的部分)的 score。这和模型在训练时的玩法不一致,因为在训练时模型拿到的可是完整的图像,现在来的却是 “残缺品”,怎么玩嘛..
既然我们风人能力有限,那么不妨来学习下宋戌大神是如何来玩这个游戏的。 在 imputation 任务中,图像有完好和损坏的部分,于是不妨将它们分别表示为 ,那么最终目标就是要从 中采样,也就是走完 conditional reversetime SDE 至终点 ,同时获得采样结果。
进一步,将 看作是另一个独立的扩散过程,对应到以下伊藤 SDE:
于是,对应以 为条件的 conditional reverse-time SDE 就表示为:
但是,在大多数情况下,上式中的 是 intractable 的。面对 intractable 的东西,当然就是去 approximate 它啦!
记条件 为事件 ,然后以 作为媒介,引入
最后的 代表从 采样出来的向量,这说明最后一步是蒙特卡洛那一套。另外,以上第2 3步的近似可以认为是最外围有个期望 ,这隐含了以 为条件。
于是,我们的目标一一条件 score 就近似为:
以上最后一步出现的向量 代表将 时刻下图像完好部分与损坏部分的特征拼接在一起(这不恰好就是完整的一幅图所对应的特征吗!),即有:
也就是说,如今在完成 imputation 任务时,模型的输入就是 ,即整幅图像(包含完好与损坏部分)在对应时刻的加橾结果,与模型在无条件场景下训练时的输入形式一致。
你又会问:“那对应图像损坏部分(i.e. )的 score 要怎么取呢?”
别忘了, score 终究是个向量,现在模型输出的 score 是对应整幅图(也就是向量 ) 的,而 是 的一部分,那么只需要将模型输出的 score 在对应维度上 “抠” 出来即可对应到 的部分。
纸上得来终觉浅,绝知此事要打码
花里胡哨讲了一大堆理论,始终不如打码来得实际。不是让你打马赛克,是写代码!
首先要明确输入形式:输入图像本身是完好的样本,不过额外有 mask 来指示图片中哪些位置对应是完好无损的、哪些是缺失(被破坏)的。其中,mask 中值为1的部分代表完好的部分;而值为0的部分则代表缺失部分,同时也是我需要重建的目标。
OK,现在来说说 inpainting 的实现流程。这是个有条件的采样过程,既然是采样过程,那么在上一篇文章中与大家打过照面的那些采样方法都可以拿来这里打辅助。考虑到 PC Sampling 比较有特色,于是 pick 它来助攻一波叭~
整体 workflow 与 naive PC sampling 一致,从先验开始,交替使用 corrector & predictor。在最后一步采样时,可以有选择性地根据 Tweedie's formula 去噪——不加上高斯噪声项,即返回最后一步采样结果的均值。
不过,这里需要预先对 corrector & predictor 进行改造,以实现 inpainting 功能,实质是在原有 corrector & predictor 的基础上实施后处理,因为利用它们是为得到 (当然,不是直接能得到,而是要在它俩的输出结果上施加 mask 才可);而同时我们还需要 ,这需要额外从 SDE class 的 p_0t() 方法即 中采样,然后再叠加上 mask 才可获得(因为 代表的是原图完好无损的部分)。所以,后处理就是对 corrector/predictor 的输出结果施加 mask 得到 后,接着在 中采样并施加 mask 得到 ,最后将两者进行拼接得到新一轮的 (作为该轮采样结果)。
这里特别需要注意的是先验的设置,根据以上的推导与公式可能会以为先验是 (标准高斯噪声) 与 的拼接。但更 “聪明” 的方式应该是将后者改为 ,即原图完好无损的部分。这样做的理由是: 一方面它可以作为引导条件,把握住了大方向以避免随机生成(否则从一整幅纯噪声开始谁知道它最终会出来个什么玩意儿..); 另一方面,这样也符合 “先验”的意思一一我们预先知道/确信的东西。既然我们都提前知道了最终要生成的图像中的某些部分了(也就是原图本身完好无损的部分),何不干脆利用起来作为一种约束/正则。
from typing import Callable
@torch.no_grad()
def pc_impainting(
sde: SDE, image: Tensor, mask: Tensor,
corrector_fn: Callable, predictor_fn: Callable,
denoise: bool = True, eps: float = 1e-5
) -> Tensor:
def impaint_fn_wrapper(update_fn: Callable) -> Callable:
""" 对原有采样函数进行封装, 加入后处理以实现 impainting 功能. """
def _impaint_fn(x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
t = t.repeat(x.size(0)).to(x.device)
# 将 u(t) 输入模型去估计 score, 得到对应时间的采样结果和均值
x, x_mean = update_fn(x, t)
# $x(t) \sim p_0t(x(t)|x(0))$
image_t_mean, image_t_std = sde.p_0t(image, t)
image_t = image_t_mean + image_t_std[:, None, None, None] * torch.randn_like(image)
# x * (1. - mask) 即 z(t);
# image_t * mask 是 u(t) 的另一部分, 即 $\hat{\Omega}(x(t)) \sim p_t(\Omega(x(t)) | \Omega(x(0))=A)$
# 将两者拼起来得到新的 u(t)
x = x * (1. - mask) + image_t * mask
x_mean = x_mean * (1. - mask) + image_t_mean * mask
return x, x_mean
return _impaint_fn
# 将原始的 corrector & predictor 改为能够实现 inpainting 的功能
corrector_fn = impaint_fn_wrapper(corrector_fn)
predictor_fn = impaint_fn_wrapper(predictor_fn)
# SDE 的终止时刻及离散步数
T, N = sde.T, sde.N
# (B,C,H,W)
shape = image.shape
# 将原图像中完好的部分保留作为条件, 以引导模型能够利用上下文生成与原图相关的图像
x = sde.prior_sampling(shape).to(image.device) * (1. - mask) + image * mask
timesteps = torch.linspace(T, eps, N)
# PC sampling 过程
for t in timesteps:
x, x_mean = corrector_fn(x, t)
x, x_mean = predictor_fn(x, t)
# 如果 denoise=True 则使用 Tweedie's formula 去噪,
# 等价于返回采样结果的均值.
return x_mean if denoise else x
接着简单写一个测试流程,看看上面实现的 pc_impainting() 能不能正常玩起来。就搞个 CIFAR10 玩玩吧,拿其中一个 batch 看看效果,mask 就设计为将每张图的右半部分都遮住(涂黑)。
%matplotlib inline
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
def show_images(images: Tensor):
num = images.size(0)
images = images.clamp(0., 1.)
grids = make_grid(images, nrow=int(np.sqrt(num)))
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(grids.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()
test_bs = 64 #@param {'type': 'integer'}
cifar10_test = CIFAR10('.', train=False, download=True, transform=ToTensor())
data_loader = DataLoader(
cifar10_test, test_bs,
num_workers=2, pin_memory=True
)
images = next(iter(data_loader))[0].to(device)
show_images(images)
mask = torch.ones_like(images)
mask[:, :, :, 16:] = 0.
show_images(images * mask)
最后,把你在 CIFAR10 上训练后的模型(正常无条件训练即可)拿出来搬砖:
snr = 0.16
n_corrector_steps = 1
model.eval()
predictor_fn = ReverseDiffusionPredictor(sde, model).update_fn
corrector_fn = LangevinDynamicsCorrector(sde, model, snr, n_corrector_steps).update_fn
inpainted_images = pc_impainting(
sde, images, mask,
corrector_fn, predictor_fn,
denoise=True, eps=eps
)
show_images(inpainted_images)
Colorization: 创造多姿多彩的画面
看到 color 这个单词也就知道这玩的是什么花样了——上色呗!
在有了前面 imputation 的背景后,完成 colorization 就很巧妙了,因为它可以看成是 imputation 的 special case。
你仔细想想,通常黑白的灰度图和 RGB 彩图在结构上最大的区别是什么?揭晓答案——前者是单通道,而后者是三通道。那么,其实可以将灰度图看成是三通道都相等的图像,只不过三个通道耦合在了同一个向量空间。于是,不妨对灰度图进行解耦,以将各通道分离至独立的空间中。
可以将其中一个通道视为图像的已知部分(对应于前面 imputation 中图像完好无损的部分),同时将另外两通道视为未知部分(对应于前面 imputation 中被涂成乌漆嘛黑的图像部分)。我们的目标是去重建未知部分,使得灰度图变成彩图。当然,最后要将图像还原(耦合)回原来的图像空间中。
至于如何对灰度图进行通道解耦,则可以考虑使用正交(orthogonal)矩阵以实施正交线性变换,从而使得通道之间处于相互正交的图像空间中。
在实验中,作者使用的正交矩阵如下:
手动上色:代码实现
现在,是时候动起双手来实现图像上色了。
逻辑基本上与前面 imputation 的类似,只不过这里需要额外实现正交线性变换,实质上就是矩阵相乘,可以使用 Pytorch 的爱因斯坦求和接口:torch.einsum()。另外,这里的 mask 根据输入图像制作,以第一个通道为已知部分,其余的通道视为未知的重建目标。
注意,以下输入的灰度图本身已是三通道(只不过三个通道对应空间位置的像素值均相等)。
@torch.no_grad()
def pc_colorize(
sde: SDE, gray_scale_img: Tensor,
corrector_fn: Callable, predictor_fn: Callable,
denoise: bool = True, eps: float = 1e-5
) -> Tensor:
# 解耦灰度图通道所用的正交矩阵
M = torch.tensor([
[5.7735014e-01, -8.1649649e-01, 4.7008697e-08],
[5.7735026e-01, 4.0824834e-01, 7.0710671e-01],
[5.7735026e-01, 4.0824822e-01, -7.0710683e-01]
])
# 以上正交矩阵的逆
inv_M = torch.inverse(M)
def transform(img: Tensor, op: str = "decouple") -> Tensor:
""" 对灰度图实施正交线性变换, 以实现解耦/耦合通道的效果 """
if op not in ["couple", "decouple"]:
raise ValueError(f"op must be 'couple' or 'decouple', got: {op}")
m = M if op == "decouple" else inv_M
return torch.einsum('bihw, ij -> bjhw', img, m.to(img.device))
# 为灰度图制作 mask, 将首个通道视作已知部分,
# 其余通道视为未知 impainting 目标
mask = torch.cat(
[torch.ones_like(gray_scale_img[:, :1, ...]),
torch.zeros_like(gray_scale_img[:, 1:, ...])],
dim=1
)
def colorize_fn_wrapper(update_fn: Callable) -> Callable:
""" 对原采样函数进行封装, 加入后处理以实现 colorize 功能. """
def _colorize_fn(x: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
t = t.repeat(x.size(0)).to(x.device)
x, x_mean = update_fn(x, t)
# $x(t) \sim p_0t(x(t)|x(0))$ 注意要先解耦
img_t_mean, img_t_std = sde.p_0t(transform(gray_scale_img, op="decouple"), t)
img_t = img_t_mean + img_t_std[:, None, None, None] * torch.randn_like(gray_scale_img)
# (1. - mask) 部分对应 z(t);
# mask 部分对应u(t) 的另一部分, 即 $\hat{\Omega}(x(t)) \sim p_t(\Omega(x(t)) | \Omega(x(0))=A)$
# 先在解耦空间中拼接两部分, 在耦合回去原空间
x = transform(
transform(x, op="decouple") * (1. - mask) + \
img_t * mask,
op = "couple"
)
x_mean = transform(
transform(x_mean, op="decouple") * (1. - mask) + \
img_t_mean * mask,
op = "couple"
)
return x, x_mean
return _colorize_fn
corrector_fn = colorize_fn_wrapper(corrector_fn)
predictor_fn = colorize_fn_wrapper(predictor_fn)
# SDE 的终止时刻及离散步数
T, N = sde.T, sde.N
# (B,C,H,W)
shape = gray_scale_img.shape
# 先验的设置:
# 先解耦到对应空间, 然后根据 mask 拼接原图和噪声部分, 最后再耦合回原空间.
x = transform(
transform(gray_scale_img, op="decouple") * mask + \
transform(sde.prior_sampling(shape).to(gray_scale_img.device), op="decouple") * (1. - mask),
op = "couple"
)
timesteps = torch.linspace(T, eps, N)
# PC sampling 过程
for t in timesteps:
x, x_mean = corrector_fn(x, t)
x, x_mean = predictor_fn(x, t)
# 如果 denoise=True 则使用 Tweedie's formula 去噪,
# 等价于返回采样结果的均值.
return x_mean if denoise else x
有了前面的 inpainting 基础,如今在这部分就没有太多可言了,需要小心的就是在采样过程中需要反复对图像进行解耦与耦合。
同样地,最后写一个测试流程把上面的 pc_colorize() 跑起来。
继续沿用前面 imputation 部分的 CIFAR10,并将图像的三个通道取均值作为灰度图,然后再复制出额外的两个通道,作为最终的三通道灰度图。
images = next(iter(data_iter))[0].to(device)
show_images(images)
gray_scale_img = torch.mean(images, dim=1, keepdims=True).repeat(1, 3, 1, 1)
show_images(gray_scale_img)
最后是模型的打工时间:
snr = 0.16
n_corrector_steps = 1
model.eval()
predictor_fn = ReverseDiffusionPredictor(sde, model).update_fn
corrector_fn = LangevinDynamicsCorrector(sde, model, snr, n_corrector_steps).update_fn
colorized_images = pc_colorize(
sde, gray_scale_img,
corrector_fn, predictor_fn,
denoise=True, eps=1e-5
)
show_images(colorized_images)
出师感言
CW 认为,Score-based SDE 是扩散生成模型史上里程碑式的方法之一,在目前的历史环境下,尽可能彻底搞明白它是非常有必要的!但同时它也是一道坎,在没有牢固的统计学与随机过程基础的情况下初次接触它很可能会破防,而且目前网上的解析都不太全面,导致跨过这道坎需要花费不少时间与精力。
正因如此,CW 才萌生了要写这个系列的 motivation,毕竟自己也是跌跌撞撞地走过来的。同样,个人能力有限,单纯就该系列的内容也许不足以全面覆盖 score-based SDE 的所有,但与网上友人们的解析还是可以形成互补的,CW 特意写下了网上难以挖掘到却又挺重要的一些细节,以助力大家全面攻陷这些知识点的防御门,同时也希望本系列能够帮助大家以更低的成本来跨过这道坎。

公众号后台回复“数据集”获取100+深度学习各方向资源整理
极市干货

点击阅读原文进入CV社区
收获更多技术干货

