大数跨境
0
0

深度解析Swin Transformer:架构与关键运算

深度解析Swin Transformer:架构与关键运算 ai算法芯片与系统
2025-12-05
4
导读:Swin Transformer 通过层次化架构与移位窗口注意力机制,解决了 ViT 在高分辨率图像处理中的效率和多尺度建模问题。本文解析图像块合并构建特征金字塔的方法,重点阐述移位窗口注意力及其高效

 

摘要

Swin Transformer 是计算机视觉领域的一项里程碑式工作。它通过引入 层次化架构 与 移位窗口注意力 机制,系统性解决了标准 Vision Transformer (ViT) 在处理高分辨率图像时面临的计算效率和多尺度建模瓶颈。本博客将深入剖析其核心设计:详细阐述图像块合并 (Patch Merging) 如何构建特征金字塔,并重点解析 移位窗口注意力 (Shifted Window based Self-Attention) 及其高效实现方案(包括循环移位注意力掩码)如何在不牺牲全局建模能力的前提下,将自注意力计算复杂度从 O(N²) 降至 O(N)。本文还将结合详细的代码片段、运算复杂度分析和张量操作图解,为读者提供一个从理论到实践的完整视角。

目录

  1. 1. 引言:ViT 的局限与 Swin 的突破
  2. 2. Swin Transformer 整体架构概览
  3. 3. 核心概念一:图像块合并 (Patch Merging)
    • • 3.1 动机:构建特征金字塔
    • • 3.2 工作流程与 SpaceToDepth 的联系
    • • 3.3 代码实现与运算解析
  4. 4. 核心概念二:窗口注意力与移位窗口注意力
    • • 4.1 窗口自注意力 (Windowed Self-Attention):从 O(N²) 到 O(N)
    • • 4.2 移位窗口 (Shifted Windows):实现跨窗口信息交互
    • • 4.3 高效批处理:循环移位与注意力掩码
  5. 5. Swin Transformer Block 实现精讲
    • • 5.1 结构总览:W-MSA 与 SW-MSA 的交替与代码共享
    • • 5.2 关键运算分解:reshapepermuteroll 与 mask
    • • 5.3 注意力计算核心:Batch MM 与带掩码的 Softmax
  6. 6. 总结

1. 引言:ViT 的局限与 Swin 的突破

Vision Transformer (ViT) 在2020年的提出,标志着纯Transformer架构在图像分类任务上取得了突破性成功。它摒弃了CNN固有的归纳偏置(如局部性和平移等变性),完全依赖自注意力机制来建立图像块(Patch)之间的全局依赖关系。然而,这种“暴力”的全局建模方式带来了两个在密集预测任务(如目标检测、语义分割)中尤为突出的挑战:

  1. 1. 二次方计算复杂度: 标准自注意力机制需要计算所有输入token对之间的关联。对于一个由 N 个token组成的序列,其计算和内存复杂度为 O(N²)。当处理高分辨率图像(N 很大)时,这带来了难以承受的计算负担。
  2. 2. 单一尺度特征图: ViT通常将输入图像分割为固定大小的块,并通过线性投影转换为token序列。在整个网络中,这个序列的长度保持不变,从而输出单一分辨率的特征图。这与卷积神经网络(CNN)中通过池化或跨步卷积自然形成的多尺度、金字塔式特征图结构截然不同,而后者被广泛证明对处理不同尺度物体至关重要。

Swin Transformer 的提出,正是为了系统性地解决以上问题,其设计哲学是在保持Transformer强大建模能力的同时,引入类似于CNN的归纳偏置来提升效率。其核心思路可概括为两点:

  • • 局部性与层次化: 通过在非重叠的局部窗口内计算自注意力,将计算复杂度降至线性;同时,通过图像块合并操作进行渐进式下采样,构建层次化的特征金字塔。
  • • 跨窗口连接: 通过移位窗口划分这一巧妙设计,在连续的Transformer Block中引入规则的、跨越窗口边界的注意力连接,从而在不增加额外计算开销的前提下恢复全局建模能力。

理解Swin Transformer,实质上就是理解它如何通过移位窗口注意力图像块合并这两个核心创新,在效率与性能、局部与全局之间找到了一个精妙的平衡点。

2. Swin Transformer 整体架构概览

Swin Transformer 的整体设计采用了经典的金字塔结构,这与现代CNN(如ResNet、VGG)的设计理念一脉相承,使其能够轻松融入现有的大部分视觉任务框架中作为骨干网络(Backbone)。整个网络通常分为4个阶段(Stage),每个阶段在空间分辨率上逐级递减,而在特征维度(通道数)上逐级递增。

每个阶段的核心构成如下:

  1. 1. 图像块处理层
    • • Stage 1: 图像块嵌入层 (Patch Embedding)。其功能与ViT相同:将输入图像(例如 224x224x3)分割为不重叠的图像块(如 4x4),将每个图像块展平并通过一个线性层投影到指定维度的token(例如 C=96)。输出token序列的空间维度为 (H/4, W/4)
    • • Stage 2, 3, 4: 图像块合并层 (Patch Merging)。这是实现下采样的关键模块。它通过合并相邻的 2x2 局部窗口的token,将空间分辨率降低至一半,同时将特征维度增加一倍(具体过程见第3章)。经过三次合并,最终特征图的分辨率降至输入图像的 1/32
  2. 2. Swin Transformer Block 堆叠
    每个阶段的主体是由多个 Swin Transformer Block 堆叠而成。例如,Swin-Tiny版本的配置为 [2, 2, 6, 2]这里最精妙的设计在于,这些Block是以“成对”的方式工作的。在一个“配对”中:
    • • 第一个Block使用常规窗口划分 (Window Multi-head Self-Attention, W-MSA)。它将特征图划分为规则的非重叠 MxM 窗口,并在每个窗口内独立进行自注意力计算。
    • • 第二个Block紧接着使用移位窗口划分 (Shifted-Window Multi-head Self-Attention, SW-MSA)。它将特征图在水平和垂直方向各循环移位 (M/2, M/2) 个像素后,再应用同样的非重叠窗口划分和窗口内自注意力计算。

这种W-MSA与SW-MSA交替出现的模式,是Swin Transformer的灵魂。它确保了信息不仅能在窗口内部高效聚合,还能通过移位操作,在相邻的Transformer Block之间,跨越窗口边界进行传递,从而逐步建立起整个图像范围的全局上下文信息。

为了直观展示Swin Transformer如何通过层次化下采样构建特征金字塔,并形成类似CNN的金字塔结构,下图清晰地展示了从输入图像到四个阶段(Stage)输出的完整流程:

Swin Transformer整体架构

上图展示了Swin Transformer的层次化架构。可以看到,输入图像经过“Patch Embedding”后进入Stage 1,随后通过三次“Patch Merging”操作,特征图的空间尺寸逐级减半(从H/4到H/32),而通道数逐级翻倍(从C到8C),形成了一个经典的特征金字塔。每个Stage内部由若干个Swin Transformer Block组成,负责进行特征变换。

3. 核心概念一:图像块合并 (Patch Merging)

3.1 动机:构建特征金字塔

在计算机视觉中,多尺度特征表示是处理不同大小物体的关键。CNN通过卷积和池化层天然地实现了这一点。Swin Transformer 的 图像块合并层 就是Transformer世界中对标CNN池化层或跨步卷积的组件。它的核心目标是:在降低特征图空间分辨率(增大感受野)的同时,增加特征维度(丰富语义信息),从而构建一个层次化的特征金字塔。这个金字塔的每一层都对应着不同尺度的视觉特征,非常适合作为目标检测(如FPN)、语义分割(如U-Net)等任务的输入。

3.2 工作流程与 SpaceToDepth 的联系

图像块合并层可以看作一个“空间到深度”的操作,后接一个线性变换。我们以一个具体例子说明:假设Stage 1的输出特征图尺寸为 (B, 56, 56, 96),其中 B 为批大小,56 为高和宽上的token数,96 是通道数 C

其工作流程如下表所示:

步骤
操作描述
数学变换
输出形状
目的
输入
来自上一阶段的特征图
-
(B, 56, 56, 96)
-
1. 重排与拼接
提取 2x2 邻域,并在通道维拼接
SpaceToDepth
 (block_size=2)
(B, 28, 28, 384)
下采样2倍,通道扩至4倍
2. 归一化
应用LayerNorm稳定训练
LayerNorm(4*C) (B, 28, 28, 384)
规范化特征
3. 线性投影
全连接层降低通道数
Linear(4*C -> 2*C) (B, 28, 28, 192)
可控地增加维度至2C

这个过程与TensorFlow的 tf.nn.space_to_depth 或PyTorch的 torch.nn.PixelUnshuffle 操作在思想上同源。它们都是将空间维度上 block_size x block_size 区域内的像素或特征,“打包”到通道维度上。区别在于,标准的 SpaceToDepth 不改变总信息量(H*W*C -> (H/2)*(W/2)*(C*4)),而 Patch Merging 在“打包”之后,额外引入了一个可学习的线性投影。这个投影不仅可以将通道数从 4C 灵活地调整到 2C(而非固定的 4C),更重要的是,它允许模型学习如何最有效地组合来自四个相邻位置的特征信息,这比简单的拼接或池化更具表现力。

下图通过一个 4x4 网格的Excel表格类比,生动地展示了 2x2 邻域的提取过程:

图像块合并层操作

上图使用Excel表格直观解释了Patch Merging的第一步。一个4x4的网格被四种符号标记。每种符号代表一组特定的采样点:✅从(0,0)开始步长为2采样,❌从(1,0)开始,⚫从(0,1)开始,⬛从(1,1)开始。将这四组数据在通道维度拼接,就实现了将空间上2x2的区域合并到一个拥有4倍通道数的特征向量中,空间尺寸随之减半。

为了更清晰地理解“空间到深度”这一核心思想,下图对比了PyTorch中的PixelUnshuffle、ONNX中的SpaceToDepth操作与Swin Transformer中PatchMerging的相似性。它们都遵循着同一套将空间维度重排到通道维度的逻辑。

图像块合并和torch的UnShuffle以及onnx的SpaceToDepth在变换上有相似原理

上图通过示意图对比了PixelUnshuffle/SpaceToDepthPatchMerging的相似性。左半部分展示了标准操作:它将一个[C, H, W]的输入张量(这里H=W=4),通过将每个2x2空间块(共4个元素)展平并排列到通道维度,转换为[C*4, H/2, W/2]的输出。右半部分示意了PatchMerging的后续步骤:在完成上述空间到通道的重排后,还会接一个线性层(图中以“Linear”表示),将通道数从4C投影到2C,实现可控的下采样和特征融合。

通过下图,我们可以将Swin Transformer的层次化下采样过程与CNN的经典操作进行类比。Patch Merging起到了类似卷积网络中池化层或跨步卷积的作用,是构建视觉特征金字塔的关键步骤。

Swin Transformer通过图像块合并实现类似CNN的下采样

上图形象地说明了Swin Transformer如何通过Patch Merging模拟CNN的下采样功能。左侧代表CNN的典型操作:一个特征图经过池化层(如MaxPool)或跨步卷积(Strided Conv)后,空间尺寸(H, W)缩小,通道数可能增加。右侧代表Swin Transformer的操作:特征图经过Patch Merging层,通过提取并合并2x2邻域的特征,同样实现了空间尺寸减半、通道数增加(通过后续的线性投影)的效果,从而在Transformer架构中构建了多尺度特征表示。

3.3 代码实现与运算解析

让我们深入 PatchMerging 模块的代码,看上述思想如何转化为PyTorch张量操作。


   
    
   class PatchMerging(nn.Module):
    def
 __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super
().__init__()
        self
.input_resolution = input_resolution
        self
.dim = dim # 输入通道数 C
        self
.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) # 线性投影: 4C -> 2C
        self
.norm = norm_layer(4 * dim) # 在4C维上归一化

    def
 forward(self, x):
        """
        x: 输入张量,形状为 (B, H, W, C)
        输出: 形状为 (B, H/2 * W/2, 2*C)
        """

        B, H, W, C = x.shape
        # 步骤1:利用切片操作实现高效的2x2邻域提取

        # 这些切片操作是跨步的,直接跳步采样,没有冗余计算

        x0 = x[:, 0::2, 0::2, :]  # 从(0,0)开始,步长为2, 形状 (B, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, :]  # 从(1,0)开始,步长为2, 形状 (B, H/2, W/2, C)
        x2 = x[:, 0::2, 1::2, :]  # 从(0,1)开始,步长为2, 形状 (B, H/2, W/2, C)
        x3 = x[:, 1::2, 1::2, :]  # 从(1,1)开始,步长为2, 形状 (B, H/2, W/2, C)

        # 在最后一个维度(通道维,dim=-1)进行拼接

        x = torch.cat([x0, x1, x2, x3], dim=-1)  # 输出形状: (B, H/2, W/2, 4*C)

        # 将空间维度展平,以适应后续的线性层 (期望输入为 (B, seq_len, features))

        x = x.view(B, -1, 4 * C)  # 形状: (B, (H/2)*(W/2), 4*C)

        # 步骤2:应用层归一化和线性投影

        x = self.norm(x)
        x = self.reduction(x)  # 形状: (B, (H/2)*(W/2), 2*C)
        return
 x

关键张量操作解析

  • • x[:, 0::2, 0::2, :]: 这是NumPy/PyTorch中强大的切片语法。0::2 表示从索引0开始,到结束,步长为2。它高效地、无循环地实现了下采样操作,是 SpaceToDepth 逻辑的直观代码体现。
  • • torch.cat(..., dim=-1): 沿着通道维度拼接,这是将空间信息“折叠”到通道维的关键一步。
  • • view(B, -1, 4*C): -1 表示该维度由其他维度推断得出。这里将二维的空间网格 (H/2, W/2) 展平为一维的序列,是切换回Transformer类模型常用的 (B, N, C) 格式的必要步骤。

4. 核心概念二:窗口注意力与移位窗口注意力

4.1 窗口自注意力:从 O(N²) 到 O(N)

标准Transformer的自注意力机制之所以计算昂贵,是因为它需要计算一个 N x N 的注意力矩阵(N为序列长度),这导致了 O(N²) 的复杂度。在图像任务中,N = H * W,即使对于中等分辨率的特征图,N 也很大。

Swin Transformer 的 窗口自注意力 提供了一个优雅的解决方案:将全局计算限制在局部窗口内。具体而言,它将 H x W 的特征图均匀地划分为多个大小为 M x M 的不重叠窗口(为简化,假设 H 和 W 可被 M 整除)。然后,自注意力计算独立地在每个窗口内进行。

复杂度定量分析
对于一个 h x w 的token矩阵,其特征维度为 C

  • • 全局多头自注意力 (MSA) 的复杂度为:
    Ω(MSA) = 4hwC² + 2(hw)²C
    其中,4hwC² 来自生成Q、K、V的线性投影,2(hw)²C 来自注意力矩阵的计算((hw)² 是核心问题)。
  • • 基于窗口的多头自注意力 (W-MSA) 的复杂度为:
    Ω(W-MSA) = 4hwC² + 2M²hwC
    第一项线性投影开销不变。第二项中,由于每个窗口有  个token,每个窗口的计算复杂度为 O((M²)²) = O(M⁴)。总共有 (hw / M²) 个窗口,因此总复杂度为 (hw / M²) * M⁴ * C = 2M²hwC

比较: 当窗口大小 M 固定时(论文中默认设为7),W-MSA的复杂度与token数量 hw 呈线性关系,而MSA是二次方关系。这使得Swin Transformer能够高效处理高分辨率图像。

下图直观对比了两种注意力机制的计算范围差异,这是理解Swin Transformer效率提升的关键:

标准 ViT(左)与 Swin Transformer(右)的自注意力操作对比

上图对比了标准ViT(左)和Swin Transformer(右)的自注意力计算范围。在标准ViT中,每个token(灰色方块)需要与图像中所有其他token计算注意力(红色大框所示),导致O(N²)复杂度。在Swin Transformer中,特征图被划分为不重叠的窗口(红色小框),每个token仅与同一窗口内的其他token计算注意力,将复杂度降至O(N)

4.2 移位窗口:实现跨窗口信息交互

尽管窗口注意力极大地提升了效率,但它也引入了一个新问题:一个窗口内的token无法与其他窗口的token直接交互。这相当于在每个窗口周围筑起了“信息高墙”,限制了模型的全局建模能力。

为了解决这个问题,Swin Transformer 提出了 移位窗口注意力。其策略不是拆掉“墙”,而是在连续的层中巧妙地“开门”。具体而言,它在两个连续的Swin Transformer Block中使用不同的窗口划分方式:

  • • 第 l 个 Block: 使用常规窗口划分,起始于 (0, 0) 坐标。
  • • 第 l+1 个 Block: 使用移位窗口划分,将特征图在水平和垂直方向各循环移位 (⌊M/2⌋, ⌊M/2⌋) 个像素后,再应用常规的 M x M 窗口划分。

这个设计的精妙之处在于:第 l 个 Block中相邻窗口的边界区域,在第 l+1 个 Block中会因为移位而被包含进同一个新窗口内部。这样,通过一层常规窗口注意力加一层移位窗口注意力的组合,原本被隔开的两个窗口的token就能在下一层进行直接的注意力交互。多层堆叠后,信息理论上可以传播到特征图的任何位置。

下图通过一个 8x8 特征图(M=4)的Excel示意图,展示了移位操作如何创建新的窗口连接模式:

移位操作

上图以Excel表格形式演示了移位窗口的机制。左侧为“常规划分”,特征图被划分为4个4x4的非重叠窗口(A, B, C, D)。右侧为“移位划分(2,2)”,即将特征图向左上角各移动2格后,再进行同样的4x4窗口划分。可以看到,新形成的窗口(编号0-8)包含了来自原始不同窗口(A, B, C, D)的区域。例如,新窗口4包含了原始窗口A的右下角、B的左下角、C的右上角和D的左上角。这种机制使得原本不连通窗口的token在下一层得以交互。

4.3 高效批处理:循环移位与注意力掩码

直接实现上述移位划分会带来一个工程上的难题:窗口数量增加且大小不一致。例如,对于 8x8 的图,常规划分有4个窗口;移位 (2,2) 后, naive的划分会产生9个较小的、尺寸不一的窗口。这会导致无法进行整齐的批处理计算,降低GPU利用率。

Swin Transformer 采用了一个堪称经典的解决方案:循环移位 (Cyclic Shift) + 注意力掩码 (Attention Mask)。该方案的目标是:在保持批处理窗口数量不变且尺寸统一的前提下,模拟移位窗口划分的注意力模式。

其步骤如下,下图完整展示了这一过程:

Swin 中移位窗口块的示意图

上图详细说明了高效批处理方案“循环移位+注意力掩码”的原理。该图展示了一个8x8特征图的处理过程:(a) 原始特征图;(b) 应用(2,2)循环移位后,像素被“卷”到对侧,形成一个视觉上错位但尺寸不变的图;(c) 对移位后的图应用常规4x4窗口划分,此时每个窗口内包含了来自原始图中不相邻区域的像素。图中用数字0-8标记了在理想移位划分下应属于同一窗口的区域。可以看到,实际划分出的窗口中混合了不同数字(即不同区域)的像素。(d) 为了在计算注意力时屏蔽掉不属于同一区域的连接,需要引入注意力掩码。计算完成后,再执行反向循环移位(e),将特征图恢复原状。这个过程在数学上等价于直接进行移位窗口划分,但实现了规则的批处理。

  1. 1. 循环移位: 对输入特征图应用 torch.roll(x, shifts=(-M//2, -M//2), dims=(1, 2))。这会将左上角 M//2 区域的数据“卷”到右下角,形成一个看似“错位”但尺寸不变的 8x8 新图。
  2. 2. 常规窗口划分: 对移位后的特征图应用常规的 MxM 非重叠划分。此时,我们仍然得到4个 4x4 的规则窗口,非常适合批处理。关键点在于:现在每个窗口内部包含了来自原始特征图中不相邻区域(即本应属于不同移位窗口)的token。
  3. 3. 带掩码的注意力计算: 为了阻止那些在原始移位划分方案中不属于同一窗口的token之间进行注意力计算,需要引入一个预计算的二进制注意力掩码。在计算注意力权重(Softmax(QK^T/√d))之前,将这个掩码加到 QK^T 的结果上。掩码在允许连接的位置设为0,在需要屏蔽的位置设为一个很大的负数(如 -100)。这样,在后续的Softmax中,被屏蔽位置的权重就会趋于零。
  4. 4. 反向循环移位: 完成窗口注意力计算后,对输出特征图执行反向的 torch.roll(x, shifts=(M//2, M//2), dims=(1, 2)),将数据移回原始位置。

掩码生成原理: 在模型初始化时,对于给定的输入分辨率和窗口/移位大小,就可以预先计算好一个 attn_mask

  1. 1. 创建一个与移位后特征图同尺寸的整数矩阵 img_mask,并根据理想的移位窗口分区,为每个区域分配一个唯一的ID。例如,上图(c)中,左上角窗口ID为0,右上角窗口ID为1,左下角为2,右下角为3,中间四个小区域ID分别为4,5,6,7,边缘条带ID为8。
  2. 2. 将这个 img_mask 用 window_partition 函数划分为窗口,得到形状为 (nW, M, M) 的窗口掩码。
  3. 3. 通过广播操作计算 attn_mask = window_mask.unsqueeze(1) - window_mask.unsqueeze(2),得到一个形状为 (nW, M*M, M*M) 的矩阵。在这个矩阵中,同一区域相减为0,不同区域相减不为0
  4. 4. 最后,将 attn_mask 中不等于0的位置设为 -100,等于0的位置设为 0。这个掩码将在所有批处理数据中共享。

5. Swin Transformer Block 实现精讲

5.1 结构总览:W-MSA 与 SW-MSA 的交替与代码共享

一个 Swin Transformer Block 的结构遵循了Transformer编码器的经典设计,即“注意力-前馈网络” (Attention + FFN) 的子层结构,并配有残差连接 (Residual Connection) 和层归一化 (Layer Normalization)。其与标准Transformer块最主要的区别,就是用 W-MSA 或 SW-MSA 模块替换了标准的 MSA 模块。

一个关键的设计亮点是,常规窗口(W-MSA)和移位窗口(SW-MSA)的Block共享同一套代码实现。它们都是同一个 SwinTransformerBlock 类的实例,区别仅在于初始化时传入的 shift_size 参数:当 shift_size=0 时,该Block执行常规窗口注意力;当 shift_size=window_size//2 时,该Block执行移位窗口注意力。这种设计极大提高了代码的复用性和简洁性。

下图清晰地展示了这两个连续Block的结构与数据流,它们是构成Swin Transformer Stage的基本单元:

两个连续的 Swin Transformer Block

上图展示了构成Swin Transformer核心的连续两个Block。第一个Block使用常规窗口划分(W-MSA),其“Window Partition”模块将输入特征图划分为规则的非重叠窗口。第二个Block使用移位窗口划分(SW-MSA),在“Window Partition”前加入了“Shifted Window”操作,即循环移位。两个Block都包含层归一化(LN)、窗口多头自注意力(Window Attention)、残差连接(粉色箭头)和多层感知机(MLP)。正是这种“W-MSA + SW-MSA”的交替配对,使得模型既能高效计算,又能建立跨窗口的全局依赖。

5.2 关键运算分解:reshapepermuteroll 与 mask

1. 窗口划分 (window_partition) 与还原 (window_reverse)
这两个函数是连接常规张量和批处理窗口视图的桥梁。


   
    
   def window_partition(x, window_size):
    """
    将特征图划分为非重叠窗口。
    输入: x (B, H, W, C)
    输出: windows (B * num_windows, window_size, window_size, C)
    """

    B, H, W, C = x.shape
    # 1. 添加窗口网格维度: 将H和W分别分解为【窗口数量】和【窗口大小】

    # 结果形状: (B, H//M, M, W//M, M, C)

    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    
    # 2. 维度置换: 将窗口索引维度 (H//M, W//M) 移到批次维度附近

    # 目标: (B, H//M, W//M, M, M, C)

    # permute(0, 1, 3, 2, 4, 5) 将第2维(M)和第3维(W//M)交换

    windows = x.permute(0, 1, 3, 2, 4, 5)
    
    # 3. 合并前三维 (B, num_windows_h, num_windows_w) -> (B * num_windows)

    # contiguous()确保内存连续, view操作安全

    windows = windows.contiguous().view(-1, window_size, window_size, C)
    return
 windows

def
 window_reverse(windows, window_size, H, W):
    """
    window_partition的逆操作。
    输入: windows (B * num_windows, M, M, C)
    输出: x (B, H, W, C)
    """

    B = windows.shape[0] // ((H * W) // (window_size * window_size))
    # 1. 将窗口视图重塑回 (B, H//M, W//M, M, M, C)

    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    # 2. 维度置换,恢复空间顺序: (B, H//M, W//M, M, M, C) -> (B, H//M, M, W//M, M, C)

    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    # 3. 合并最后的两个M维度,恢复(H, W)

    x = x.view(B, H, W, -1)
    return
 x

运算意义permute 是这里的关键。在 window_partition 中,它把代表窗口内部空间的 M 维度挪到一起;在 window_reverse 中,它再把它们放回 H 和 W 维度中间的正确位置。这些操作都是 O(1) 复杂度的,仅改变张量的步长(stride)信息,而不实际移动数据。

2. 循环移位 (torch.roll)


   
    
   # 在forward中
if
 self.shift_size > 0:
    # 移位: 将特征图向左上角循环移动

    shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else
:
    shifted_x = x
# ... 经过窗口注意力和还原后 ...

if
 self.shift_size > 0:
    # 反向移位: 移回原处

    x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

torch.roll 是PyTorch提供的循环移位函数。shifts=(-2, -2) 表示在第1维(H)和第2维(W)上分别向上和向左移动2个单位。移出边界的元素会从另一端重新出现。这是实现“无缝”跨窗口连接且不丢失任何信息的数学基础。

5.3 注意力计算核心:Batch MM 与带掩码的 Softmax

在 WindowAttention 模块内部,注意力权重的计算是核心。

1. 批处理矩阵乘法 (torch.bmm 或 @)
经过 window_partition 和展平后,输入张量形状为 (B*nW, M*M, C),其中 B*nW 可以被视为新的批大小。注意力机制中的 QKV 投影和 QK^T 计算都利用批处理矩阵乘法高效完成。


   
    
   # 假设 q, k 的形状为 (batch, seq_len, head_dim)
attn = (q @ k.transpose(-2, -1)) * self.scale  # @ 操作符执行批处理矩阵乘法

这里,@ 操作符(或等价的 torch.matmul)会自动在 batch 维度上进行广播,一次性计算出所有窗口、所有注意力头的 QK^T 结果。这比在循环中逐个窗口计算要高效数个数量级,充分利用了GPU的并行计算能力。

2. 带掩码的 Softmax
对于 SW-MSA,需要应用预计算的注意力掩码 attn_mask


   
    
   if mask is not None:
    # mask 形状通常为 (nW, M*M, M*M) 或 (1, nW, 1, M*M, M*M)以便广播

    # mask 中,允许连接处为0,需要屏蔽处为一个很大的负数(如 -100)

    attn = attn + mask

attn = self.softmax(attn)  # 沿最后一个维度做Softmax
attn = self.attn_drop(attn)

掩码原理详解: mask 不是二进制的 0/1 矩阵,而是一个“惩罚”矩阵。在需要屏蔽的位置,mask 值为一个极大的负数(-100)。当这个 mask 与注意力分数 attn(在乘以 scale 后,数值通常不大)相加时,被屏蔽位置的分数会变得非常小。随后,Softmax 函数会对每一行进行指数归一化。由于 exp(-100) 近乎为0,这些被屏蔽位置对应的注意力权重在输出中也就基本为0,从而实现了精确的、可微的屏蔽效果。

完整前向传播流程整合


   
    
   def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert
 L == H * W, "input feature has wrong size"
    shortcut = x

    x = self.norm1(x)
    x = x.view(B, H, W, C)

    # 1. 循环移位 (仅在SW-MSA时进行,即shift_size > 0)

    if
 self.shift_size > 0:
        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    else
:
        shifted_x = x

    # 2. 窗口划分

    x_windows = window_partition(shifted_x, self.window_size)  # (nW*B, M, M, C)
    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # (nW*B, M*M, C)

    # 3. 窗口注意力 (W-MSA/SW-MSA)

    # self.attn 内部包含了QKV投影、带掩码的Batch MM、Softmax和输出投影

    # 当shift_size=0时,传入的mask为None,即为W-MSA。

    # 当shift_size>0时,传入预计算的attn_mask,即为SW-MSA。

    attn_windows = self.attn(x_windows, mask=self.attn_mask)  # (nW*B, M*M, C)

    # 4. 窗口还原

    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # (B, H, W, C)

    # 5. 反向循环移位 (仅在SW-MSA时进行)

    if
 self.shift_size > 0:
        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    else
:
        x = shifted_x

    x = x.view(B, H * W, C)

    # 6. FFN (前馈网络)

    x = shortcut + self.drop_path(x)  # 第一个残差连接
    x = x + self.drop_path(self.mlp(self.norm2(x)))  # 第二个残差连接
    return
 x

6. 总结

Swin Transformer 通过两项核心创新,成功地将 Transformer 架构高效地适配于广泛的视觉任务,并在效率与性能之间取得了卓越的平衡:

  1. 1. 层次化架构与图像块合并: 通过引入 图像块合并层,Swin Transformer 构建了与CNN相似的特征金字塔,使其能够直接作为骨干网络,无缝替换现有目标检测、分割等框架中的CNN。其设计与 SpaceToDepth 思想一脉相承,但通过可学习的线性投影增强了表达能力。
  2. 2. 移位窗口注意力
    • • 效率突破: 窗口自注意力 (W-MSA) 将全局计算的复杂度从 O(N²) 降至 O(N),使Transformer处理高分辨率图像成为可能。
    • • 效能保障: 移位窗口注意力 (SW-MSA) 通过 循环移位 (torch.roll) 与 注意力掩码 相结合的巧妙设计,在几乎不引入额外计算开销的前提下,恢复了跨窗口的信息流,保障了模型的全局建模能力。其高效实现深度依赖于 批处理矩阵乘法 (@)带掩码的 Softmax 以及通过 reshape/permute 实现的张量视图变换。
    • • 优雅实现: 常规与移位两种注意力模式通过同一个 SwinTransformerBlock 类实现,仅由 shift_size 参数控制,体现了出色的代码设计。

Swin Transformer 的设计深刻影响了后续的视觉Transformer研究。其“局部注意力+层次化结构”的设计范式被众多模型所借鉴,证明了在视觉任务中,将Transformer的全局动态建模能力与适当的、来源于图像的归纳偏置相结合,是一条行之有效且前景广阔的技术路径。理解Swin Transformer,不仅是掌握了一个强大的模型,更是把握了视觉Transformer领域一个关键的设计思想脉络。

 


【声明】内容源于网络
0
0
ai算法芯片与系统
长期关注ai领域,算法,芯片,软件(系统,框架,编译器,算子库)等联合设计
内容 196
粉丝 0
ai算法芯片与系统 长期关注ai领域,算法,芯片,软件(系统,框架,编译器,算子库)等联合设计
总阅读118
粉丝0
内容196