大数跨境
0
0

蚂蚁集团inclusionAI 贡献SGLang-Jax方案,打造业界领先的开源推理引擎

蚂蚁集团inclusionAI 贡献SGLang-Jax方案,打造业界领先的开源推理引擎 蚂蚁技术AntTech
2025-11-04
2



10月19日,SGLang团队发布文章,宣布推出全新的开源推理引擎SGLang-Jax,这是一个完全基于Jax 和 XLA 构建的最先进的开源推理引擎,实现了高效、原生的TPU推理能力。蚂蚁集团inclusionAI团队一起参与该项目的建设,同时SgLang-Jax项目首发支持了蚂蚁自研的蚂蚁百灵非思考模型Ling和思考模型Ring,未来蚂蚁InclusionAI团队会和社区一起,持续投入SGLang-Jax项目,推动项目持续发展和演进。



小编对SGLang-Jax技术博客进行了翻译,希望能为模型研究、推理研究从业者带来一定的启发。原文翻译如下:






SGLang-Jax:一个用于原生 TPU 推理的开源解决方案

作者:SGLang-Jax 团队,2025年10月29日


我们很高兴推出 SGLang-Jax,这是一个完全基于 Jax 和 XLA 构建的最先进的开源推理引擎。它利用了 SGLang 的高性能服务器架构,并使用 Jax 来编译模型的前向传播。通过结合 SGLang 和 Jax,该项目提供了快速、原生的 TPU 推理能力,同时保持了对高级功能的支持,如连续批处理、前缀缓存、张量和专家并行、推测解码、内核融合以及高度优化的 TPU 内核。


基准测试表明,SGLang-Jax 的性能与其他 TPU 推理解决方案相当或更优。源代码可在(https://github.com/sgl-project/sglang-jax) 获取。



为何选择 Jax 后端?


虽然 SGLang 最初基于 PyTorch 构建,但社区一直渴望获得 Jax 支持。


我们构建 Jax 后端主要基于以下几个关键原因:


Jax 是从零开始为 TPU 设计的。为了实现无妥协的最大性能,Jax 是明确的选择。随着谷歌扩大对 TPU 的公共访问,我们预计 Jax + TPU 将获得显著关注,并实现成本高效的推理。


领先的 AI 实验室——包括 Google DeepMind、xAI、Anthropic 和 Apple——已经依赖 Jax。在训练和推理中使用相同的框架可以减少维护开销,并消除两个阶段之间的差异。


Jax + XLA 是一个经过验证的、编译驱动的技术栈,在 TPU 上表现出色,并在各种定制的类 TPU AI 芯片上表现良好。



架构


下图展示了 SGLang-Jax 的架构。整个技术栈是纯 Jax,代码简洁,依赖极少。


在输入侧,它通过 OpenAI 兼容的 API 接受请求,并利用 SGLang 高效的 RadixCache 进行前缀缓存,以及其重叠调度器进行低开销批处理。调度器为不同的批处理大小预编译 Jax 计算图。在模型侧,我们使用 Flax 实现模型,并使用shard_map实现各种并行策略。两个核心算子——注意力机制和 MoE——是作为自定义的 Pallas 内核实现的。


(SGLang-Jax 架构图)



关键优化


集成 Ragged Paged Attention v3

我们集成了 Ragged Paged Attention V3(RPA v3)(https://github.com/vllm-project/tpu-inference/tree/main/tpu_inference/kernels/ragged_paged_attention/v3)并对其进行了扩展以支持 SGLang 功能:

-我们根据不同场景调整了内核网格块配置,以实现更好的性能。

-我们使其与 RadixCache 兼容。

-为了支持 EAGLE 推测解码,我们在 RPA v3 中添加了自定义掩码,用于验证阶段。


减少调度开销

前向传播期间在 CPU 和 TPU 上的顺序操作可能会损害性能。然而,不同设备上的操作可以解耦——例如,在 TPU 上启动计算的同时立即准备下一个要运行的批次。为了提高性能,我们的调度器将 CPU 处理与 TPU 计算重叠进行。


在重叠事件循环中,调度器使用结果队列和线程事件来流水线化 CPU 和 TPU 工作。当 TPU 处理批次 N 时,CPU 准备批次 N+1。为了最大化 CPU 和 TPU 之间的重叠,SGLang-jax 根据性能分析结果仔细安排了操作顺序。对于 Qwen/Qwen3-32B 模型,我们将预填充和解码之间的时间间隔从大约 12 毫秒减少到 38 微秒,从大约 7 毫秒减少到 24 微秒。更多细节可以在我们之前的博客(https://lmsys.org/blog/2024-12-04-sglang-v0-4/)中找到。


(使用重叠调度器的性能分析图,批次间的间隙极小)


(未使用重叠调度器的性能分析图,注意批次间存在较大间隙(CPU 开销)



MoE 内核优化


MoE 层目前支持两种实现策略:EPMoE 和 FusedMoE。在 EPMoE 中,我们集成了 Megablox GMM 算子,取代了之前基于 jax ragged_dot 的实现。Megablox GMM 专为 MoE 工作负载设计,能高效处理由 group_sizes 描述的可变大小专家组,消除了不必要的计算和非连续内存访问。在典型配置下,与 jax 原生的 ragged_dot 实现相比,该算子带来了3-4倍的端到端 ITL 加速。结合高效的令牌置换(permute/unpermute)、通过 ragged_all_to_all 实现的专家并行通信以及自适应分块策略,EPMoE 显著提高了整体吞吐量,并在需要跨设备并行处理大量专家的场景中表现良好。相比之下,FusedMoE 使用密集的 einsum 操作融合所有专家计算,没有设备间通信开销。它更适用于单个专家规模大但专家总数较少的情况(例如,< 64 个专家)。它也可以作为一个轻量级的备选方案,便于调试和正确性验证。



推测解码


SGLang-jax 实现了基于 EAGLE 的推测解码,也称为多令牌预测。这种先进的推测解码技术通过使用一个轻量级的草稿头来预测多个令牌,然后通过一次完整模型的前向传播并行验证这些令牌,从而加速生成。为了实现基于树的 MTP-Verify,SGLang-jax 在 Ragged Paged Attention V3 之上添加了非因果掩码支持,使得在验证阶段能够并行解码基于树的、非因果的草稿令牌。我们目前支持 Eagle2 和 Eagle3,并计划继续优化内核实现,并在不同的 MTP 阶段添加对不同注意力后端的支持。



TPU 性能


经过所有优化后,SGLang-Jax 的性能与其他 TPU 推理解决方案相当或更优。与 GPU 解决方案相比,在 TPU 上运行的 SGLang-Jax 也具备竞争力。


您可以在(https://github.com/sgl-project/sglang-jax/issues/297)找到完整的基准测试结果和说明。



使用方法


安装 SGLang-Jax 并启动服务器

安装:

# with uvuv venv --python 3.12 && source .venv/bin/activateuv pip install sglang-jax# from sourcegit clone https://github.com/sgl-project/sglang-jaxcd sglang-jaxuv venv --python 3.12 && source .venv/bin/activateuv pip install -e python/


启动服务器:

MODEL_NAME="Qwen/Qwen3-8B"  # or "Qwen/Qwen3-32B"jax_COMPILATION_CACHE_DIR=/tmp/jit_cache \uv run python -u -m sgl_jax.launch_server \--model-path ${MODEL_NAME} \--trust-remote-code \--tp-size=4 \--device=tpu \--mem-fraction-static=0.8 \--chunked-prefill-size=2048 \--download-dir=/tmp \--dtype=bfloat16 \--max-running-requests 256 \--page-size=128



通过 GCP 控制台使用 TPU


您可以在菜单 → Compute Engine 下找到TPU 选项,并在控制台中点击创建TPU。注意:只有特定区域支持特定的TPU版本。请记得将TPU软件版本设置为 v2-alpha-tpuv6e。在Compute Engine菜单下,转到Settings → Metadata,点击SSH Keys按钮,并添加您的公钥。TPU服务器创建后,您可以使用控制台中显示的外部 IP 和公钥用户名登录。另请参阅:(https://docs.cloud.google.com/tpu/docs/setup-gcp-account) 




通过 SkyPilot 使用 TPU


我们推荐在日常开发中使用SkyPilot(https://github.com/skypilot-org/skypilot)。您可以快速设置SkyPilot,并在 sglang-jax 代码库中找到用于启动开发机和运行测试的脚本。


为 GCP 安装 SkyPilot:

(https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp)


然后启动 sgl-jax.sky.yaml (https://github.com/sgl-project/sglang-jax/blob/main/scripts/tpu_resource.sky.yaml)

sky launch sgl-jax.sky.yaml --cluster=sgl-jax-skypilot-v6e-4 --infra=gcp -i 30 --down -y --use-spot

此命令将在各个区域中寻找成本最低的 TPU 抢占式实例,并在空闲 30 分钟后自动关闭实例。它还会为您安装 sglang-jax 环境。设置完成后,您可以直接使用 ssh <集群名称> 登录,无需跟踪外部 IP 地址。



发展路线图


社区正在与 Google Cloud 团队及多个合作伙伴就以下路线图进行合作:


  • 模型支持与优化

o 优化 Grok2, Ling/Ring, DeepSeek V3, 和 GPT-OSS

o 支持 MiMo-Audio, Wan 2.1, Qwen3 VL


  • TPU 优化内核

o 量化内核

o 通信与计算重叠内核

o MLA 内核


  • RL 集成与 tunixhttps://github.com/google/tunix)

o 权重同步


  • Pathways 和多主机支持


  • 高级服务特性

o 预填充-解码解耦

o 分层 KV 缓存

o 多 LoRA 批处理



致谢


  • SGLang-jax 团队:sii-xinglong, jimoosciuc, Prayer, aolemila, JamesBrianD, zkkython, neo, leos, pathfinder-pf, Jiacheng Yang, Hongzhen Chen, Ying Sheng, Ke Bao, Qinghan Chen

  • Google:Chris Yang, Shun Wang, Michael Zhang, Xiang Li, Xueqi Liu

  • InclusionAI:Junping Zhao, Guowei Wang, Yuhong Guo, Zhenxuan Pan

【声明】内容源于网络
0
0
蚂蚁技术AntTech
科技是蚂蚁创造未来的核心动力
内容 1081
粉丝 0
蚂蚁技术AntTech 科技是蚂蚁创造未来的核心动力
总阅读263
粉丝0
内容1.1k