Skip to content

PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731

Open
sharifhsn wants to merge 7 commits intomne-tools:mainfrom
sharifhsn:perf-opt
Open

PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731
sharifhsn wants to merge 7 commits intomne-tools:mainfrom
sharifhsn:perf-opt

Conversation

@sharifhsn
Copy link

@sharifhsn sharifhsn commented Mar 9, 2026

Reference issue

Related: #5439, #7784, #8095, #12609

What does this implement/fix?

Speeds up spatio_temporal_cluster_1samp_test (and the other permutation_cluster_* functions) by ~5-10x on realistic data. The PR is split into 7 incremental commits. Maintainers can accept or reject each layer independently.

Commit 1 — Precompute sum-of-squares for sign-flip t-test (+29/−9 lines, 3.2x)
For the default ttest_1samp_no_p, s²=1 means sum(X²) is constant across permutations. Each permutation becomes a single signs @ X dot product instead of calling stat_fun. Also skips buffer_size verification for built-in stat functions.

Commit 2 — Numba union-find for spatio-temporal CCL (+226/−11 lines, 10.3x cumulative)
JIT-compiled union-find kernel (_st_fused_ccl) with path compression and union-by-rank, replacing the Python BFS in _get_clusters_st. Bundles tightly-coupled pieces: pre-computed CSR adjacency arrays, _sums_only flag to skip cluster list construction (uses np.bincount instead), and _csr_data parameter threading. These are bundled because _sums_only only fires inside if has_numba: and CSR data is only consumed by the Numba kernel.

Commit 3 — Extract _union helper + simplify (+36/−72 lines)
Extract duplicated find+union logic into _union() with inline="always", simplify _sum_cluster_data, trim docstrings/comments to match codebase style.

Commit 4 — Fix step-down reshape (+1/−1 lines)
Pre-existing bug: adjacency is None and adjacency is not False was equivalent to just adjacency is None, missing the adjacency is False case where step_down_include still needs reshaping.

Commit 5 — Changelog entries

Commit 6 — Test fixture (+1 line)
Patch has_numba in numba_conditional fixture so the "NumPy" test variant actually exercises the Python BFS fallback path for spatio-temporal clustering.

Commit 7 — Docstring (+18/−1 lines)
Expand _st_fused_ccl docstring with algorithm description, complexity analysis, and Wikipedia reference, per reviewer request.

Commits 3-7 are cleanup, bugfix, docs, and tests — they don't affect performance. All optimizations fall back to the original code paths when Numba is not installed. No public API changes.

Benchmarks

Per-commit cumulative speedup (local, Apple M-series, spatio_temporal_cluster_1samp_test, ico-5, 15 subjects x 15 timepoints x 20,484 vertices, threshold=3.0, 512 permutations, median of 3 runs):

Cumulative through ms/perm Speedup Net lines
main (baseline) 16.94 1.0x
commit 1 (precomputed sum_sq) 5.37 3.2x +20
commit 2 (Numba union-find) 1.64 10.3x +235

AWS HPC end-to-end (AMD EPYC 7R13, same data dimensions):

Permutations Before After Speedup
256 4.12 s 0.86 s 4.8x
1024 16.35 s 3.25 s 5.0x
4096 65.00 s 12.64 s 5.1x

Per-permutation cost: 15.8 ms → 3.1 ms (5.2x). Projected 10,000 permutations: 31 s vs 159 s.

Reproduce benchmarks locally
"""Quick benchmark: perf-opt vs baseline on realistic source-space data."""
import time
import numpy as np
import mne
from mne.stats import spatio_temporal_cluster_1samp_test
from mne.stats import cluster_level as cl

# Load fsaverage ico-5 adjacency
subjects_dir = mne.datasets.sample.data_path() / "subjects"
src = mne.setup_source_space(
    "fsaverage", spacing="ico5", subjects_dir=subjects_dir, add_dist=False
)
adjacency = mne.spatial_src_adjacency(src)

# Synthetic data: 15 subjects x 15 timepoints x 20,484 vertices
rng = np.random.default_rng(42)
X = rng.standard_normal((15, 15, adjacency.shape[0]))
X[:, 5:10, 1000:1100] += 1.0  # inject focal activation

# Warmup JIT
spatio_temporal_cluster_1samp_test(
    X, adjacency=adjacency, n_permutations=64,
    threshold=3.0, tail=1, verbose=False, seed=42
)

# Optimized
t0 = time.perf_counter()
spatio_temporal_cluster_1samp_test(
    X, adjacency=adjacency, n_permutations=512,
    threshold=3.0, tail=1, verbose=False, seed=42
)
t_opt = time.perf_counter() - t0

# Baseline (disable Numba path)
saved = cl.has_numba
cl.has_numba = False
t0 = time.perf_counter()
spatio_temporal_cluster_1samp_test(
    X, adjacency=adjacency, n_permutations=512,
    threshold=3.0, tail=1, verbose=False, seed=42
)
t_base = time.perf_counter() - t0
cl.has_numba = saved

print(f"Optimized: {t_opt:.2f}s  Baseline: {t_base:.2f}s  Speedup: {t_base/t_opt:.1f}x")

Additional information

  • Numba JIT warmup happens once on first call; subsequent calls pay no warmup cost
  • TFCE (threshold=dict(...)) correctly falls back to the original code path
  • Custom stat functions still benefit from the CCL and overhead optimizations but not the precomputed sum-of-squares
  • AI (Claude) was used to generate the code, which was checked over manually

@drammock
Copy link
Member

drammock commented Mar 9, 2026

this is a pretty large diff. Before we invest time in a review, a few questions:

  1. is this ready for review? if not, please mark as draft until it is ready
  2. have you run relevant tests and doc build locally? are they all passing?
  3. please disclose the way(s) in which you used AI to assist in this contribution (if any)

tip: next time, if you name the changelog newfeature.rst instead of XXXXX.newfeature.rst then one of our CIs will automatically rename it to include the PR number for you.

@sharifhsn sharifhsn marked this pull request as draft March 9, 2026 19:21
@sharifhsn
Copy link
Author

sharifhsn commented Mar 9, 2026

Note: The PR has since been restructured and is now ready for review. See the latest comments and PR body for the current state.

My bad, I meant to publish it as draft. Yes it's still in progress.

Tests are all passing.

AI was used to generate the code which was checked over manually to catch bugs.

I'm going to try to reduce the amount of code as much as possible while still maintaining the main speedups.

@sharifhsn sharifhsn force-pushed the perf-opt branch 2 times, most recently from 50c67c1 to 9b7edfd Compare March 9, 2026 19:43
@sharifhsn sharifhsn changed the title PERF: Speed up permutation cluster tests ~15× via Numba JIT kernels PERF: Speed up permutation cluster tests via Numba union-find + compact graph Mar 9, 2026
@sharifhsn sharifhsn marked this pull request as ready for review March 10, 2026 07:56
@sharifhsn
Copy link
Author

sharifhsn commented Mar 10, 2026

Note: This comment describes an earlier version of the PR. The current 7-commit structure is in the updated PR body — commits 1-4 from this comment have been dropped per reviewer feedback.

Restructured this into 6 incremental commits ordered roughly by bang-for-buck, so you can review (and accept/reject) each optimization independently.

Quick summary of what each one does and what it costs:

  1. Vectorize p-value computation: np.searchsorted replacing O(n·k) list comp
  2. Compact graph reindexing : subgraph of supra-threshold vertices only
  3. Vectorize cluster sums: np.add.reduceat replacing per-cluster loop
  4. Vectorize sign-order generation: bit-shifting instead of np.binary_repr
  5. Precompute sum-of-squares for sign-flip t-test: 3.2× end-to-end
  6. Numba union-find for spatio-temporal CCL (+215 lines): 10.3× end-to-end

Commits 1-4 are small, pure NumPy/SciPy, and basically free in terms of complexity. They mainly help the global-adjacency path rather than the spatio-temporal one, so they don't show up much in the benchmark above.

Commit 5 is where the first big measurable win kicks in (3.2×), and it's still straightforward NumPy.

Commit 6 is the heavy one. It adds a Numba JIT union-find kernel with CSR precomputation and a bincount shortcut for cluster sums. It's responsible for most of the remaining speedup (3.2× → 10.3×), but it's also +215 lines and adds a Numba dependency to this code path. I have a stripped-down scipy-only version that gets ~3.8× in about +84 lines if you'd prefer to keep it simpler. Happy to swap that in or simplify further if the Numba approach feels like too much for this module. Would appreciate your guidance on where you'd want to draw that line.

sharifhsn added a commit to sharifhsn/mne-python that referenced this pull request Mar 10, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
sharifhsn added a commit to sharifhsn/mne-python that referenced this pull request Mar 10, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@sharifhsn
Copy link
Author

sharifhsn commented Mar 10, 2026

Note: This comment describes an earlier version. Commits 1-4 have been dropped per reviewer feedback. The current 7-commit structure is in the updated PR body.

Update: Restructured from the initial 6 commits:

  • Commit 7 (new): Extracted duplicated union-find code into _union() helper with inline="always", simplified _sum_cluster_data, and trimmed docstrings/comments to match codebase conventions. Net −36 lines from commit 6.
  • Commit 8 (new): Split out a pre-existing bugfix — adjacency is None and adjacency is not Falseadjacency is None or adjacency is False — the old condition always evaluated to just adjacency is None, missing the adjacency is False case in step-down-in-jumps. This is a 1-line fix to existing code, separated so it can be merged independently.
  • Commit 9: Changelog entries (newfeature + bugfix).
  • Commit 10 (new): Patch has_numba in the numba_conditional test fixture so the "NumPy" variant actually exercises the Python BFS fallback for spatio-temporal clustering.

PR body updated accordingly.

sharifhsn added a commit to sharifhsn/mne-python that referenced this pull request Mar 10, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
sharifhsn added a commit to sharifhsn/mne-python that referenced this pull request Mar 10, 2026
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@sharifhsn
Copy link
Author

sharifhsn commented Mar 10, 2026

The PR has been restructured extensively since that initial comment and is now ready for review. All 55 relevant cluster-level tests pass locally (both Numba and NumPy variants).

@wmvanvliet
Copy link
Contributor

Looking at it quickly, commits 1-4 don't really make much of an impact in speed while slightly increasing the code complexity. However, commits 5 and especially 6 look interesting. Speeding up the cluster permutation test by 10x is a huge improvement! Since this is AI generated, could the AI generate some more information in the docstring about the role of union_find and how it goes about doing it exactly? It's hard to tell from the code alone what is actually going on and why it is so fast. If this algorithm has been published somewhere, it would be good to link to an article. Or if it's some well-known algorithm with a wikipedia page we could link there. Something to aid human readers of the code.

@sharifhsn
Copy link
Author

sharifhsn commented Mar 15, 2026

Thanks for taking a look!

Commits 1-4: Agreed — they don't measurably speed up the spatio-temporal path and add unnecessary complexity. Dropped all four; the PR is now 7 commits focused on the two performance wins (precomputed sum-of-squares + Numba union-find) plus cleanup, bugfix, docs, and test.

Union-find documentation: Added an expanded docstring to _st_fused_ccl explaining what the algorithm does, why it's faster than the BFS it replaces, and a link to the Wikipedia article on disjoint-set data structures (commit 7).

sharifhsn and others added 6 commits March 15, 2026 12:36
For the default ttest_1samp_no_p, s^2=1 means sum(X^2) is constant
across sign-flip permutations. Each permutation becomes a single
signs @ X dot product instead of calling stat_fun. Also skips
buffer_size verification for built-in stat functions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
JIT-compiled union-find kernel (_st_fused_ccl) with path compression
and union-by-rank, replacing the Python BFS in _get_clusters_st.
Bundles tightly-coupled pieces: pre-computed CSR adjacency arrays,
_sums_only flag to skip cluster list construction (uses np.bincount
instead), and _csr_data parameter threading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extract duplicated find+union logic into _union() with
inline="always", simplify _sum_cluster_data, trim docstrings
and comments to match codebase style.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-existing bug: adjacency is None and adjacency is not False was
equivalent to just adjacency is None, missing the adjacency is False
case where step_down_include still needs reshaping.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The NumPy variant of the fixture patched the JIT helper functions but
not has_numba itself, so the Numba union-find code path was still
used for spatio-temporal clustering even in the "NumPy" test variant.
Patching has_numba ensures the Python BFS fallback is actually tested.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Address reviewer request for documentation explaining what the
union-find algorithm does, why it is faster than the BFS it replaces,
and a link to the Wikipedia article on disjoint-set data structures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants