Skip to content

Fix Qwen3.6 (qwen3_5 / qwen3_5_moe) on Metal: RMSNorm, AFQ, lm_head, hybrid KV cache#2201

Open
sergey-scherbina wants to merge 1 commit into
EricLBuehler:masterfrom
sergey-scherbina:qwen36-fixes
Open

Fix Qwen3.6 (qwen3_5 / qwen3_5_moe) on Metal: RMSNorm, AFQ, lm_head, hybrid KV cache#2201
sergey-scherbina wants to merge 1 commit into
EricLBuehler:masterfrom
sergey-scherbina:qwen36-fixes

Conversation

@sergey-scherbina

@sergey-scherbina sergey-scherbina commented Jun 10, 2026

Copy link
Copy Markdown

Adds support for Qwen3.6 - both the dense qwen3_5 and the MoE qwen3_5_moe
(35B-A3B) - and makes them numerically correct on Metal.

Root causes fixed

  • RMSNorm +1 convention. GemmaRmsNorm::new bakes weight = on_disk + 1.0, but
    the sanitized mlx-community/Qwen3.6-* checkpoints ship raw RMSNorm weights
    (MLX's should_shift_norm_weights is false for these: conv1d (…,4,1), no MTP). The
    +1 over-scaled every norm (~2.1x) and the SiLU MoE compounded it into a ~14x
    experts blow-up and an over-peaked router. Fixed with an unshifted norm at the
    affected sites.
  • AFQ packed loading / lm_head for the MLX-quantized checkpoints.
  • Hybrid KV cache: full-attention layers use the paged KV pool; the
    linear-attention (GatedDeltaNet) layers carry a recurrent state. The cache/config
    plumbing wires both per layer_types.

The MoE experts ship pre-fused as switch_mlp.{gate,up,down}_proj; the existing
FusedExperts path loads them via afq_packed_linear_b and applies router weights
correctly - no layout change was needed (verified, was a red herring).

Correctness

"Hello" renders identically to mlx_lm (embedding byte-for-byte, all layer
last-position norms within bf16 rounding, identical top-1 logit), and greedy
generation begins identically. Verified on Qwen3.6-27B-4bit and Qwen3.6-35B-A3B-4bit
on Apple Silicon Metal.

Notes

  • Rebased onto current main (clean).
  • This PR is now scoped to Qwen3.6 only. It was split out of a larger branch into
    reviewable pieces; the zero-element KV buffer fix is a prerequisite for running
    this on Metal. Suggested merge order: zero-buffer + engine-reap → this → chunked-prefill.

Scope

gdn/*, vision_models/qwen3_5{,_moe}/*, paged_attention/config.rs, layers.rs,
mistralrs-quant/{afq,lib} - 12 files, +614/-74.


Part of splitting the Qwen3.6 work into focused, reviewable PRs:

Suggested merge order: #2206 + #2207 -> #2201 -> #2208.

@github-actions

github-actions Bot commented Jun 10, 2026

Copy link
Copy Markdown
Code Metrics Report
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 Language              Files        Lines         Code     Comments       Blanks
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 C Header                 23         4454         3116          790          548
 CSS                       3          281          252            5           24
 CUDA                    119        23575        19136         1696         2743
 Dockerfile                1           38           21            8            9
 HTML                      2           27           27            0            0
 JavaScript                3          392          387            2            3
 Jinja2                    7          694          656            5           33
 JSON                     26         9360         9357            0            3
 Makefile                  1            6            5            0            1
 MDX                       1          149            0          133           16
 Metal Shading Lan|       37        14287        11284         1136         1867
 PowerShell                1          357          276           33           48
 Python                  131        10342         8515          460         1367
 Shell                     2          549          379          101           69
 Plain Text                3         3723            0         2413         1310
 TOML                     29         1388         1211           41          136
 TypeScript               11         1607         1371           66          170
 YAML                      3           25           23            2            0
─────────────────────────────────────────────────────────────────────────────────
 Jupyter Notebooks         3          122           83           23           16
 |- Markdown               1           60           30           22            8
 |- Python                 1          122          113            1            8
 (Total)                              304          226           46           32
─────────────────────────────────────────────────────────────────────────────────
 Markdown                129         9703            0         6648         3055
 |- BASH                  61          600          520           47           33
 |- Dockerfile             2            5            5            0            0
 |- JSON                  18          700          700            0            0
 |- PowerShell             3            5            5            0            0
 |- Python                25          830          722            5          103
 |- Rust                  15          437          382            1           54
 |- TOML                  10          124           98            3           23
 |- YAML                   1           13           13            0            0
 (Total)                            12417         2445         6704         3268
─────────────────────────────────────────────────────────────────────────────────
 Rust                    625       270388       239956         5864        24568
 |- Markdown             397         9504          452         7882         1170
 (Total)                           279892       240408        13746        25738
─────────────────────────────────────────────────────────────────────────────────
 Svelte                   18         1831         1696           50           85
 |- CSS                    1            4            4            0            0
 |- JavaScript            18          876          727           24          125
 (Total)                             2711         2427           74          210
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 Total                  1178       366578       301522        27461        37595
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

@sergey-scherbina sergey-scherbina force-pushed the qwen36-fixes branch 2 times, most recently from 4d19c55 to 5474ace Compare June 10, 2026 17:35
… cache

The qwen3_5_moe support from EricLBuehler#2196 had several bugs that made mlx-community
Qwen3.6 checkpoints load wrong / not at all on Metal. Fixes:

- RMSNorm convention. `GemmaRmsNorm::new` bakes `weight + 1.0`, but sanitized
  MLX checkpoints ship raw RMSNorm weights, so every norm was over-scaled and
  the MoE experts blew up ~14x. Add `GemmaRmsNorm::new_unshifted` and a
  `checkpoint_shifts_norm_weights()` probe (mirrors mlx_lm qwen3_5.py: shift
  only when MTP weights or unsanitized conv1d are present, detected via conv1d
  layout). Applied to both the MoE (`qwen3_5_moe`) and dense (`qwen3_5`) text
  models, all five norm sites.

- Per-tensor AFQ quantization. MLX checkpoints run the MoE router (`mlp.gate`,
  `shared_expert_gate`) at 8-bit while the rest is 4-bit. `QuantizedConfig::Afq`
  gains an `overrides` map (parsed from sibling keys via `collect_afq_overrides`)
  and `afq_params_for_path()`; `AfqLayer::afq_linear_b`/experts apply it. The
  top-level `quantization_config` is also propagated into `text_config`.

- Dense `qwen3_5` lm_head. It loaded `lm_head` unconditionally, missing the MLX
  `language_model.lm_head.*` wrapper -> "cannot find tensor lm_head.weight".
  Pick the prefix off the same condition as the embed-tokens wrapper.

- Hybrid KV cache waste. PagedAttention sized the KV pool by all 64 layers, but
  only the full-attention layers hold a context KV cache (the linear-attention
  GatedDeltaNet layers carry fixed state). Add
  `ModelConfigLike::num_kv_cache_layers()` (default `num_layers()`) and a
  `HybridKvCacheConfig` wrapper that reports the full-attention count and 0 KV
  heads for linear layers, so those allocate no cache (~4x less KV). Wired via
  the model's `model_config()`; loader-side device-map estimate is a follow-up.

- GatedDeltaNet AFQ loading (SplitAfq path, conv1d layout), qwen3_next ISQ
  skipping already-quantised experts, and removal of debug instrumentation.

Verified against mlx-community/Qwen3.6-{27B,35B-A3B}-4bit on Metal: byte-identical
embedding, layer norms within bf16 rounding, top-1 logit matches mlx_lm, coherent
generation and tool calls.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant