引言
本文提出了SparseDiT,一种通过空间 (Model Structure) 和时间维度 (Timestep) 的token稀疏化来提高Diffusion Transformer(DiT)计算效率的新框架。
在空间维度上,SparseDiT采用三段式架构:底层使用Poolingformer进行高效全局特征提取,中层利用Sparse-dense generation token model (SDTM)平衡全局与局部的特征,顶层则采用密集token来提炼高频细节。
在时间维度上,SparseDiT动态调整去噪阶段中的token数量,随着时间步 (Timestep)进逐渐增加token以在实现高效性的同时捕捉细节。这种空间和时间适应性策略提高了计算效率,同时保持生成质量。实验证明SparseDiT在图片生成、视频生成和文生图等多个生成任务上取得优异表现,例如在512x512分辨率上的图片生成任务中,在减少FLOPs(55%)和提升推理速度(175%)的同时,仍能保持接近的生成质量。
论文地址:https://arxiv.org/pdf/2412.06028
代码地址:https://github.com/changsn/SparseDiT
现存问题及挑战介绍
尽管Diffusion Transformer (DiT)在生成性能上表现出色,但其在自注意力和采样步骤上的高计算复杂度限制了其在实际应用中的广泛性。大多数现有方法通过加速采样过程来降低复杂性,但忽略了DiT自身结构上的效率问题。
与U-Net相比,DiT在token级别的自注意力引入了更多计算开销,因此需要设计针对DiT的创新方法来智能管理token密度,从而实现效率和生成质量之间的平衡。
工作价值以及方法介绍
工作价值
当前工作通过SparseDiT方法有效解决了DiT模型中的计算效率问题。SparseDiT通过动态token稀疏化策略,不仅减少了模型计算复杂度,同时在多个生成任务上保持了高质量的生成性能。此外,该方法在多个实验验证中均显示出减少了的计算量(如FLOPs),并显著提高了推理速度,这对于大规模应用和实际部署尤为重要。因此,该工作为高质量、高效率的扩散模型提供了可扩展的解决方案。
方法介绍
SparseDiT的核心创新在于通过空间和时间维度对token进行稀疏化,以提升扩散模型的计算效率,同时维持生成质量。该方法的设计可以分为两个主要部分:空间上的token密度管理和时间上的timestep-wise修剪策略。
空间维度:三段式架构
底层架构:Poolingformer
在底层Transformer,SparseDiT使用Poolingformer结构替换传统自注意力机制以捕获全局特征。实验发现,底层自注意力的复杂运算并不能带来额外信息,反而可以通过全局平均pooling实现效率的提高。Poolingformer通过删除查询和键操作,直接对值进行全局平均池化,集成到输入token中,从而减少计算开销。
上图实验显示,不经过任何finetuning的情况下,直接将注意力层替换为全局pooling,不会对画面造成很大的影响,说明底层的注意力层产生的效果有限。
中层架构:Sparse-dense token module
中层结构采用Sparse-dense token module (SDTM) 技术,将表示过程分为全局结构提取和局部细节增强。Sparse token负责全局结构信息的捕获,有效降低计算成本,而dense token则用于细节增强,稳定训练过程。SDTM通过交互注意力层实现sparse token与dense token的相互转换,其中sparse transformer处理sparse token,dense transformer处理由sparse token恢复而来的dense token,同时实现了信息的保留与算力的节省。
顶层架构:标准Transformer
在顶层,SparseDiT继续使用标准transformer层,以dense token处理模式专注于高频细节的提炼,确保生成质量。
时间维度:Time-wise pruning rate
动态Time-wise pruning rate是SparseDiT的另一个关键创新,旨在随着去噪的进行而动态调整token密度。具体来说:
早期阶段:
在早期去噪阶段,由于以低频结构为主,SparseDiT应用较高的剪枝率来保存计算资源。此时,计算复杂度较低,节约了不必要的token操作。
后期阶段:
随着去噪阶段的推进,渐进式减少剪枝率以增加token密度,确保高频细节能够被准确捕获。此时,token需求逐步增加,反映了对细节需求的增长。
通过这种时空双重适应性策略,SparseDiT在保持生成细节的同时,大幅度提高了计算效率,表现在通过减少FLOPs和加速推理速度。
实验
在论文中,SparseDiT在以下三个生成任务中进行了实验并取得了显著效果:Class-conditional image generation、Class-conditional video generation和Text-to-image generation。
Class-conditional image generation实验设置
在256×256分辨率下,SparseDiT-XL实现了43%的FLOPs减少,与87%的推理速度提升,同时FID分数仅增加了0.11。这表明,即使只使用约25%的tokens,依然能保持类似的性能。在512×512分辨率条件下,SparseDiT在高剪枝率情况下表现出更优质的性能-效率trade-off,通过剪枝超过90%的tokens,得到55%的FLOPs减少及175%的速度提升,FID分数仅增加了0.09。这些结果证明SparseDiT解决了DiT架构中的计算负担问题,并在保持性能质量同时带来了显著的计算效率提升。
Class-conditional video generation实验设置
在FaceForensics、SkyTimelapse、UCF101和Taichi-HD等四个公众数据集上进行,分辨率为256×256。SparseDiT在视频数据的额外时间维度上应用了更高的剪枝率,达到了FLOPs减少56%的效果,同时保持了竞争性的FVD评分,证明其在视频生成任务上的有效性。视频生成任务的结果展示了SparseDiT不仅在空间维度上应用有效,在时间维度上也能够显著提升效率。
Text-to-image generation实验设置
使用PixArt-α模型为基础模型,采用SAM数据集进行训练与评估,进行文本到图像生成。SparseDiT在该任务上达到了与原始PixArt-α模型相当的FID分数,同时显著加快了生成速度,显示出方案在文本到图像生成任务中的有效性。
|往期内容回看



