大数跨境

(二)CUTLASS —— 掌握 Tensor Memory Accelerator(TMA)

(二)CUTLASS —— 掌握 Tensor Memory Accelerator(TMA) AI Infra之道
2026-06-14
2
导读:简介TMA用于在GPU的全局内存和线程块的共享内存之间进行异步内存拷贝,支持拷贝一维数组和多维数组,最高支持

简介

TMA用于在GPU的全局内存和线程块的共享内存之间进行异步内存拷贝,支持拷贝一维数组和多维数组,最高支持 5 维。

TMA的优势是:

  • 通过异步机制支持 warp-specialized kernel schedule,从而提升 GPU 利用率;
  • 通过 TMA copy descriptor,以单线程方式处理地址、stride 等辅助拷贝信息的计算,因此更加节省寄存器,并且天然能够处理 predication,例如越界检查。

TMA 典型场景 Prefetching Data

在一种迭代式的copy + compute模式中,这样做可以用当前迭代的计算来隐藏未来迭代的数据传输延迟,从而潜在地增加正在传输中的数据量,也就是 bytes-in-flight。TMA底层是DMA进行数据传输,不需要GPU和CPU参数数据搬运过程,此时可以按batch实现数据预期。

fetch_batch = compute_batch + num_stages

也就是,预取始终领先当前计算 num_stages 个 batch,例如当nums_stahes = 2时:

正在计算 batch 0,同时预取 batch 2正在计算 batch 1,同时预取 batch 3正在计算 batch 2,同时预取 batch 4

TMA Load

使用 TMA load 的 kernel 和使用其他内存拷贝方式的 kernel 有很大不同。因此,我们会先通过一个简单示例任务展示如何编写这样的 kernel,然后再解释其中涉及的概念。

示例任务

为了演示 TMA load 的用法,我们考虑一个简单任务:对一个二维 row-major 矩阵进行 tiling。

我们用numpy描述这种tiling。

A = np.random.uniform(M, N)for i in range(M):  for j in range(N):    cta_i_j = A.reshape(        M // CTA_M, CTA_M,        N // CTA_N, N    )[i, :, j, :]

Host code

我们要创建三个对象

  • 要从中拷贝数据的Gemm tensor;
  • 每个CTA内部用于接受数据的SMEM tensor layout;
  • 一个以这两个对象作为参数的tma_load对象;

一旦创建好这些对象,就可以将他们传递给device端的kernel。在kernel内部,会真正调用TMA load操作。

host端完整代码

template <typename T, int CTA_M, int CTA_N>void host_fn(T* data, int M, int N) {  using namespace cute;
  // create the GMEM tensor  auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});  auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);
  // create the SMEM layout  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});
  // create the TMA object  auto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);
  // invoke the kernel  tma_load_kernel<CTA_M, CTA_N>                 <<<dim3{M / CTA_M, N / CTA_N, 1}, 1>>>                 (tma_load, gmem_tensor, smem_layout);}

上面代码中tma_load对象是cute::make_tma_copy函数显示默认创建的。建议使用这种默认形式,以避免引入不必要的复杂性。

Kernel code

相关的 kernel 代码片段如下。这些代码行包含了许多重要的 TMA 概念,后面会逐一解释。

template <typename T, int CTA_Mint CTA_Nclass TmaLoad, class GmemTensor>void tma_load_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor) {  using namespace cute;  constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);
  __shared__ T smem_data[CTA_M * CTA_N];  __shared__ uint64_t tma_load_mbar;  // host端创建tensor: engine是一段SMEM,layout是CTA_TILE  auto smem_layout = make_layout(make_shape(CTA_MCTA_N), LayoutRight{});  auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout);
  if (threadIdx.x == 0) {    // 保存Gemm tensor的坐标    auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));    // 保存Gemm cta tensor的坐标    auto gmem_tensor_coord_cta = local_tile(        gmem_tensor_coord,        Tile<Int<CTA_M>, Int<CTA_N>>{},        make_coord(blockIdx.x, blockIdx.y));
    initialize_barrier(tma_load_mbar, /* arrival count */ 1);
    set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);
    auto tma_load_per_cta = tma_load.get_slice(0);    copy(tma_load.with(tma_load_mbar),         tma_load_per_cta.partition_S(gmem_tensor_coord_cta),         tma_load_per_cta.partition_D(smem_tensor));  }  __syncthreads();  wait_barrier(tma_load_mbar, /* phase */ 0);
  // after this line, the TMA load is finished}

如果我们有两个tensor需要从GEMM拷贝SMEM,那么每个tensor都必须有自己独立的TileCopy实例,并且每个实例都必须是__grid__ constant const,这是将cuTensorMap从host传递到device时的要求。下一个重点是:对于一次TMA copy,只有一个线程负责发起TMA操作。

Memory barrier

tma_load_mbar是一个异步事务barries,用于同步TMA load和kernel中后续消费SMEM数据部分。cute方法的initialize_barrier封装了PTX指令:

mbarrier.init.shared.b64

这个指令还需要一个额外的arrival_count参数。在我们的场景中,由于只有一个线程会发起 TMA load,所以 arrival count 应该设置为 1。此外,mbarrier 的初始 phase 总是被设置为 0 。

set_barrier_transaction_bytes封装了PTX指令:

mbarrier.arrive.expect_tx.shared::cta.b64

wait_barrier执行对mbarrier对象的wait操作。需要注意的是,所有线程都等待这个mbarrier,这和只有thread 0对arrive操作执行对比。

另外wait_barrier前调用__syncthreads()是必须的,因为它用于处理之前的线程分歧。

parity修饰符表示等待时需要提供一个phase bit,线程会睡眠,直到mbarrier的这个phase bit翻转。由于这是mbarrier初始化后第一次用于跟踪完成状态,所以我们传入的phase是0。如果同一个kernel需要执行多次TMA拷贝,并复用这个mbarrier,就需要翻转这个phase。

在 wait_barrier 之后,内存一致性模型会提供保证,TMA load写入SMEM的数据,对于此CTA中所有调用了mbarrier_wait的线程都可见。

使用 TMA 处理 remainder tiles 以及 stride 要求

上面的例子中,我们假设 m % CTA_M == 0; n % CTA_N == 0; 不过对于 TMA load 来说,这个假设其实可以完全去掉。当从GMEM向SMEM加载remainder tile时,我们不需要字节处理越界逻辑。TMA copy单元会自动对内存拷贝进行predication(mask操作),确保不会读取越界地址。

由于普通的cute tensor,更像是普通的指针,按照指针访问这段地址存在越界风险,而TMA通过坐标形式表示tensor,并结合硬件predication,可以避免这类问题。

不过对于TMA来说,需要特别注意GMEM本身的stride的要求,也就是16-byte boundary requirement。正如预期的一样,TMA并不支持拷贝任意stride的GMEM区域。相反,我们需要假设被拷贝的tile满足:

  • 存在一个连续方向,也就是stride=1;
  • 其他方向的stride必须是16 byte对齐;

如果不满足这个条件,需要在调用kernel之前,对tensor做padding。

TMA Store

和 TMA load 类似,实现 TMA store 也是一个两步流程:

在 host 端定义 TMA copy descriptor

在 kernel 内部发起 TMA store 操作

template <typename T, int CTA_M=32int CTA_N=32>void host_fn(T* data, int M, int N) {  using namespace cute;  // 创建一个GMEM tensor  auto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});  auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);  // 创建一个 SMEM layout  auto smem_layout = make_layout(make_shape(CTA_MCTA_N), LayoutRight{});  // 创建 TMA object  auto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);  // 调用tma store kernel  tma_store_kernel<CTA_MCTA_N>                  <<<dim3{M / CTA_M, N / CTA_N1}, CTA_M>>>                  (tma_store, gmem_tensor, smem_layout);}template <typename T, int CTA_Mint CTA_Nclass TmaStore, class GmemTensor>void tma_store_kernel(__grid_constant__ const TmaStore tma_store, GmemTensor gmem_tensor) {  using namespace cute;  __shared__ T smem_data[CTA_M * CTA_N];
  auto smem_layout = make_layout(make_shape(CTA_MCTA_N), LayoutRight{});  auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout);
  // 验证数据:每个线程把自己的threadIdx.x写到smem_data的一整行里。  for (int j = 0; j < CTA_N; ++j) {    smem_tensor(threadIdx.x, j) = threadIdx.x;  }  // 确保所有线程写SMEM完成  __syncthreads();  tma_store_fence();
  if (threadIdx.x == 0) {    auto gmem_tensor_coord = tma_store.get_tma_tensor(shape(gmem_tensor));
    auto gmem_tensor_coord_cta = local_tile(      gmem_tensor_coord,      Tile<Int<CTA_M>, Int<CTA_N>>{},      make_coord(blockIdx.x, blockIdx.y));
    auto tma_store_per_cta = tma_store.get_slice(0);    copy(tma_store,         tma_store_per_cta.partition_S(smem_tensor),         tma_store_per_cta.partition_D(gmem_tensor_coord_per_cta));    // tma_store_arrive();  }  // tma_store_wait<0>();}

TMA load和TMA store代码间有个重要的区别是:TMA store中,不再使用mbarrier对象,而是使用另一种机制保证内存一致性:memory fence。

memory fence的作用是在执行该fence的线程前后发起的内存访问之间,建立一个有保证的顺序关系。

TMA Store arrive 和 wait

tma_store_arrive()会提交TMA store操作,更准确的说,它提交的是一个cp.async.bulk-group,而tma_store_wait()会等待,直到已提交但未完成的TMA store操作数最多剩下count个。若希望所有提交的TMA store都完成,就要设置count=0。由于TMA store是异步的,如果kernel内还有其他工作依赖TMA store完成,或者需要复用这段SMEM,这种同步机制很有必要。

深入理解 TMA 操作


TMA LOAD
TMA STORE
Direction
GMEM → SMEM
SMEM → GMEM
Sync method
Memory barrier
Proxy fence
When to sync
操作之后
操作之前

上表中对比了TMA load和store操作,无论哪种操作,我们都需要在host code中通过make_tma_copy方法创建一个类似TiledCopy对象,然后将这个对象传入kernel函数。在kernel中,我们使用它们调用cute::copy,从而真正发起对应的TMA操作。

本节中,我们深入研究,当kernel中调用TiledCopy对象时,底层到底发生了什么。

TMA Load 和 Store 的 PTX 指令

cute会利用这个tensor和layout,决定当kernel调用:

cute::copy(tma_load, ...)

底层的PTX指令是:

asm volatile (  "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"  " [%0], [%1, {%3%4}], [%2];"  :  : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),    "r"(crd0), "r"(crd1)  : "memory");

所有TMA的操作,本质是对cp.async.bulk的封装。

TMA Store Reduce

若想处理Reduce类操作,我们先用python模拟这一过程:

for cta_idx in range(number_of_ctas):  gmem_dst[cta_idx] += smem_src[cta_idx]  # 或者这样:  gmem_dst[cta_idx] = max(gmem_dst[cta_idx], smem_src[cta_idx])  # 或者这样:  gmem_dst[cta_idx] = min(gmem_dst[cta_idx], smem_src[cta_idx])

一个普通的reduce实现,是把CTA的SMEM的值,累加到GMEM tensor的某个tile中,这个过程会分为散步:(一)一次GMEM load;(二)一次reduce计算;(三)一次GMEM store;这三步往往很慢。我们修改一下TMA store的TileCopy对象构造方式,就可以把三步过程压缩成一条PTX指令,即:

cp.reduce.async.bulk 代替原来的cp.async.bulk

具体来说,我们只需要在 host code 中做如下的一行修改:

// original: create a TMA store objectauto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);// to create a TMA reduce sum objectauto tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, gmem_tensor, smem_layout);

然后使用 tma_reduce_sum 即可,这样底层调用的就是优化后的TMA指令,将上述的三个过程,由硬件集成到一条指令,由硬件实现"读-算-存"过程。

TMA Load Multicast

正如名字所示,Multicast是指这样的情况:

我们希望把GMEM tensor中的同一个tile,拷贝到多个CTA的SMEM中,例如:

  • 一个输入矩阵的 column tile,可能会被多个 row tile 使用;
  • 或者反过来,一个 row tile 可能会被多个 column tile 使用。

此时.multicast operand允许我们进一步保证L2 cache hit,从而更高效地把同一份GMEM数据分发给同一个cluster内的多个CTA。

从普通TMA load扩展到multicast

我们考虑把前面的 TMA load 示例扩展成 multicast 版本。

一组CTA想要共同参与一次TMA load multicast,它们必须属于同一个threadblock cluster。

为了保持例子简单,我们只修改grid维度和cluster维度:

// old grid dimensions and implicit trivial cluster dimensionsdim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 1};dim3 cluster_dums = dim3{111};// new grid dimensions and cluster dimensionsdim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 2};dim3 cluster_dums = dim3{112};

使用cluster时,cluster dimensions必须能够整除grid dimensions,否则kernel无法launch。

在新的kernel中,我们会让同一个GMEM tile被加载到同一个cluster内每个CTA的SMEM中。

Host code 的变化

// original: create a TMA load objectauto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);// new1: create a TMA load multicast object for the given cluster sizeauto tma_load = make_tma_copy(    SM90_TMA_LOAD_MULTICAST{},    gmem_tensor,    smem_layout,    cute::_2{});// new2: create a TMA load multicast object for the predefineusing ClusterShape = Shape<_1, _1, _2>;auto tma_load = make_tma_copy(    SM90_TMA_LOAD_MULTICAST{},    gmem_tensor,    smem_layout,    size<2>(ClusterShape{});  // 取ClusterShape第二维度的数据);

Kernel code 的变化

template <typename T, int CTA_M, int CTA_N, class ClusterShape,          class TmaLoadclass GmemTensor>void tma_load_kernel(__grid_constant__ const TmaLoad tma_load,                     GmemTensor gmem_tensor) {  using namespace cute;  // 跟踪CTA在cluster内部的index,在cute中会读取特殊寄存器 %cluster_ctarank  // 这个值下面简称为ctaid  uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();  constexpr uint32_t cluster_size = size<2>(ClusterShape{}));  constexpr uint16_t tma_mcast_mask = (uint16_t(1) << cluster_size) - 1;  constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);
  __shared__ T smem_data[CTA_M * CTA_N];  __shared__ uint64_t tma_load_mbar;
  auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});  auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);  auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));  auto gmem_tensor_coord_cta = local_tile(        gmem_tensor_coord,        Tile<Int<CTA_M>, Int<CTA_N>>{},        make_coord(blockIdx.x, blockIdx.y));
  if (threadIdx.x == 0) {    initialize_barrier(tma_load_mbar, /* arrival count */ 1);  }  __syncthreads();  cute::cluster_sync();  cutlass::arch::fence_barrier_init();
  if (threadIdx.x == 0) {    set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);    auto tma_load_per_cta = tma_load.get_slice(block_rank_in_cluster);    copy(tma_load.with(tma_load_mbar, tma_mcast_mask),         tma_load_per_cta.partition_S(gmem_tensor_coord_per_cta),         tma_load_per_cta.partition_D(smem_tensor));  }  __syncthreads();  wait_barrier(tma_load_mbar, /* phase */ 0);
  // after this line, the TMA load is finished
  cute::cluster_sync();}

在copy操作中,我们会传入一个uint16 bitmask,用来指定哪些CTA参与TMA multicast load。mask中bit为1的位置表示对应的CTA是active的。最多支持16个CTA,也就是maxumum nonportable size。

block_rank_in_cluster即ctaid会用来指定从当前CTA发起TMA multicast load时,切入GMEM的offset。即:

block_rank_in_clusterauto tma_load_per_cta = tma_load.get_slice(block_rank_in_cluster);

不同CTA使用不同的slice,即:

  • ctaid = 0 的 CTA 使用 get_slice(0);
  • ctaid = 1 的 CTA 使用 get_slice(1)。

这样两个 CTA 就会协作完成同一个 tile 的 multicast load。这样多个CTA的SMEM中都可正确得到完整的GMEM数据。

这个过程可以描述为:每个CTA负责同一个GMEM tile的一部分,然后通过multicast,把自己负责的这部分广播到cluster内所有CTA的SMEM,最终每个CTA的SMEM都拼出完整的tile。

总结

在这篇博客中,我们通过几个简化示例,介绍了如何使用 CUTLASS 库提供的方法,在 CUDA kernel 中通过 TMA load、TMA store、TMA store reduce 和 TMA load multicast,完成 GMEM 和 SMEM 之间的内存拷贝。

我们首先对 TMA 做了整体概述,并说明了用户如何在 GPU kernel 中调用这些操作。随后,我们进一步深入到底层 PTX 指令,以帮助读者更深入地理解 TMA。

希望这篇博客能够帮助那些想要理解 TMA 的读者,也能够帮助已经了解 TMA、想要复习相关知识的读者,或者正在调试已有 TMA 项目的开发者。

【声明】内容源于网络
0
0
AI Infra之道
AI infra时代已经到来,大家一起打造中国的AI知识库
内容 16
粉丝 0
AI Infra之道 AI infra时代已经到来,大家一起打造中国的AI知识库
总阅读10
粉丝0
内容16