Skip to content

[megatron] Enable Nemotron-3-Ultra-550B GRPO RL + fix multi-rank (EP>16/PP>2) weight sync#1816

Open
erictang000 wants to merge 5 commits into
NovaSky-AI:mainfrom
erictang000:nemotron-3-ultra-550b-rl
Open

[megatron] Enable Nemotron-3-Ultra-550B GRPO RL + fix multi-rank (EP>16/PP>2) weight sync#1816
erictang000 wants to merge 5 commits into
NovaSky-AI:mainfrom
erictang000:nemotron-3-ultra-550b-rl

Conversation

@erictang000

Copy link
Copy Markdown
Collaborator

What

Enables end-to-end full-finetuning GRPO RL of NVIDIA-Nemotron-3-Ultra-550B-A55B (NemotronH hybrid Mamba2 + attention, latent MoE with 512 experts, reasoning model) colocated with vLLM on 8× nodes of 8×H200 (64 GPUs, EFA) — and fixes the weight-sync/reload correctness bugs that block it (and other large MoE models).

Validated: trains end-to-end with the included recipe — avg_raw_reward ≈ 0.9, GSM8K eval ≈ 0.94, grad_norm > 0. Megatron mesh TP8 / PP4 / EP16 / ETP1 (DP2); vLLM TP8 × PP4.

Replication guide: examples/train/megatron/README_nemotron_ultra.md.

Why (root cause)

vLLM produced coherent-looking garbage after every weight sync → all rewards 0 → no learning. The bridge export was proven bit-correct (0 mismatches over 108k tensors), which localized the bug to the CUDA-IPC weight transport:

Each Megatron rank packs its own contiguous buffer (different params/order per rank — expert chunks carry per-EP-rank names) and registers one IPC handle per physical GPU, but only rank-0's names/sizes/shapes slicing metadata was sent. Each vLLM worker rebuilt its own GPU's buffer yet sliced it with rank-0's metadata → correct bytes loaded under the wrong names → coherent-but-garbage, no crash. Identical layout across PP ranks (so it worked at PP=2) but divergent at PP>2 / EP>16.

This is general to any MoE synced at EP>16/PP>2, not just Nemotron.

Changes

  • Weight-sync transport fix (core): send per-GPU slicing metadata; each vLLM worker slices its own buffer with its own metadata. Verified: EP16/PP4 post-sync logprob diff 2.0 → 0.15. (cuda_ipc_strategy.py, new_inference_worker_wrap.py)
  • fp32 MoE router bias (gate.e_score_correction_bias) preserved through sync — bf16 ULP at its ~25–57 magnitude collapses the tiny per-expert offsets (std ~7e-4) and corrupts routing. (megatron_worker.py)
  • vLLM layerwise-reload guard (cf. [Bugfix] Fix layerwise reload dropping params after a composed weight loader vllm-project/vllm#44814): the composed_weight_loader double-count finalizes a layer early and drops Mamba mixer.D (uninitialized → NaN). Guarded monkeypatch mirroring the existing conv_weights workaround; remove once on a vLLM that includes #44814. (layerwise_reload.py)
  • Worker env forwarding: HF_*/cache dirs and SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S → Ray worker actors. (utils.py, GPU-CI conftest.py)
  • Reasoning-aware GSM8K reward: strip <think>, score strict #### <n> else last-number with comma/$ normalization. (skyrl_gym/envs/gsm8k/env.py)
  • Example + tooling: run_megatron_nemotron_ultra.sh, stage_nemotron_ultra.py (model+data staging incl. chat_template.jinja), README_nemotron_ultra.md.
  • Tests: Nemotron-Ultra logprob round-trip params — EP16/PP2 baseline (passes) + EP32/PP4 (reproduced the bug pre-fix; PP4 now passes post-fix). All h100-marked / gated on the 550B model.

Testing

  • test_logprobs_matching_roundtrip[nemotron3-ultra…] round-trip on 8×8 H200 (EP16/PP2 and EP16/PP4 pass; pre-sync Megatron↔vLLM ~0.06, post-sync ~0.15–0.30).
  • Full GRPO training run: reward ~0.9, GSM8K eval ~0.94, non-zero grad_norm.

Note

The get_numel_loaded and conv_weights reload patches are temporary vLLM workarounds (pending #44814 and #42481); the per-GPU-metadata transport fix and the rest are permanent.

🤖 Generated with Claude Code

…ht sync

Adds an end-to-end full-finetuning GRPO recipe for NVIDIA-Nemotron-3-Ultra-550B
(NemotronH hybrid Mamba2+attention, latent MoE, reasoning) colocated with vLLM on
8x8 H200 (EFA), plus the weight-sync/reload correctness fixes it depends on.
Validated: avg_raw_reward ~0.9, GSM8K eval ~0.94, grad_norm > 0.

Core fix (general, affects any MoE synced at EP>16 / PP>2): the CUDA-IPC weight
transport sent only rank-0's per-param slicing metadata, but each Megatron rank packs
its own (per-rank-divergent) buffer. Each vLLM worker rebuilt its own GPU's buffer yet
sliced it with rank-0's metadata -> correct bytes loaded under wrong names -> coherent-
but-garbage generations and reward stuck at 0. Now sends per-GPU metadata and each
worker slices its own buffer (cuda_ipc_strategy.py, new_inference_worker_wrap.py).

Also included:
- Preserve fp32 for the MoE router bias (gate.e_score_correction_bias) through sync;
  bf16 ULP at its ~25-57 magnitude collapses the per-expert offsets and corrupts routing.
- Guard vLLM layerwise-reload get_numel_loaded (cf. vllm-project/vllm#44814): the
  composed-weight-loader double-count finalizes a layer early and drops Mamba mixer.D
  (uninitialized -> NaN). Mirrors the existing conv_weights reload workaround.
- Forward HF_*/cache dirs and SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S to
  Ray worker actors (prepare_runtime_environment + GPU-CI conftest).
- Reasoning-aware GSM8K reward: strip the <think> trace, score strict `#### <n>` else
  last-number with comma/$ normalization.
- Nemotron-Ultra logprob round-trip test params (EP16/PP2 baseline; EP32/PP4 regressions).
- Example recipe (run_megatron_nemotron_ultra.sh), staging helper, and README.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for training the Nemotron-3-Ultra-550B model using GRPO RL on GSM8K with Megatron and vLLM. Key changes include a new multi-node launch script, a staging script, robust GSM8K reward parsing for reasoning models, and critical fixes for CUDA-IPC weight synchronization, vLLM layerwise-reload, and native fp32 precision syncing for MoE router biases. The review feedback suggests improving the number normalization function to handle float representations of integers, simplifying regex group extraction in the staging script, and double-quoting the arguments in the shell script to prevent word splitting.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +10 to +12
def _norm_num(s: str) -> str:
"""Normalize a parsed number to compare against the (comma-free integer) ground truth."""
return s.strip().rstrip(".").replace(",", "").replace("$", "")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _norm_num does not normalize float representations of integers (e.g., '72.0' or '72.'). If the model outputs a decimal representation of an integer, the comparison against the ground truth (which is typically a clean integer string like '72') will fail, leading to a reward of 0.0 instead of 1.0.

We can make this more robust by attempting to parse the value as a float and checking if it represents an integer.

Suggested change
def _norm_num(s: str) -> str:
"""Normalize a parsed number to compare against the (comma-free integer) ground truth."""
return s.strip().rstrip(".").replace(",", "").replace("$", "")
def _norm_num(s: str) -> str:
"""Normalize a parsed number to compare against the (comma-free integer) ground truth."""
val = s.strip().rstrip(".").replace(",", "").replace("$", "")
try:
f_val = float(val)
if f_val.is_integer():
return str(int(f_val))
except ValueError:
pass
return val


def to_row(example, idx, split):
q = example["question"]
sol = re.search(r"#### (\-?[0-9\.\,]+)", example["answer"]).group(0).split("#### ")[1].replace(",", "")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using group(0) followed by .split('#### ')[1] is redundant and can be simplified by directly accessing the captured group group(1). Additionally, wrapping this in a safe check prevents potential AttributeError if re.search ever returns None.

            match = re.search(r"#### (\-?[0-9\.\,]+)", example["answer"])
            sol = match.group(1).replace(",", "") if match else ""

trainer.max_ckpts_to_keep=3 \
trainer.ckpt_interval=20 \
trainer.ckpt_path="$HOME/ckpts/gsm8k_nemotron_ultra_ckpt" \
$@

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In shell scripts, it is highly recommended to double-quote "$@" to preserve arguments containing spaces or special characters and prevent word splitting.

Suggested change
$@
"$@"

erictang000 and others added 2 commits June 20, 2026 12:43
…dings

Adds a sweep harness (real Megatron fwd+bwd on fabricated rollouts, no vLLM
generation) to map, on 64xH200, the max max_tokens_per_microbatch and the
parallelism (TP/PP/CP/EP/DP) that maximizes full-FT GRPO training throughput
for NVIDIA-Nemotron-3-Ultra-550B, plus a long-context (variable-length) study.

Findings (examples/train/megatron/NEMOTRON_ULTRA_THROUGHPUT.md):
- Max MTPM ~= 64k tokens/microbatch at the validated TP8/PP4/EP16/DP2 config.
- Highest throughput: TP8/PP2/EP32/DP4 (~8.5k tok/s, +11% over PP4/DP2) for
  short/medium seqs; config space is pinned (PP8 invalid for 108 layers, EP8
  OOMs, TP4 doubles activations via sequence parallelism).
- Long context: single-sequence ceiling ~40-48k tokens (CP1/PP4/DP2); CP gives
  little net benefit (CP=2 forces PP2 whose 2x weights cancel the savings). Long
  seqs are more throughput-efficient per token (~12k tok/s at ~39k mean).

New files:
- examples/train_scripts/full_context/{trainer_ultra_sweep,main_ultra_sweep,analyze_sweep}.py
- examples/train/megatron/run_ultra_sweep.sh
- examples/train/megatron/NEMOTRON_ULTRA_THROUGHPUT.md

worker.py: get_cuda_memory() now also returns max_allocated/max_reserved
high-water marks (capture in-step peak even when queried after offload).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Follow-up to the throughput sweep. CP composes with EP in Megatron-Core
(EP divides TP*CP*DP, with ETP=1), so adding CP does NOT force EP down.
Measured:
- TP8/PP4/CP2/EP16/DP1 fits a single 96k sequence (128k OOMs) -- CP2 roughly
  doubles the single-sequence ceiling (~40-48k -> ~96k) while keeping PP4's
  low weights and baseline expert memory. Best long-context config.
- TP8/PP2/CP4/EP32/DP1 is valid and loads but still OOMs at 128k: dropping to
  PP2 to free GPUs for CP4 doubles the weights and eats the budget CP frees.
So the 60k+-30k distribution is mostly trainable with PP4/CP2 (clamp ~96k,
~10% truncated) at the cost of DP->1; the full 131k tail still OOMs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@dyurk-lila

dyurk-lila commented Jun 25, 2026

Copy link
Copy Markdown

Heads up — I think this PR fixes the CUDA-IPC weight-sync path but leaves a sibling bug on the non-colocated NCCL broadcast path (colocate_all=false, weight_sync_backend=nccl), with the same coherent-but-garbage / reward-0 symptom.

The new extract_weights() regroups each bucket's params by output dtype — it yields one bf16 chunk, then a separate fp32 chunk for the routing-critical params (e_score_correction_bias via _FP32_SYNC_SUFFIXES / _sync_dtype_for). But get_weight_metadata() (same class) still appends names/dtype_names/shapes in raw export_hf_weights order, with no dtype regrouping — despite its comment claiming "same order as extract_weights".

On the broadcast new-inference path (broadcast_strategy._send_chunks_vllm_nativeNCCLWeightTransferEngine.trainer_send_weights), the producer packs the byte stream in extract_weights() order, but the receiver slices it using get_weight_metadata() order. On any dtype-mixed bucket (a fp32 router bias interleaved among bf16 params — which happens with the default 1.0 GB bucket threshold, since the router bias is a regular non-grouped task that shares a bucket with bf16 params), the two orders desync at the first dtype switch, and every param after it loads under the wrong name. Silent: byte totals still match, no crash.

CUDA-IPC is immune because it derives metadata per-chunk directly from each already-regrouped chunk (cuda_ipc_strategy.py), so it never reads the separately-built flat metadata. That's why this PR's IPC fix doesn't cover it.

Fix is to mirror the same per-bucket dtype regrouping in get_weight_metadata:

# in the bucketing branch of get_weight_metadata, replacing the raw-order append:
for index_group in self.bucket_index_groups:
    bucket_tasks = [fresh_tasks[i] for i in index_group]
    # Mirror extract_weights' per-bucket dtype regrouping EXACTLY: the NCCL
    # broadcast path packs the byte stream in extract_weights() chunk order but
    # slices it on the receiver using this metadata's order, so the two must
    # agree positionally. Raw export order desyncs on any dtype-mixed bucket.
    groups = {}  # out_dtype -> (names, dtype_names, shapes)
    for name, tensor in self.bridge.export_hf_weights(
        self.actor_module, show_progress=False, conversion_tasks=bucket_tasks,
    ):
        out_dtype = _sync_dtype_for(name, dtype, tensor.dtype)
        g = groups.setdefault(out_dtype, ([], [], []))
        g[0].append(name)
        g[1].append(str(out_dtype).split(".")[-1])
        g[2].append(list(tensor.shape))
        del tensor
    for g_names, g_dtype_names, g_shapes in groups.values():
        names.extend(g_names); dtype_names.extend(g_dtype_names); shapes.extend(g_shapes)

Verified: with the IPC fix from this PR but the broadcast path unpatched, a non-colocated GRPO run on Nemotron-3-Ultra-550B at EP16/PP4 reproduces the garbage-generation/reward-0 symptom; the regrouping above resolves it. Happy to send a follow-up PR if useful.


Update (corrected fix — the mirror approach above is fragile):

Shipping the get_weight_metadata mirror exactly as written above fixed the garbage-generation symptom on smaller models, but on Nemotron-3-Ultra-550B at EP16/PP4 it was not enough — the run progressed past init and rollouts, then deadlocked in the first weight-sync: a NCCL collective hung at a deep SeqNum (the EP16 expert ALLGATHER, NumelOut = 16 × NumelIn) until the 1,800,000 ms watchdog fired. Thousands of collectives completed first, so it's a positional desync at one specific tensor, not a rendezvous failure. Stack: broadcast_strategy._send_chunks_vllm_nativeweight_iteratorbroadcast_to_inference_engines.

The lesson: keeping two independently-authored copies of the per-bucket dtype regrouping (one in get_weight_metadata, one in extract_weights) is brittle — they're equal only by inspection, and each still re-runs export_hf_weights separately, so any drift in iteration/grouping between the two passes desyncs the broadcast byte stream from the receiver's slice plan. On a hybrid MoE with many dtype-mixed expert/router buckets, that drift turns a silent garbage bug into a hard hang.

The robust fix is to make both passes consume one chunk-plan generator, so the chunk count/order is identical by construction:

class MegatronWeightExtractor(WeightExtractor):
    def _iter_chunk_plan(self, dtype, *, materialize):
        """Single source of truth for the per-(bucket, out-dtype) chunk order.
        get_weight_metadata() calls with materialize=False (discards tensors so
        it never holds the full model on rank 0); extract_weights() calls with
        materialize=True. Both walk the SAME grouping, so the broadcast byte
        stream and the receiver's slice plan cannot desync."""
        self._ensure_buckets_initialized()
        device = torch.cuda.current_device()
        fresh_tasks = self.bridge.get_conversion_tasks(self.actor_module)
        for index_group in self.bucket_index_groups:
            bucket_tasks = [fresh_tasks[i] for i in index_group]
            groups = {}  # out_dtype -> (names, dtypes, shapes, tensors)
            for name, tensor in self.bridge.export_hf_weights(
                self.actor_module, show_progress=False, conversion_tasks=bucket_tasks,
            ):
                out_dtype = _sync_dtype_for(name, dtype, tensor.dtype)
                t = tensor.to(device=device, dtype=out_dtype, non_blocking=True) if materialize else None
                g = groups.setdefault(out_dtype, ([], [], [], []))
                g[0].append(name); g[1].append(out_dtype); g[2].append(list(tensor.shape)); g[3].append(t)
                if not materialize:
                    del tensor
            for names, dtypes, shapes, tensors in groups.values():
                yield names, dtypes, shapes, tensors

    def get_weight_metadata(self, dtype):
        names, dtype_names, shapes = [], [], []
        for c_names, c_dtypes, c_shapes, _ in self._iter_chunk_plan(dtype, materialize=False):
            names += c_names
            dtype_names += [str(d).split(".")[-1] for d in c_dtypes]
            shapes += c_shapes
        return {"names": names, "dtype_names": dtype_names, "shapes": shapes}

    def extract_weights(self, dtype):
        for names, dtypes, shapes, tensors in self._iter_chunk_plan(dtype, materialize=True):
            if tensors:
                yield WeightChunk(names=names, dtypes=[str(d) for d in dtypes], shapes=shapes, tensors=tensors)

Cheap insurance so a residual desync is a loud error instead of a 30-min hang: have weight_iterator record the streamed names and, after the send, assert they match weight_metadata["names"] (length + first-mismatch index).

Validated: with the single-source-of-truth version (fp32 router precision retained), the first non-colocated weight-sync on Nemotron-3-Ultra-550B EP16/PP4 completes in ~214 s instead of hanging. Happy to send this as a follow-up PR.

erictang000 and others added 2 commits July 1, 2026 23:42
… + varlen full-ctx harness

Re-tune of glm47_355b_128k_megatron_ablation after peak-mem improvements + the switch to
token-based dynamic micro-batching (trainer.max_tokens_per_microbatch).

- ai_docs/glm47_355b_128k_max_tokens_throughput_0702.md: full writeup (Stages 1-3).
- examples/train/megatron/stage_glm47.py: stage GLM-4.7 (~667GB) to node-local NVMe on all
  nodes (subprocess watchdog + resume — a single snapshot_download hung 6/8 nodes).
- examples/train_scripts/full_context/run_full_ctx_glm47_355b.sh: parameterized 355B full-ctx
  benchmark harness (all TP/PP/CP/EP/grad/recompute/batch/max_tokens knobs env-overridable).
- full_context/{main_full_ctx,trainer_full_ctx}.py: mature peak-mem + step-timing logging,
  stubbed inference, and Stage-3 variable-length dummy batch (trainer.dummy_variable_length).

Result (FP32 grads): winner TP8/PP4/CP2/EP16/ETP1/DP1. Realistic 256-sample varlen batch
(avg ~70K) ~8.9K tok/s/cluster (139 tok/s/GPU) at max_tokens=256K, peak 66GB. all-128K
train_batch=32 ~7.0K tok/s. max_tokens sweet spot 256K (128K/GPU); 384K OOMs at full pipeline.
DP2/DP4 do not help under FP32 grads (grad-reduce 2x/4x bytes; PP2 OOMs) — need bf16 grads.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants