Skip to content

feat(metal): paged chunked prefill, env-tunable with a safe chunk-size floor#2208

Open
sergey-scherbina wants to merge 2 commits into
EricLBuehler:masterfrom
sergey-scherbina:metal-chunked-prefill
Open

feat(metal): paged chunked prefill, env-tunable with a safe chunk-size floor#2208
sergey-scherbina wants to merge 2 commits into
EricLBuehler:masterfrom
sergey-scherbina:metal-chunked-prefill

Conversation

@sergey-scherbina

@sergey-scherbina sergey-scherbina commented Jun 11, 2026

Copy link
Copy Markdown

What

Enable paged prompt-prefill chunking on Metal. The chunking loop already exists in
Pipeline::step (build_prompt_chunk_plan + set_prefix_cache_len /
set_prefill_toks) but was gated to self.device().is_cuda(), so Metal always
prefilled the whole prompt in one shot. This:

  • relaxes the gate to is_cuda() || is_metal();
  • makes the chunk size env-tunable via MISTRALRS_PREFILL_CHUNK (default 4096, 0
    disables) - the analogue of llama.cpp's n_ubatch;
  • floors the effective chunk to next_multiple_of(block_size).max(512).

Why

On Apple Silicon the prefill activation peak scales with prompt length, so a long
prompt (e.g. a coding-assistant context of ~20k tokens) drives the box into swap and
thrashes. Chunking bounds the peak to the chunk size. Measured on a 36 GB Mac with
Qwen3.6: a ~20k-token prefill's peak swap drops from ~5.9 GB to ~1.3 GB for ~13%
slower prefill, and the request completes instead of stalling.

The floor (important)

Sub-block / very small chunks are not just slow, they are wrong on hybrid models:
the GatedDeltaNet conv-state path overruns its buffer once the cumulative prefill
nears its ~1024 window (a narrow out-of-range panic on a ~2.5k-token prompt), and
sub-block chunks land mid paged-block and drop prompt tokens. MISTRALRS_PREFILL_CHUNK=8
reliably reproduced both. Flooring to a block-aligned >= 512 sidesteps it; 512 is the
smallest size verified faithful and deterministic, and sub-512 buys only marginal
activation-peak savings for many more kernel launches. The CUDA path only ever used
4096, so this regime was never exercised upstream.

Scope

mistralrs-core/src/pipeline/mod.rs, +29/-2 (two commits). Independent of the other
PRs in code; the floor's motivation is the hybrid (Qwen3.6) path, so this pairs well
landing after #2201.


Part of splitting the Qwen3.6 work into focused, reviewable PRs:

Suggested merge order: #2206 + #2207 -> #2201 -> #2208.

The paged prompt-prefill chunking added for CUDA was gated to is_cuda(), so Metal
always prefilled the whole prompt in one shot and the activation peak scaled with
prompt length (large prompts swap-thrash). Enable it on Metal too and make the
chunk size overridable via MISTRALRS_PREFILL_CHUNK (default 4096, 0 disables).

Verified on Qwen3.6-35B-A3B (Metal): a single-chunk run is byte-identical to the
unchunked path; multi-chunk output is coherent and deterministic (cross-chunk
attention reordering gives FP-level, not structural, differences). A ~20k-token
prefill's peak swap drops from ~5.9 GB to ~1.3 GB for ~13% slower prefill.
…l-chunk overrun

Small prefill chunks corrupt hybrid (GatedDeltaNet) inference. Sub-block chunks land
mid paged-block and drop prompt tokens (the model misreads its input); even
block-aligned chunks below ~512 overrun the conv-state buffer once the cumulative
prefill nears the conv's ~1024 window, panicking with a narrow out-of-range on a
~2.5k-token prompt. The CUDA path only ever used 4096, so this was never hit upstream.

Promote MISTRALRS_PREFILL_CHUNK to next_multiple_of(block_size).max(512) so a
too-small or unaligned value is silently raised to a verified-safe size instead of
corrupting output or crashing. 512 is the smallest size verified faithful and
deterministic; sub-512 saves only marginal activation peak for many more launches.
@github-actions

Copy link
Copy Markdown
Code Metrics Report
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 Language              Files        Lines         Code     Comments       Blanks
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 C Header                 23         4454         3116          790          548
 CSS                       3          281          252            5           24
 CUDA                    119        23575        19136         1696         2743
 Dockerfile                1           38           21            8            9
 HTML                      2           27           27            0            0
 JavaScript                3          392          387            2            3
 Jinja2                    7          694          656            5           33
 JSON                     26         9360         9357            0            3
 Makefile                  1            6            5            0            1
 MDX                       1          149            0          133           16
 Metal Shading Lan|       37        14287        11284         1136         1867
 PowerShell                1          357          276           33           48
 Python                  131        10342         8515          460         1367
 Shell                     2          549          379          101           69
 Plain Text                3         3723            0         2413         1310
 TOML                     29         1388         1211           41          136
 TypeScript               11         1607         1371           66          170
 YAML                      3           25           23            2            0
─────────────────────────────────────────────────────────────────────────────────
 Jupyter Notebooks         3          122           83           23           16
 |- Markdown               1           60           30           22            8
 |- Python                 1          122          113            1            8
 (Total)                              304          226           46           32
─────────────────────────────────────────────────────────────────────────────────
 Markdown                129         9703            0         6648         3055
 |- BASH                  61          600          520           47           33
 |- Dockerfile             2            5            5            0            0
 |- JSON                  18          700          700            0            0
 |- PowerShell             3            5            5            0            0
 |- Python                25          830          722            5          103
 |- Rust                  15          437          382            1           54
 |- TOML                  10          124           98            3           23
 |- YAML                   1           13           13            0            0
 (Total)                            12417         2445         6704         3268
─────────────────────────────────────────────────────────────────────────────────
 Rust                    625       270388       239956         5864        24568
 |- Markdown             397         9504          452         7882         1170
 (Total)                           279892       240408        13746        25738
─────────────────────────────────────────────────────────────────────────────────
 Svelte                   18         1831         1696           50           85
 |- CSS                    1            4            4            0            0
 |- JavaScript            18          876          727           24          125
 (Total)                             2711         2427           74          210
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 Total                  1178       366578       301522        27461        37595
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant