-
Notifications
You must be signed in to change notification settings - Fork 376
[megatron] Stream ChunkedDistributedLogprob.backward into a preallocated buffer (lower peak memory) #1806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dyurk-lila
wants to merge
3
commits into
NovaSky-AI:main
Choose a base branch
from
dyurk-lila:perf/streaming-chunked-logprob-backward
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[megatron] Stream ChunkedDistributedLogprob.backward into a preallocated buffer (lower peak memory) #1806
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
194 changes: 194 additions & 0 deletions
194
tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| """CPU parity tests for streamed ``ChunkedDistributedLogprob.backward``. | ||
|
|
||
| The backward path now writes each chunk directly into a preallocated fp32 | ||
| ``[B, S, V//TP]`` grad buffer. These tests compare against ``DistributedLogprob`` | ||
| on CPU; TP>1 coverage lives in | ||
| ``tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py``. | ||
|
|
||
| Run with: | ||
| uv run --isolated --extra skyrl-train --extra dev -- pytest -s \ | ||
| tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py | ||
| """ | ||
|
|
||
| import os | ||
| import sys | ||
| from types import ModuleType | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.utils import get_free_port | ||
|
|
||
| # Stub megatron so CPU CI can import model_utils without megatron-core. | ||
| # The fixture restores prior modules, leaving GPU lanes with real megatron intact. | ||
|
|
||
| _MEGATRON_MODULES = [ | ||
| "megatron", | ||
| "megatron.core", | ||
| "megatron.core.parallel_state", | ||
| ] | ||
|
|
||
| _mock_modules: dict[str, ModuleType] = {} | ||
| for _name in _MEGATRON_MODULES: | ||
| _mock_modules[_name] = ModuleType(_name) | ||
|
|
||
| _mock_modules["megatron.core"].parallel_state = _mock_modules["megatron.core.parallel_state"] | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module", autouse=True) | ||
| def _stub_megatron_modules(): | ||
| """Install the mock ``megatron`` modules for this module only.""" | ||
| saved = {_name: sys.modules.get(_name) for _name in _MEGATRON_MODULES} | ||
| for _name in _MEGATRON_MODULES: | ||
| sys.modules[_name] = _mock_modules[_name] | ||
| try: | ||
| yield | ||
| finally: | ||
| for _name in _MEGATRON_MODULES: | ||
| if saved[_name] is None: | ||
| sys.modules.pop(_name, None) | ||
| else: | ||
| sys.modules[_name] = saved[_name] | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def tp_group(): | ||
| """Single-rank gloo TP group; only destroy it if this fixture created it.""" | ||
| initialized_here = False | ||
| if not dist.is_initialized(): | ||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = str(get_free_port()) | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| dist.init_process_group(backend="gloo", rank=0, world_size=1) | ||
| initialized_here = True | ||
| yield dist.group.WORLD | ||
| if initialized_here and dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| def _backward_grad(func_cls, logits, target, vocab_start, vocab_end, tp_group, *, chunk_size=None): | ||
| """Return the input grad using a non-uniform upstream gradient.""" | ||
| leaf = logits.detach().clone().requires_grad_(True) | ||
| if chunk_size is None: | ||
| out = func_cls.apply(leaf, target, vocab_start, vocab_end, tp_group, False) | ||
| else: | ||
| out = func_cls.apply(leaf, target, vocab_start, vocab_end, chunk_size, tp_group, False) | ||
| grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape) | ||
| out.backward(grad_seed) | ||
| return leaf.grad.detach() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("chunk_size", [1, 3, 7, 16, 32, 64]) | ||
| @pytest.mark.parametrize("with_oov_targets", [False, True]) | ||
| def test_streamed_backward_bit_identical_to_non_chunked(tp_group, chunk_size, with_oov_targets): | ||
| """Chunked backward matches non-chunked exactly across chunk sizes and OOV targets.""" | ||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| DistributedLogprob, | ||
| ) | ||
|
|
||
| device = torch.device("cpu") | ||
| torch.manual_seed(0) | ||
|
|
||
| batch_size = 4 | ||
| # Covers ragged chunks and chunk_size > seq_len. | ||
| seq_len = 30 | ||
| vocab_size = 256 | ||
|
|
||
| target_high = vocab_size + 64 if with_oov_targets else vocab_size | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0 | ||
| target = torch.randint(0, target_high, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| grad_chunk = _backward_grad( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| assert grad_chunk.shape == grad_ref.shape == logits.shape | ||
| assert grad_chunk.dtype == torch.float32 | ||
| assert torch.equal(grad_chunk, grad_ref), "streamed chunked grad must be bit-identical to non-chunked grad" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "case", | ||
| [ | ||
| # (batch, seq_len, vocab, chunk_size, mask_mode) | ||
| # mask_mode: "default" (mixed), "all_in" (no OOV), "all_out" (all OOV) | ||
| pytest.param((1, 1, 64, 4, "default"), id="seq1"), | ||
| pytest.param((2, 8, 64, 32, "all_in"), id="all_in_vocab"), | ||
| pytest.param((2, 8, 64, 32, "all_out"), id="all_out_vocab"), | ||
| pytest.param((2, 9, 8, 4, "default"), id="ragged_tiny_vocab"), | ||
| ], | ||
| ) | ||
| def test_streamed_backward_edge_cases(tp_group, case): | ||
| """Covers short sequences, mask extremes, ragged chunks, and tiny vocab.""" | ||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| DistributedLogprob, | ||
| ) | ||
|
|
||
| batch_size, seq_len, vocab_size, chunk_size, mask_mode = case | ||
| device = torch.device("cpu") | ||
| torch.manual_seed(1) | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0 | ||
| if mask_mode == "all_out": | ||
| target = torch.full((batch_size, seq_len), vocab_size + 5, device=device, dtype=torch.long) | ||
| else: | ||
| target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| grad_chunk = _backward_grad( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| assert torch.equal(grad_chunk, grad_ref), "streamed chunked grad must be bit-identical to non-chunked grad" | ||
|
|
||
|
|
||
| def test_streamed_backward_covers_sequence_without_overlap(tp_group): | ||
| """Prime seq length checks that ragged tiling writes every buffer slice.""" | ||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| DistributedLogprob, | ||
| ) | ||
|
|
||
| device = torch.device("cpu") | ||
| torch.manual_seed(2) | ||
|
|
||
| batch_size = 3 | ||
| seq_len = 17 | ||
| vocab_size = 128 | ||
| chunk_size = 5 | ||
|
|
||
| logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0 | ||
| target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) | ||
|
|
||
| grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group) | ||
| grad_chunk = _backward_grad( | ||
| ChunkedDistributedLogprob, | ||
| logits, | ||
| target, | ||
| 0, | ||
| vocab_size, | ||
| tp_group, | ||
| chunk_size=chunk_size, | ||
| ) | ||
|
|
||
| assert grad_chunk.shape == grad_ref.shape == logits.shape | ||
| # Exact match catches unwritten torch.empty slices without relying on NaN/inf. | ||
| assert torch.equal(grad_chunk, grad_ref), "every sequence slice of the preallocated buffer must be written" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
146 changes: 146 additions & 0 deletions
146
tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """TP>1 parity for streamed ``ChunkedDistributedLogprob.backward``. | ||
|
|
||
| Spawns NCCL ranks, shards a shared full-vocab problem, and compares each local | ||
| grad shard with a single-process fp32 reference. | ||
|
|
||
| Requires ``TP`` free GPUs. It will NOT run on a CPU-only / macOS dev box. | ||
|
|
||
| Run with (>=2 free GPUs; the ``tp_size=4`` case is skipped unless 4 are present): | ||
| uv run --isolated --extra dev --extra megatron -- \ | ||
| pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
|
|
||
| # Every rank rebuilds the same full-vocab problem. | ||
| _SEED = 0 | ||
| _BATCH = 4 | ||
| _SEQ_LEN = 30 | ||
| _VOCAB = 256 | ||
| _CHUNK_SIZE = 7 | ||
|
|
||
|
|
||
| def _reference_full_grad(logits_fp32: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
| """Full-vocab reference grad for chosen-token logprobs.""" | ||
| leaf = logits_fp32.detach().clone().requires_grad_(True) | ||
| log_probs = torch.log_softmax(leaf, dim=-1) | ||
| chosen = torch.gather(log_probs, -1, target.unsqueeze(-1)).squeeze(-1) | ||
| grad_seed = torch.linspace(0.5, 1.5, steps=chosen.numel(), device=chosen.device, dtype=chosen.dtype).reshape( | ||
| chosen.shape | ||
| ) | ||
| (grad_seed * chosen).sum().backward() | ||
| return leaf.grad.detach() | ||
|
|
||
|
|
||
| def _set_ci_nccl_env(): | ||
| """Apply the gpu_ci NCCL env that ``mp.spawn`` children do not inherit.""" | ||
| from skyrl.train.utils.utils import run_p2p_access_check | ||
|
|
||
| os.environ["NCCL_CUMEM_ENABLE"] = "0" | ||
| os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" | ||
| os.environ["NVTE_FUSED_ATTN"] = "0" | ||
| # Avoid peer_access_supported here; it would spin up Ray. | ||
| if not run_p2p_access_check(): | ||
| os.environ["NCCL_P2P_DISABLE"] = "1" | ||
| os.environ["NCCL_SHM_DISABLE"] = "1" | ||
|
|
||
|
|
||
| def _tp_worker(rank: int, world_size: int, master_port: str, result_path: str): | ||
| """One TP rank: shard the vocab, run chunked forward+backward, save the local grad.""" | ||
| import megatron.core.parallel_state as mpu | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| ) | ||
|
|
||
| torch.cuda.set_device(rank) | ||
| # Set before init_process_group; spawned children miss the conftest runtime_env. | ||
| _set_ci_nccl_env() | ||
|
|
||
| os.environ["MASTER_ADDR"] = "localhost" | ||
| os.environ["MASTER_PORT"] = master_port | ||
| os.environ["RANK"] = str(rank) | ||
| os.environ["WORLD_SIZE"] = str(world_size) | ||
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
|
|
||
| mpu.initialize_model_parallel(tensor_model_parallel_size=world_size) | ||
| tp_group = mpu.get_tensor_model_parallel_group() | ||
|
|
||
| device = torch.device("cuda", rank) | ||
| torch.manual_seed(_SEED) | ||
|
|
||
| # Build the same full-vocab problem on every rank, then keep this rank's shard. | ||
| assert _VOCAB % world_size == 0 | ||
| partition = _VOCAB // world_size | ||
| vocab_start = rank * partition | ||
| vocab_end = vocab_start + partition | ||
|
|
||
| logits_full = (torch.randn(_BATCH, _SEQ_LEN, _VOCAB, dtype=torch.float32, device=device) * 2.0).contiguous() | ||
| target = torch.randint(0, _VOCAB, (_BATCH, _SEQ_LEN), device=device, dtype=torch.long) | ||
|
|
||
| leaf = logits_full[:, :, vocab_start:vocab_end].detach().clone().requires_grad_(True) | ||
| out = ChunkedDistributedLogprob.apply(leaf, target, vocab_start, vocab_end, _CHUNK_SIZE, tp_group, False) | ||
| grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape) | ||
| out.backward(grad_seed) | ||
|
|
||
| # Rank 0 later builds the reference from these shared inputs. | ||
| torch.save( | ||
| { | ||
| "rank": rank, | ||
| "vocab_start": vocab_start, | ||
| "vocab_end": vocab_end, | ||
| "grad_local": leaf.grad.detach().cpu(), | ||
| "logits_full": logits_full.detach().cpu(), | ||
| "target": target.detach().cpu(), | ||
| "logprob_out": out.detach().cpu(), | ||
| }, | ||
| f"{result_path}.{rank}", | ||
| ) | ||
|
|
||
| mpu.destroy_model_parallel() | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| @pytest.mark.megatron | ||
| @pytest.mark.parametrize("tp_size", [2, 4]) | ||
| def test_streamed_chunked_backward_matches_reference_at_tp(tmp_path, tp_size): | ||
| """Per-rank streamed grads match the full-vocab reference slices.""" | ||
| if not torch.cuda.is_available(): | ||
| pytest.skip("CUDA required for the vocab-parallel backward") | ||
| if torch.cuda.device_count() < tp_size: | ||
| pytest.skip(f"requires {tp_size} GPUs, found {torch.cuda.device_count()}") | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.utils import get_free_port | ||
|
|
||
| master_port = str(get_free_port()) | ||
| result_path = str(tmp_path / "tp_grad") | ||
|
|
||
| mp.spawn(_tp_worker, args=(tp_size, master_port, result_path), nprocs=tp_size, join=True) | ||
|
|
||
| shards = [torch.load(f"{result_path}.{rank}") for rank in range(tp_size)] | ||
|
|
||
| # Sanity-check that the spawned ranks used the same problem. | ||
| logits_full = shards[0]["logits_full"] | ||
| target = shards[0]["target"] | ||
| for shard in shards[1:]: | ||
| assert torch.equal(shard["logits_full"], logits_full) | ||
| assert torch.equal(shard["target"], target) | ||
|
|
||
| grad_ref = _reference_full_grad(logits_full, target) | ||
|
|
||
| # Rank vocab slices must tile [0, _VOCAB) exactly once. | ||
| covered = torch.zeros(_VOCAB, dtype=torch.int64) | ||
| for shard in shards: | ||
| covered[shard["vocab_start"] : shard["vocab_end"]] += 1 | ||
| assert torch.equal(covered, torch.ones(_VOCAB, dtype=torch.int64)), "vocab slices must tile [0, vocab) once" | ||
|
|
||
| # Allow fp32 reduction-order noise from the cross-rank all-reduce. | ||
| for shard in shards: | ||
| ref_slice = grad_ref[:, :, shard["vocab_start"] : shard["vocab_end"]] | ||
| assert shard["grad_local"].shape == ref_slice.shape | ||
| torch.testing.assert_close(shard["grad_local"], ref_slice, atol=1e-5, rtol=1e-4) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.