Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions skyrl/backends/skyrl_train/distributed/megatron/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,16 @@ def backward(
seq_size = int(vocab_parallel_logits.shape[1])
num_chunks = (seq_size + chunk_size - 1) // chunk_size

all_grad_input = []

batch_size = int(vocab_parallel_logits.shape[0])

# Stream chunk grads into a preallocated buffer instead of keeping every
# chunk alive until torch.cat. Each chunk owns one contiguous seq slice.
grad_input = torch.empty(
(batch_size, seq_size, partition_vocab_size),
dtype=torch.float32,
device=vocab_parallel_logits.device,
)

for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size)
Expand Down Expand Up @@ -251,15 +257,14 @@ def backward(
flat_chosen = flat_idx.masked_select(valid_mask.reshape(-1)) + chunk_masked_target.masked_select(valid_mask)

# `neg` is zero-copy; the subsequent mul_ writes in place.
grad_input = softmax_output.neg_()
grad_input.mul_(chunk_grad_output.unsqueeze(-1))
chunk_grad_input = softmax_output.neg_()
chunk_grad_input.mul_(chunk_grad_output.unsqueeze(-1))

grad_output_selected = chunk_grad_output.masked_select(valid_mask)
grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected)

all_grad_input.append(grad_input)
chunk_grad_input.view(-1).scatter_add_(0, flat_chosen, grad_output_selected)

grad_input = torch.cat(all_grad_input, dim=1)
# Write this chunk into its non-overlapping sequence slice.
grad_input[:, chunk_start:chunk_end, :] = chunk_grad_input

# if you add an argument to the forward method, then you must add a corresponding None here
return grad_input, None, None, None, None, None, None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""CPU parity tests for streamed ``ChunkedDistributedLogprob.backward``.

The backward path now writes each chunk directly into a preallocated fp32
``[B, S, V//TP]`` grad buffer. These tests compare against ``DistributedLogprob``
on CPU; TP>1 coverage lives in
``tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py``.

Run with:
uv run --isolated --extra skyrl-train --extra dev -- pytest -s \
tests/backends/skyrl_train/distributed/test_chunked_logprob_backward_streaming.py
"""

import os
import sys
from types import ModuleType

import pytest
import torch
import torch.distributed as dist

from skyrl.backends.skyrl_train.distributed.utils import get_free_port

# Stub megatron so CPU CI can import model_utils without megatron-core.
# The fixture restores prior modules, leaving GPU lanes with real megatron intact.

_MEGATRON_MODULES = [
"megatron",
"megatron.core",
"megatron.core.parallel_state",
]

_mock_modules: dict[str, ModuleType] = {}
for _name in _MEGATRON_MODULES:
_mock_modules[_name] = ModuleType(_name)

_mock_modules["megatron.core"].parallel_state = _mock_modules["megatron.core.parallel_state"]


@pytest.fixture(scope="module", autouse=True)
def _stub_megatron_modules():
"""Install the mock ``megatron`` modules for this module only."""
saved = {_name: sys.modules.get(_name) for _name in _MEGATRON_MODULES}
for _name in _MEGATRON_MODULES:
sys.modules[_name] = _mock_modules[_name]
try:
yield
finally:
for _name in _MEGATRON_MODULES:
if saved[_name] is None:
sys.modules.pop(_name, None)
else:
sys.modules[_name] = saved[_name]


@pytest.fixture(scope="module")
def tp_group():
"""Single-rank gloo TP group; only destroy it if this fixture created it."""
initialized_here = False
if not dist.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
dist.init_process_group(backend="gloo", rank=0, world_size=1)
initialized_here = True
yield dist.group.WORLD
if initialized_here and dist.is_initialized():
dist.destroy_process_group()
Comment thread
dyurk-lila marked this conversation as resolved.


def _backward_grad(func_cls, logits, target, vocab_start, vocab_end, tp_group, *, chunk_size=None):
"""Return the input grad using a non-uniform upstream gradient."""
leaf = logits.detach().clone().requires_grad_(True)
if chunk_size is None:
out = func_cls.apply(leaf, target, vocab_start, vocab_end, tp_group, False)
else:
out = func_cls.apply(leaf, target, vocab_start, vocab_end, chunk_size, tp_group, False)
grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape)
out.backward(grad_seed)
return leaf.grad.detach()


@pytest.mark.parametrize("chunk_size", [1, 3, 7, 16, 32, 64])
@pytest.mark.parametrize("with_oov_targets", [False, True])
def test_streamed_backward_bit_identical_to_non_chunked(tp_group, chunk_size, with_oov_targets):
"""Chunked backward matches non-chunked exactly across chunk sizes and OOV targets."""
from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
ChunkedDistributedLogprob,
DistributedLogprob,
)

device = torch.device("cpu")
torch.manual_seed(0)

batch_size = 4
# Covers ragged chunks and chunk_size > seq_len.
seq_len = 30
vocab_size = 256

target_high = vocab_size + 64 if with_oov_targets else vocab_size

logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0
target = torch.randint(0, target_high, (batch_size, seq_len), device=device, dtype=torch.long)

grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group)
grad_chunk = _backward_grad(
ChunkedDistributedLogprob,
logits,
target,
0,
vocab_size,
tp_group,
chunk_size=chunk_size,
)

assert grad_chunk.shape == grad_ref.shape == logits.shape
assert grad_chunk.dtype == torch.float32
assert torch.equal(grad_chunk, grad_ref), "streamed chunked grad must be bit-identical to non-chunked grad"


@pytest.mark.parametrize(
"case",
[
# (batch, seq_len, vocab, chunk_size, mask_mode)
# mask_mode: "default" (mixed), "all_in" (no OOV), "all_out" (all OOV)
pytest.param((1, 1, 64, 4, "default"), id="seq1"),
pytest.param((2, 8, 64, 32, "all_in"), id="all_in_vocab"),
pytest.param((2, 8, 64, 32, "all_out"), id="all_out_vocab"),
pytest.param((2, 9, 8, 4, "default"), id="ragged_tiny_vocab"),
],
)
def test_streamed_backward_edge_cases(tp_group, case):
"""Covers short sequences, mask extremes, ragged chunks, and tiny vocab."""
from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
ChunkedDistributedLogprob,
DistributedLogprob,
)

batch_size, seq_len, vocab_size, chunk_size, mask_mode = case
device = torch.device("cpu")
torch.manual_seed(1)

logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0
if mask_mode == "all_out":
target = torch.full((batch_size, seq_len), vocab_size + 5, device=device, dtype=torch.long)
else:
target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long)

grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group)
grad_chunk = _backward_grad(
ChunkedDistributedLogprob,
logits,
target,
0,
vocab_size,
tp_group,
chunk_size=chunk_size,
)

assert torch.equal(grad_chunk, grad_ref), "streamed chunked grad must be bit-identical to non-chunked grad"


def test_streamed_backward_covers_sequence_without_overlap(tp_group):
"""Prime seq length checks that ragged tiling writes every buffer slice."""
from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
ChunkedDistributedLogprob,
DistributedLogprob,
)

device = torch.device("cpu")
torch.manual_seed(2)

batch_size = 3
seq_len = 17
vocab_size = 128
chunk_size = 5

logits = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device) * 2.0
target = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long)

grad_ref = _backward_grad(DistributedLogprob, logits, target, 0, vocab_size, tp_group)
grad_chunk = _backward_grad(
ChunkedDistributedLogprob,
logits,
target,
0,
vocab_size,
tp_group,
chunk_size=chunk_size,
)

assert grad_chunk.shape == grad_ref.shape == logits.shape
# Exact match catches unwritten torch.empty slices without relying on NaN/inf.
assert torch.equal(grad_chunk, grad_ref), "every sequence slice of the preallocated buffer must be written"
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ChunkedDistributedLogprob,
DistributedLogprob,
)
from skyrl.train.utils.utils import get_free_port
from skyrl.backends.skyrl_train.distributed.utils import get_free_port


@pytest.fixture(scope="module")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""TP>1 parity for streamed ``ChunkedDistributedLogprob.backward``.

Spawns NCCL ranks, shards a shared full-vocab problem, and compares each local
grad shard with a single-process fp32 reference.

Requires ``TP`` free GPUs. It will NOT run on a CPU-only / macOS dev box.

Run with (>=2 free GPUs; the ``tp_size=4`` case is skipped unless 4 are present):
uv run --isolated --extra dev --extra megatron -- \
pytest -s tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_chunked_logprob_backward_tp.py
"""

import os

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

# Every rank rebuilds the same full-vocab problem.
_SEED = 0
_BATCH = 4
_SEQ_LEN = 30
_VOCAB = 256
_CHUNK_SIZE = 7


def _reference_full_grad(logits_fp32: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Full-vocab reference grad for chosen-token logprobs."""
leaf = logits_fp32.detach().clone().requires_grad_(True)
log_probs = torch.log_softmax(leaf, dim=-1)
chosen = torch.gather(log_probs, -1, target.unsqueeze(-1)).squeeze(-1)
grad_seed = torch.linspace(0.5, 1.5, steps=chosen.numel(), device=chosen.device, dtype=chosen.dtype).reshape(
chosen.shape
)
(grad_seed * chosen).sum().backward()
return leaf.grad.detach()


def _set_ci_nccl_env():
"""Apply the gpu_ci NCCL env that ``mp.spawn`` children do not inherit."""
from skyrl.train.utils.utils import run_p2p_access_check

os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
# Avoid peer_access_supported here; it would spin up Ray.
if not run_p2p_access_check():
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_SHM_DISABLE"] = "1"


def _tp_worker(rank: int, world_size: int, master_port: str, result_path: str):
"""One TP rank: shard the vocab, run chunked forward+backward, save the local grad."""
import megatron.core.parallel_state as mpu

from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
ChunkedDistributedLogprob,
)

torch.cuda.set_device(rank)
# Set before init_process_group; spawned children miss the conftest runtime_env.
_set_ci_nccl_env()

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = master_port
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

mpu.initialize_model_parallel(tensor_model_parallel_size=world_size)
tp_group = mpu.get_tensor_model_parallel_group()

device = torch.device("cuda", rank)
torch.manual_seed(_SEED)

# Build the same full-vocab problem on every rank, then keep this rank's shard.
assert _VOCAB % world_size == 0
partition = _VOCAB // world_size
vocab_start = rank * partition
vocab_end = vocab_start + partition

logits_full = (torch.randn(_BATCH, _SEQ_LEN, _VOCAB, dtype=torch.float32, device=device) * 2.0).contiguous()
target = torch.randint(0, _VOCAB, (_BATCH, _SEQ_LEN), device=device, dtype=torch.long)

leaf = logits_full[:, :, vocab_start:vocab_end].detach().clone().requires_grad_(True)
out = ChunkedDistributedLogprob.apply(leaf, target, vocab_start, vocab_end, _CHUNK_SIZE, tp_group, False)
grad_seed = torch.linspace(0.5, 1.5, steps=out.numel(), device=out.device, dtype=out.dtype).reshape(out.shape)
out.backward(grad_seed)

# Rank 0 later builds the reference from these shared inputs.
torch.save(
{
"rank": rank,
"vocab_start": vocab_start,
"vocab_end": vocab_end,
"grad_local": leaf.grad.detach().cpu(),
"logits_full": logits_full.detach().cpu(),
"target": target.detach().cpu(),
"logprob_out": out.detach().cpu(),
},
f"{result_path}.{rank}",
)

mpu.destroy_model_parallel()
dist.destroy_process_group()


@pytest.mark.megatron
@pytest.mark.parametrize("tp_size", [2, 4])
def test_streamed_chunked_backward_matches_reference_at_tp(tmp_path, tp_size):
"""Per-rank streamed grads match the full-vocab reference slices."""
if not torch.cuda.is_available():
pytest.skip("CUDA required for the vocab-parallel backward")
if torch.cuda.device_count() < tp_size:
pytest.skip(f"requires {tp_size} GPUs, found {torch.cuda.device_count()}")

from skyrl.backends.skyrl_train.distributed.utils import get_free_port

master_port = str(get_free_port())
result_path = str(tmp_path / "tp_grad")

mp.spawn(_tp_worker, args=(tp_size, master_port, result_path), nprocs=tp_size, join=True)

shards = [torch.load(f"{result_path}.{rank}") for rank in range(tp_size)]

# Sanity-check that the spawned ranks used the same problem.
logits_full = shards[0]["logits_full"]
target = shards[0]["target"]
for shard in shards[1:]:
assert torch.equal(shard["logits_full"], logits_full)
assert torch.equal(shard["target"], target)

grad_ref = _reference_full_grad(logits_full, target)

# Rank vocab slices must tile [0, _VOCAB) exactly once.
covered = torch.zeros(_VOCAB, dtype=torch.int64)
for shard in shards:
covered[shard["vocab_start"] : shard["vocab_end"]] += 1
assert torch.equal(covered, torch.ones(_VOCAB, dtype=torch.int64)), "vocab slices must tile [0, vocab) once"

# Allow fp32 reduction-order noise from the cross-rank all-reduce.
for shard in shards:
ref_slice = grad_ref[:, :, shard["vocab_start"] : shard["vocab_end"]]
assert shard["grad_local"].shape == ref_slice.shape
torch.testing.assert_close(shard["grad_local"], ref_slice, atol=1e-5, rtol=1e-4)
Loading