PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731
PERF: Speed up permutation cluster tests via Numba union-find + compact graph#13731sharifhsn wants to merge 7 commits intomne-tools:mainfrom
Conversation
|
this is a pretty large diff. Before we invest time in a review, a few questions:
tip: next time, if you name the changelog |
My bad, I meant to publish it as draft. Yes it's still in progress. Tests are all passing. AI was used to generate the code which was checked over manually to catch bugs. I'm going to try to reduce the amount of code as much as possible while still maintaining the main speedups. |
50c67c1 to
9b7edfd
Compare
Restructured this into 6 incremental commits ordered roughly by bang-for-buck, so you can review (and accept/reject) each optimization independently. Quick summary of what each one does and what it costs:
Commits 1-4 are small, pure NumPy/SciPy, and basically free in terms of complexity. They mainly help the global-adjacency path rather than the spatio-temporal one, so they don't show up much in the benchmark above. Commit 5 is where the first big measurable win kicks in (3.2×), and it's still straightforward NumPy. Commit 6 is the heavy one. It adds a Numba JIT union-find kernel with CSR precomputation and a bincount shortcut for cluster sums. It's responsible for most of the remaining speedup (3.2× → 10.3×), but it's also +215 lines and adds a Numba dependency to this code path. I have a stripped-down scipy-only version that gets ~3.8× in about +84 lines if you'd prefer to keep it simpler. Happy to swap that in or simplify further if the Numba approach feels like too much for this module. Would appreciate your guidance on where you'd want to draw that line. |
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Update: Restructured from the initial 6 commits:
PR body updated accordingly. |
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
The PR has been restructured extensively since that initial comment and is now ready for review. All 55 relevant cluster-level tests pass locally (both Numba and NumPy variants). |
|
Looking at it quickly, commits 1-4 don't really make much of an impact in speed while slightly increasing the code complexity. However, commits 5 and especially 6 look interesting. Speeding up the cluster permutation test by 10x is a huge improvement! Since this is AI generated, could the AI generate some more information in the docstring about the role of |
|
Thanks for taking a look! Commits 1-4: Agreed — they don't measurably speed up the spatio-temporal path and add unnecessary complexity. Dropped all four; the PR is now 7 commits focused on the two performance wins (precomputed sum-of-squares + Numba union-find) plus cleanup, bugfix, docs, and test. Union-find documentation: Added an expanded docstring to |
For the default ttest_1samp_no_p, s^2=1 means sum(X^2) is constant across sign-flip permutations. Each permutation becomes a single signs @ X dot product instead of calling stat_fun. Also skips buffer_size verification for built-in stat functions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
JIT-compiled union-find kernel (_st_fused_ccl) with path compression and union-by-rank, replacing the Python BFS in _get_clusters_st. Bundles tightly-coupled pieces: pre-computed CSR adjacency arrays, _sums_only flag to skip cluster list construction (uses np.bincount instead), and _csr_data parameter threading. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extract duplicated find+union logic into _union() with inline="always", simplify _sum_cluster_data, trim docstrings and comments to match codebase style. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Pre-existing bug: adjacency is None and adjacency is not False was equivalent to just adjacency is None, missing the adjacency is False case where step_down_include still needs reshaping. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The NumPy variant of the fixture patched the JIT helper functions but not has_numba itself, so the Numba union-find code path was still used for spatio-temporal clustering even in the "NumPy" test variant. Patching has_numba ensures the Python BFS fallback is actually tested. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Address reviewer request for documentation explaining what the union-find algorithm does, why it is faster than the BFS it replaces, and a link to the Wikipedia article on disjoint-set data structures. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Reference issue
Related: #5439, #7784, #8095, #12609
What does this implement/fix?
Speeds up
spatio_temporal_cluster_1samp_test(and the otherpermutation_cluster_*functions) by ~5-10x on realistic data. The PR is split into 7 incremental commits. Maintainers can accept or reject each layer independently.Commit 1 — Precompute sum-of-squares for sign-flip t-test (+29/−9 lines, 3.2x)
For the default
ttest_1samp_no_p, s²=1 meanssum(X²)is constant across permutations. Each permutation becomes a singlesigns @ Xdot product instead of callingstat_fun. Also skipsbuffer_sizeverification for built-in stat functions.Commit 2 — Numba union-find for spatio-temporal CCL (+226/−11 lines, 10.3x cumulative)
JIT-compiled union-find kernel (
_st_fused_ccl) with path compression and union-by-rank, replacing the Python BFS in_get_clusters_st. Bundles tightly-coupled pieces: pre-computed CSR adjacency arrays,_sums_onlyflag to skip cluster list construction (usesnp.bincountinstead), and_csr_dataparameter threading. These are bundled because_sums_onlyonly fires insideif has_numba:and CSR data is only consumed by the Numba kernel.Commit 3 — Extract _union helper + simplify (+36/−72 lines)
Extract duplicated find+union logic into
_union()withinline="always", simplify_sum_cluster_data, trim docstrings/comments to match codebase style.Commit 4 — Fix step-down reshape (+1/−1 lines)
Pre-existing bug:
adjacency is None and adjacency is not Falsewas equivalent to justadjacency is None, missing theadjacency is Falsecase wherestep_down_includestill needs reshaping.Commit 5 — Changelog entries
Commit 6 — Test fixture (+1 line)
Patch
has_numbainnumba_conditionalfixture so the "NumPy" test variant actually exercises the Python BFS fallback path for spatio-temporal clustering.Commit 7 — Docstring (+18/−1 lines)
Expand
_st_fused_ccldocstring with algorithm description, complexity analysis, and Wikipedia reference, per reviewer request.Commits 3-7 are cleanup, bugfix, docs, and tests — they don't affect performance. All optimizations fall back to the original code paths when Numba is not installed. No public API changes.
Benchmarks
Per-commit cumulative speedup (local, Apple M-series,
spatio_temporal_cluster_1samp_test, ico-5, 15 subjects x 15 timepoints x 20,484 vertices, threshold=3.0, 512 permutations, median of 3 runs):AWS HPC end-to-end (AMD EPYC 7R13, same data dimensions):
Per-permutation cost: 15.8 ms → 3.1 ms (5.2x). Projected 10,000 permutations: 31 s vs 159 s.
Reproduce benchmarks locally
Additional information
threshold=dict(...)) correctly falls back to the original code path