大数跨境

全面解析MoE专家并行EP all-to-all算子(小白也能看懂)

全面解析MoE专家并行EP all-to-all算子(小白也能看懂) AI不止算法
2025-10-30
0
最近,又面了些候选人,主要是3-6年经验段的,给我的感觉就是水平一般,听说过很多东西,但是一深问就歇菜,我们面试的标准倒不是非要问倒你,只是想要摸一下底,毕竟100%回答上问题这是不现实的。在这么多次面试中,我看到很多人简历上写了TP PP等并行策略等等,也对dense模型或者moe模型做了很多性能分析,于是我常规下问了TP PP等基础知识后,又转向EP专家并行,问了下EP中非常重要的通信算子All-to-All,很遗憾,很多人只是知道EP里面有all-to-all,但是并不知道其实现过程,于是乎本文旨在帮助不了解EP all-to-all的朋友了解以下all-to-all的实现过程,仅关乎讲清实现原理,不谈其性能优化

背景:MoE 与EP专家并行的诞生


MoE 的核心思想源于 "分而治之":将模型的FFN/MLP拆分为多个独立的 "专家"(Expert),每个输入 token 仅由部分专家(TOPK个)处理,再通过门控网络聚合结果。这种设计使模型参数量随专家数量线性增长,却无需按比例增加计算量(因每个 token 仅激活少数专家)。

然而,当专家数量规模达到数千甚至上万时,单设备已无法容纳所有专家 —— 这就需要专家并行(Expert Parallelism):将专家分布在不同GPU上,每个GPU仅保存部分专家,这与TP的思路其实差不多,都是把参数分布到不同GPU,但是粒度和效果与TP截然不同,之前写了篇文章分析过EP与TP的对比。此时,把专家按 GPU 切片后,每张卡只存 Expert数量/world_size 个专家。前向传播时,token 可能命中任意专家 ,我们必须把 token 发到目标GPU,算完再拿回来这个“跨卡搬运 + 还原到原始GPU”就是 EP all-to-all,也是专家并行的灵魂:

  • 如何将输入 token 高效分发到其选中的专家所在设备(all-to-all Dispatch 阶段)?
  • 如何将专家的输出结果回传到原始设备并聚合结果(all-to-all Combine 阶段)?

All-to-All 通信正是为解决这两个问题而生的分布式通信范式。

动机:MoE场景下EP all-to-all无法被TP DP取代

除了之前文章分析过EP与TP的对比之外,这里再补充几点:


在传统分布式训练 / 推理中,数据并行(Data Parallelism)和模型并行(Model Parallelism)已成为主流,但它们在 MoE 场景下存在显著局限:

  • 数据并行DP:每个设备复制完整模型(包括所有专家),导致内存占用随专家数量线性增长;

  • 模型并行TP:仅能拆分单个专家的参数,无法处理 MoE 中多个独立专家的分布式部署;

  • 点对点通信:若直接用点对点(Point-to-Point)通信分发 token,会导致通信次数随设备数量平方级增长(O (N²)),且易出现负载不均衡。


All-to-All 通信的核心优势在于全局协同的数据交换:每个设备可同时向所有其他设备发送数据,并从所有其他设备接收数据,且通信次数仅为 O (N)(N 为设备数)。这种特性完美匹配 MoE 中 "每个 token 可能被分配到任意专家或设备" 的稀疏激活场景。


问题阐述: all-to-all Dispatch和combine


背景部分最后留下了all-to-all解决的问题,本节来重点阐述一下,如下图,在一个moe layer里面,完整的过程包括all-to-all dispatch,MoE function(即各个专家MLP),all-to-all combine三阶段。

每个方块代表1个token,每个颜色用来更好看一下这些token在dispatch和combine前后的设备变化。举个例子,A1一开始分布在rank1,它选中的专家在rank0,经过dispatch后,它就跑rank0去了,跑过去之后,就和rank0上的专家斗地主(图中省略),然后又回到自己家rank1去了(combine)。每个token都遵循类似的过程。 

工作机制:从代码看 All-to-All 的实现细节

下面我们跟着代码来看看实现细节,完整代码比较长,这里摘重点

1 一些数据的初始化

初始化时, 明确专家数量、设备数及每个设备承载的专家数(num_local_experts),并预留足够的缓冲区以应对动态 token 数量(max_recv

class AllToAll:    # 元数据维度:global_exp, src_rank, src_token, src_k, 记录token从哪儿来的,用于combine把token发回去    meta = 4    def __init__(self, cfg: MoEConfig, rank: int, world_size: int):        self.cfg = cfg        self.rank = rank  # 当前设备编号        self.world_size = world_size  # 总设备数        self.num_local_experts = cfg.num_experts // world_size  # 每个设备上的专家数        self.max_recv = cfg.max_num_tokens * world_size  # 最大接收token数

2 Dispatch 阶段:将 token 分发到目标专家


Dispatch 的目标是:让每个 token 到达其选中的专家所在设备。代码中dispatch方法分为 6 个关键步骤:

步骤 1:计算发送 / 接收计数

首先统计当前设备需要向其他设备发送的 token 数量(send_counts),以及需要从其他设备接收的 token 数量(recv_counts):

# 统计每个目标设备需要接收的token数send_counts = [0] * self.world_sizetoken_map = [[] for _ in range(self.world_size)]  # 记录每个设备接收了哪些tokenmeta_map = [[] for _ in range(self.world_size)]   # 记录每个token的元数据for t, expert_list in enumerate(indices.tolist()):  # indices:记录了每个token选中的专家列表    for k, e in enumerate(expert_list):        dst_rank = e // self.num_local_experts  # 专家e所在的目标设备        send_counts[dst_rank] += 1        token_map[dst_rank].append(t)  # 记录token索引t需要发送到dst_rank        # 元数据:专家ID、源设备、源token索引、选中的专家序号(k)        meta_map[dst_rank].extend([e, self.rank, t, k])

通过torch.distributed.all_to_all_single交换send_counts,得到每个设备需要接收的 token 数recv_counts_t

send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device)recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device)dist.all_to_all_single(recv_counts_tsend_counts_t)  # 交换计数

步骤 2:构建发送缓冲区


将需要发送的 token 和元数据整理为连续的缓冲区:

# 构建token发送缓冲区(按目标设备拼接)send_buf = torch.cat([dp_x[idx_list] for idx_list in token_map], dim=0)# 构建元数据发送缓冲区(按META_DIM维度整理)send_meta = torch.tensor(    [v for sub in meta_map for v in sub], dtype=torch.int32, device=device).view(-1, self.meta)

步骤 3:All-to-All 数据交换


通过torch.distributed.all_to_all_singlesend_bufsend_meta发送到目标设备,同时接收来自其他设备的数据:

# 接收token的缓冲区(大小为总接收数×隐藏维度)total_recv = int(recv_counts_t.sum().item())recv_buf = torch.empty(total_recv, cfg.hidden_dim, dtype=cfg.in_dtype, device=device)# 接收元数据的缓冲区recv_meta = torch.empty(total_recv, self.META_DIM, dtype=torch.int32, device=device)# 交换token数据dist.all_to_all_single(    recv_buf, send_buf,    output_split_sizes=recv_counts_t.tolist(),  # 每个源设备接收的数量    input_split_sizes=send_counts_t.tolist()    # 每个目标设备发送的数量)# 交换元数据(元数据长度是token数×META_DIM)dist.all_to_all_single(    recv_meta.view(-1), send_meta.view(-1),    output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()],    input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()])

步骤 4:分发到本地专家


接收完成后,将recv_buf中的 token 按元数据分配到当前设备的本地专家的token缓冲区expert_token,它和expert_meta,expert_num_tokens也正是dispatch的输出:

# 初始化本地专家的token缓冲区expert_token = torch.empty(    (self.num_local_experts, self.max_recv, cfg.hidden_dim),    dtype=cfg.in_dtype, device=device)expert_meta = torch.empty(    (self.num_local_experts, self.max_recv, self.META_DIM),    dtype=torch.int32, device=device)expert_num_tokens = torch.zeros(self.num_local_experts, dtype=torch.int32, device=device)# 按元数据分配token到对应专家for i in range(total_recv):    global_eid = int(recv_meta[i, 0].item())  # 全局专家ID    local_eid = global_eid % self.num_local_experts  # 本地专家ID(相对于当前设备)    # 记录token数据、元数据和计数    expert_token[local_eid, expert_num_tokens[local_eid]] = recv_buf[i]    expert_meta[local_eid, expert_num_tokens[local_eid]] = recv_meta[i]    expert_num_tokens[local_eid] += 1

至此,Dispatch 阶段完成:每个 token 已到达其选中的专家所在设备,并被正确分配到本地专家的缓冲区


3 Combine 阶段:聚合专家输出结果


Combine 的目标是将专家处理后的结果回传到原始 token 所在设备,并按门控权重聚合。combine方法的核心逻辑与 Dispatch 对称,但方向相反:

步骤 1:统计回传计数


首先确定当前设备的专家输出需要回传到哪些设备:

send_counts = [0] * self.world_sizey_map = [[] for _ in range(self.world_size)]  # 专家输出的tokenmeta_map = [[] for _ in range(self.world_size)]  # 回传的元数据for local_eid in range(self.num_local_experts):    cnt = int(expert_num_tokens[local_eid].item())  # 该专家处理的token数    for j in range(cnt):        meta = expert_meta[local_eid, j]  # 提取元数据(包含原始设备信息)        dst_rank = int(meta[1].item())  # 原始token所在设备        send_counts[dst_rank] += 1        y_map[dst_rank].append(expert_y[local_eid, j].unsqueeze(0))  # 专家输出        meta_map[dst_rank].extend(meta.tolist())  # 回传元数据

步骤 2:All-to-All 回传数据


通过dist.all_to_all_single将专家输出和元数据回传到原始设备:

# 构建发送缓冲区send_buf = torch.cat([torch.cat(sub_list, dim=0if sub_list else torch.empty(0for sub_list in y_map], dim=0)send_meta = torch.tensor([v for sub in meta_map for v in sub], dtype=torch.int32, device=device).view(-1self.META_DIM)# 接收缓冲区total_recv = int(recv_counts_t.sum().item())recv_buf = torch.empty(total_recv, cfg.hidden_dim, dtype=cfg.out_dtype, device=device)recv_meta = torch.empty(total_recv, self.META_DIM, dtype=torch.int32, device=device)# 回传数据和元数据dist.all_to_all_single(recv_buf, send_buf, ...)dist.all_to_all_single(recv_meta.view(-1), send_meta.view(-1), ...)

步骤 3:加权聚合


最后,根据元数据找到原始 token,并按门控权重(weights)聚合结果,out_tokens即combine的输出:

for i in range(total_recv):    src_token = int(recv_meta[i, 2].item())  # 原始token索引    src_k = int(recv_meta[i, 3].item())      # 该token选中的第k个专家    w = weights[src_token, src_k].to(torch.float32)  # 门控权重    out_tokens[src_token] += recv_buf[i].to(torch.float32) * w  # 加权累加

至此,Combine 阶段完成:所有专家的输出已按权重聚合到原始 token,并且原始token也回到了原始rank,形成最终结果。


常见坑


1.专家负载不均衡:某些卡 recv 爆掉 max_recv 缓冲区,然而某些卡却没几个token。

2. 通信 hang: 通信里面hang的情况真是太多了,bug贼难找,检查是否做了正确的资源管理,有没有什么僵尸进程,进程组管理是否正确等等。

3. 精度掉点: FP16 all-to-all 累加误差。combine 的 out_tokens 用 FP32 累加,最后再 down-cast可能可以解决。

总结


All-to-All 通信是 MoE 实现专家并行的核心—— 它解决了稀疏激活场景下跨设备数据分发与聚合的核心难题,使 MoE 在保持计算效率的同时突破模型容量限制。从代码中可以看到,其实现的关键在于 "计数交换 - 数据缓冲 - 元数据跟踪" 的闭环设计:通过动态统计发送 / 接收数量实现灵活分配,通过连续缓冲区提升通信效率,通过元数据确保 token 流转的可追溯性。


随着大模型向万亿参数级迈进,MoE 与 All-to-All 通信的结合将成为规模化训练 / 推理的标配技术,而对其实现功能和底层机制的理解(如本文代码解析,虽然仅展现它是做了个什么事),正是优化分布式系统性能的基础。



【声明】内容源于网络
0
0
AI不止算法
AI-HPC/AI工程/AI推理加速/AI算子开发的技术分享和入门转行学习的全套解决方案提供
内容 92
粉丝 0
AI不止算法 AI-HPC/AI工程/AI推理加速/AI算子开发的技术分享和入门转行学习的全套解决方案提供
总阅读93
粉丝0
内容92