背景:MoE 与EP专家并行的诞生
MoE 的核心思想源于 "分而治之":将模型的FFN/MLP拆分为多个独立的 "专家"(Expert),每个输入 token 仅由部分专家(TOPK个)处理,再通过门控网络聚合结果。这种设计使模型参数量随专家数量线性增长,却无需按比例增加计算量(因每个 token 仅激活少数专家)。
然而,当专家数量规模达到数千甚至上万时,单设备已无法容纳所有专家 —— 这就需要专家并行(Expert Parallelism):将专家分布在不同GPU上,每个GPU仅保存部分专家,这与TP的思路其实差不多,都是把参数分布到不同GPU,但是粒度和效果与TP截然不同,之前写了篇文章分析过EP与TP的对比。此时,把专家按 GPU 切片后,每张卡只存 Expert数量/world_size 个专家。前向传播时,token 可能命中任意专家 ,我们必须把 token 发到目标GPU,算完再拿回来。这个“跨卡搬运 + 还原到原始GPU”就是 EP all-to-all,也是专家并行的灵魂:
-
如何将输入 token 高效分发到其选中的专家所在设备(all-to-all Dispatch 阶段)? -
如何将专家的输出结果回传到原始设备并聚合结果(all-to-all Combine 阶段)?
All-to-All 通信正是为解决这两个问题而生的分布式通信范式。
除了之前文章分析过EP与TP的对比之外,这里再补充几点:
在传统分布式训练 / 推理中,数据并行(Data Parallelism)和模型并行(Model Parallelism)已成为主流,但它们在 MoE 场景下存在显著局限:
数据并行DP:每个设备复制完整模型(包括所有专家),导致内存占用随专家数量线性增长;
模型并行TP:仅能拆分单个专家的参数,无法处理 MoE 中多个独立专家的分布式部署;
点对点通信:若直接用点对点(Point-to-Point)通信分发 token,会导致通信次数随设备数量平方级增长(O (N²)),且易出现负载不均衡。
All-to-All 通信的核心优势在于全局协同的数据交换:每个设备可同时向所有其他设备发送数据,并从所有其他设备接收数据,且通信次数仅为 O (N)(N 为设备数)。这种特性完美匹配 MoE 中 "每个 token 可能被分配到任意专家或设备" 的稀疏激活场景。
问题阐述: all-to-all Dispatch和combine
背景部分最后留下了all-to-all解决的问题,本节来重点阐述一下,如下图,在一个moe layer里面,完整的过程包括all-to-all dispatch,MoE function(即各个专家MLP),all-to-all combine三阶段。
每个方块代表1个token,每个颜色用来更好看一下这些token在dispatch和combine前后的设备变化。举个例子,A1一开始分布在rank1,它选中的专家在rank0,经过dispatch后,它就跑rank0去了,跑过去之后,就和rank0上的专家斗地主(图中省略),然后又回到自己家rank1去了(combine)。每个token都遵循类似的过程。
工作机制:从代码看 All-to-All 的实现细节
下面我们跟着代码来看看实现细节,完整代码比较长,这里摘重点
1 一些数据的初始化
初始化时, 明确专家数量、设备数及每个设备承载的专家数(num_local_experts),并预留足够的缓冲区以应对动态 token 数量(max_recv)。
class AllToAll:# 元数据维度:global_exp, src_rank, src_token, src_k, 记录token从哪儿来的,用于combine把token发回去meta = 4def __init__(self, cfg: MoEConfig, rank: int, world_size: int):self.cfg = cfgself.rank = rank # 当前设备编号self.world_size = world_size # 总设备数self.num_local_experts = cfg.num_experts // world_size # 每个设备上的专家数self.max_recv = cfg.max_num_tokens * world_size # 最大接收token数
2 Dispatch 阶段:将 token 分发到目标专家
Dispatch 的目标是:让每个 token 到达其选中的专家所在设备。代码中dispatch方法分为 6 个关键步骤:
步骤 1:计算发送 / 接收计数
首先统计当前设备需要向其他设备发送的 token 数量(send_counts),以及需要从其他设备接收的 token 数量(recv_counts):
# 统计每个目标设备需要接收的token数send_counts = [0] * self.world_sizetoken_map = [[] for _ in range(self.world_size)] # 记录每个设备接收了哪些tokenmeta_map = [[] for _ in range(self.world_size)] # 记录每个token的元数据for t, expert_list in enumerate(indices.tolist()): # indices:记录了每个token选中的专家列表for k, e in enumerate(expert_list):dst_rank = e // self.num_local_experts # 专家e所在的目标设备send_counts[dst_rank] += 1token_map[dst_rank].append(t) # 记录token索引t需要发送到dst_rank# 元数据:专家ID、源设备、源token索引、选中的专家序号(k)meta_map[dst_rank].extend([e, self.rank, t, k])
通过torch.distributed.all_to_all_single交换send_counts,得到每个设备需要接收的 token 数recv_counts_t:
send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device)recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device)dist.all_to_all_single(recv_counts_t, send_counts_t) # 交换计数
步骤 2:构建发送缓冲区
将需要发送的 token 和元数据整理为连续的缓冲区:
# 构建token发送缓冲区(按目标设备拼接)send_buf = torch.cat([dp_x[idx_list] for idx_list in token_map], dim=0)# 构建元数据发送缓冲区(按META_DIM维度整理)send_meta = torch.tensor([v for sub in meta_map for v in sub], dtype=torch.int32, device=device).view(-1, self.meta)
步骤 3:All-to-All 数据交换
通过torch.distributed.all_to_all_single将send_buf和send_meta发送到目标设备,同时接收来自其他设备的数据:
# 接收token的缓冲区(大小为总接收数×隐藏维度)total_recv = int(recv_counts_t.sum().item())recv_buf = torch.empty(total_recv, cfg.hidden_dim, dtype=cfg.in_dtype, device=device)# 接收元数据的缓冲区recv_meta = torch.empty(total_recv, self.META_DIM, dtype=torch.int32, device=device)# 交换token数据dist.all_to_all_single(recv_buf, send_buf,output_split_sizes=recv_counts_t.tolist(), # 每个源设备接收的数量input_split_sizes=send_counts_t.tolist() # 每个目标设备发送的数量)# 交换元数据(元数据长度是token数×META_DIM)dist.all_to_all_single(recv_meta.view(-1), send_meta.view(-1),output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()],input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()])
步骤 4:分发到本地专家
接收完成后,将recv_buf中的 token 按元数据分配到当前设备的本地专家的token缓冲区expert_token,它和expert_meta,expert_num_tokens也正是dispatch的输出:
# 初始化本地专家的token缓冲区expert_token = torch.empty((self.num_local_experts, self.max_recv, cfg.hidden_dim),dtype=cfg.in_dtype, device=device)expert_meta = torch.empty((self.num_local_experts, self.max_recv, self.META_DIM),dtype=torch.int32, device=device)expert_num_tokens = torch.zeros(self.num_local_experts, dtype=torch.int32, device=device)# 按元数据分配token到对应专家for i in range(total_recv):global_eid = int(recv_meta[i, 0].item()) # 全局专家IDlocal_eid = global_eid % self.num_local_experts # 本地专家ID(相对于当前设备)# 记录token数据、元数据和计数expert_token[local_eid, expert_num_tokens[local_eid]] = recv_buf[i]expert_meta[local_eid, expert_num_tokens[local_eid]] = recv_meta[i]expert_num_tokens[local_eid] += 1
至此,Dispatch 阶段完成:每个 token 已到达其选中的专家所在设备,并被正确分配到本地专家的缓冲区
3 Combine 阶段:聚合专家输出结果
Combine 的目标是将专家处理后的结果回传到原始 token 所在设备,并按门控权重聚合。combine方法的核心逻辑与 Dispatch 对称,但方向相反:
步骤 1:统计回传计数
首先确定当前设备的专家输出需要回传到哪些设备:
send_counts = [0] * self.world_sizey_map = [[] for _ in range(self.world_size)] # 专家输出的tokenmeta_map = [[] for _ in range(self.world_size)] # 回传的元数据for local_eid in range(self.num_local_experts):cnt = int(expert_num_tokens[local_eid].item()) # 该专家处理的token数for j in range(cnt):meta = expert_meta[local_eid, j] # 提取元数据(包含原始设备信息)dst_rank = int(meta[1].item()) # 原始token所在设备send_counts[dst_rank] += 1y_map[dst_rank].append(expert_y[local_eid, j].unsqueeze(0)) # 专家输出meta_map[dst_rank].extend(meta.tolist()) # 回传元数据
步骤 2:All-to-All 回传数据
通过dist.all_to_all_single将专家输出和元数据回传到原始设备:
# 构建发送缓冲区send_buf = torch.cat([torch.cat(sub_list, dim=0) if sub_list else torch.empty(0) for sub_list in y_map], dim=0)send_meta = torch.tensor([v for sub in meta_map for v in sub], dtype=torch.int32, device=device).view(-1, self.META_DIM)# 接收缓冲区total_recv = int(recv_counts_t.sum().item())recv_buf = torch.empty(total_recv, cfg.hidden_dim, dtype=cfg.out_dtype, device=device)recv_meta = torch.empty(total_recv, self.META_DIM, dtype=torch.int32, device=device)# 回传数据和元数据dist.all_to_all_single(recv_buf, send_buf, ...)dist.all_to_all_single(recv_meta.view(-1), send_meta.view(-1), ...)
步骤 3:加权聚合
最后,根据元数据找到原始 token,并按门控权重(weights)聚合结果,out_tokens即combine的输出:
for i in range(total_recv):src_token = int(recv_meta[i, 2].item())src_k = int(recv_meta[i, 3].item())w = weights[src_token, src_k].to(torch.float32)out_tokens[src_token] += recv_buf[i].to(torch.float32) * w
至此,Combine 阶段完成:所有专家的输出已按权重聚合到原始 token,并且原始token也回到了原始rank,形成最终结果。
常见坑
1.专家负载不均衡:某些卡 recv 爆掉 max_recv 缓冲区,然而某些卡却没几个token。
2. 通信 hang: 通信里面hang的情况真是太多了,bug贼难找,检查是否做了正确的资源管理,有没有什么僵尸进程,进程组管理是否正确等等。
3. 精度掉点: FP16 all-to-all 累加误差。combine 的 out_tokens 用 FP32 累加,最后再 down-cast可能可以解决。
总结
All-to-All 通信是 MoE 实现专家并行的核心—— 它解决了稀疏激活场景下跨设备数据分发与聚合的核心难题,使 MoE 在保持计算效率的同时突破模型容量限制。从代码中可以看到,其实现的关键在于 "计数交换 - 数据缓冲 - 元数据跟踪" 的闭环设计:通过动态统计发送 / 接收数量实现灵活分配,通过连续缓冲区提升通信效率,通过元数据确保 token 流转的可追溯性。
随着大模型向万亿参数级迈进,MoE 与 All-to-All 通信的结合将成为规模化训练 / 推理的标配技术,而对其实现功能和底层机制的理解(如本文代码解析,虽然仅展现它是做了个什么事),正是优化分布式系统性能的基础。

