从这篇文章开始,将带来kittens三部曲,皆出自HazyResearch这个斯坦福的研究实验机构,我发现他们很喜欢给自己的项目加一个kittens的后缀,看来作者非常喜欢🐱:)
引言
ThunderKittens诞生自2024年10月,那个时候,cutlass才3.5版本,还在支持一系列H100 GEMM变种,DeepseekV3还没信儿,Triton是当时python kernel DSL的绝对扛把子,cuteDSL、tilelang、cuTile等一种python kernel DSL还在娘胎,懂王对格陵兰岛还没想法。
当时ThunderKittens(后文简称TK)在业界掀起来了一小股风暴,因为AI workload与 GPU 硬件的映射存在瓶颈,随着hopper这种DSL化的架构推出,纯手写custom kernel非常难且难以达到峰值性能,现有的框架中,cutlass/cute虽然支持完备,但用起来以及扩展难度都很大,triton虽然用起来容易,但是性能又差点意思。TK正是在这种背景下诞生,想要在高性能和易用性扩展性的tradeoff中找到一个balance,但是从我个人的观点和经验来看,这种很容易成为四不像,最后两头都不占,逐渐被市场抛弃,前车之鉴有大名鼎鼎的TF Keras、MXNet等等。我个人认为,如果你真的下定决心成为一个算子专家,还是下狠心学学cutlass吧,cutlass覆盖Gemm和conv为中心的所有场景,且生态非常活跃,他们团队在github经常回答issue,我提了非常多的issue,基本每个都有回应,而且有同仁说有些GPU编程的细节,比如TMA descriptor、Hopper异步引入的各种barrier代码、thread-level的操作细节,就应该被隐藏,我是不赞同的,这个也碍不了多少事,知道一下最底层代码是怎么写的总是好的,从现在这个时间点回头看,我感觉TK有点跟不上了(接受反驳,仅个人观点),至少从github活跃度来看,TK的更新频率很低,只有paper作者在更新,见下图,一方面确实毕竟是学术机构而非市场团队,但是如果你想要在工作中快速写一个高性能算子又想获得对硬件特性的一定控制权,又熟悉 pytorch(见下图TK的API完全类似pytorch),TK还是值得推荐和porting的,并且如果你想要以学习为目的,又实在觉得cutlass难,那么TK也是一个非常合适的学习对象
核心动机
回到TK本身,在当时那个时间,编写算子存在着下列两个矛盾:
1.高性能与易用性的矛盾:
- CUTLASS/CuTe:高性能但嵌套模板复杂,编译体积大,开发门槛高;
- 基于编译器的DSL Triton:开发算子简单,但优化不足,用户难以控制硬件特性
2.GPU高速迭代和算子开发时间的矛盾:FlashAttention-2 在 H100 上性能下降 47%,FlashAttention-3 耗时两年才适配 H100
所以TK在这俩矛盾的前提下,后续idea全部围绕以下三个问题来展开:
1. 如何用简洁的代码抽象覆盖kernel(以GEMM为代表)多数应用场景,从而降低开发与维护成本
2. 如下图,一个CUDA kernel生命周期在overlap下的breakdown等于compute和memory消耗时间的最大值加启动和kernel内同步的时间
如何协调 GPU 编程三大核心并行抽象(warp/block/grid),从而最大化TensorCore在Critical Path的比例且尽可能让各级memory单元掩盖在TensorCore和CUDACore下?
3. 如何避免现有框架的缺陷,例如cutlass的开发复杂性和triton 的硬件特性控制限制
TK核心idea
既然TK说它更简洁,那唯一的办法就是帮用户做好抽象,TK 的三级抽象(warp/block/grid)直接对应 GPU 的硬件层级,其实cutlass也是一样的,不过就是多了一层thread层级,TK把这个隐藏掉了
1.warp级别抽象
如果屏幕前的你写过tensorcore gemm的话,在warp层面如果shared memory layout没有弄好,可能就导致 bank 冲突;warp tile的尺寸和layout与tensorcore mma的尺寸和layout也不一定完全match,需要仔细的比对
所以TK提出
1) 一个基础的Tile 抽象: 16×16 大小(适配各代tensorcore mma尺寸),同时还在gemm搬运过程中经过的各级memory提供相应tile 抽象:Reg Tile、Smem Tile、Gmem Tile;这其实也免去了我们手动去索引数据,这在2026年的今天,各大DSL也具备这个能力
2) PyTorch-like的编码风格:这个见上图
3)自动布局优化:提供 3 种 shared memory swizzle layout(32/64/128 字节),编译时根据 Tile 大小选择最优布局,最小化 bank 冲突。这其实也是gemm优化的重要步骤
以上都有一个共同特点,那就是这些抽象都是模板化,意味着很多检查都可以编译期完成的,比如type check,layout check,这个非常重要,像如今的python kernel DSL,他们要解决的一个重要问题就是怎么降低这些check的开销,毕竟在运行时去check,那可受不了,当然,有了TVM-FFI,python kernel DSL比如cuteDSL和tilelang这个问题得到了很大的缓解。
2.block级别抽象
block 层面,我们需要协调多 warp 并行,cutlass需手动处理 warp 间异步执行、内存同步与overlap,这对开发者的代码水平要求比较高,且容易出错,TK在此处也用了抽象帮助我们屏蔽掉这些,如下code snippet
我们在用 TK写kernel的时候,只需要重点开发load,compute逻辑,大幅减少了cutlass中哪里都是的同步异步逻辑,只需要我们稍微写一丢丢同步逻辑,其它都被屏蔽了,并且不用我们自己去设定producer用多少寄存器,consumer又用多少寄存器。除此之外,compute完了,还有个store和finish逻辑,他们组成了“Load-Compute-Store-Finish” 四步模板。
这里存在一些可调的参数,TK也通过模板参数开放给你供你tune,比如用于produce和consumer的warpgroups
3.grid级别抽象
grid 级负责调度多个 block 执行,对应于cutlass里的tile scheduler,如果调度不当,可能造成现有问题
1)L2 cache命中率低(换成threadblock swizzle会不会熟悉一些?)
2)wave quantization(streamK主要解决的就这玩意)
对于通用Gemm来讲,这俩的优化手段也比较常见,解决问题一用threadblock swizzle,解决问题二用persistent grid+streamK
用TK写算子
我们看一下用TK写一个完整attention
using namespace kittens;using namespace kittens::prototype;using namespace kittens::prototype::lcf;template<int D, int NUM_WORKERS> struct attn_fwd_layout {using qo_tile = st_bf<64, D>;using kv_tile = st_bf<D==64?192:128, D>;using qo_global = kittens::gl<bf16, -1, -1, -1, D, qo_tile>;using kv_global = kittens::gl<bf16, -1, -1, -1, D, kv_tile>;struct globals { qo_global O, Q; kv_global K, V; };struct input_block { kv_tile k, v; };struct scratch_block { qo_tile q[NUM_WORKERS]; };struct common_state { int batch, head, seq; };struct consumer_state {rt_fl<16, qo_tile::cols> o_reg;col_vec<rt_fl<16, kv_tile::rows>> max_vec, norm_vec;col_vec<rt_fl<16, kv_tile::rows>> max_vec_last_scaled, max_vec_scaled;rt_fl<16, kv_tile::rows> att_block;rt_bf<16, kv_tile::rows> att_block_mma;};};template<int D> struct attn_fwd_template {static constexpr int NUM_CONSUMER_WARPS = 12, NUM_WORKERS = NUM_CONSUMER_WARPS/4, INPUT_PIPE_STAGES = 2;using layout = attn_fwd_layout<D, NUM_WORKERS>;__device__ static inline void common_setup(common_setup_args<layout> args) {args.common.batch = blockIdx.z; args.common.head = blockIdx.y; args.common.seq = blockIdx.x;args.num_iters = args.task_iter == 0 ? args.globals.K.rows/layout::kv_tile::rows : -1;}struct producer {__device__ static inline void setup(producer_setup_args<layout> args) {warpgroup::producer_registers();}__device__ static inline void load(producer_load_args<layout> args) {if(warpgroup::warpid() == 0) {tma::expect(args.inputs_arrived, args.input);tma::load_async(args.input.k, args.globals.K, {args.common.batch, args.common.head, args.iter, 0}, args.inputs_arrived);tma::load_async(args.input.v, args.globals.V, {args.common.batch, args.common.head, args.iter, 0}, args.inputs_arrived);}else if(laneid() == 0) arrive(args.inputs_arrived);}};struct consumer {__device__ static inline void setup(consumer_setup_args<layout> args) {warpgroup::consumer_registers<NUM_WORKERS>();if((args.common.seq*NUM_WORKERS + warpgroup::groupid())*layout::qo_tile::rows < args.globals.Q.rows) // out of bounds?warpgroup::load(args.scratch.q[warpgroup::groupid()], args.globals.Q,{args.common.batch, args.common.head, args.common.seq*NUM_WORKERS+warpgroup::groupid(), 0});zero(args.state.o_reg);zero(args.state.norm_vec);neg_infty(args.state.max_vec);warpgroup::sync(warpgroup::groupid());}__device__ static inline void compute(consumer_compute_args<layout> args) {constexpr float TEMPERATURE_SCALE = (D == 128) ? 0.08838834764f*1.44269504089f : 0.125f*1.44269504089f;// A = Q @ K.Twarpgroup::mm_ABt(args.state.att_block, args.scratch.q[warpgroup::groupid()], args.input.k);mul(args.state.max_vec_last_scaled, args.state.max_vec, TEMPERATURE_SCALE);warpgroup::mma_async_wait();// softmaxrow_max(args.state.max_vec, args.state.att_block, args.state.max_vec); // accumulate onto the max_vecmul(args.state.max_vec_scaled, args.state.max_vec, TEMPERATURE_SCALE);mul(args.state.att_block, args.state.att_block, TEMPERATURE_SCALE);sub_row(args.state.att_block, args.state.att_block, args.state.max_vec_scaled);exp2(args.state.att_block, args.state.att_block);sub(args.state.max_vec_last_scaled, args.state.max_vec_last_scaled, args.state.max_vec_scaled);exp2(args.state.max_vec_last_scaled, args.state.max_vec_last_scaled);mul(args.state.norm_vec, args.state.norm_vec, args.state.max_vec_last_scaled);row_sum(args.state.norm_vec, args.state.att_block, args.state.norm_vec); // accumulate onto the norm_vecmul_row(args.state.o_reg, args.state.o_reg, args.state.max_vec_last_scaled); // normalize o_reg before mmacopy(args.state.att_block_mma, args.state.att_block); // convert to bf16 for mma// O += A @ Vwarpgroup::mma_AB(args.state.o_reg, args.state.att_block_mma, args.input.v);warpgroup::mma_async_wait();if(laneid() == 0) arrive(args.inputs_finished); // done!}__device__ static inline void finish(consumer_finish_args<layout> args) {if((args.common.seq*NUM_WORKERS+warpgroup::groupid())*64 >= args.globals.Q.rows) return; // out of bounds?div_row(args.state.o_reg, args.state.o_reg, args.state.norm_vec);auto &o_smem = reinterpret_cast<typename layout::qo_tile&>(args.scratch.q[warpgroup::groupid()]);warpgroup::store(o_smem, args.state.o_reg);warpgroup::sync(warpgroup::groupid());if(warpgroup::warpid() == 0)tma::store_async(args.globals.O, o_smem, {args.common.batch, args.common.head, args.common.seq*NUM_WORKERS+warpgroup::groupid(), 0});}};};
我的直观感受有如下几点
1.相比TriDao的FA1/FA2/FA3,这毫无疑问是精简了非常多,load那就真的是纯load,就几个load API放那,compute那就真的是纯compute,不用你手写异步同步编排流水线细节什么的,epilogue(FA2提出了把除softmax的分母放在epilogue)和store就放在finish
2.但是也需要自己写一些带arrive和sync的简单primitive,这点对于CUDA水平适中的人能够接收
3.整体来看,确实简洁不少,但是这是2026年了,大家可以去看看tilelang/cuTile的FA实现,你会发现,那和pytorch FA实现真没啥区别。。。。我这里就不贴了,大家自己去看吧,技术确实是在发展并且很快。
最后
我聊聊TK相比2025/2026年这些python kernel DSL的优劣势吧
- TK 是 C++ 模板框架,无需 Python到C++到CUDA 的转换链路,避免了bridge带来的性能损耗和兼容性问题;
- TK 的抽象(Tile)直接对应 GPU 硬件层级,开发者能通过 NCU精准定位性能瓶颈且修改对应级别代码,而 Python DSL 自带的编译栈会增加调试难度;
- 兼容 C++ 生态:在一些只能用C++的项目里面,TK具有绝对优势,并且现在其实并不是所有人都喜欢用python,对于使用习惯了C++的人来说,讲真python会觉得有点反人类
至于在定制化算子以及运行时开销上面,老实说,虽然很多人觉得TK是C++模板,是封装更好的cutlass,定制化和运行时开销上面绝对打爆python kernel DSL,但其实,你去看下tilelang,就会发现,在tvm ffi的支持下,运行时开销已经很小了,定制化上面,这些python kernel DSL也支持越来越多的primitive,不输这些C++模板库了。
最后,话说回来,TK作为一个学习对象,绝对是OK的

