[megatron] Enable Nemotron-3-Ultra-550B GRPO RL + fix multi-rank (EP>16/PP>2) weight sync#1816
[megatron] Enable Nemotron-3-Ultra-550B GRPO RL + fix multi-rank (EP>16/PP>2) weight sync#1816erictang000 wants to merge 5 commits into
Conversation
…ht sync Adds an end-to-end full-finetuning GRPO recipe for NVIDIA-Nemotron-3-Ultra-550B (NemotronH hybrid Mamba2+attention, latent MoE, reasoning) colocated with vLLM on 8x8 H200 (EFA), plus the weight-sync/reload correctness fixes it depends on. Validated: avg_raw_reward ~0.9, GSM8K eval ~0.94, grad_norm > 0. Core fix (general, affects any MoE synced at EP>16 / PP>2): the CUDA-IPC weight transport sent only rank-0's per-param slicing metadata, but each Megatron rank packs its own (per-rank-divergent) buffer. Each vLLM worker rebuilt its own GPU's buffer yet sliced it with rank-0's metadata -> correct bytes loaded under wrong names -> coherent- but-garbage generations and reward stuck at 0. Now sends per-GPU metadata and each worker slices its own buffer (cuda_ipc_strategy.py, new_inference_worker_wrap.py). Also included: - Preserve fp32 for the MoE router bias (gate.e_score_correction_bias) through sync; bf16 ULP at its ~25-57 magnitude collapses the per-expert offsets and corrupts routing. - Guard vLLM layerwise-reload get_numel_loaded (cf. vllm-project/vllm#44814): the composed-weight-loader double-count finalizes a layer early and drops Mamba mixer.D (uninitialized -> NaN). Mirrors the existing conv_weights reload workaround. - Forward HF_*/cache dirs and SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S to Ray worker actors (prepare_runtime_environment + GPU-CI conftest). - Reasoning-aware GSM8K reward: strip the <think> trace, score strict `#### <n>` else last-number with comma/$ normalization. - Nemotron-Ultra logprob round-trip test params (EP16/PP2 baseline; EP32/PP4 regressions). - Example recipe (run_megatron_nemotron_ultra.sh), staging helper, and README. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for training the Nemotron-3-Ultra-550B model using GRPO RL on GSM8K with Megatron and vLLM. Key changes include a new multi-node launch script, a staging script, robust GSM8K reward parsing for reasoning models, and critical fixes for CUDA-IPC weight synchronization, vLLM layerwise-reload, and native fp32 precision syncing for MoE router biases. The review feedback suggests improving the number normalization function to handle float representations of integers, simplifying regex group extraction in the staging script, and double-quoting the arguments in the shell script to prevent word splitting.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _norm_num(s: str) -> str: | ||
| """Normalize a parsed number to compare against the (comma-free integer) ground truth.""" | ||
| return s.strip().rstrip(".").replace(",", "").replace("$", "") |
There was a problem hiding this comment.
The current implementation of _norm_num does not normalize float representations of integers (e.g., '72.0' or '72.'). If the model outputs a decimal representation of an integer, the comparison against the ground truth (which is typically a clean integer string like '72') will fail, leading to a reward of 0.0 instead of 1.0.
We can make this more robust by attempting to parse the value as a float and checking if it represents an integer.
| def _norm_num(s: str) -> str: | |
| """Normalize a parsed number to compare against the (comma-free integer) ground truth.""" | |
| return s.strip().rstrip(".").replace(",", "").replace("$", "") | |
| def _norm_num(s: str) -> str: | |
| """Normalize a parsed number to compare against the (comma-free integer) ground truth.""" | |
| val = s.strip().rstrip(".").replace(",", "").replace("$", "") | |
| try: | |
| f_val = float(val) | |
| if f_val.is_integer(): | |
| return str(int(f_val)) | |
| except ValueError: | |
| pass | |
| return val |
|
|
||
| def to_row(example, idx, split): | ||
| q = example["question"] | ||
| sol = re.search(r"#### (\-?[0-9\.\,]+)", example["answer"]).group(0).split("#### ")[1].replace(",", "") |
There was a problem hiding this comment.
Using group(0) followed by .split('#### ')[1] is redundant and can be simplified by directly accessing the captured group group(1). Additionally, wrapping this in a safe check prevents potential AttributeError if re.search ever returns None.
match = re.search(r"#### (\-?[0-9\.\,]+)", example["answer"])
sol = match.group(1).replace(",", "") if match else ""| trainer.max_ckpts_to_keep=3 \ | ||
| trainer.ckpt_interval=20 \ | ||
| trainer.ckpt_path="$HOME/ckpts/gsm8k_nemotron_ultra_ckpt" \ | ||
| $@ |
…dings
Adds a sweep harness (real Megatron fwd+bwd on fabricated rollouts, no vLLM
generation) to map, on 64xH200, the max max_tokens_per_microbatch and the
parallelism (TP/PP/CP/EP/DP) that maximizes full-FT GRPO training throughput
for NVIDIA-Nemotron-3-Ultra-550B, plus a long-context (variable-length) study.
Findings (examples/train/megatron/NEMOTRON_ULTRA_THROUGHPUT.md):
- Max MTPM ~= 64k tokens/microbatch at the validated TP8/PP4/EP16/DP2 config.
- Highest throughput: TP8/PP2/EP32/DP4 (~8.5k tok/s, +11% over PP4/DP2) for
short/medium seqs; config space is pinned (PP8 invalid for 108 layers, EP8
OOMs, TP4 doubles activations via sequence parallelism).
- Long context: single-sequence ceiling ~40-48k tokens (CP1/PP4/DP2); CP gives
little net benefit (CP=2 forces PP2 whose 2x weights cancel the savings). Long
seqs are more throughput-efficient per token (~12k tok/s at ~39k mean).
New files:
- examples/train_scripts/full_context/{trainer_ultra_sweep,main_ultra_sweep,analyze_sweep}.py
- examples/train/megatron/run_ultra_sweep.sh
- examples/train/megatron/NEMOTRON_ULTRA_THROUGHPUT.md
worker.py: get_cuda_memory() now also returns max_allocated/max_reserved
high-water marks (capture in-step peak even when queried after offload).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Follow-up to the throughput sweep. CP composes with EP in Megatron-Core (EP divides TP*CP*DP, with ETP=1), so adding CP does NOT force EP down. Measured: - TP8/PP4/CP2/EP16/DP1 fits a single 96k sequence (128k OOMs) -- CP2 roughly doubles the single-sequence ceiling (~40-48k -> ~96k) while keeping PP4's low weights and baseline expert memory. Best long-context config. - TP8/PP2/CP4/EP32/DP1 is valid and loads but still OOMs at 128k: dropping to PP2 to free GPUs for CP4 doubles the weights and eats the budget CP frees. So the 60k+-30k distribution is mostly trainable with PP4/CP2 (clamp ~96k, ~10% truncated) at the cost of DP->1; the full 131k tail still OOMs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Heads up — I think this PR fixes the CUDA-IPC weight-sync path but leaves a sibling bug on the non-colocated NCCL broadcast path ( The new On the broadcast new-inference path ( CUDA-IPC is immune because it derives metadata per-chunk directly from each already-regrouped chunk ( Fix is to mirror the same per-bucket dtype regrouping in # in the bucketing branch of get_weight_metadata, replacing the raw-order append:
for index_group in self.bucket_index_groups:
bucket_tasks = [fresh_tasks[i] for i in index_group]
# Mirror extract_weights' per-bucket dtype regrouping EXACTLY: the NCCL
# broadcast path packs the byte stream in extract_weights() chunk order but
# slices it on the receiver using this metadata's order, so the two must
# agree positionally. Raw export order desyncs on any dtype-mixed bucket.
groups = {} # out_dtype -> (names, dtype_names, shapes)
for name, tensor in self.bridge.export_hf_weights(
self.actor_module, show_progress=False, conversion_tasks=bucket_tasks,
):
out_dtype = _sync_dtype_for(name, dtype, tensor.dtype)
g = groups.setdefault(out_dtype, ([], [], []))
g[0].append(name)
g[1].append(str(out_dtype).split(".")[-1])
g[2].append(list(tensor.shape))
del tensor
for g_names, g_dtype_names, g_shapes in groups.values():
names.extend(g_names); dtype_names.extend(g_dtype_names); shapes.extend(g_shapes)Verified: with the IPC fix from this PR but the broadcast path unpatched, a non-colocated GRPO run on Nemotron-3-Ultra-550B at EP16/PP4 reproduces the garbage-generation/reward-0 symptom; the regrouping above resolves it. Happy to send a follow-up PR if useful. Update (corrected fix — the mirror approach above is fragile): Shipping the The lesson: keeping two independently-authored copies of the per-bucket dtype regrouping (one in The robust fix is to make both passes consume one chunk-plan generator, so the chunk count/order is identical by construction: class MegatronWeightExtractor(WeightExtractor):
def _iter_chunk_plan(self, dtype, *, materialize):
"""Single source of truth for the per-(bucket, out-dtype) chunk order.
get_weight_metadata() calls with materialize=False (discards tensors so
it never holds the full model on rank 0); extract_weights() calls with
materialize=True. Both walk the SAME grouping, so the broadcast byte
stream and the receiver's slice plan cannot desync."""
self._ensure_buckets_initialized()
device = torch.cuda.current_device()
fresh_tasks = self.bridge.get_conversion_tasks(self.actor_module)
for index_group in self.bucket_index_groups:
bucket_tasks = [fresh_tasks[i] for i in index_group]
groups = {} # out_dtype -> (names, dtypes, shapes, tensors)
for name, tensor in self.bridge.export_hf_weights(
self.actor_module, show_progress=False, conversion_tasks=bucket_tasks,
):
out_dtype = _sync_dtype_for(name, dtype, tensor.dtype)
t = tensor.to(device=device, dtype=out_dtype, non_blocking=True) if materialize else None
g = groups.setdefault(out_dtype, ([], [], [], []))
g[0].append(name); g[1].append(out_dtype); g[2].append(list(tensor.shape)); g[3].append(t)
if not materialize:
del tensor
for names, dtypes, shapes, tensors in groups.values():
yield names, dtypes, shapes, tensors
def get_weight_metadata(self, dtype):
names, dtype_names, shapes = [], [], []
for c_names, c_dtypes, c_shapes, _ in self._iter_chunk_plan(dtype, materialize=False):
names += c_names
dtype_names += [str(d).split(".")[-1] for d in c_dtypes]
shapes += c_shapes
return {"names": names, "dtype_names": dtype_names, "shapes": shapes}
def extract_weights(self, dtype):
for names, dtypes, shapes, tensors in self._iter_chunk_plan(dtype, materialize=True):
if tensors:
yield WeightChunk(names=names, dtypes=[str(d) for d in dtypes], shapes=shapes, tensors=tensors)Cheap insurance so a residual desync is a loud error instead of a 30-min hang: have Validated: with the single-source-of-truth version (fp32 router precision retained), the first non-colocated weight-sync on Nemotron-3-Ultra-550B EP16/PP4 completes in ~214 s instead of hanging. Happy to send this as a follow-up PR. |
… + varlen full-ctx harness
Re-tune of glm47_355b_128k_megatron_ablation after peak-mem improvements + the switch to
token-based dynamic micro-batching (trainer.max_tokens_per_microbatch).
- ai_docs/glm47_355b_128k_max_tokens_throughput_0702.md: full writeup (Stages 1-3).
- examples/train/megatron/stage_glm47.py: stage GLM-4.7 (~667GB) to node-local NVMe on all
nodes (subprocess watchdog + resume — a single snapshot_download hung 6/8 nodes).
- examples/train_scripts/full_context/run_full_ctx_glm47_355b.sh: parameterized 355B full-ctx
benchmark harness (all TP/PP/CP/EP/grad/recompute/batch/max_tokens knobs env-overridable).
- full_context/{main_full_ctx,trainer_full_ctx}.py: mature peak-mem + step-timing logging,
stubbed inference, and Stage-3 variable-length dummy batch (trainer.dummy_variable_length).
Result (FP32 grads): winner TP8/PP4/CP2/EP16/ETP1/DP1. Realistic 256-sample varlen batch
(avg ~70K) ~8.9K tok/s/cluster (139 tok/s/GPU) at max_tokens=256K, peak 66GB. all-128K
train_batch=32 ~7.0K tok/s. max_tokens sweet spot 256K (128K/GPU); 384K OOMs at full pipeline.
DP2/DP4 do not help under FP32 grads (grad-reduce 2x/4x bytes; PP2 OOMs) — need bf16 grads.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
What
Enables end-to-end full-finetuning GRPO RL of NVIDIA-Nemotron-3-Ultra-550B-A55B (NemotronH hybrid Mamba2 + attention, latent MoE with 512 experts, reasoning model) colocated with vLLM on 8× nodes of 8×H200 (64 GPUs, EFA) — and fixes the weight-sync/reload correctness bugs that block it (and other large MoE models).
Validated: trains end-to-end with the included recipe —
avg_raw_reward ≈ 0.9, GSM8Keval ≈ 0.94,grad_norm > 0. Megatron mesh TP8 / PP4 / EP16 / ETP1 (DP2); vLLM TP8 × PP4.Replication guide:
examples/train/megatron/README_nemotron_ultra.md.Why (root cause)
vLLM produced coherent-looking garbage after every weight sync → all rewards 0 → no learning. The bridge export was proven bit-correct (0 mismatches over 108k tensors), which localized the bug to the CUDA-IPC weight transport:
This is general to any MoE synced at EP>16/PP>2, not just Nemotron.
Changes
cuda_ipc_strategy.py,new_inference_worker_wrap.py)gate.e_score_correction_bias) preserved through sync — bf16 ULP at its ~25–57 magnitude collapses the tiny per-expert offsets (std ~7e-4) and corrupts routing. (megatron_worker.py)composed_weight_loaderdouble-count finalizes a layer early and drops Mambamixer.D(uninitialized → NaN). Guarded monkeypatch mirroring the existingconv_weightsworkaround; remove once on a vLLM that includes #44814. (layerwise_reload.py)HF_*/cache dirs andSKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S→ Ray worker actors. (utils.py, GPU-CIconftest.py)<think>, score strict#### <n>else last-number with comma/$ normalization. (skyrl_gym/envs/gsm8k/env.py)run_megatron_nemotron_ultra.sh,stage_nemotron_ultra.py(model+data staging incl.chat_template.jinja),README_nemotron_ultra.md.h100-marked / gated on the 550B model.Testing
test_logprobs_matching_roundtrip[nemotron3-ultra…]round-trip on 8×8 H200 (EP16/PP2 and EP16/PP4 pass; pre-sync Megatron↔vLLM ~0.06, post-sync ~0.15–0.30).Note
The
get_numel_loadedandconv_weightsreload patches are temporary vLLM workarounds (pending #44814 and #42481); the per-GPU-metadata transport fix and the rest are permanent.🤖 Generated with Claude Code