Skip to content

Fix top-k ternary search timeout fallback producing inconsistent pivot/stats in Triton kernel#10

Draft
zhenwei-intel with Copilot wants to merge 2 commits into
mainfrom
copilot/fix-timeout-range-collapse-bug
Draft

Fix top-k ternary search timeout fallback producing inconsistent pivot/stats in Triton kernel#10
zhenwei-intel with Copilot wants to merge 2 commits into
mainfrom
copilot/fix-timeout-range-collapse-bug

Conversation

Copilot AI commented Jun 23, 2026

Copy link
Copy Markdown

The top-k ternary search in _topk_topp_kernel had a correctness bug in its timeout/range-collapse fallback: k_pivot was set to a fresh midpoint never evaluated during the search, while min_larger/num_min_larger pointed at k_pivot_0 (1/3-position) and k_pivots_num retained its initial value of 0. This broke the invariant that all four values must describe the same threshold, causing uint32 underflow in num_keep and incorrect surviving-token counts when tied logits exist near the truncation boundary.

Changes

vllm/v1/sample/ops/topk_topp_triton.py

  • Both fallback branches (buffer-path ~L279 and full-vocab-path ~L373): instead of fabricating a midpoint pivot, snap to the best already-evaluated candidate — prefer k_pivot_1 when k_pivots_num_1 >= k, otherwise use k_pivot_0. All four variables (k_pivot, k_pivots_num, min_larger, num_min_larger) are now assigned together from the same evaluated pivot:
# Before
if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9:
    k_pivot = (max_range + min_range) / 2.0  # never evaluated
    min_larger = min_larger_0                 # mismatched pivot
    num_min_larger = num_min_larger_0         # mismatched pivot
    # k_pivots_num left at 0 → uint32 underflow downstream
    found_pivot = 1

# After
if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9:
    if k_pivots_num_1 >= k:
        k_pivot = k_pivot_1
        k_pivots_num = k_pivots_num_1
        min_larger = min_larger_1
        num_min_larger = num_min_larger_1
    else:
        k_pivot = k_pivot_0
        k_pivots_num = k_pivots_num_0
        min_larger = min_larger_0
        num_min_larger = num_min_larger_0
    found_pivot = 1
  • num_keep computation: guard against uint32 underflow with tl.where:
excess = tl.where(k_pivots_num > k, k_pivots_num - k, 0)
num_keep = num_duplicate_logit - excess

tests/v1/sample/test_topk_topp_sampler.py

  • Added test_topk_tied_logits_exact_count (parametrized over k∈{5,10,20}): constructs inputs where k+50 tokens share identical logit values, forcing the search interval to collapse and the fallback branch to fire. Asserts surviving-token count exactly matches the PyTorch reference implementation.

Copilot AI changed the title [WIP] Fix correctness bug in top-k ternary search fallback Fix top-k ternary search timeout fallback producing inconsistent pivot/stats in Triton kernel Jun 23, 2026
Copilot AI requested a review from zhenwei-intel June 23, 2026 06:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants