fix(QwenShardingStrategy): skip weight sharding and ShardedMoE wrap for Qwen3MoeSparseMoeBlock#2164
Open
FlashShift wants to merge 1 commit into
Open
Conversation
…or Qwen3MoeSparseMoeBlock
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
NOTE: This is an AI-produced patch. I have tested it, and it works correctly, but obviously please be cautious with any AI product.
Description:
Problem
Running Qwen3 MoE models (e.g. Qwen3-235B-A22B, Qwen3-397B-A17B) with tensor parallelism across multiple nodes produces a SIGSEGV in libjaccl.dylib inside jaccl::MeshImpl::all_gather, crashing the runner process during prefill.
The crash has two compounding causes, both in QwenShardingStrategy.shard_model:
Qwen3MoeSparseMoeBlock.switch_mlp uses SwitchLinear, which stores stacked expert weights as a single 3-D tensor of shape [num_experts, hidden, input]. The all_to_sharded_linear_in_place / sharded_to_all_linear_in_place helpers call shard_inplace, which is designed for nn.Linear and slices on ndim-2 (the output/hidden dimension). For a 3-D stacked weight this is axis 1 (the hidden dimension), not axis 0 (the expert dimension). The result is silently corrupted weight tensors on every rank.
Qwen3NextSparseMoeBlock and Qwen3_5SparseMoeBlock are not affected — their switch_mlp and shared_expert weights are standard nn.Linear.
Because the weights are not validly sharded, ShardedMoE.call calls mx.distributed.all_sum on the full unsharded output of each rank. This is mathematically wrong (scales activations by world_size) and — critically — the output tensor during prefill is much larger than the per-token decode tensors JACCL sized its collective buffers for at initialization time. The write overruns the receive buffer, producing:
SIGSEGV: jaccl::MeshImpl::all_gather(char const*, char*, long long) + 1380
KERN_INVALID_ADDRESS at 0x0000008210bb0000
Fix
Guard both the shard_inplace calls and the ShardedMoE wrap with if not isinstance(layer.mlp, Qwen3MoeSparseMoeBlock). This leaves Qwen3MoeSparseMoeBlock as a plain unsharded module on each rank. Pipeline parallelism handles cross-node communication correctly without any MoE-level collective.
Testing
Verified on two Mac Studio nodes (4 GPUs total) running mlx-community/Qwen3-235B-A22B-6bit. Previously crashed immediately on prefill warmup; with this fix inference completes successfully.
Notes
A proper tensor-parallel implementation for SwitchLinear-based MoE blocks (sharding on the expert axis) would give memory reduction benefits and is worth a follow-up, but requires a custom collective that understands the 3-D weight layout and the sparse routing indices.