简介
TMA用于在GPU的全局内存和线程块的共享内存之间进行异步内存拷贝,支持拷贝一维数组和多维数组,最高支持 5 维。
TMA的优势是:
-
通过异步机制支持 warp-specialized kernel schedule,从而提升 GPU 利用率; -
通过 TMA copy descriptor,以单线程方式处理地址、stride 等辅助拷贝信息的计算,因此更加节省寄存器,并且天然能够处理 predication,例如越界检查。
TMA 典型场景 Prefetching Data
在一种迭代式的copy + compute模式中,这样做可以用当前迭代的计算来隐藏未来迭代的数据传输延迟,从而潜在地增加正在传输中的数据量,也就是 bytes-in-flight。TMA底层是DMA进行数据传输,不需要GPU和CPU参数数据搬运过程,此时可以按batch实现数据预期。
fetch_batch = compute_batch + num_stages
也就是,预取始终领先当前计算 num_stages 个 batch,例如当nums_stahes = 2时:
正在计算 batch 0,同时预取 batch 2正在计算 batch 1,同时预取 batch 3正在计算 batch 2,同时预取 batch 4
TMA Load
使用 TMA load 的 kernel 和使用其他内存拷贝方式的 kernel 有很大不同。因此,我们会先通过一个简单示例任务展示如何编写这样的 kernel,然后再解释其中涉及的概念。
示例任务
为了演示 TMA load 的用法,我们考虑一个简单任务:对一个二维 row-major 矩阵进行 tiling。
我们用numpy描述这种tiling。
A = np.random.uniform(M, N)for i in range(M):for j in range(N):cta_i_j = A.reshape(M // CTA_M, CTA_M,N // CTA_N, N)[i, :, j, :]
Host code
我们要创建三个对象
-
要从中拷贝数据的Gemm tensor; -
每个CTA内部用于接受数据的SMEM tensor layout; -
一个以这两个对象作为参数的tma_load对象;
一旦创建好这些对象,就可以将他们传递给device端的kernel。在kernel内部,会真正调用TMA load操作。
host端完整代码
template <typename T, int CTA_M, int CTA_N>void host_fn(T* data, int M, int N) {using namespace cute;// create the GMEM tensorauto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);// create the SMEM layoutauto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});// create the TMA objectauto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);// invoke the kerneltma_load_kernel<CTA_M, CTA_N><<<dim3{M / CTA_M, N / CTA_N, 1}, 1>>>(tma_load, gmem_tensor, smem_layout);}
上面代码中tma_load对象是cute::make_tma_copy函数显示默认创建的。建议使用这种默认形式,以避免引入不必要的复杂性。
Kernel code
相关的 kernel 代码片段如下。这些代码行包含了许多重要的 TMA 概念,后面会逐一解释。
template <typename T, int CTA_M, int CTA_N, class TmaLoad, class GmemTensor>void tma_load_kernel(__grid_constant__ const TmaLoad tma_load, GmemTensor gmem_tensor) {using namespace cute;constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);__shared__ T smem_data[CTA_M * CTA_N];__shared__ uint64_t tma_load_mbar;// host端创建tensor: engine是一段SMEM,layout是CTA_TILEauto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout);if (threadIdx.x == 0) {// 保存Gemm tensor的坐标auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));// 保存Gemm cta tensor的坐标auto gmem_tensor_coord_cta = local_tile(gmem_tensor_coord,Tile<Int<CTA_M>, Int<CTA_N>>{},make_coord(blockIdx.x, blockIdx.y));initialize_barrier(tma_load_mbar, /* arrival count */ 1);set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);auto tma_load_per_cta = tma_load.get_slice(0);copy(tma_load.with(tma_load_mbar),tma_load_per_cta.partition_S(gmem_tensor_coord_cta),tma_load_per_cta.partition_D(smem_tensor));}__syncthreads();wait_barrier(tma_load_mbar, /* phase */ 0);// after this line, the TMA load is finished}
如果我们有两个tensor需要从GEMM拷贝SMEM,那么每个tensor都必须有自己独立的TileCopy实例,并且每个实例都必须是__grid__ constant const,这是将cuTensorMap从host传递到device时的要求。下一个重点是:对于一次TMA copy,只有一个线程负责发起TMA操作。
Memory barrier
tma_load_mbar是一个异步事务barries,用于同步TMA load和kernel中后续消费SMEM数据部分。cute方法的initialize_barrier封装了PTX指令:
mbarrier.init.shared.b64
这个指令还需要一个额外的arrival_count参数。在我们的场景中,由于只有一个线程会发起 TMA load,所以 arrival count 应该设置为 1。此外,mbarrier 的初始 phase 总是被设置为 0 。
set_barrier_transaction_bytes封装了PTX指令:
mbarrier.arrive.expect_tx.shared::cta.b64
wait_barrier执行对mbarrier对象的wait操作。需要注意的是,所有线程都等待这个mbarrier,这和只有thread 0对arrive操作执行对比。
另外wait_barrier前调用__syncthreads()是必须的,因为它用于处理之前的线程分歧。
parity修饰符表示等待时需要提供一个phase bit,线程会睡眠,直到mbarrier的这个phase bit翻转。由于这是mbarrier初始化后第一次用于跟踪完成状态,所以我们传入的phase是0。如果同一个kernel需要执行多次TMA拷贝,并复用这个mbarrier,就需要翻转这个phase。
在 wait_barrier 之后,内存一致性模型会提供保证,TMA load写入SMEM的数据,对于此CTA中所有调用了mbarrier_wait的线程都可见。
使用 TMA 处理 remainder tiles 以及 stride 要求
上面的例子中,我们假设 m % CTA_M == 0; n % CTA_N == 0; 不过对于 TMA load 来说,这个假设其实可以完全去掉。当从GMEM向SMEM加载remainder tile时,我们不需要字节处理越界逻辑。TMA copy单元会自动对内存拷贝进行predication(mask操作),确保不会读取越界地址。
由于普通的cute tensor,更像是普通的指针,按照指针访问这段地址存在越界风险,而TMA通过坐标形式表示tensor,并结合硬件predication,可以避免这类问题。
不过对于TMA来说,需要特别注意GMEM本身的stride的要求,也就是16-byte boundary requirement。正如预期的一样,TMA并不支持拷贝任意stride的GMEM区域。相反,我们需要假设被拷贝的tile满足:
-
存在一个连续方向,也就是stride=1; -
其他方向的stride必须是16 byte对齐;
如果不满足这个条件,需要在调用kernel之前,对tensor做padding。
TMA Store
和 TMA load 类似,实现 TMA store 也是一个两步流程:
在 host 端定义 TMA copy descriptor
在 kernel 内部发起 TMA store 操作
template <typename T, int CTA_M=32, int CTA_N=32>void host_fn(T* data, int M, int N) {using namespace cute;// 创建一个GMEM tensorauto gmem_layout = make_layout(make_shape(M, N), LayoutRight{});auto gmem_tensor = make_tensor(make_gmem_ptr(T), gmem_layout);// 创建一个 SMEM layoutauto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});// 创建 TMA objectauto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);// 调用tma store kerneltma_store_kernel<CTA_M, CTA_N><<<dim3{M / CTA_M, N / CTA_N, 1}, CTA_M>>>(tma_store, gmem_tensor, smem_layout);}template <typename T, int CTA_M, int CTA_N, class TmaStore, class GmemTensor>void tma_store_kernel(__grid_constant__ const TmaStore tma_store, GmemTensor gmem_tensor) {using namespace cute;__shared__ T smem_data[CTA_M * CTA_N];auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});auto smem_tensor = make_tensor(make_smem_ptr(smem_data), smem_layout);// 验证数据:每个线程把自己的threadIdx.x写到smem_data的一整行里。for (int j = 0; j < CTA_N; ++j) {smem_tensor(threadIdx.x, j) = threadIdx.x;}// 确保所有线程写SMEM完成__syncthreads();tma_store_fence();if (threadIdx.x == 0) {auto gmem_tensor_coord = tma_store.get_tma_tensor(shape(gmem_tensor));auto gmem_tensor_coord_cta = local_tile(gmem_tensor_coord,Tile<Int<CTA_M>, Int<CTA_N>>{},make_coord(blockIdx.x, blockIdx.y));auto tma_store_per_cta = tma_store.get_slice(0);copy(tma_store,tma_store_per_cta.partition_S(smem_tensor),tma_store_per_cta.partition_D(gmem_tensor_coord_per_cta));// tma_store_arrive();}// tma_store_wait<0>();}
TMA load和TMA store代码间有个重要的区别是:TMA store中,不再使用mbarrier对象,而是使用另一种机制保证内存一致性:memory fence。
memory fence的作用是在执行该fence的线程前后发起的内存访问之间,建立一个有保证的顺序关系。
TMA Store arrive 和 wait
tma_store_arrive()会提交TMA store操作,更准确的说,它提交的是一个cp.async.bulk-group,而tma_store_wait()会等待,直到已提交但未完成的TMA store操作数最多剩下count个。若希望所有提交的TMA store都完成,就要设置count=0。由于TMA store是异步的,如果kernel内还有其他工作依赖TMA store完成,或者需要复用这段SMEM,这种同步机制很有必要。
深入理解 TMA 操作
|
|
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
上表中对比了TMA load和store操作,无论哪种操作,我们都需要在host code中通过make_tma_copy方法创建一个类似TiledCopy对象,然后将这个对象传入kernel函数。在kernel中,我们使用它们调用cute::copy,从而真正发起对应的TMA操作。
本节中,我们深入研究,当kernel中调用TiledCopy对象时,底层到底发生了什么。
TMA Load 和 Store 的 PTX 指令
cute会利用这个tensor和layout,决定当kernel调用:
cute::copy(tma_load, ...)
底层的PTX指令是:
asm volatile ("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"" [%0], [%1, {%3, %4}], [%2];":: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),"r"(crd0), "r"(crd1): "memory");
所有TMA的操作,本质是对cp.async.bulk的封装。
TMA Store Reduce
若想处理Reduce类操作,我们先用python模拟这一过程:
for cta_idx in range(number_of_ctas):gmem_dst[cta_idx] += smem_src[cta_idx]# 或者这样:gmem_dst[cta_idx] = max(gmem_dst[cta_idx], smem_src[cta_idx])# 或者这样:gmem_dst[cta_idx] = min(gmem_dst[cta_idx], smem_src[cta_idx])
一个普通的reduce实现,是把CTA的SMEM的值,累加到GMEM tensor的某个tile中,这个过程会分为散步:(一)一次GMEM load;(二)一次reduce计算;(三)一次GMEM store;这三步往往很慢。我们修改一下TMA store的TileCopy对象构造方式,就可以把三步过程压缩成一条PTX指令,即:
cp.reduce.async.bulk 代替原来的cp.async.bulk
具体来说,我们只需要在 host code 中做如下的一行修改:
// original: create a TMA store objectauto tma_store = make_tma_copy(SM90_TMA_STORE{}, gmem_tensor, smem_layout);// to create a TMA reduce sum objectauto tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, gmem_tensor, smem_layout);
然后使用 tma_reduce_sum 即可,这样底层调用的就是优化后的TMA指令,将上述的三个过程,由硬件集成到一条指令,由硬件实现"读-算-存"过程。
TMA Load Multicast
正如名字所示,Multicast是指这样的情况:
我们希望把GMEM tensor中的同一个tile,拷贝到多个CTA的SMEM中,例如:
-
一个输入矩阵的 column tile,可能会被多个 row tile 使用; -
或者反过来,一个 row tile 可能会被多个 column tile 使用。
此时.multicast operand允许我们进一步保证L2 cache hit,从而更高效地把同一份GMEM数据分发给同一个cluster内的多个CTA。
从普通TMA load扩展到multicast
我们考虑把前面的 TMA load 示例扩展成 multicast 版本。
一组CTA想要共同参与一次TMA load multicast,它们必须属于同一个threadblock cluster。
为了保持例子简单,我们只修改grid维度和cluster维度:
// old grid dimensions and implicit trivial cluster dimensionsdim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 1};dim3 cluster_dums = dim3{1, 1, 1};// new grid dimensions and cluster dimensionsdim3 grid_dims = dim3{M / CTA_M, N / CTA_N, 2};dim3 cluster_dums = dim3{1, 1, 2};
使用cluster时,cluster dimensions必须能够整除grid dimensions,否则kernel无法launch。
在新的kernel中,我们会让同一个GMEM tile被加载到同一个cluster内每个CTA的SMEM中。
Host code 的变化
// original: create a TMA load objectauto tma_load = make_tma_copy(SM90_TMA_LOAD{}, gmem_tensor, smem_layout);// new1: create a TMA load multicast object for the given cluster sizeauto tma_load = make_tma_copy(SM90_TMA_LOAD_MULTICAST{},gmem_tensor,smem_layout,cute::_2{});// new2: create a TMA load multicast object for the predefineusing ClusterShape = Shape<_1, _1, _2>;auto tma_load = make_tma_copy(SM90_TMA_LOAD_MULTICAST{},gmem_tensor,smem_layout,size<2>(ClusterShape{}); // 取ClusterShape第二维度的数据);
Kernel code 的变化
template <typename T, int CTA_M, int CTA_N, class ClusterShape,class TmaLoad, class GmemTensor>void tma_load_kernel(__grid_constant__ const TmaLoad tma_load,GmemTensor gmem_tensor) {using namespace cute;// 跟踪CTA在cluster内部的index,在cute中会读取特殊寄存器 %cluster_ctarank// 这个值下面简称为ctaiduint32_t block_rank_in_cluster = cute::block_rank_in_cluster();constexpr uint32_t cluster_size = size<2>(ClusterShape{}));constexpr uint16_t tma_mcast_mask = (uint16_t(1) << cluster_size) - 1;constexpr int tma_transaction_bytes = CTA_M * CTA_N * sizeof(T);__shared__ T smem_data[CTA_M * CTA_N];__shared__ uint64_t tma_load_mbar;auto smem_layout = make_layout(make_shape(CTA_M, CTA_N), LayoutRight{});auto smem_tensor = make_tensor(make_smem_ptr(T), smem_layout);auto gmem_tensor_coord = tma_load.get_tma_tensor(shape(gmem_tensor));auto gmem_tensor_coord_cta = local_tile(gmem_tensor_coord,Tile<Int<CTA_M>, Int<CTA_N>>{},make_coord(blockIdx.x, blockIdx.y));if (threadIdx.x == 0) {initialize_barrier(tma_load_mbar, /* arrival count */ 1);}__syncthreads();cute::cluster_sync();cutlass::arch::fence_barrier_init();if (threadIdx.x == 0) {set_barrier_transaction_bytes(tma_load_mbar, tma_transaction_bytes);auto tma_load_per_cta = tma_load.get_slice(block_rank_in_cluster);copy(tma_load.with(tma_load_mbar, tma_mcast_mask),tma_load_per_cta.partition_S(gmem_tensor_coord_per_cta),tma_load_per_cta.partition_D(smem_tensor));}__syncthreads();wait_barrier(tma_load_mbar, /* phase */ 0);// after this line, the TMA load is finishedcute::cluster_sync();}
在copy操作中,我们会传入一个uint16 bitmask,用来指定哪些CTA参与TMA multicast load。mask中bit为1的位置表示对应的CTA是active的。最多支持16个CTA,也就是maxumum nonportable size。
block_rank_in_cluster即ctaid会用来指定从当前CTA发起TMA multicast load时,切入GMEM的offset。即:
block_rank_in_clusterauto tma_load_per_cta = tma_load.get_slice(block_rank_in_cluster);
不同CTA使用不同的slice,即:
-
ctaid = 0 的 CTA 使用 get_slice(0); -
ctaid = 1 的 CTA 使用 get_slice(1)。
这样两个 CTA 就会协作完成同一个 tile 的 multicast load。这样多个CTA的SMEM中都可正确得到完整的GMEM数据。
这个过程可以描述为:每个CTA负责同一个GMEM tile的一部分,然后通过multicast,把自己负责的这部分广播到cluster内所有CTA的SMEM,最终每个CTA的SMEM都拼出完整的tile。
总结
在这篇博客中,我们通过几个简化示例,介绍了如何使用 CUTLASS 库提供的方法,在 CUDA kernel 中通过 TMA load、TMA store、TMA store reduce 和 TMA load multicast,完成 GMEM 和 SMEM 之间的内存拷贝。
我们首先对 TMA 做了整体概述,并说明了用户如何在 GPU kernel 中调用这些操作。随后,我们进一步深入到底层 PTX 指令,以帮助读者更深入地理解 TMA。
希望这篇博客能够帮助那些想要理解 TMA 的读者,也能够帮助已经了解 TMA、想要复习相关知识的读者,或者正在调试已有 TMA 项目的开发者。

