Skip to content

fix(QwenShardingStrategy): skip weight sharding and ShardedMoE wrap for Qwen3MoeSparseMoeBlock#2164

Open
FlashShift wants to merge 1 commit into
exo-explore:mainfrom
FlashShift:fix/qwen3moe-switchlinear-sharding
Open

fix(QwenShardingStrategy): skip weight sharding and ShardedMoE wrap for Qwen3MoeSparseMoeBlock#2164
FlashShift wants to merge 1 commit into
exo-explore:mainfrom
FlashShift:fix/qwen3moe-switchlinear-sharding

Conversation

@FlashShift

Copy link
Copy Markdown

NOTE: This is an AI-produced patch. I have tested it, and it works correctly, but obviously please be cautious with any AI product.

Description:
Problem
Running Qwen3 MoE models (e.g. Qwen3-235B-A22B, Qwen3-397B-A17B) with tensor parallelism across multiple nodes produces a SIGSEGV in libjaccl.dylib inside jaccl::MeshImpl::all_gather, crashing the runner process during prefill.
The crash has two compounding causes, both in QwenShardingStrategy.shard_model:

  1. Wrong sharding axis on SwitchLinear weights
    Qwen3MoeSparseMoeBlock.switch_mlp uses SwitchLinear, which stores stacked expert weights as a single 3-D tensor of shape [num_experts, hidden, input]. The all_to_sharded_linear_in_place / sharded_to_all_linear_in_place helpers call shard_inplace, which is designed for nn.Linear and slices on ndim-2 (the output/hidden dimension). For a 3-D stacked weight this is axis 1 (the hidden dimension), not axis 0 (the expert dimension). The result is silently corrupted weight tensors on every rank.
    Qwen3NextSparseMoeBlock and Qwen3_5SparseMoeBlock are not affected — their switch_mlp and shared_expert weights are standard nn.Linear.
  2. ShardedMoE wrapping an unsharded block causes a JACCL buffer overrun
    Because the weights are not validly sharded, ShardedMoE.call calls mx.distributed.all_sum on the full unsharded output of each rank. This is mathematically wrong (scales activations by world_size) and — critically — the output tensor during prefill is much larger than the per-token decode tensors JACCL sized its collective buffers for at initialization time. The write overruns the receive buffer, producing:
    SIGSEGV: jaccl::MeshImpl::all_gather(char const*, char*, long long) + 1380
    KERN_INVALID_ADDRESS at 0x0000008210bb0000
    Fix
    Guard both the shard_inplace calls and the ShardedMoE wrap with if not isinstance(layer.mlp, Qwen3MoeSparseMoeBlock). This leaves Qwen3MoeSparseMoeBlock as a plain unsharded module on each rank. Pipeline parallelism handles cross-node communication correctly without any MoE-level collective.
    Testing
    Verified on two Mac Studio nodes (4 GPUs total) running mlx-community/Qwen3-235B-A22B-6bit. Previously crashed immediately on prefill warmup; with this fix inference completes successfully.
    Notes
    A proper tensor-parallel implementation for SwitchLinear-based MoE blocks (sharding on the expert axis) would give memory reduction benefits and is worth a follow-up, but requires a custom collective that understands the 3-D weight layout and the sparse routing indices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant