小编对SGLang-Jax技术博客进行了翻译,希望能为模型研究、推理研究从业者带来一定的启发。原文翻译如下:
作者:SGLang-Jax 团队,2025年10月29日
我们很高兴推出 SGLang-Jax,这是一个完全基于 Jax 和 XLA 构建的最先进的开源推理引擎。它利用了 SGLang 的高性能服务器架构,并使用 Jax 来编译模型的前向传播。通过结合 SGLang 和 Jax,该项目提供了快速、原生的 TPU 推理能力,同时保持了对高级功能的支持,如连续批处理、前缀缓存、张量和专家并行、推测解码、内核融合以及高度优化的 TPU 内核。
基准测试表明,SGLang-Jax 的性能与其他 TPU 推理解决方案相当或更优。源代码可在(https://github.com/sgl-project/sglang-jax) 获取。
为何选择 Jax 后端?
虽然 SGLang 最初基于 PyTorch 构建,但社区一直渴望获得 Jax 支持。
我们构建 Jax 后端主要基于以下几个关键原因:
Jax 是从零开始为 TPU 设计的。为了实现无妥协的最大性能,Jax 是明确的选择。随着谷歌扩大对 TPU 的公共访问,我们预计 Jax + TPU 将获得显著关注,并实现成本高效的推理。
领先的 AI 实验室——包括 Google DeepMind、xAI、Anthropic 和 Apple——已经依赖 Jax。在训练和推理中使用相同的框架可以减少维护开销,并消除两个阶段之间的差异。
Jax + XLA 是一个经过验证的、编译驱动的技术栈,在 TPU 上表现出色,并在各种定制的类 TPU AI 芯片上表现良好。
架构
下图展示了 SGLang-Jax 的架构。整个技术栈是纯 Jax,代码简洁,依赖极少。
在输入侧,它通过 OpenAI 兼容的 API 接受请求,并利用 SGLang 高效的 RadixCache 进行前缀缓存,以及其重叠调度器进行低开销批处理。调度器为不同的批处理大小预编译 Jax 计算图。在模型侧,我们使用 Flax 实现模型,并使用shard_map实现各种并行策略。两个核心算子——注意力机制和 MoE——是作为自定义的 Pallas 内核实现的。
(SGLang-Jax 架构图)
关键优化
MoE 内核优化
MoE 层目前支持两种实现策略:EPMoE 和 FusedMoE。在 EPMoE 中,我们集成了 Megablox GMM 算子,取代了之前基于 jax ragged_dot 的实现。Megablox GMM 专为 MoE 工作负载设计,能高效处理由 group_sizes 描述的可变大小专家组,消除了不必要的计算和非连续内存访问。在典型配置下,与 jax 原生的 ragged_dot 实现相比,该算子带来了3-4倍的端到端 ITL 加速。结合高效的令牌置换(permute/unpermute)、通过 ragged_all_to_all 实现的专家并行通信以及自适应分块策略,EPMoE 显著提高了整体吞吐量,并在需要跨设备并行处理大量专家的场景中表现良好。相比之下,FusedMoE 使用密集的 einsum 操作融合所有专家计算,没有设备间通信开销。它更适用于单个专家规模大但专家总数较少的情况(例如,< 64 个专家)。它也可以作为一个轻量级的备选方案,便于调试和正确性验证。
推测解码
SGLang-jax 实现了基于 EAGLE 的推测解码,也称为多令牌预测。这种先进的推测解码技术通过使用一个轻量级的草稿头来预测多个令牌,然后通过一次完整模型的前向传播并行验证这些令牌,从而加速生成。为了实现基于树的 MTP-Verify,SGLang-jax 在 Ragged Paged Attention V3 之上添加了非因果掩码支持,使得在验证阶段能够并行解码基于树的、非因果的草稿令牌。我们目前支持 Eagle2 和 Eagle3,并计划继续优化内核实现,并在不同的 MTP 阶段添加对不同注意力后端的支持。
TPU 性能
经过所有优化后,SGLang-Jax 的性能与其他 TPU 推理解决方案相当或更优。与 GPU 解决方案相比,在 TPU 上运行的 SGLang-Jax 也具备竞争力。
您可以在(https://github.com/sgl-project/sglang-jax/issues/297)找到完整的基准测试结果和说明。
使用方法
通过 GCP 控制台使用 TPU
您可以在菜单 → Compute Engine 下找到TPU 选项,并在控制台中点击创建TPU。注意:只有特定区域支持特定的TPU版本。请记得将TPU软件版本设置为 v2-alpha-tpuv6e。在Compute Engine菜单下,转到Settings → Metadata,点击SSH Keys按钮,并添加您的公钥。TPU服务器创建后,您可以使用控制台中显示的外部 IP 和公钥用户名登录。另请参阅:(https://docs.cloud.google.com/tpu/docs/setup-gcp-account)
通过 SkyPilot 使用 TPU
我们推荐在日常开发中使用SkyPilot(https://github.com/skypilot-org/skypilot)。您可以快速设置SkyPilot,并在 sglang-jax 代码库中找到用于启动开发机和运行测试的脚本。
为 GCP 安装 SkyPilot:
(https://docs.skypilot.co/en/latest/getting-started/installation.html#gcp)
然后启动 sgl-jax.sky.yaml (https://github.com/sgl-project/sglang-jax/blob/main/scripts/tpu_resource.sky.yaml):
sky launch sgl-jax.sky.yaml --cluster=sgl-jax-skypilot-v6e-4 --infra=gcp -i 30 --down -y --use-spot
此命令将在各个区域中寻找成本最低的 TPU 抢占式实例,并在空闲 30 分钟后自动关闭实例。它还会为您安装 sglang-jax 环境。设置完成后,您可以直接使用 ssh <集群名称> 登录,无需跟踪外部 IP 地址。
发展路线图
社区正在与 Google Cloud 团队及多个合作伙伴就以下路线图进行合作:
致谢
SGLang-jax 团队:sii-xinglong, jimoosciuc, Prayer, aolemila, JamesBrianD, zkkython, neo, leos, pathfinder-pf, Jiacheng Yang, Hongzhen Chen, Ying Sheng, Ke Bao, Qinghan Chen
Google:Chris Yang, Shun Wang, Michael Zhang, Xiang Li, Xueqi Liu
InclusionAI:Junping Zhao, Guowei Wang, Yuhong Guo, Zhenxuan Pan

