大数跨境
0
0

简单了解下CUDA Green Context

简单了解下CUDA Green Context 极市平台
2025-07-17
0
↑ 点击蓝字 关注极市平台
作者丨GiantPandaLLM
来源丨GiantPandaLLM
编辑丨极市平台

极市导读

 

本文介绍了 CUDA Green Context 的实现和使用方法,文章基于 Flashinfer 的实现,详细讲解了如何创建和使用 Green Context,并探讨了其在实际应用中的性能表现。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

0x0. 前言

在DeepSeek V3的blog(https://zhuanlan.zhihu.com/p/27181462601) 中提到的TBO作用在Prefill阶段时,我们可以从它的调度图上看到对于计算的Stream使用了108个SM,而通信的Stream则使用了剩下的24个SM。之前一直比较好奇这个SM划分是怎么做到的,最近关注到Flashinfer引入了CUDA Green Context可以比较方便的来实现这个功能(要求CUDA 12.0+),所以这里就基于Flashinfer相关的实现来简单了解一下CUDA Green Context的实现。从NV论坛和CCCL的支持来看这个feature似乎也是处于实验阶段, 在CUDA-Samples里面也找不到例子,所以我这里的介绍只是起一个科普作用,可以关注后续的演进。

相关的PR为:https://github.com/flashinfer-ai/flashinfer/pull/1163

0x1. CUDA Green Context和普通Context的区别

根据CUDA Green Context的文档:https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html#group__CUDA__GREEN__CONTEXTS_1g6115d21604653f4eafb257f725538ab6在CUDA 12.0+中引入了CUDA Green Context,它和普通Context的区别在于:

  • CUDA Green Contexts 提供了资源隔离功能,可以让每个上下文在执行时不干扰其他上下文,这对于需要高并发的任务尤为重要。
  • 对于多线程的应用程序,CUDA Green Contexts 可以有效降低因上下文切换导致的性能损失,使得多线程的 CUDA 应用能更顺畅地运行。
  • 通过在多个上下文之间进行并行处理,能够提高 GPU 的使用率,从而提升整体计算吞吐量。适合需要执行多个独立运算的场景。

普通的CUDA Context Stream无法实现资源隔离或要实现资源隔离需要做一些很tricky的魔法,并且普通的多Stream并行执行kernel的时候也容易因为某个kernel用满了SM导致无法overlap。可以参考https://mp.weixin.qq.com/s/Y6r-rjBEEN5akPHmx6jS3w 这里的cuda kernel执行和nsys图。而CUDA Green Context Stream可以通过划分SM实现资源隔离让overlap更容易做到,这是我的理解。

0x2. CUDA Green Context怎么用

FlashInfer中引入CUDA Green Context的代码就对应下面这个代码片段:

from typing import List, Tuple  

import cuda.bindings.driver as driver  
import cuda.bindings.runtime as runtime  
import cuda.cudart as cudart  
import cuda.nvrtc as nvrtc  
import torch  
from cuda.bindings.driver import CUdevice, CUdevResource  

def_cudaGetErrorEnum(error):  
"""获取CUDA错误枚举的名称字符串  

    Args:  
        error: CUDA错误对象,可能是driver、runtime或nvrtc的错误类型  

    Returns:  
        错误名称的字符串表示  
    "
""  
if isinstance(error, driver.CUresult):  
# 处理CUDA Driver API错误  
        err, name = driver.cuGetErrorName(error)  
return name if err == driver.CUresult.CUDA_SUCCESS else"<unknown>"  
elif isinstance(error, runtime.cudaError_t):  
# 处理CUDA Runtime API错误  
return cudart.cudaGetErrorName(error)[1]  
elif isinstance(error, nvrtc.nvrtcResult):  
# 处理NVRTC编译错误  
return nvrtc.nvrtcGetErrorString(error)[1]  
else:  
raise RuntimeError(f"Unknown error type: {error}")  

defcheckCudaErrors(result):  
"""检查CUDA API调用的返回结果,如果有错误则抛出异常  

    Args:  
        result: CUDA API调用的返回结果,通常是一个元组  

    Returns:  
        如果没有错误,返回结果数据部分(去除错误码)  

    Raises:  
        RuntimeError: 如果CUDA调用出现错误  
    "
""  
if result[0].value:  
# 如果错误码非零,说明有错误发生  
raise RuntimeError(  
f"CUDA error code={result[0].value}({_cudaGetErrorEnum(result[0])})"  
        )  
# 根据返回结果的长度来决定返回什么  
if len(result) == 1:  
returnNone# 只有错误码,没有数据  
elif len(result) == 2:  
return result[1]  # 返回数据部分  
else:  
return result[1:]  # 返回多个数据项  

defget_cudevice(dev: torch.device) -> CUdevice:  
"""获取指定PyTorch设备对应的CUDA设备句柄  

    Args:  
        dev: PyTorch设备对象  

    Returns:  
        CUDA设备句柄  
    "
""  
try:  
# 尝试直接获取CUDA设备  
        cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index))  
except RuntimeError as e:  
# 如果失败,先初始化设备再获取  
        runtime.cudaInitDevice(dev.index, 0, 0)  
        cu_dev = checkCudaErrors(driver.cuDeviceGet(dev.index))  
return cu_dev  

defget_device_resource(cu_dev: CUdevice) -> CUdevResource:  
"""获取指定CUDA设备的SM(流处理器)资源  

    Args:  
        cu_dev: CUDA设备句柄  

    Returns:  
        设备的SM资源对象  
    "
""  
return checkCudaErrors(  
        driver.cuDeviceGetDevResource(  
            cu_dev, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM  
        )  
    )  

defsplit_resource(  
    resource: CUdevResource,  
    num_groups: int,  
    min_count: int,  
) -> Tuple[CUdevResource, CUdevResource]:  
"""将SM资源按指定数量分割成多个组  

    Args:  
        resource: 要分割的SM资源  
        num_groups: 要分割成的组数  
        min_count: 每组最少的SM数量  

    Returns:  
        分割后的资源组列表和剩余的资源  
    "
""  
    results, _, remaining = checkCudaErrors(  
        driver.cuDevSmResourceSplitByCount(  
            num_groups,      # 分组数量  
            resource,        # 原始资源  
0,              # useFlags - 使用标志,0表示默认  
            min_count,      # 每组最小SM数量  
        )  
    )  
return results, remaining  

defcreate_green_ctx_streams(  
    cu_dev: CUdevice, resources: List[CUdevResource]  
) -> List[torch.Stream]:  
"""为每个SM资源组创建对应的Green Context和Stream  

    Args:  
        cu_dev: CUDA设备句柄  
        resources: SM资源组列表  

    Returns:  
        对应每个资源组的PyTorch Stream列表  
    "
""  
    streams = []  
for split in resources:  
# 为每个分割的资源创建描述符  
        desc = checkCudaErrors(driver.cuDevResourceGenerateDesc([split], 1))  

# 创建Green Context,这是CUDA 12.0+的新特性  
# Green Context允许在不同的SM分区上并发执行多个kernel  
        green_ctx = checkCudaErrors(  
            driver.cuGreenCtxCreate(  
                desc,    # 资源描述符  
                cu_dev,  # 设备句柄  
                driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM  # 创建标志  
            )  
        )  

# 在Green Context中创建Stream  
        stream = checkCudaErrors(  
            driver.cuGreenCtxStreamCreate(  
                green_ctx,  # Green Context  
                driver.CUstream_flags.CU_STREAM_NON_BLOCKING,  # 非阻塞Stream  
0,          # priority - 优先级,0表示默认  
            )  
        )  

# 将CUDA Driver API的Stream转换为PyTorch的Stream  
        streams.append(torch.cuda.get_stream_from_external(stream))  

return streams  

defsplit_device_green_ctx(  
    dev: torch.device, num_groups: int, min_count: int  
) -> Tuple[List[torch.Stream], List[CUdevResource]]:  
r"""  
    将设备分割成多个Green Context,为每个组和剩余的SM返回对应的Stream和资源。  
    Green Context允许在不同的SM分区上并发执行多个kernel。  

    Args:  
        dev: 要分割的设备  
        num_groups: 要分割成的组数  
        min_count: 每组所需的最少SM数量,会根据对齐和粒度要求进行调整  

    Returns:  
        streams: 对应于Green Context的torch.Stream对象列表  
        resources: 对应于Green Context的CUdevResource对象列表  

    Example:  
        >>> from flashinfer.green_ctx import split_device_green_ctx  
        >>> import torch  
        >>> dev = torch.device("
cuda:0")  
        >>> streams, resources = split_device_green_ctx(dev, 2, 16)  
        >>> print([r.sm.smCount for r in resources])  
        [16, 16, 100]  
        >>> with torch.cuda.stream(streams[0]):  
        ...     x = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)  
        ...     y = torch.randn(8192, 8192, device=dev, dtype=torch.bfloat16)  
        ...     z = x @ y  
        ...     print(z.shape)  
        ...  
        torch.Size([8192, 8192])  

    Note:  
        返回的streams和resources的长度为 ``num_groups + 1``,  
        其中最后一个是剩余的SM。  

    Raises:  
        RuntimeError: 当请求的SM分配超过设备容量时:  
        ``num_groups * round_up(min_count, 8) > num_sm``  
    "
""  
# 1. 获取CUDA设备句柄  
    cu_dev = get_cudevice(dev)  

# 2. 获取设备的SM资源  
    resource = get_device_resource(cu_dev)  

# 3. 将SM资源分割成指定数量的组  
    results, remaining = split_resource(resource, num_groups, min_count)  

# 4. 将分割的结果和剩余资源合并成一个列表  
    resources = results + [remaining]  

# 5. 为每个资源组创建对应的Green Context和Stream  
    streams = create_green_ctx_streams(cu_dev, resources)  

return streams, resources  

可以看到这个CUDA Green Context的使用相对还是比较简单的,主要就是通过cuDevResourceGenerateDesc来生成资源描述符,然后通过cuGreenCtxCreate来创建Green Context,最后通过cuGreenCtxStreamCreate来创建CUDA Green Context的Stream。

用法也比较简单,可以参考这里的单测代码:

@pytest.mark.parametrize("device", ["cuda:0"])  
@pytest.mark.parametrize("num_groups", [1, 2, 3])  
@pytest.mark.parametrize("min_count", [16, 32])  
deftest_green_ctx_kernel_execution(  
    device: str,  
    num_groups: int,  
    min_count: int,  
):  
    streams, resources = green_ctx.split_device_green_ctx(  
        torch.device(device), num_groups, min_count  
    )  
    num_partitions = num_groups + 1  
assert len(streams) == num_partitions  
assert len(resources) == num_partitions  

for stream in streams:  
with torch.cuda.stream(stream):  
            x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)  
            y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)  
            z = x @ y  
            print(z.shape)  

这个用法只是展示了一下怎么利用FlashInfer的CUDA Green Context来实现SM的分割和创建多Streams,并没有看到如何利用CUDA Green Context来实现kernel overlap相关例子。我尝试使用这里提供的api来实现一个没有依赖的M,N,K=8192,8192,8192torch.matmultorch.sigmoidkernel overlap,使用10个SM做torch.sigmoid,剩下的SM做torch.matmul。但是测试之后发现这个性能相比于baseline直接顺序执行的版本反而耗时高了快2倍,不确定是不是FlashInfer CUDA Green Context这里的打开方式不正确,之后有这个feature的发展或者相关应用的话继续关注一下。


公众号后台回复“数据集”获取100+深度学习各方向资源整理

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

点击阅读原文进入CV社区

收获更多技术干货

【声明】内容源于网络
0
0
极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
内容 8155
粉丝 0
极市平台 为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
总阅读3.2k
粉丝0
内容8.2k