Skip to content

Commit 1d6c751

Browse files
[FIX] qwen3_5_moe / llama4 / qwen2_moe / qwen3_next awq layer grouping (#2634)
* fix qwen3.5 moe awq layer grouping Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * fix llama4/qwen2_moe/qwen3_next awq layer grouping Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> --------- Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent b6ca848 commit 1d6c751

File tree

5 files changed

+21
-20
lines changed

5 files changed

+21
-20
lines changed

gptqmodel/models/definitions/llama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ class Llama4QModel(BaseQModel):
3232
"self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"),
3333
"post_attention_layernorm": ("post_attention_layernorm:!",),
3434
"feed_forward:moe": {
35-
"experts": {
35+
"experts:0": {
3636
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
3737
},
38-
"shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"),
38+
"shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"),
3939
},
4040
}
4141
]

gptqmodel/models/definitions/qwen2_moe.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,3 @@ class Qwen2MoeQModel(BaseQModel):
3535
},
3636
}
3737
]
38-
39-
# module_tree_overrides = {
40-
# METHOD.AWQ: [
41-
# {
42-
# "mlp:moe:?": {
43-
# "gate": ("gate:!",),
44-
# "shared_expert": None,
45-
# "experts": {
46-
# "#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
47-
# },
48-
# },
49-
# }
50-
# ]
51-
# }

gptqmodel/models/definitions/qwen3_5_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ class Qwen3_5_MoeQModel(BaseQModel):
5555
"mlp:moe:?": {
5656
"gate": ("gate:!",), # <-- 0.5MB per layer. Not worth quantizing
5757
"shared_expert_gate": ("shared_expert_gate:!",),
58-
"experts": {
58+
"experts:0": {
5959
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
6060
},
61-
"shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"),
61+
"shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"),
6262
},
6363
}
6464
]

gptqmodel/models/definitions/qwen3_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ class Qwen3NextGPTQ(BaseQModel):
4444
# MoE router + shared expert (Qwen3NextSparseMoeBlock)
4545
"gate": ("gate:!",), # router gate linear
4646
"shared_expert_gate": ("shared_expert_gate:!",), # <-- single (1, N) logic projections should not be quantized
47-
"shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"),
47+
"shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"),
4848

4949
# Experts list with dynamic index
50-
"experts": {
50+
"experts:0": {
5151
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
5252
},
5353
},

tests/module_tree/test_subset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from gptqmodel.looper.named_module import NamedModule
2828
from gptqmodel.looper.stage_subset import build_subset_plan, run_subset_stage
2929
from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel
30+
from gptqmodel.models.definitions.qwen3_5_moe import Qwen3_5_MoeQModel
3031
from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel
3132
from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy
3233
from gptqmodel.quantization import FORMAT, METHOD
@@ -111,6 +112,20 @@ def test_qwen2_moe_shared_expert_merges_with_experts():
111112
assert len(expert_gate_blocks) == 1
112113

113114

115+
def test_qwen3_5_moe_shared_expert_merges_with_experts():
116+
blocks = Qwen3_5_MoeQModel.build_layer_modules(Qwen3_5_MoeQModel.module_tree)
117+
print("blocks",blocks)
118+
gate_block = next(block for block in blocks if "mlp.shared_expert.gate_proj" in block)
119+
assert "mlp.experts.{expert_index}.gate_proj" in gate_block
120+
assert "mlp.experts.{expert_index}.up_proj" in gate_block
121+
122+
down_block = next(block for block in blocks if "mlp.shared_expert.down_proj" in block)
123+
assert "mlp.experts.{expert_index}.down_proj" in down_block
124+
125+
expert_gate_blocks = [block for block in blocks if "mlp.experts.{expert_index}.gate_proj" in block]
126+
assert len(expert_gate_blocks) == 1
127+
128+
114129
def test_awq_processor_enables_subset_early_stop():
115130
calibration = [{"input_ids": torch.tensor([1, 2, 3])}]
116131
qcfg = _make_quant_config()

0 commit comments

Comments
 (0)