fix(metal): register PR #2166 kernels in runtime-compile path#2169
Conversation
Code Metrics Report━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Language Files Lines Code Comments Blanks ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ C Header 6 482 253 175 54 CSS 3 281 252 5 24 CUDA 63 22172 18136 1531 2505 Dockerfile 1 38 21 8 9 HTML 2 27 27 0 0 JavaScript 3 392 387 2 3 Jinja2 7 694 656 5 33 JSON 26 9360 9357 0 3 Makefile 1 6 5 0 1 MDX 1 149 0 133 16 Metal Shading Lan| 37 14287 11284 1136 1867 PowerShell 1 357 276 33 48 Python 131 10342 8515 460 1367 Shell 2 549 379 101 69 Plain Text 3 3723 0 2413 1310 TOML 28 1359 1185 43 131 TypeScript 11 1607 1371 66 170 YAML 3 25 23 2 0 ───────────────────────────────────────────────────────────────────────────────── Jupyter Notebooks 3 122 83 23 16 |- Markdown 1 60 30 22 8 |- Python 1 122 113 1 8 (Total) 304 226 46 32 ───────────────────────────────────────────────────────────────────────────────── Markdown 127 9600 0 6575 3025 |- BASH 61 600 520 47 33 |- Dockerfile 2 5 5 0 0 |- JSON 18 700 700 0 0 |- PowerShell 3 5 5 0 0 |- Python 25 830 722 5 103 |- Rust 15 437 382 1 54 |- TOML 10 124 98 3 23 |- YAML 1 9 9 0 0 (Total) 12310 2441 6631 3238 ───────────────────────────────────────────────────────────────────────────────── Rust 599 265776 235112 6354 24310 |- Markdown 392 9302 452 7728 1122 (Total) 275078 235564 14082 25432 ───────────────────────────────────────────────────────────────────────────────── Svelte 18 1831 1696 50 85 |- CSS 1 4 4 0 0 |- JavaScript 18 876 727 24 125 (Total) 2711 2427 74 210 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Total 1076 356253 292785 26946 36522 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ |
e9c3952 to
2286cc0
Compare
EricLBuehler
left a comment
There was a problem hiding this comment.
Hey @ljchang! Thanks for the PR.
The PR looks good broadly but I added a few very small comments.
| #include <metal_stdlib> | ||
| using namespace metal; | ||
|
|
||
| // The kernel instantiations below use `bfloat16_t` — the ggml/llama.cpp |
There was a problem hiding this comment.
This is the source of the merge conflict I think, if you remove this change (I made a fix on master), it should be resolved.
There was a problem hiding this comment.
Dropped that commit entirely and rebased onto master — your typedef bfloat bfloat16_t at gdn.metal:8 covers it, so the conflict is gone. The PR is now mod.rs-only.
| "softmax_with_sinks.metal", | ||
| include_str!("softmax_with_sinks.metal"), | ||
| ); | ||
| // Kernels added by upstream PR #2166 (Gemma 4 Metal optimization). |
There was a problem hiding this comment.
Let's not reference specific PRs here, ideally this comment can be removed.
| "copy.metal", // Copy operations (includes utils.metal, copy_impl.metal) | ||
| "scan.metal", // Scan operations (includes utils.metal, scan_impl.metal) | ||
| "sort.metal", // Sort operations (includes utils.metal, sort_impl.metal) | ||
| "flash_attn.metal", // Flash attention DK=512 variants (PR #2166) |
There was a problem hiding this comment.
Same here, ideally just describe the kernel and not refernece when it was added.
There was a problem hiding this comment.
Done — the three entries now just describe the kernel (Flash attention DK=512 variants / Fused RMSNorm + residual / Two-stage top-k + softmax stats).
…path PR EricLBuehler#2166 added three new top-level kernel files (rmsnorm_residual.metal, topk_logits.metal, flash_attn.metal) to mistralrs-quant's METAL_SOURCES list in build.rs, but did not register them in the runtime fallback path in metal_kernels/mod.rs. With MISTRALRS_METAL_PRECOMPILE=0 set, the precompiled .metallib is skipped and Kernels::compile_kernels_at_runtime builds the library from include_str!'d sources -- so any kernel not in both the file_system map and the main_files list silently disappears. This caused Gemma 4 inference to fail at first call with: Error while loading function: rmsnorm_residual_bf16 on hosts that need precompile=0 (e.g. M5 / Metal 4 GPUs, where the precompiled metallib lacks function variants for the current GPU and the project relies on MTLDevice.newLibraryWithSource at runtime instead). Fix: add the three new files to both the file_system map and the main_files list. Each has clean header dependencies (only metal_stdlib and metal_simdgroup, which main_source already pulls in as a preamble), so no further plumbing is needed. Note: f8q8.metal (EricLBuehler#1883) and hqq_bitpack.metal (EricLBuehler#1586) are also in build.rs's METAL_SOURCES but missing from the runtime path. They predate this regression and aren't part of this fix; landing the EricLBuehler#2166 fix first keeps this PR surgical. A future CI matrix entry for MISTRALRS_METAL_PRECOMPILE=0 would prevent recurrences of this class of bug.
28cdeca to
424b506
Compare
Fixes #2168 —
MISTRALRS_METAL_PRECOMPILE=0builds (introduced by #1518) regressed in two places after recent kernel work.Commit 1: register PR #2166's new kernels in the runtime-compile path
PR #2166 added three new top-level kernel files (
rmsnorm_residual.metal,topk_logits.metal,flash_attn.metal) and registered them inmistralrs-quant/build.rs'sMETAL_SOURCESfor the precompiled metallib path, but did not register them in the runtime-compile path'sfile_systemmap andmain_fileslist inmistralrs-quant/src/metal_kernels/mod.rs.When precompile is disabled (e.g. on Apple Silicon with Metal 4 where the precompiled metallib lacks function variants for the current GPU), runtime falls back to
MTLDevice.newLibraryWithSourceviaKernels::compile_kernels_at_runtime. Without the new files in both lists, the runtime library is built without those kernels andlibrary.get_function("rmsnorm_residual_bf16", ...)fails on first inference:13 lines: 3 entries each to
file_systemandmain_files. Each file has clean header dependencies (only<metal_stdlib>and<metal_simdgroup*>, whichmain_source's preamble already provides).Commit 2: typedef
bfloat16_tfor runtime-compiled GDN kernelsmistralrs-core/src/metal/kernels/gdn.metalinstantiates templates withbfloat16_t, the ggml/llama.cpp type-name convention inherited from the CUDA port. Apple Metal's stdlib doesn't exposebfloat16_t: it hasbfloat(Metal 3.1+, macOS 14+) as the public type, andbfloat16as a forward-only struct (__Reserved_Name__Do_not_use_bfloat16) inmetal_extended_vector.The precompiled path resolves
bfloat16_tvia build-time headers, but the runtime-compiled path bails with:Aliasing to the forward-only
bfloat16fails differently (incomplete-type errors at template instantiation). The fix is a one-linetypedef bfloat bfloat16_t;at the top ofgdn.metalso both compile paths see a fully-defined type. This unblocks Qwen 3.5 / 3.6 hybrid models on Apple Silicon under precompile=0; Gemma 4 and other non-hybrid models were unaffected.This fix overlaps with #2047 (
fix(metal): GDN bfloat16, PA scheduler, error handling, MLX SDPA fixes) which has been open since 2026-04-02 and contains the same typedef plus several unrelated improvements. Happy to drop this commit if #2047 is prioritized.Validation
cargo check --features metalonmistralrs-quantwithMISTRALRS_METAL_PRECOMPILE=0— clean.rmsnorm_residual_bf16missing on first inference. After commit 1, before commit 2: GDN kernel compile fails withbfloat16_tunknown. After both: clean inference, PR feat(metal): optimize Gemma 4 prefill and decode on Apple Silicon #2166's decode perf wins intact.Out of scope (separate work)
f8q8.metal(Add new quant method: F8Q8 #1883) andhqq_bitpack.metal(Add blockwise fp8 quantize kernels #1586) are also inMETAL_SOURCESbut missing from the runtime path. They predate the regression; leaving them out to keep this PR surgical.METAL_SOURCESinbuild.rs+file_system+main_filesinmod.rs) is the design smell that made this regression possible. Single-sourcing (e.g.build.rsemits a generatedkernels.rsthatmod.rsinclude!s) is a worthwhile follow-up refactor.Suggestion
Adding
MISTRALRS_METAL_PRECOMPILE=0to the CI matrix would catch this class of regression before it ships. Three precompile=0 incompatibilities have surfaced in the last six weeks (the two in this PR plus anf8q8/hqq_bitpackgap that affects pre-existing kernels).