Skip to content

[feat] Add Spec Decoding and MTP training support#1832

Open
zanderjiang wants to merge 39 commits into
NovaSky-AI:mainfrom
zanderjiang:mtp
Open

[feat] Add Spec Decoding and MTP training support#1832
zanderjiang wants to merge 39 commits into
NovaSky-AI:mainfrom
zanderjiang:mtp

Conversation

@zanderjiang

Copy link
Copy Markdown
Contributor

This PR adds end-to-end Multi-Token Prediction support: train a model's native MTP head(s) with a decoupled draft loss on the Megatron backend, and reuse those heads for vLLM MTP speculative decoding to speed up rollout. The draft loss is autograd-decoupled, so it trains only the MTP head and never perturbs the policy.

Training (Megatron)

  • Decoupled draft loss (mtp/soft_ce.py): vocab-parallel soft-CE distillation against the policy's own detached next-token distribution, plus hard-CE and top-k variants; packed-sequence aware (masks cross-segment positions).
  • Decoupled hidden capture (mtp/hidden_capture.py): a forward pre-hook records the MTP block's inputs, then replays it on the detached trunk — the draft gradient reaches the head (and optionally the shared output/embedding) but never the policy backbone.
  • Native loss disabled (mtp/native_loss_patch.py): patches Megatron's in-forward process_mtp_loss to a no-op. Left active it back-props a hard-CE into the policy trunk (inflated grad-norm / entropy collapse); the "pass no labels to skip it" trick broke when megatron-core started deriving labels from input_ids. Fails loud on API drift.
  • Output projection (mtp/adapter.py): projects MTP hidden states through the shared (tied) or detached (untied) output layer.
  • Optional isolated optimizer (mtp/mtp_optim.py, megatron_config.mtp_separate_optimizer): give the head its own grad buffer + DistributedOptimizer so the policy's grad-norm/clip and distributed reduction stay byte-identical to a no-MTP build.
  • Integration: megatron_model_wrapper.py (the decoupled loss in forward_backward), megatron_worker.py (config wiring, per-worker head build/disable, head-optimizer checkpoint save/load), model_bridges.py (HF↔Megatron MTP weight round-trip).

Inference (vLLM)

  • speculative_config wired through to the engine (inference_servers/, vllm_engine.py); draft depth = num_speculative_tokens (depth > 1 reuses a single trained head autoregressively).
  • Per-step draft acceptance metrics (spec_decode_metrics.py): vllm/draft_acceptance_rate plus a per-position breakdown showing per-depth decay.
  • Weight sync keeps the draft head aligned with the trained policy each step.

Config

  • trainer.mtp.{enabled, num_speculative_tokens, loss_type, loss_weight} — single high-level knob; _apply_mtp_config propagates it to both training and inference.
  • megatron_config.mtp_*: mtp_num_layers, mtp_loss_type/weight/topk/chunk_size, mtp_detach_trunk, mtp_detach_shared_output, mtp_separate_optimizer.

Tests

  • Unit: soft_ce, adapter, hidden_capture, mtp_config, torch_utils, spec-decode metrics/client, vLLM CLI-arg wiring.
  • Integration: <test_name> (matches in-forward output), test_mtp_packed_vs_unpacked (packing parity), test_mtp_weight_roundtrip (HF↔Megatron).

Examples

examples/train/spec_decode/: MiMo-7B-RL DAPO, Qwen3.5-9B DAPO, and MiMo-7B SearchR1 — all with MTP speculative decoding.

zanderjiang and others added 30 commits June 3, 2026 01:05
Resolve conflicts across 6 files. Trivial additive merges:
- skyrl_train_backend.py / main_base.py / ray_wrapped_inference_engine.py:
  keep both speculative_config (mtp) and use_expandable_segments (main).
- megatron_worker.py: keep both enable_mtp (mtp) and language_model_only (main)
  kwargs on init_configs and its ref-worker call site.

Non-trivial (confirmed with author):
- megatron_model_wrapper.py loss scaling: adopt main's kl_entropy_microbatch_scale
  (drop the cp_size correction); fold the decoupled MTP draft_loss in with the same
  micro-batch correction.
- megatron_model_wrapper.py de-padding: adopt main's packed-logits path (keep logits
  packed under remove_microbatch_padding, loss_func uses packed_targets); depad() now
  only recovers left-padding. Add a guard that MTP requires remove_microbatch_padding=
  False (the teacher main-logits and de-padded student logits must share [b,s,v]).
- vllm_worker.py: keep main's bracketed/non-bracketed weight update + synchronize, and
  add the MTP/Eagle drafter reload after the main load in both paths.
…ull)

Isolate the MTP/draft head into its own Megatron DDP grad buffer + DistributedOptimizer
so the policy's distributed grad reduction is byte-identical to a no-MTP model, while the
head co-trains at full strength:
- mtp/separate_optim.py: pre-wrap freeze, SeparateMTPOptimizer (own DDP buffer + optimizer +
  scheduler), policy-finalize exclusion, hidden() to exclude the head from the policy grad-norm/clip.
- megatron_worker.py: wire C-full (freeze pre-wrap -> separate optimizer -> per-iter buffer zero ->
  separate grad-norm/clip + step), plus checkpoint sidecar save/load for the head optimizer state.
- config: mtp_separate_optimizer toggle (MegatronConfig).
- trainer.py: guard the rollout/train logprob-diff metric against an empty loss_mask batch.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Replace the brittle mark_pre_eval()/sample_split() approach (which relied on
the caller's generate/eval ordering and threaded generation-time values back
in) with an explicit start()/pause()/resume()/stop() window that owns its own
timing. The scraper accumulates only un-paused active time, so the trainer and
evaluate() just bracket their generation calls.

- start(label)/stop() return metrics nested under {label}/ (vllm/train/*,
  vllm/eval/*); pause/resume gate the throughput denominator.
- trainer: wrap train generation (handles dynamic sampling's repeated
  generates) and eval generation; evaluate()/evaluate_step_wise() resume/pause
  around each eval generation.
- fully-async sample() path unchanged.
- Rewrite split tests for the window API.
…olicy grad-norm/clip

Megatron's get_main_grads_for_grad_norm() collects a param's grad when `grad is not None` — it does
NOT check requires_grad. The previous hidden() only flipped requires_grad, so the head's already-
populated grad still entered the policy grad-norm, inflating policy/grad_norm (observed ~2.5 vs ~0.3
no-MTP) and over-clipping the policy via the shared max_grad_norm — a head->policy coupling through
the clip. hidden() now stashes+clears grad/main_grad/decoupled_grad for the policy step and restores
them before the separate MTP optimizer steps.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…key-patch

Root cause of the MiMo entropy collapse / inflated policy grad-norm: GPTModel/HybridModel.forward
call process_mtp_loss unconditionally when MTP heads are built, and a megatron-core update made it
derive labels from input_ids ("e.g. RL training") — so SkyRL's "pass no labels to short-circuit it"
no longer works. The native hard-CE MTP loss then back-props into the policy trunk (grad-norm 2.5 vs
0.3, entropy collapse), independent of mtp_loss_weight.

Replace process_mtp_loss with a no-op at its call sites (gpt_model, hybrid_model) so only SkyRL's
decoupled soft-CE loss trains the head. Applied at config time, before any forward. The patch raises
loudly if Megatron renames/removes the function rather than silently re-enabling the native loss.
Confirmed: with the native loss off, policy grad_norm drops 2.5 -> ~0.3 (matches no-MTP) and entropy
stops collapsing, while the decoupled head loss keeps training.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
zanderjiang and others added 8 commits June 24, 2026 16:36
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Bring in logger: VLLMMetricsScraper window API (train/eval rollout timing),
worker train/rollout logprob-diff metric, fully-async trainer updates, and
upstream main PRs. Conflicts resolved keeping the MTP fixes (native-loss patch,
C-full separate optimizer, hidden() grad-clear, empty-batch guard, MEAN_LOSS_METRICS)
and taking logger's updated vllm_metrics_scraper + its trainer/eval integration.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces decoupled Multi-Token Prediction (MTP) draft-head training and speculative decoding integration with Megatron and vLLM, including new training scripts, hidden state capture, isolated optimizers, custom draft-head losses, and weight reloading helpers. The review feedback highlights critical robustness and performance improvements in the newly added MTP modules. Specifically, it addresses potential deadlocks and crashes on tensor-parallel ranks with empty local vocabulary shards within the softmax, top-k soft CE, and one-hot logit generation functions. It also suggests vectorizing the segment-shifting mask operation to reduce CPU overhead and passing the weight list directly to the drafter model to prevent iterator exhaustion.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +96 to +118
def _vocab_parallel_softmax(vocab_parallel_logits, group):
"""Global softmax over a TP-sharded vocab dim. Allocates one full-vocab output (the ``- logits_max``
subtraction) and does the rest in place on it; the input is not mutated."""
logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True)
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group)
exp_logits = (vocab_parallel_logits - logits_max).exp_() # new buffer, then in place
sum_exp = exp_logits.sum(dim=-1, keepdim=True)
torch.distributed.all_reduce(sum_exp, op=torch.distributed.ReduceOp.SUM, group=group)
return exp_logits.div_(sum_exp)


def _vocab_parallel_log_softmax(vocab_parallel_logits, group):
"""Global log-softmax over a TP-sharded vocab dim. Uses ``torch.logsumexp`` (a fused
``[.,.,V]->[.,.,1]`` reduction) so it never materializes a full-vocab ``exp`` temporary -- multiple
GiB at a 248K vocab. Allocates one full-vocab output buffer; the input is not mutated."""
logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True)
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group)
shifted = vocab_parallel_logits - logits_max # new buffer (input untouched); values now <= 0
# exp() of the *local* logsumexp recovers this shard's sum-of-exp, which reduces across TP.
local_sum_exp = torch.logsumexp(shifted, dim=-1, keepdim=True).exp_()
torch.distributed.all_reduce(local_sum_exp, op=torch.distributed.ReduceOp.SUM, group=group)
return shifted.sub_(local_sum_exp.log_()) # in place on the owned buffer -> log-softmax

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When the vocabulary size is not a multiple of tp_size, Megatron pads the vocabulary. This padding can result in some ranks having an empty local vocabulary shard (vocab_parallel_logits.shape[-1] == 0). In such cases, calling torch.amax or torch.logsumexp on an empty tensor will raise a RuntimeError. Furthermore, if some ranks skip the collective all_reduce calls, the entire training run will deadlock.\n\nWe should make these functions robust to empty local shards by participating in the collective communication with safe dummy values (e.g., -inf for max, 0 for sum).

def _vocab_parallel_softmax(vocab_parallel_logits, group):\n    \"\"\"Global softmax over a TP-sharded vocab dim. Allocates one full-vocab output (the ``- logits_max``\n    subtraction) and does the rest in place on it; the input is not mutated.\"\"\"\n    if vocab_parallel_logits.shape[-1] == 0:\n        logits_max = torch.full((*vocab_parallel_logits.shape[:-1], 1), float('-inf'), dtype=vocab_parallel_logits.dtype, device=vocab_parallel_logits.device)\n    else:\n        logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True)\n    torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group)\n    if vocab_parallel_logits.shape[-1] == 0:\n        exp_logits = torch.empty_like(vocab_parallel_logits)\n        sum_exp = torch.zeros((*vocab_parallel_logits.shape[:-1], 1), dtype=vocab_parallel_logits.dtype, device=vocab_parallel_logits.device)\n    else:\n        exp_logits = (vocab_parallel_logits - logits_max).exp_()  # new buffer, then in place\n        sum_exp = exp_logits.sum(dim=-1, keepdim=True)\n    torch.distributed.all_reduce(sum_exp, op=torch.distributed.ReduceOp.SUM, group=group)\n    return exp_logits.div_(sum_exp)\n\n\ndef _vocab_parallel_log_softmax(vocab_parallel_logits, group):\n    \"\"\"Global log-softmax over a TP-sharded vocab dim. Uses ``torch.logsumexp`` (a fused\n    ``[.,.,V]->[.,.,1]`` reduction) so it never materializes a full-vocab ``exp`` temporary -- multiple\n    GiB at a 248K vocab. Allocates one full-vocab output buffer; the input is not mutated.\"\"\"\n    if vocab_parallel_logits.shape[-1] == 0:\n        logits_max = torch.full((*vocab_parallel_logits.shape[:-1], 1), float('-inf'), dtype=vocab_parallel_logits.dtype, device=vocab_parallel_logits.device)\n    else:\n        logits_max = torch.amax(vocab_parallel_logits, dim=-1, keepdim=True)\n    torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group)\n    if vocab_parallel_logits.shape[-1] == 0:\n        shifted = torch.empty_like(vocab_parallel_logits)\n        local_sum_exp = torch.zeros((*vocab_parallel_logits.shape[:-1], 1), dtype=vocab_parallel_logits.dtype, device=vocab_parallel_logits.device)\n    else:\n        shifted = vocab_parallel_logits - logits_max  # new buffer (input untouched); values now <= 0\n        # exp() of the *local* logsumexp recovers this shard's sum-of-exp, which reduces across TP.\n        local_sum_exp = torch.logsumexp(shifted, dim=-1, keepdim=True).exp_()\n    torch.distributed.all_reduce(local_sum_exp, op=torch.distributed.ReduceOp.SUM, group=group)\n    return shifted.sub_(local_sum_exp.log_())  # in place on the owned buffer -> log-softmax

Comment on lines +253 to +309
@staticmethod
def forward(ctx, student_logits, teacher_logits, k, group, roll_shift):
ws = torch.distributed.get_world_size(group) if group is not None else 1

def ar(t, op):
if ws > 1:
torch.distributed.all_reduce(t, op=op, group=group)
return t

k_eff = min(int(k), teacher_logits.shape[-1])
# Teacher's local top-k (teacher is detached). When roll_shift != 0, teacher_logits is the
# UN-rolled policy logits: position t's draft target is the policy distribution at t+roll_shift,
# so we top-k the policy logits once (no full-vocab copy) and roll only the small [B, S, k]
# result -- avoiding a ~[S, vocab] rolled-teacher copy. The wrapped boundary positions are
# zeroed by the caller's shifted loss mask.
t_vals, t_idx = teacher_logits.topk(k_eff, dim=-1)
if roll_shift:
t_vals = torch.roll(t_vals, shifts=-int(roll_shift), dims=1)
t_idx = torch.roll(t_idx, shifts=-int(roll_shift), dims=1)
t_vals = t_vals.float()
# student at position t, teacher's top-k indices for position t (already rolled).
s_vals = student_logits.gather(-1, t_idx).float()

# Stable-softmax shift = global max over the union (across the TP group).
t_max = ar(t_vals.max(dim=-1, keepdim=True).values.clone(), torch.distributed.ReduceOp.MAX)
s_max = ar(s_vals.max(dim=-1, keepdim=True).values.clone(), torch.distributed.ReduceOp.MAX)

# Teacher probs over the union (denominator summed across the group).
t_exp = (t_vals - t_max).exp()
t_denom = ar(t_exp.sum(dim=-1, keepdim=True), torch.distributed.ReduceOp.SUM)
t_p = t_exp / t_denom

# Student probs / log-probs over the union (denominator summed across the group).
s_exp = (s_vals - s_max).exp()
s_denom = ar(s_exp.sum(dim=-1, keepdim=True), torch.distributed.ReduceOp.SUM)
q_s = s_exp / s_denom
s_logp = (s_vals - s_max) - s_denom.log()

# Per-rank partial CE summed across the group -> soft CE over the global union.
per_token_loss = ar(-(t_p * s_logp).sum(dim=-1), torch.distributed.ReduceOp.SUM)

ctx.save_for_backward(q_s, t_p, t_idx)
ctx.vocab_size = student_logits.shape[-1]
ctx.input_dtype = student_logits.dtype
return per_token_loss

@staticmethod
def backward(ctx, grad_output):
q_s, t_p, t_idx = ctx.saved_tensors
# d(H(p,q))/d(student_logit_v) = softmax(student)_v - softmax(teacher)_v, over the union; zero
# elsewhere. Scatter the k per-token grads back to this rank's own vocab columns. The full-vocab
# buffer is allocated directly in the input dtype (the only fp32->input cast is the tiny [.., k]
# grad); an fp32 buffer here would double the largest transient of the whole loss.
grad_topk = (q_s - t_p) * grad_output.unsqueeze(-1)
grad_student = torch.zeros(*t_idx.shape[:-1], ctx.vocab_size, dtype=ctx.input_dtype, device=grad_topk.device)
grad_student.scatter_(-1, t_idx, grad_topk.to(ctx.input_dtype))
return grad_student, None, None, None, None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the exact softmax, the top-k soft CE implementation can crash or deadlock if a rank has an empty local vocabulary shard (teacher_logits.shape[-1] == 0), leading to k_eff == 0. We should add a guard to handle k_eff == 0 gracefully by participating in the collective communication with safe dummy values and returning early.

    @staticmethod\n    def forward(ctx, student_logits, teacher_logits, k, group, roll_shift):\n        ws = torch.distributed.get_world_size(group) if group is not None else 1\n\n        def ar(t, op):\n            if ws > 1:\n                torch.distributed.all_reduce(t, op=op, group=group)\n            return t\n\n        k_eff = min(int(k), teacher_logits.shape[-1])\n\n        if k_eff == 0:\n            t_idx = torch.empty((*teacher_logits.shape[:-1], 0), dtype=torch.long, device=teacher_logits.device)\n            t_p = torch.empty((*teacher_logits.shape[:-1], 0), dtype=torch.float32, device=teacher_logits.device)\n            q_s = torch.empty((*student_logits.shape[:-1], 0), dtype=torch.float32, device=student_logits.device)\n            per_token_loss = torch.zeros(teacher_logits.shape[:-1], dtype=torch.float32, device=teacher_logits.device)\n            ar(per_token_loss, torch.distributed.ReduceOp.SUM)\n            \n            ctx.save_for_backward(q_s, t_p, t_idx)\n            ctx.vocab_size = student_logits.shape[-1]\n            ctx.input_dtype = student_logits.dtype\n            return per_token_loss\n\n        # Teacher's local top-k (teacher is detached). When roll_shift != 0, teacher_logits is the\n        # UN-rolled policy logits: position t's draft target is the policy distribution at t+roll_shift,\n        # so we top-k the policy logits once (no full-vocab copy) and roll only the small [B, S, k]\n        # result -- avoiding a ~[S, vocab] rolled-teacher copy. The wrapped boundary positions are\n        # zeroed by the caller's shifted loss mask.\n        t_vals, t_idx = teacher_logits.topk(k_eff, dim=-1)\n        if roll_shift:\n            t_vals = torch.roll(t_vals, shifts=-int(roll_shift), dims=1)\n            t_idx = torch.roll(t_idx, shifts=-int(roll_shift), dims=1)\n        t_vals = t_vals.float()\n        # student at position t, teacher's top-k indices for position t (already rolled).\n        s_vals = student_logits.gather(-1, t_idx).float()\n\n        # Stable-softmax shift = global max over the union (across the TP group).\n        t_max = ar(t_vals.max(dim=-1, keepdim=True).values.clone(), torch.distributed.ReduceOp.MAX)\n        s_max = ar(s_vals.max(dim=-1, keepdim=True).values.clone(), torch.distributed.ReduceOp.MAX)\n\n        # Teacher probs over the union (denominator summed across the group).\n        t_exp = (t_vals - t_max).exp()\n        t_denom = ar(t_exp.sum(dim=-1, keepdim=True), torch.distributed.ReduceOp.SUM)\n        t_p = t_exp / t_denom\n\n        # Student probs / log-probs over the union (denominator summed across the group).\n        s_exp = (s_vals - s_max).exp()\n        s_denom = ar(s_exp.sum(dim=-1, keepdim=True), torch.distributed.ReduceOp.SUM)\n        q_s = s_exp / s_denom\n        s_logp = (s_vals - s_max) - s_denom.log()\n\n        # Per-rank partial CE summed across the group -> soft CE over the global union.\n        per_token_loss = ar(-(t_p * s_logp).sum(dim=-1), torch.distributed.ReduceOp.SUM)\n\n        ctx.save_for_backward(q_s, t_p, t_idx)\n        ctx.vocab_size = student_logits.shape[-1]\n        ctx.input_dtype = student_logits.dtype\n        return per_token_loss\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        q_s, t_p, t_idx = ctx.saved_tensors\n        if t_idx.shape[-1] == 0:\n            grad_student = torch.zeros(*t_idx.shape[:-1], ctx.vocab_size, dtype=ctx.input_dtype, device=t_idx.device)\n            return grad_student, None, None, None, None\n        # d(H(p,q))/d(student_logit_v) = softmax(student)_v - softmax(teacher)_v, over the union; zero\n        # elsewhere. Scatter the k per-token grads back to this rank's own vocab columns. The full-vocab\n        # buffer is allocated directly in the input dtype (the only fp32->input cast is the tiny [.., k]\n        # grad); an fp32 buffer here would double the largest transient of the whole loss.\n        grad_topk = (q_s - t_p) * grad_output.unsqueeze(-1)\n        grad_student = torch.zeros(*t_idx.shape[:-1], ctx.vocab_size, dtype=ctx.input_dtype, device=grad_topk.device)\n        grad_student.scatter_(-1, t_idx, grad_topk.to(ctx.input_dtype))\n        return grad_student, None, None, None, None

Comment on lines +59 to +74
# Packed: roll each [start, end) segment independently, zeroing its last `shift` positions.
bounds = cu_seqlens.tolist() if torch.is_tensor(cu_seqlens) else list(cu_seqlens)
rolled = torch.zeros_like(mask)
for i in range(len(bounds) - 1):
start, end = int(bounds[i]), int(bounds[i + 1])
seg_len = end - start
if seg_len <= 0:
continue
seg_rolled = torch.roll(mask[:, start:end], shifts=-shift, dims=1)
# Zero the wrapped tail; the whole segment if it is no longer than the shift.
if shift < seg_len:
seg_rolled[:, -shift:] = 0
else:
seg_rolled.zero_()
rolled[:, start:end] = seg_rolled
return mask * rolled

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of shift_mask_for_mtp loops over each segment in Python when sample packing is enabled (cu_seqlens is not None). For batches with many packed sequences, this Python loop launches multiple CUDA kernels per segment, leading to significant CPU overhead and GPU starvation.\n\nWe can fully vectorize this operation using torch.bucketize to determine segment boundaries, avoiding any Python loops and drastically improving performance.

Suggested change
# Packed: roll each [start, end) segment independently, zeroing its last `shift` positions.
bounds = cu_seqlens.tolist() if torch.is_tensor(cu_seqlens) else list(cu_seqlens)
rolled = torch.zeros_like(mask)
for i in range(len(bounds) - 1):
start, end = int(bounds[i]), int(bounds[i + 1])
seg_len = end - start
if seg_len <= 0:
continue
seg_rolled = torch.roll(mask[:, start:end], shifts=-shift, dims=1)
# Zero the wrapped tail; the whole segment if it is no longer than the shift.
if shift < seg_len:
seg_rolled[:, -shift:] = 0
else:
seg_rolled.zero_()
rolled[:, start:end] = seg_rolled
return mask * rolled
# Packed: roll each [start, end) segment independently, zeroing its last `shift` positions.\n positions = torch.arange(mask.shape[1], device=mask.device)\n cu_seqlens_dev = cu_seqlens.to(mask.device)\n segment_ids = torch.bucketize(positions, cu_seqlens_dev)\n shifted_positions = (positions + shift).clamp(max=mask.shape[1] - 1)\n boundary_mask = (segment_ids == segment_ids[shifted_positions]) & ((positions + shift) < mask.shape[1])\n rolled = torch.roll(mask, shifts=-shift, dims=1)\n rolled[:, ~boundary_mask] = 0\n return mask * rolled

Comment on lines +342 to +363
def _onehot_vp_logits(
labels: torch.Tensor,
like: torch.Tensor,
vocab_start_index: int,
) -> torch.Tensor:
"""Build vocab-parallel logits whose global softmax is a one-hot over ``labels``.

Each rank holds ``vocab_size`` columns starting at ``vocab_start_index``. The
column matching the label (on whichever rank owns it) is set high and all
others low, so a *global* softmax across the TP group recovers the one-hot
distribution. Reused to express hard cross-entropy as soft cross-entropy with
a one-hot teacher.
"""
vocab_size = like.shape[-1]
local_idx = labels.long() - vocab_start_index # [batch, seq]
holds = (local_idx >= 0) & (local_idx < vocab_size)
onehot = torch.full_like(like, -30.0)
safe_idx = local_idx.clamp(0, vocab_size - 1).unsqueeze(-1)
hot = torch.where(holds.unsqueeze(-1), 30.0, -30.0).to(like.dtype)
onehot.scatter_(-1, safe_idx, hot)
return onehot

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If vocab_size is 0 (which happens on ranks with empty vocabulary shards), local_idx.clamp(0, vocab_size - 1) will clamp to [0, -1], which is invalid and will raise an error. We should add a guard to return early when vocab_size == 0.

def _onehot_vp_logits(\n    labels: torch.Tensor,\n    like: torch.Tensor,\n    vocab_start_index: int,\n) -> torch.Tensor:\n    \"\"\"Build vocab-parallel logits whose global softmax is a one-hot over ``labels``.\n\n    Each rank holds ``vocab_size`` columns starting at ``vocab_start_index``. The\n    column matching the label (on whichever rank owns it) is set high and all\n    others low, so a *global* softmax across the TP group recovers the one-hot\n    distribution. Reused to express hard cross-entropy as soft cross-entropy with\n    a one-hot teacher.\n    \"\"\"\n    vocab_size = like.shape[-1]\n    if vocab_size == 0:\n        return torch.empty_like(like)\n    local_idx = labels.long() - vocab_start_index  # [batch, seq]\n    holds = (local_idx >= 0) & (local_idx < vocab_size)\n    onehot = torch.full_like(like, -30.0)\n    safe_idx = local_idx.clamp(0, vocab_size - 1).unsqueeze(-1)\n    hot = torch.where(holds.unsqueeze(-1), 30.0, -30.0).to(like.dtype)\n    onehot.scatter_(-1, safe_idx, hot)\n    return onehot

Comment on lines +35 to +42
drafter = getattr(model_runner, "drafter", None)
drafter_model = getattr(drafter, "model", None)
if drafter_model is None or not hasattr(drafter_model, "load_weights"):
# No spec decoding, or a drafter without a weight-loadable model (e.g. ngram).
return False
weights: Iterable[Tuple[str, torch.Tensor]] = iter(weight_list)
drafter_model.load_weights(weights)
return True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Converting weight_list to a one-shot iterator using iter(weight_list) before passing it to drafter_model.load_weights can cause issues if the underlying load_weights implementation needs to iterate over the weights multiple times (e.g., to load different parameter groups or perform multi-pass validation). Since weight_list is already a list (which is a reusable Iterable), we should pass it directly to avoid exhausting the iterator.

    drafter = getattr(model_runner, "drafter", None)\n    drafter_model = getattr(drafter, "model", None)\n    if drafter_model is None or not hasattr(drafter_model, "load_weights"):\n        # No spec decoding, or a drafter without a weight-loadable model (e.g. ngram).\n        return False\n    drafter_model.load_weights(weight_list)\n    return True

Resolved conflicts:
- utils.py: union both env-var forwarding blocks (mtp UV/debug vars + main's inference-server
  health-check timeout)
- evaluate.py: keep the eval-timing comment
- test_vllm_metrics_scraper.py: drop the stale sample_split/mark_pre_eval tests -- mtp's scraper
  now uses the start/pause/resume window API (consistent with main's canonical scraper)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@erictang000 erictang000 mentioned this pull request Jun 29, 2026
33 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant