diff --git a/crypto/stark/src/constraints/evaluator.rs b/crypto/stark/src/constraints/evaluator.rs index 39d9af2a0..1622174d0 100644 --- a/crypto/stark/src/constraints/evaluator.rs +++ b/crypto/stark/src/constraints/evaluator.rs @@ -7,7 +7,7 @@ use crate::trace::LDETraceTable; use crate::traits::{AIR, TransitionEvaluationContext, ZerofierEvaluations}; use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain}; use math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; -#[cfg(not(feature = "parallel"))] +#[cfg(all(debug_assertions, not(feature = "parallel")))] use math::polynomial::Polynomial; use math::{fft::errors::FFTError, field::element::FieldElement}; #[cfg(feature = "parallel")] diff --git a/crypto/stark/src/lookup.rs b/crypto/stark/src/lookup.rs index 17ba7c5ec..d359155f7 100644 --- a/crypto/stark/src/lookup.rs +++ b/crypto/stark/src/lookup.rs @@ -1443,6 +1443,83 @@ where { } +/// Classifies a multiplicity as sparse-capable for a given chunk. +/// +/// Returns `Some(active_indices)` when the multiplicity has a cheap structural +/// test for zero (e.g. `Column`, `Sum`, `Sum3` where each term is a selector +/// column that is mostly 0), and `None` when the multiplicity is effectively +/// dense (e.g. `One`, or forms like `Negated`/`Diff`/`Linear` whose zero set is +/// not easily identifiable from a single column check). +/// +/// Soundness: `active_indices[..]` must include every row where the multiplicity +/// is non-zero. Here we mark a row as active if ANY of the contributing columns +/// is non-zero — which is an over-approximation of `m != 0` (we don't miss any +/// non-zero rows, we just might process some m=0 rows due to cancellation). +/// +/// The threshold check: if more than `SPARSE_ACTIVE_FRAC` of the chunk is active, +/// return None to fall back to the dense path (gather/scatter overhead not worth it). +#[inline] +fn sparse_active_rows( + multiplicity: &Multiplicity, + main_segment_cols: &[Vec>], + chunk_start: usize, + chunk_len: usize, +) -> Option> +where + F: IsField, +{ + let zero = FieldElement::::zero(); + // If more than this fraction of rows are active, skip sparse path. + // (Numerator/denominator comparison to avoid floating point.) + const SPARSE_NUM: usize = 7; + const SPARSE_DEN: usize = 8; + + let collect_if_sparse = |active: Vec| -> Option> { + if active.len() * SPARSE_DEN > chunk_len * SPARSE_NUM { + None + } else { + Some(active) + } + }; + + match multiplicity { + Multiplicity::One => None, + Multiplicity::Column(col) => { + let c = &main_segment_cols[*col]; + let active: Vec = (0..chunk_len) + .filter(|&i| c[chunk_start + i] != zero) + .collect(); + collect_if_sparse(active) + } + Multiplicity::Sum(col_a, col_b) => { + let ca = &main_segment_cols[*col_a]; + let cb = &main_segment_cols[*col_b]; + let active: Vec = (0..chunk_len) + .filter(|&i| { + let row = chunk_start + i; + ca[row] != zero || cb[row] != zero + }) + .collect(); + collect_if_sparse(active) + } + Multiplicity::Sum3(col_a, col_b, col_c) => { + let ca = &main_segment_cols[*col_a]; + let cb = &main_segment_cols[*col_b]; + let cc = &main_segment_cols[*col_c]; + let active: Vec = (0..chunk_len) + .filter(|&i| { + let row = chunk_start + i; + ca[row] != zero || cb[row] != zero || cc[row] != zero + }) + .collect(); + collect_if_sparse(active) + } + // Negated = 1 - col is dense (usually 1 when col is a bit flag). + // Diff and Linear could cancel arbitrarily; safest to treat as dense. + Multiplicity::Negated(_) | Multiplicity::Diff(_, _) | Multiplicity::Linear(_) => None, + } +} + /// Computes a term column for a table interaction without writing to the trace. /// /// Each row contains the LogUp quotient: `term[i] = sign * multiplicity[i] / fingerprint[i]` @@ -1453,6 +1530,10 @@ where /// With `parallel`: processes rows in chunks of `LOGUP_CHUNK_SIZE` via `par_chunks_mut`, /// giving good cache locality (each thread touches only CHUNK_SIZE rows before moving on). /// Without `parallel`: processes all rows as a single chunk (equivalent to the old sequential path). +/// +/// Sparse fast path: when the multiplicity's zero-set is structurally identifiable +/// (e.g., the interaction is gated by a selector column), we compute fingerprints +/// and terms only at the rows where the multiplicity is potentially non-zero. fn compute_logup_term_column( table_interaction: &BusInteraction, main_segment_cols: &[Vec>], @@ -1477,9 +1558,17 @@ where let process_chunk = |chunk_start: usize, result_chunk: &mut [FieldElement]| { let chunk_len = result_chunk.len(); - // Phase 1: Compute fingerprints - let mut fingerprints: Vec> = Vec::with_capacity(chunk_len); - for row in chunk_start..chunk_start + chunk_len { + // Try sparse path first: compute fingerprints/terms only at rows where + // the multiplicity is (potentially) non-zero. Rows initialize to zero in + // the result vector, so inactive rows naturally stay zero. + let active = sparse_active_rows( + &table_interaction.multiplicity, + main_segment_cols, + chunk_start, + chunk_len, + ); + + let compute_fp = |row: usize| -> FieldElement { let mut lc = &bus_id_f * &alpha_powers[0]; let mut alpha_offset = 1; for bv in &table_interaction.values { @@ -1493,44 +1582,93 @@ where ); alpha_offset += consumed; } - fingerprints.push(z - &lc); + z - &lc + }; - #[cfg(feature = "debug-checks")] - { - let mut base_elements: Vec> = vec![bus_id_f.clone()]; - base_elements.extend( - table_interaction - .values - .iter() - .flat_map(|bv| bv.combine_from(|col| main_segment_cols[col][row].clone())), - ); - let multiplicity = table_interaction - .multiplicity - .evaluate_at_row(main_segment_cols, row); - crate::bus_debug::log_interaction( - _table_name, - row, - table_interaction.bus_id, - table_interaction.is_sender, - &multiplicity.canonical(), - &base_elements, - fingerprints.last().unwrap(), - ); + match active { + Some(indices) => { + // Sparse path + let mut fingerprints: Vec> = indices + .iter() + .map(|&i| compute_fp(chunk_start + i)) + .collect(); + + #[cfg(feature = "debug-checks")] + { + // Log all rows for debug symmetry; inactive rows contribute zero. + for row in chunk_start..chunk_start + chunk_len { + let mut base_elements: Vec> = vec![bus_id_f.clone()]; + base_elements.extend(table_interaction.values.iter().flat_map(|bv| { + bv.combine_from(|col| main_segment_cols[col][row].clone()) + })); + let multiplicity = table_interaction + .multiplicity + .evaluate_at_row(main_segment_cols, row); + // Compute fp fresh for debug (cheap). + let fp = compute_fp(row); + crate::bus_debug::log_interaction( + _table_name, + row, + table_interaction.bus_id, + table_interaction.is_sender, + &multiplicity.canonical(), + &base_elements, + &fp, + ); + } + } + + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + for (k, &i) in indices.iter().enumerate() { + let row = chunk_start + i; + let m = table_interaction + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let term = &m * &fingerprints[k]; + result_chunk[i] = if negate { -term } else { term }; + } } - } + None => { + // Dense path + let mut fingerprints: Vec> = Vec::with_capacity(chunk_len); + for row in chunk_start..chunk_start + chunk_len { + fingerprints.push(compute_fp(row)); + + #[cfg(feature = "debug-checks")] + { + let mut base_elements: Vec> = vec![bus_id_f.clone()]; + base_elements.extend(table_interaction.values.iter().flat_map(|bv| { + bv.combine_from(|col| main_segment_cols[col][row].clone()) + })); + let multiplicity = table_interaction + .multiplicity + .evaluate_at_row(main_segment_cols, row); + crate::bus_debug::log_interaction( + _table_name, + row, + table_interaction.bus_id, + table_interaction.is_sender, + &multiplicity.canonical(), + &base_elements, + fingerprints.last().unwrap(), + ); + } + } - // Phase 2: Batch-invert - FieldElement::inplace_batch_inverse(&mut fingerprints) - .expect("fingerprint is zero - probability of sampling zero is negligible"); + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); - // Phase 3: Compute terms - for (i, result_elem) in result_chunk.iter_mut().enumerate() { - let row = chunk_start + i; - let m = table_interaction - .multiplicity - .evaluate_at_row(main_segment_cols, row); - let term = &m * &fingerprints[i]; - *result_elem = if negate { -term } else { term }; + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let m = table_interaction + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let term = &m * &fingerprints[i]; + *result_elem = if negate { -term } else { term }; + } + } } }; @@ -1583,52 +1721,130 @@ where let process_chunk = |chunk_start: usize, result_chunk: &mut [FieldElement]| { let chunk_len = result_chunk.len(); - // Phase 1: Compute fingerprints for both interactions - let compute_fps = |interaction: &BusInteraction, - bus_id_f: &FieldElement, - fps: &mut Vec>| { - for row in chunk_start..chunk_start + chunk_len { - let mut lc = bus_id_f * &alpha_powers[0]; - let mut alpha_offset = 1; - for bv in &interaction.values { - let consumed = bv.accumulate_fingerprint( - main_segment_cols, - row, - &alpha_powers, - alpha_offset, - &mut lc, - &shifts, - ); - alpha_offset += consumed; + // Build per-interaction active-row index lists. `None` means dense + // (treat all rows as active); `Some(idx)` means process only `idx`. + let active_a = sparse_active_rows( + &interaction_a.multiplicity, + main_segment_cols, + chunk_start, + chunk_len, + ); + let active_b = sparse_active_rows( + &interaction_b.multiplicity, + main_segment_cols, + chunk_start, + chunk_len, + ); + + // Dense fast path: preserves the original contiguous memory access + // pattern when sparse gather would be pure overhead. + if active_a.is_none() && active_b.is_none() { + let compute_fps = |interaction: &BusInteraction, + bus_id_f: &FieldElement, + fps: &mut Vec>| { + for row in chunk_start..chunk_start + chunk_len { + let mut lc = bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + fps.push(z - &lc); } - fps.push(z - &lc); + }; + + let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); + compute_fps(interaction_a, &bus_id_a, &mut fingerprints); + compute_fps(interaction_b, &bus_id_b, &mut fingerprints); + + FieldElement::inplace_batch_inverse(&mut fingerprints) + .expect("fingerprint is zero - probability of sampling zero is negligible"); + + for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let row = chunk_start + i; + let fp_a_inv = &fingerprints[i]; + let fp_b_inv = &fingerprints[chunk_len + i]; + let m_a = interaction_a + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let m_b = interaction_b + .multiplicity + .evaluate_at_row(main_segment_cols, row); + let term_a = &m_a * fp_a_inv; + let term_b = &m_b * fp_b_inv; + let term_a = if negate_a { -term_a } else { term_a }; + let term_b = if negate_b { -term_b } else { term_b }; + *result_elem = term_a + term_b; } + return; + } + + // Sparse (or half-sparse) path: gather fingerprints only for active rows. + let compute_fp = |interaction: &BusInteraction, bus_id_f: &FieldElement, row: usize| { + let mut lc = bus_id_f * &alpha_powers[0]; + let mut alpha_offset = 1; + for bv in &interaction.values { + let consumed = bv.accumulate_fingerprint( + main_segment_cols, + row, + &alpha_powers, + alpha_offset, + &mut lc, + &shifts, + ); + alpha_offset += consumed; + } + z - &lc + }; + + let indices_a: Vec = match &active_a { + Some(v) => v.clone(), + None => (0..chunk_len).collect(), + }; + let indices_b: Vec = match &active_b { + Some(v) => v.clone(), + None => (0..chunk_len).collect(), }; - let mut fingerprints: Vec> = Vec::with_capacity(2 * chunk_len); - compute_fps(interaction_a, &bus_id_a, &mut fingerprints); - compute_fps(interaction_b, &bus_id_b, &mut fingerprints); + let mut fingerprints: Vec> = + Vec::with_capacity(indices_a.len() + indices_b.len()); + for &i in &indices_a { + fingerprints.push(compute_fp(interaction_a, &bus_id_a, chunk_start + i)); + } + for &i in &indices_b { + fingerprints.push(compute_fp(interaction_b, &bus_id_b, chunk_start + i)); + } - // Phase 2: Batch-invert FieldElement::inplace_batch_inverse(&mut fingerprints) .expect("fingerprint is zero - probability of sampling zero is negligible"); - // Phase 3: Compute terms - for (i, result_elem) in result_chunk.iter_mut().enumerate() { + let (fp_a_inv, fp_b_inv) = fingerprints.split_at(indices_a.len()); + + // Scatter a-terms (overwrite: result_chunk starts at zero). + for (k, &i) in indices_a.iter().enumerate() { let row = chunk_start + i; - let fp_a_inv = &fingerprints[i]; - let fp_b_inv = &fingerprints[chunk_len + i]; let m_a = interaction_a .multiplicity .evaluate_at_row(main_segment_cols, row); + let term = &m_a * &fp_a_inv[k]; + result_chunk[i] = if negate_a { -term } else { term }; + } + // Add b-terms on top of whatever a left behind (possibly zero). + for (k, &i) in indices_b.iter().enumerate() { + let row = chunk_start + i; let m_b = interaction_b .multiplicity .evaluate_at_row(main_segment_cols, row); - let term_a = &m_a * fp_a_inv; - let term_b = &m_b * fp_b_inv; - let term_a = if negate_a { -term_a } else { term_a }; - let term_b = if negate_b { -term_b } else { term_b }; - *result_elem = term_a + term_b; + let term = &m_b * &fp_b_inv[k]; + let term = if negate_b { -term } else { term }; + result_chunk[i] = &result_chunk[i] + term; } };