diff --git a/skyrl/backends/skyrl_train/distributed/megatron/model_utils.py b/skyrl/backends/skyrl_train/distributed/megatron/model_utils.py index 2d37fbf0d3..626edc1b87 100644 --- a/skyrl/backends/skyrl_train/distributed/megatron/model_utils.py +++ b/skyrl/backends/skyrl_train/distributed/megatron/model_utils.py @@ -214,10 +214,16 @@ def backward( seq_size = int(vocab_parallel_logits.shape[1]) num_chunks = (seq_size + chunk_size - 1) // chunk_size - all_grad_input = [] - batch_size = int(vocab_parallel_logits.shape[0]) + # Stream chunk grads into a preallocated buffer instead of keeping every + # chunk alive until torch.cat. Each chunk owns one contiguous seq slice. + grad_input = torch.empty( + (batch_size, seq_size, partition_vocab_size), + dtype=torch.float32, + device=vocab_parallel_logits.device, + ) + for chunk_idx in range(num_chunks): chunk_start = chunk_idx * chunk_size chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) @@ -251,15 +257,14 @@ def backward( flat_chosen = flat_idx.masked_select(valid_mask.reshape(-1)) + chunk_masked_target.masked_select(valid_mask) # `neg` is zero-copy; the subsequent mul_ writes in place. - grad_input = softmax_output.neg_() - grad_input.mul_(chunk_grad_output.unsqueeze(-1)) + chunk_grad_input = softmax_output.neg_() + chunk_grad_input.mul_(chunk_grad_output.unsqueeze(-1)) grad_output_selected = chunk_grad_output.masked_select(valid_mask) - grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) - - all_grad_input.append(grad_input) + chunk_grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected) - grad_input = torch.cat(all_grad_input, dim=1) + # Write this chunk into its non-overlapping sequence slice. + grad_input[:, chunk_start:chunk_end, :] = chunk_grad_input # if you add an argument to the forward method, then you must add a corresponding None here return grad_input, None, None, None, None, None, None diff --git a/tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py b/tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py new file mode 100644 index 0000000000..99de88fa44 --- /dev/null +++ b/tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py @@ -0,0 +1,194 @@ +"""CPU parity tests for streamed ``ChunkedDistributedLogprob.backward``. + +The backward path now writes each chunk directly into a preallocated fp32 +``[B, S, V//TP]`` grad buffer. These tests compare against ``DistributedLogprob`` +on CPU; TP>1 coverage lives in +``tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py``. + +Run with: + uv run --isolated --extra skyrl-train --extra dev -- pytest -s \ + tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py +""" + +import os +import sys +from types import ModuleType + +import pytest +import torch +import torch.distributed as dist + +from skyrl.backends.skyrl_train.distributed.utils import get_free_port + +# Stub megatron so CPU CI can import model_utils without megatron-core. +# The fixture restores prior modules, leaving GPU lanes with real megatron intact. + +_MEGATRON_MODULES = [ + "megatron", + "megatron.core", + "megatron.core.parallel_state", +] + +_mock_modules: dict[str, ModuleType] = {} +for _name in _MEGATRON_MODULES: + _mock_modules[_name] = ModuleType(_name) + +_mock_modules["megatron.core"].parallel_state = _mock_modules["megatron.core.parallel_state"] + + +@pytest.fixture(scope="module", autouse=True) +def _stub_megatron_modules(): + """Install the mock ``megatron`` modules for this module only.""" + saved = {_name: sys.modules.get(_name) for _name in _MEGATRON_MODULES} + for _name in _MEGATRON_MODULES: + sys.modules[_name] = _mock_modules[_name] + try: + yield + finally: + for _name in _MEGATRON_MODULES: + if saved[_name] is None: + sys.modules.pop(_name, None) + else: + sys.modules[_name] = saved[_name] + + +@pytest.fixture(scope="module") +def tp_group(): + """Single-rank gloo TP group; only destroy it if this fixture created it.""" + initialized_here = False + if not dist.is_initialized(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + dist.init_process_group(backend="gloo", rank=0, world_size=1) + initialized_here = True + yield dist.group.WORLD + if initialized_here and dist.is_initialized(): + dist.destroy_process_group() + + +def _backward_grad(func_cls, logits, target, vocab_start, vocab_end, tp_group, *, chunk_size=None): + """Return the input grad using a non-uniform upstream gradient.""" + leaf = logits.detach().clone().requires_grad_(True) + if chunk_size is None: + out = func_cls.apply(leaf, target, vocab_start, vocab_end, tp_group, False) + else: + out = func_cls.apply(leaf, target, vocab_start, vocab_end, chunk_size, tp_group, False) + grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape) + out.backward(grad_seed) + return leaf.grad.detach() + + +@pytest.mark.parametrize("chunk_size", [1, 3, 7, 16, 32, 64]) +@pytest.mark.parametrize("with_oov_targets", [False, True]) +def test_streamed_backward_bit_identical_to_non_chunked(tp_group, chunk_size, with_oov_targets): + """Chunked backward matches non-chunked exactly across chunk sizes and OOV targets.""" + from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( + ChunkedDistributedLogprob, + DistributedLogprob, + ) + + device = torch.device("cpu") + torch.manual_seed(0) + + batch_size = 4 + # Covers ragged chunks and chunk_size > seq_len. + seq_len = 30 + vocab_size = 256 + + target_high = vocab_size + 64 if with_oov_targets else vocab_size + + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0 + target = torch.randint(0, target_high, (batch_size, seq_len), device=device, dtype=torch.long) + + grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group) + grad_chunk = _backward_grad( + ChunkedDistributedLogprob, + logits, + target, + 0, + vocab_size, + tp_group, + chunk_size=chunk_size, + ) + + assert grad_chunk.shape == grad_ref.shape == logits.shape + assert grad_chunk.dtype == torch.float32 + assert torch.equal(grad_chunk, grad_ref), "streamed chunked grad must be bit-identical to non-chunked grad" + + +@pytest.mark.parametrize( + "case", + [ + # (batch, seq_len, vocab, chunk_size, mask_mode) + # mask_mode: "default" (mixed), "all_in" (no OOV), "all_out" (all OOV) + pytest.param((1, 1, 64, 4, "default"), id="seq1"), + pytest.param((2, 8, 64, 32, "all_in"), id="all_in_vocab"), + pytest.param((2, 8, 64, 32, "all_out"), id="all_out_vocab"), + pytest.param((2, 9, 8, 4, "default"), id="ragged_tiny_vocab"), + ], +) +def test_streamed_backward_edge_cases(tp_group, case): + """Covers short sequences, mask extremes, ragged chunks, and tiny vocab.""" + from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( + ChunkedDistributedLogprob, + DistributedLogprob, + ) + + batch_size, seq_len, vocab_size, chunk_size, mask_mode = case + device = torch.device("cpu") + torch.manual_seed(1) + + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0 + if mask_mode == "all_out": + target = torch.full((batch_size, seq_len), vocab_size + 5, device=device, dtype=torch.long) + else: + target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) + + grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group) + grad_chunk = _backward_grad( + ChunkedDistributedLogprob, + logits, + target, + 0, + vocab_size, + tp_group, + chunk_size=chunk_size, + ) + + assert torch.equal(grad_chunk, grad_ref), "streamed chunked grad must be bit-identical to non-chunked grad" + + +def test_streamed_backward_covers_sequence_without_overlap(tp_group): + """Prime seq length checks that ragged tiling writes every buffer slice.""" + from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( + ChunkedDistributedLogprob, + DistributedLogprob, + ) + + device = torch.device("cpu") + torch.manual_seed(2) + + batch_size = 3 + seq_len = 17 + vocab_size = 128 + chunk_size = 5 + + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0 + target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long) + + grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group) + grad_chunk = _backward_grad( + ChunkedDistributedLogprob, + logits, + target, + 0, + vocab_size, + tp_group, + chunk_size=chunk_size, + ) + + assert grad_chunk.shape == grad_ref.shape == logits.shape + # Exact match catches unwritten torch.empty slices without relying on NaN/inf. + assert torch.equal(grad_chunk, grad_ref), "every sequence slice of the preallocated buffer must be written" diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py b/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py index cd3fc8630d..845a14aaac 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward.py @@ -12,7 +12,7 @@ ChunkedDistributedLogprob, DistributedLogprob, ) -from skyrl.train.utils.utils import get_free_port +from skyrl.backends.skyrl_train.distributed.utils import get_free_port @pytest.fixture(scope="module") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py b/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py new file mode 100644 index 0000000000..72cfa5f134 --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py @@ -0,0 +1,146 @@ +"""TP>1 parity for streamed ``ChunkedDistributedLogprob.backward``. + +Spawns NCCL ranks, shards a shared full-vocab problem, and compares each local +grad shard with a single-process fp32 reference. + +Requires ``TP`` free GPUs. It will NOT run on a CPU-only / macOS dev box. + +Run with (>=2 free GPUs; the ``tp_size=4`` case is skipped unless 4 are present): + uv run --isolated --extra dev --extra megatron -- \ + pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py +""" + +import os + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +# Every rank rebuilds the same full-vocab problem. +_SEED = 0 +_BATCH = 4 +_SEQ_LEN = 30 +_VOCAB = 256 +_CHUNK_SIZE = 7 + + +def _reference_full_grad(logits_fp32: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Full-vocab reference grad for chosen-token logprobs.""" + leaf = logits_fp32.detach().clone().requires_grad_(True) + log_probs = torch.log_softmax(leaf, dim=-1) + chosen = torch.gather(log_probs, -1, target.unsqueeze(-1)).squeeze(-1) + grad_seed = torch.linspace(0.5, 1.5, steps=chosen.numel(), device=chosen.device, dtype=chosen.dtype).reshape( + chosen.shape + ) + (grad_seed * chosen).sum().backward() + return leaf.grad.detach() + + +def _set_ci_nccl_env(): + """Apply the gpu_ci NCCL env that ``mp.spawn`` children do not inherit.""" + from skyrl.train.utils.utils import run_p2p_access_check + + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "0" + # Avoid peer_access_supported here; it would spin up Ray. + if not run_p2p_access_check(): + os.environ["NCCL_P2P_DISABLE"] = "1" + os.environ["NCCL_SHM_DISABLE"] = "1" + + +def _tp_worker(rank: int, world_size: int, master_port: str, result_path: str): + """One TP rank: shard the vocab, run chunked forward+backward, save the local grad.""" + import megatron.core.parallel_state as mpu + + from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( + ChunkedDistributedLogprob, + ) + + torch.cuda.set_device(rank) + # Set before init_process_group; spawned children miss the conftest runtime_env. + _set_ci_nccl_env() + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = master_port + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + mpu.initialize_model_parallel(tensor_model_parallel_size=world_size) + tp_group = mpu.get_tensor_model_parallel_group() + + device = torch.device("cuda", rank) + torch.manual_seed(_SEED) + + # Build the same full-vocab problem on every rank, then keep this rank's shard. + assert _VOCAB % world_size == 0 + partition = _VOCAB // world_size + vocab_start = rank * partition + vocab_end = vocab_start + partition + + logits_full = (torch.randn(_BATCH, _SEQ_LEN, _VOCAB, dtype=torch.float32, device=device) * 2.0).contiguous() + target = torch.randint(0, _VOCAB, (_BATCH, _SEQ_LEN), device=device, dtype=torch.long) + + leaf = logits_full[:, :, vocab_start:vocab_end].detach().clone().requires_grad_(True) + out = ChunkedDistributedLogprob.apply(leaf, target, vocab_start, vocab_end, _CHUNK_SIZE, tp_group, False) + grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape) + out.backward(grad_seed) + + # Rank 0 later builds the reference from these shared inputs. + torch.save( + { + "rank": rank, + "vocab_start": vocab_start, + "vocab_end": vocab_end, + "grad_local": leaf.grad.detach().cpu(), + "logits_full": logits_full.detach().cpu(), + "target": target.detach().cpu(), + "logprob_out": out.detach().cpu(), + }, + f"{result_path}.{rank}", + ) + + mpu.destroy_model_parallel() + dist.destroy_process_group() + + +@pytest.mark.megatron +@pytest.mark.parametrize("tp_size", [2, 4]) +def test_streamed_chunked_backward_matches_reference_at_tp(tmp_path, tp_size): + """Per-rank streamed grads match the full-vocab reference slices.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA required for the vocab-parallel backward") + if torch.cuda.device_count() < tp_size: + pytest.skip(f"requires {tp_size} GPUs, found {torch.cuda.device_count()}") + + from skyrl.backends.skyrl_train.distributed.utils import get_free_port + + master_port = str(get_free_port()) + result_path = str(tmp_path / "tp_grad") + + mp.spawn(_tp_worker, args=(tp_size, master_port, result_path), nprocs=tp_size, join=True) + + shards = [torch.load(f"{result_path}.{rank}") for rank in range(tp_size)] + + # Sanity-check that the spawned ranks used the same problem. + logits_full = shards[0]["logits_full"] + target = shards[0]["target"] + for shard in shards[1:]: + assert torch.equal(shard["logits_full"], logits_full) + assert torch.equal(shard["target"], target) + + grad_ref = _reference_full_grad(logits_full, target) + + # Rank vocab slices must tile [0, _VOCAB) exactly once. + covered = torch.zeros(_VOCAB, dtype=torch.int64) + for shard in shards: + covered[shard["vocab_start"] : shard["vocab_end"]] += 1 + assert torch.equal(covered, torch.ones(_VOCAB, dtype=torch.int64)), "vocab slices must tile [0, vocab) once" + + # Allow fp32 reduction-order noise from the cross-rank all-reduce. + for shard in shards: + ref_slice = grad_ref[:, :, shard["vocab_start"] : shard["vocab_end"]] + assert shard["grad_local"].shape == ref_slice.shape + torch.testing.assert_close(shard["grad_local"], ref_slice, atol=1e-5, rtol=1e-4)