From c1464e2677eddc7d4571f862dea6ee136aee1da1 Mon Sep 17 00:00:00 2001 From: kurtislin Date: Fri, 3 Jul 2026 17:46:06 +0800 Subject: [PATCH 1/2] [fsdp] Exclude fully-padding microbatches from metric aggregation #1817 excluded fully-padding microbatches from metric aggregation for the Megatron backend. Apply the same skip to the shared forward_backward loops in worker.py used by FSDP, mirroring megatron_worker.py: padding microbatches still run forward/backward (per-rank collective counts stay equal), only the metric append is skipped. --- skyrl/backends/skyrl_train/workers/worker.py | 12 ++ .../test_forward_backward_padding_metrics.py | 130 ++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index ed1a3e74b6..57d642c52c 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -772,6 +772,13 @@ def forward_backward( if "loss_fn_outputs" in metrics: all_loss_fn_outputs.extend(metrics.pop("loss_fn_outputs")) + # Skip fully-padding microbatches: their metrics (clip_ratio=0, policy_entropy=0, + # ...) are meaningless and would drag down the mean-reduced metrics. Summed + # metrics (e.g. policy_loss) are unaffected since padding contributes 0, but + # excluding them here keeps both reductions correct. + if experience.metadata and experience.metadata.get("is_padding_batch", False): + continue + for k, v in metrics.items(): all_metrics[k].append(v) @@ -1296,6 +1303,11 @@ def forward_backward(self, data: TrainingInputBatch) -> WorkerOutput: metrics = self._forward_backward_micro(experience) self._micro_batches_accumulated += 1 + # Skip fully-padding microbatches: their all-zero metrics (e.g. critic_loss=0) + # would drag down the mean-reduced metrics. Mirrors PolicyWorkerBase.forward_backward. + if experience.metadata and experience.metadata.get("is_padding_batch", False): + continue + for k, v in metrics.items(): all_metrics[k].append(v) diff --git a/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py b/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py new file mode 100644 index 0000000000..a1592c24c6 --- /dev/null +++ b/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py @@ -0,0 +1,130 @@ +""" +CPU-only tests that PolicyWorkerBase / CriticWorkerBase.forward_backward exclude +fully-padding microbatches from metric aggregation, mirroring the Megatron-side +behavior (megatron_worker.py). + +uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import torch + +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.workers.worker import CriticWorkerBase, PolicyWorkerBase +from skyrl.backends.skyrl_train.workers.worker_utils import TokenBasedBatchIterator + +MAX_TOKENS_PER_MICROBATCH = 15 + + +def _make_batch(seq_lens, num_actions=4): + """Dummy TrainingInputBatch with variable sequence lengths.""" + batch_size = len(seq_lens) + max_seq_len = max(seq_lens) + + sequences = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + attention_mask = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + for i, seq_len in enumerate(seq_lens): + sequences[i, :seq_len] = torch.randint(0, 100, (seq_len,), dtype=int, device="cpu") + attention_mask[i, :seq_len] = 1 + + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + } + ) + data.metadata = {"response_length": num_actions} + return data + + +@pytest.fixture +def force_one_padding_microbatch(monkeypatch): + """Force TokenBasedBatchIterator to add exactly one padding microbatch, as if another + DP rank had packed one more microbatch (dist is not initialized in CPU tests, so + _sync_num_microbatches would otherwise just return the local count).""" + original = TokenBasedBatchIterator._sync_num_microbatches + + def one_extra(self): + return original(self) + 1 + + monkeypatch.setattr(TokenBasedBatchIterator, "_sync_num_microbatches", one_extra) + + +def _identity_all_reduce(d, op=None, group=None): + return dict(d) + + +def test_policy_padding_microbatch_excluded_from_metrics(force_one_padding_microbatch): + """Mean-reduced metrics must ignore padding microbatches; summed metrics are unchanged. + + [10, 5], [10, 5] pack into 2 real microbatches at 15 tokens; the fixture forces one + extra padding microbatch. Without the skip, policy_entropy = (1.0 + 1.0 + 0.0) / 3. + """ + worker = PolicyWorkerBase.__new__(PolicyWorkerBase) + worker.cfg = SimpleNamespace( + micro_train_batch_size_per_gpu=1, + max_tokens_per_microbatch=MAX_TOKENS_PER_MICROBATCH, + algorithm=SimpleNamespace(policy_loss_type="regular"), + ) + worker.strategy = MagicMock() + worker.strategy.all_reduce.side_effect = _identity_all_reduce + worker.device_mesh = MagicMock() + + padding_flags = [] + + def fake_forward_backward_micro(experience, microbatch_weight, loss_fn=None, loss_fn_config=None): + is_padding = bool(experience.metadata and experience.metadata.get("is_padding_batch", False)) + padding_flags.append(is_padding) + if is_padding: + # A fully-padding microbatch has an all-zero loss mask, so its masked-mean + # metrics come out as exactly 0 (see TokenBasedBatchIterator._create_padding_microbatch). + return {"policy_entropy": 0.0, "policy_loss": 0.0} + return {"policy_entropy": 1.0, "policy_loss": 0.5} + + worker._forward_backward_micro = fake_forward_backward_micro + + out = worker.forward_backward(_make_batch([10, 10, 5, 5])) + + # The padding microbatch still ran forward/backward (collective parity across DP ranks)... + assert len(padding_flags) == 3 and padding_flags.count(True) == 1 + # ...but is excluded from mean-reduced metrics (2/3 without the skip). + assert out.metrics["policy_entropy"] == pytest.approx(1.0) + # Summed metrics are unaffected either way: padding contributes 0 to a sum. + assert out.metrics["policy_loss"] == pytest.approx(1.0) + # Diagnostics still count the padding microbatch. + assert out.metrics["num_microbatches"] == 3.0 + assert out.metrics["num_padding_microbatches"] == 1.0 + + +def test_critic_padding_microbatch_excluded_from_metrics(force_one_padding_microbatch): + """critic_loss is mean-reduced on the critic path (reduce_metrics without + sum_loss_metrics), so a padding microbatch's 0.0 would directly bias it: + (0.5 + 0.5 + 0.0) / 3 without the skip.""" + worker = CriticWorkerBase.__new__(CriticWorkerBase) + worker.cfg = SimpleNamespace( + micro_train_batch_size_per_gpu=1, + max_tokens_per_microbatch=MAX_TOKENS_PER_MICROBATCH, + ) + worker.strategy = MagicMock() + worker.strategy.all_reduce.side_effect = _identity_all_reduce + + def fake_forward_backward_micro(experience, microbatch_weight=None): + if experience.metadata and experience.metadata.get("is_padding_batch", False): + return {"critic_loss": 0.0} + return {"critic_loss": 0.5} + + worker._forward_backward_micro = fake_forward_backward_micro + + out = worker.forward_backward(_make_batch([10, 10, 5, 5])) + + assert out.metrics["critic_loss"] == pytest.approx(0.5) From 8ffaf830cd320b9db121c3a53149feb33d54de84 Mon Sep 17 00:00:00 2001 From: kurtislin Date: Fri, 3 Jul 2026 18:09:45 +0800 Subject: [PATCH 2/2] [fsdp] Exclude padding microbatches' loss_fn_outputs too Move the padding skip above the loss_fn_outputs extraction so dummy per-sample entries from padding microbatches are not returned (review feedback). Test asserts only real-sample outputs remain. --- skyrl/backends/skyrl_train/workers/worker.py | 16 +++++++++------- .../test_forward_backward_padding_metrics.py | 15 ++++++++++++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 57d642c52c..6ee8025d0c 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -768,17 +768,19 @@ def forward_backward( experience, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config ) + # Skip fully-padding microbatches: their loss_fn_outputs are dummy entries + # for samples that don't exist, and their metrics (clip_ratio=0, + # policy_entropy=0, ...) are meaningless and would drag down the + # mean-reduced metrics. Summed metrics (e.g. policy_loss) are unaffected + # since padding contributes 0, but excluding them here keeps both + # reductions correct. + if experience.metadata and experience.metadata.get("is_padding_batch", False): + continue + # Extract loss_fn_outputs before reduce_metrics (it's not a scalar metric) if "loss_fn_outputs" in metrics: all_loss_fn_outputs.extend(metrics.pop("loss_fn_outputs")) - # Skip fully-padding microbatches: their metrics (clip_ratio=0, policy_entropy=0, - # ...) are meaningless and would drag down the mean-reduced metrics. Summed - # metrics (e.g. policy_loss) are unaffected since padding contributes 0, but - # excluding them here keeps both reductions correct. - if experience.metadata and experience.metadata.get("is_padding_batch", False): - continue - for k, v in metrics.items(): all_metrics[k].append(v) diff --git a/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py b/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py index a1592c24c6..fbbd4cf125 100644 --- a/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py +++ b/tests/backends/skyrl_train/workers/test_forward_backward_padding_metrics.py @@ -87,9 +87,14 @@ def fake_forward_backward_micro(experience, microbatch_weight, loss_fn=None, los padding_flags.append(is_padding) if is_padding: # A fully-padding microbatch has an all-zero loss mask, so its masked-mean - # metrics come out as exactly 0 (see TokenBasedBatchIterator._create_padding_microbatch). - return {"policy_entropy": 0.0, "policy_loss": 0.0} - return {"policy_entropy": 1.0, "policy_loss": 0.5} + # metrics come out as exactly 0, and its loss_fn_outputs are dummy entries + # (see TokenBasedBatchIterator._create_padding_microbatch). + return {"policy_entropy": 0.0, "policy_loss": 0.0, "loss_fn_outputs": [{"logprobs": [0.0]}]} + return { + "policy_entropy": 1.0, + "policy_loss": 0.5, + "loss_fn_outputs": [{"logprobs": [1.0]}, {"logprobs": [1.0]}], + } worker._forward_backward_micro = fake_forward_backward_micro @@ -104,6 +109,10 @@ def fake_forward_backward_micro(experience, microbatch_weight, loss_fn=None, los # Diagnostics still count the padding microbatch. assert out.metrics["num_microbatches"] == 3.0 assert out.metrics["num_padding_microbatches"] == 1.0 + # loss_fn_outputs from the padding microbatch are excluded too: 2 real microbatches + # x 2 samples each remain, with no dummy [0.0] entry. + assert len(out.loss_fn_outputs) == 4 + assert all(entry == {"logprobs": [1.0]} for entry in out.loss_fn_outputs) def test_critic_padding_microbatch_excluded_from_metrics(force_one_padding_microbatch):