摘要:通用矩阵乘法(GEMM)在GPU上的优化是一个模块化问题。高性能实现需要指定诸多超参数,例如分块形状、数学与拷贝指令以及线程束专用化方案。这些超参数在很大程度上相互独立;并且,最佳选择可能因硬件、问题形状或其他用户需求而有显著差异。通过3.x版本的重构,CUTLASS旨在通过一个分层、可组合、正交的构建块系统,最大化地覆盖GEMM实现的配置空间,同时提升代码可读性并扩展对Hopper和Blackwell等后续NVIDIA架构的支持。本文将深入探讨CUTLASS 3.x中层次化GEMM系统的设计原则,并解析其如何基于第一篇介绍的底层CuTe抽象构建GEMM核函数。
目录
-
• CUTLASS 3.x中的新概念性GEMM层次结构 -
• Collective层:主循环 -
• Collective层:结尾 -
• 核函数(Kernel)层 -
• 设备(Device)层 -
• 总结
CUTLASS 3.x中的新概念性GEMM层次结构
CUTLASS 3.x发展出了一套独立于特定硬件特性的概念性GEMM层次结构。它被组织为五个层级:
该结构由内向外,从硬件指令抽象逐步组合为完整的设备端核函数:
以下是调整后的表格,每行内容限制在20个字符以内:
|
|
|
|
|---|---|---|
| 原子(Atom)层 | cute::Mma_Atom<><br>
cute::Copy_Atom<>
|
及其关联的元信息。 这是与硬件直接对话的 最小单元。 |
| 分片MMA/拷贝(Tiled MMA/Copy)层 | cute::TiledMma<>
cute::TiledCopy<>
|
空间上的微核函数
允许对架构特定的原子操作 进行任意的交错排列和分块, 以在多个线程/线程束间 分配工作。 |
| 集合(Collective)层 | cutlass::gemm::collective::CollectiveMma<>
cutlass::epilogue::collective::CollectiveEpilogue<>
|
时间上的微核函数
利用架构特定的同步原语 来编排一个或多个空间微核 函数的执行,以计算单个 输出分块。通常对应一个 线程块或集群。 |
| 核函数(Kernel)层 | cutlass::gemm::kernel::GemmUniversal<> |
设备端代码
用于在覆盖整个问题空间的 线程块/集群网格上执行 核函数。负责分块调度和 核函数入口逻辑。 |
| 设备(Device)层 | cutlass::gemm::device::GemmUniversalAdapter<> |
主机端设置与接口
负责核函数参数管理、 工作空间分配、启动配置 以及可复用句柄的提供。 |
每一层都作为前一层抽象的组合点,这些抽象可以通过模板参数进行高度定制。用户既可以停留在最高层,信赖CUTLASS的编译时逻辑来提供高性能的GEMM实现;也可以深入到较低层次,使用该层次暴露的高级修改能力。
由原子层和分片MMA/拷贝层提供的空间微核函数是CuTe的领域(在系列第一部分已讨论)。本文将重点介绍由更高层级实现的时间组织和核函数级编排。
以下是一个在CUTLASS 3.x中定义GEMM核函数的基本示例,它清晰地展示了自底向上的组合过程:
// 步骤 1:使用Collective Builder生成所需的主循环集合体特化
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, // 架构与运算类(如Sm90, TensorOp)
ElementA, LayoutA, AlignmentA, // 矩阵A的数据类型、内存布局、对齐
ElementB, LayoutB, AlignmentB, // 矩阵B的数据类型、内存布局、对齐
ElementAccumulator, // 累加器数据类型
TilesShape, ClusterShape, // 分块形状与集群形状
cutlass::gemm::collective::StageCountAuto, // 流水线阶段数(自动)
cutlass::gemm::collective::KernelScheduleAuto // 核函数调度策略(自动)
>::CollectiveOp;
// 步骤 2:指定结尾集合体类型
using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue<
cutlass::gemm::TagToStrideC_t<LayoutC>, // 矩阵C的步长
cutlass::gemm::TagToStrideC_t<LayoutC>, // 矩阵D的步长
cutlass::epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>
>; // 融合操作:D = alpha * AB + beta * C
// 步骤 3:在核函数层将主循环和结尾组合在一起
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>, // 问题形状 [M, N, K, L]
CollectiveMainloop,
CollectiveEpilogue
>;
// 步骤 4:用设备适配器包装核函数类,获得一个主机端可用的核函数句柄
using GemmHandle = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Collective层:主循环
一个 “集合体” 是一组相互协作以执行工作的线程,并且可以被并行重复以形成整个核函数。通常,这对应一个线程块或集群。TiledMMA和TiledCopy对象描述的是计算和拷贝工作在并行工作者间的空间分配;而Collective层则负责组织这些工作的时间序列,包括设置流水线、线程束专用化方案,以及使用硬件加速的同步原语来管理流水线和异步操作。
CUTLASS 3.x的GEMM核函数包含一个主循环集合体,它是CollectiveMma类模板的一个实例,定义了一个集合体单次主循环迭代的基本要素,最重要的是加载和矩阵乘加(MMA)过程。其定义示例如下:
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
DispatchPolicy, // 分发策略,决定算法与架构特化
TileShape, // 集合体输出的分块形状
ElementA, StrideA, // 矩阵A的数据类型和全局内存步长
ElementB, StrideB, // 矩阵B的数据类型和全局内存步长
TiledMma, // 分片MMA对象,定义计算的空间组织
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, TransformA, // 矩阵A的加载路径
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, TransformB // 矩阵B的加载路径
>;
主循环集合体是底层抽象的组合点:一个TiledMma、两个从全局内存到共享内存的TiledCopy操作,以及可选的用于寄存器源MMA的共享内存到寄存器的拷贝原子。这些抽象在很大程度上是正交的,允许不同的MMA操作与不同的拷贝操作结合,同时最大化代码复用。
最关键的部分是分发策略,它将CollectiveMma特化到特定的算法或GPU架构。例如,分发策略MainloopSm90TmaGmmaWarpSpecialized将CollectiveMma特化为适用于Hopper架构的TMA线程束专用化实现。它本身也是一个模板,可针对流水线阶段数、集群形状以及核函数调度策略进行参数化。
Collective Builder
尽管CollectiveMma提供了众多调优旋钮,允许用户精确指定GEMM主循环,但随之而来的是复杂性。通常,用户希望根据流水线、硬件能力和资源可用性等高层考虑来推导这些对象及相关SMEM布局。CUTLASS提供了CollectiveBuilder接口来完成这种推导。其使用示例如下:
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, // 架构标签,如 cute::arch::Sm90
OpClass, // 运算类,如 cute::arch::OpClassTensorOp
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
StageCount, // 流水线阶段数,例如 StageCountAuto
KernelSchedule // 核函数调度,例如 KernelScheduleAuto
>::CollectiveOp;
CollectiveBuilder根据用户友好的条件选择,并自动推导出CollectiveMma模板所需的底层参数。
Collective层:结尾
结尾集合体是Collective API的另一半。它负责在主循环每次迭代后,对工作分块进行后处理并存储输出。与主循环类似,这意味着结尾集合体也是一个组合点,组合了拷贝操作和数学运算。不同的是,这些数学运算本身可以通过 Epilogue Visitor Tree (EVT) 范式进行高度组合。这对于AI工作负载尤其有用,因为它们经常需要在GEMM之后立即计算激活函数。CUTLASS的结尾集合体处理了这种激活函数与核函数的融合,从而消除了不必要的数据移动。
CUTLASS提供了多个结尾实现。CollectiveBuilder为结尾提供了更统一的高层接口:
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OpClass,
TileShape, ClusterShape,
EpilogueTileType, // 结尾子分块,用于优化计算-拷贝重叠
ElementAccumulator, // 主循环输出(累加器)的数据类型
ElementCompute, // 结尾计算使用的中间数据类型
ElementC, GmemLayoutTagC, AlignmentC, // 矩阵C的信息
ElementD, GmemLayoutTagD, AlignmentD, // 矩阵D的信息
EpilogueScheduleType, // 结尾调度类型
FusionOpOrCallbacks // 融合操作或回调
>::CollectiveOp;
CUTLASS提供了丰富的预定义融合操作。用户也可以使用Epilogue Visitor Trees构建自定义的融合操作。
核函数(Kernel)层
Collective层完整定义了一个集合体在核函数执行期间的计算过程。而核函数层的任务则是将这些集合体扩展到一个覆盖整个动态问题空间的线程块或集群网格上。核函数层通过拼接主循环和结尾集合体的原始过程来组装设备核函数。
核函数层的主要入口API是cutlass::gemm::kernel::GemmUniversal类。它是一个无状态的通用设备核函数,通过组合一个主循环集合体和一个结尾集合体来实现GEMM。其基本用法如下:
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape, // 例如 Shape<int, int, int> 表示通用GEMM
CollectiveMainloop,
CollectiveEpilogue
>;
与TiledMma和TiledCopy一样,CollectiveMainloop和CollectiveEpilogue是正交的抽象,通过GemmUniversal组合在一起。
分块调度
核函数层也是指定分块调度器的组合点。正如核函数调度定义了集合体内部工作的时间组织,分块调度器定义了工作在不同集合体间的分配顺序。最基本的调度器是每个输出分块分配一个线程块。CUTLASS 3.x为Hopper实现了两种额外的调度器:一种是持久化调度器,每个SM启动一个线程块,每个线程块在其生命周期内计算多个输出分块;另一种是Stream-K调度器,它同样是持久化的,但为了更好地负载均衡,会将某些输出分块的工作沿K维度进行分割。在Blackwell架构上,则使用支持集群启动控制的调度器。我们可以扩展上述核函数以使用Stream-K调度器:
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::StreamKScheduler // 指定Stream-K调度器
>;
设备(Device)层
与核函数启动相关的主机端逻辑,包括支持集群的启动、在不同设备或CUDA流上的启动,都在设备层实现。设备层的主要入口点是cutlass::gemm::device::GemmUniversalAdapter,它将一个GemmUniversal核函数包装在一个有状态的、可复用的句柄中。这使得用户可以方便地管理核函数参数、工作空间并执行启动。
以下示例展示了如何使用GemmUniversalAdapter启动核函数:
using GemmHandle = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using Arguments = typename GemmHandle::Arguments;
// 1. 构建参数
Arguments args {
cutlass::Gemm::kBatched, // 模式(此处为批处理GEMM)
cute::make_shape(M, N, K, L), // 问题形状
{A, stride_A, B, stride_B}, // 主循环参数
{{alpha, beta}, C, stride_C, D, stride_D}, // 结尾参数
make_kernel_hardware_info(device_id), // 硬件信息
{} // 调度器参数(默认)
};
GemmHandle gemm;
// 2. 检查可行性
cutlass::Status status = GemmHandle::can_implement(args);
if (status != cutlass::Status::kSuccess) { /* 处理错误 */ }
// 3. 分配工作空间
size_t workspace_size = GemmHandle::get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// 4. 初始化句柄状态
status = gemm.initialize(args, workspace.get());
if (status != cutlass::Status::kSuccess) { /* 处理错误 */ }
// 5. 启动核函数
status = gemm.run(); // 可在此处指定CUDA流
if (status != cutlass::Status::kSuccess) { /* 处理错误 */ }
总结
本文探讨了CUTLASS库如何被概念化为一个层次化结构,其中每一层的对象都由更低层的正交对象组合而成。这种设计使得大量可深度定制的GEMM实现能够以高水平的代码复用得以实现。其核心优势在于:
-
• 正交性与可组合性:计算、拷贝、流水线、调度等关注点分离,允许独立选择和优化。 -
• 层次化抽象:从硬件指令到完整核函数,每一层解决特定问题,降低了复杂度。 -
• 灵活性与自动化:高级Builder接口提供“自动”选项,而底层模板暴露所有调优旋钮,满足从快速原型到极致性能的所有需求。 -
• 面向未来架构:该设计已成功扩展到Hopper和Blackwell等新架构,证明了其生命力。
这种设计哲学不仅适用于GEMM,也为其他GPU计算模式提供了蓝图。在下一篇文章中,我们将探讨CUTLASS 4.0引入的变化,特别是CuTe Python DSL,它进一步将这种强大的抽象能力带入了Python生态。
本文翻译重构自CUTLASS 3.x: Orthogonal, Reusable, and Composable Abstractions for GEMM Kernel Design[1]
引用链接
[1] CUTLASS 3.x: Orthogonal, Reusable, and Composable Abstractions for GEMM Kernel Design: https://developer.nvidia.com/blog/cutlass-3-x-orthogonal-reusable-and-composable-abstractions-for-gemm-kernel-design/

