Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,23 @@ Some rules for configuring these parameters:
`optimizer_config_kwargs.use_precision_aware_optimizer=true` can cause checkpointing to fail. See: https://github.com/nvidia/megatron-lm/issues/1820. We recommend leaving this setting to `false`.
</Callout>

`optimizer_config_kwargs` accepts string values for Megatron `*_dtype` fields:

```yaml
optimizer_config_kwargs:
use_precision_aware_optimizer: true
exp_avg_dtype: bf16
exp_avg_sq_dtype: fp8
main_params_dtype: fp32
```

Accepted names are case-insensitive: `fp32` (`float32`, `float`), `fp16` (`float16`, `half`), `bf16` (`bfloat16`), and `fp8` (`float8`, `uint8`). `fp8` maps to `torch.uint8`, matching TransformerEngine optimizer state storage.

Field-specific checks:

- `main_params_dtype` (master weights): `fp32`, `fp16`
- `exp_avg_dtype` / `exp_avg_sq_dtype`: `fp32`, `fp16`, `bf16`, `fp8`

## Optimizer Configuration

For both the critic and policy model, we provide a common optimizer configuration
Expand Down
14 changes: 13 additions & 1 deletion docs/content/docs/examples/megatron.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,22 @@ empty_cuda_cache: true

These default values can be overridden by passing in the corresponding arguments to `trainer.policy.megatron_config` in the launch script.

`optimizer_config_kwargs` can set optimizer-state dtypes from YAML:

```yaml
optimizer_config_kwargs:
use_precision_aware_optimizer: true
exp_avg_dtype: bf16
exp_avg_sq_dtype: fp8
main_params_dtype: fp32
```

See the [Megatron configuration guide](../configuration/config#megatron-configuration) for accepted aliases and per-field checks.

## Parallelism Resources

Understanding and configuring parallelism strategies for large models can be challenging.
Some helpful resources for understanding and tuning large scale parallelism strategies can be found at the [Huggingface Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=finding_the_best_training_configuration),
the [The Mesh Parallelism Zoo](https://blog.ezyang.com/2025/08/the-parallelism-mesh-zoo/), and the [Visualizing 6-D Parallelism](https://main-horse.github.io/posts/visualizing-6d).

Below, we show a diagram displaying how all 5 parallelism strategies - tensor, pipeline, context, expert, and data parallelism - can be utilized in SkyRL, as well as how dispatching data across these parallel groups works.
Below, we show a diagram displaying how all 5 parallelism strategies - tensor, pipeline, context, expert, and data parallelism - can be utilized in SkyRL, as well as how dispatching data across these parallel groups works.
6 changes: 5 additions & 1 deletion skyrl/backends/skyrl_train/distributed/megatron/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from omegaconf import DictConfig

from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import (
coerce_optimizer_dtype_kwargs,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig


Expand All @@ -45,7 +48,8 @@ def init_megatron_optim_config(
"params_dtype": torch.bfloat16,
"use_distributed_optimizer": True,
}
optim_args.update(optimizer_config_kwargs)
# YAML dtype overrides arrive as strings; Megatron expects torch.dtype.
optim_args.update(coerce_optimizer_dtype_kwargs(optimizer_config_kwargs))

config = OptimizerConfig(**optim_args)
return config
Expand Down
64 changes: 64 additions & 0 deletions skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Torch-only coercion for Megatron optimizer dtype kwargs."""

from typing import Any, Dict, Set

import torch

# Megatron short names plus common YAML spellings. TE stores FP8 optimizer state
# as uint8, matching Megatron-LM's dtype map.
_DTYPE_NAME_TO_TORCH: Dict[str, torch.dtype] = {
"fp32": torch.float32,
"float32": torch.float32,
"float": torch.float32,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp16": torch.float16,
"float16": torch.float16,
"half": torch.float16,
"fp8": torch.uint8,
"float8": torch.uint8,
"uint8": torch.uint8,
}

# Only TE FusedAdam-backed fields get field-specific checks. ``main_grads_dtype``
# is not forwarded at the pinned megatron-core rev, so it is coerced only and
# left to ``OptimizerConfig.__post_init__``.
_LEGAL_FIELD_DTYPES: Dict[str, Set[torch.dtype]] = {
"main_params_dtype": {torch.float32, torch.float16},
"exp_avg_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8},
"exp_avg_sq_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8},
}


def coerce_optimizer_dtype_kwargs(optimizer_config_kwargs: Dict[str, Any] | None) -> Dict[str, Any]:
"""Return kwargs with recognized ``*_dtype`` strings converted to ``torch.dtype``."""
if optimizer_config_kwargs is None:
return {}

coerced: Dict[str, Any] = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If optimizer_config_kwargs is None (e.g., if it is omitted or set to null in the YAML configuration), calling .items() on it will raise an AttributeError. Adding a defensive None check at the beginning of the function ensures robustness and prevents runtime crashes.

    if optimizer_config_kwargs is None:
        return {}
    coerced: Dict[str, Any] = {}

for key, value in optimizer_config_kwargs.items():
if not key.endswith("_dtype"):
coerced[key] = value
continue

if isinstance(value, torch.dtype):
dtype = value
elif isinstance(value, str):
name = value.strip().lower()
if name not in _DTYPE_NAME_TO_TORCH:
raise ValueError(
f"Unrecognized dtype name {value!r} for optimizer kwarg {key!r}. "
f"Expected one of {sorted(_DTYPE_NAME_TO_TORCH)} or a torch.dtype."
)
dtype = _DTYPE_NAME_TO_TORCH[name]
else:
# Let Megatron validate non-string, non-dtype values.
coerced[key] = value
continue

legal = _LEGAL_FIELD_DTYPES.get(key)
if legal is not None and dtype not in legal:
legal_names = sorted({n for n, d in _DTYPE_NAME_TO_TORCH.items() if d in legal})
raise ValueError(f"Illegal dtype {dtype} for optimizer kwarg {key!r}; legal values are {legal_names}.")
coerced[key] = dtype
return coerced
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Tests for Megatron optimizer dtype coercion."""

import sys

import pytest
import torch

from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import (
coerce_optimizer_dtype_kwargs,
)

_has_megatron = "megatron" in sys.modules or __import__("importlib").util.find_spec("megatron") is not None


class TestCoerceOptimizerDtypeKwargs:
def _coerce(self, kwargs: dict | None) -> dict:
return coerce_optimizer_dtype_kwargs(kwargs)

@pytest.mark.parametrize(
"name,expected",
[
("bf16", torch.bfloat16),
("bfloat16", torch.bfloat16),
("fp16", torch.float16),
("float16", torch.float16),
("half", torch.float16),
("fp32", torch.float32),
("float32", torch.float32),
("float", torch.float32),
("fp8", torch.uint8),
("float8", torch.uint8),
("uint8", torch.uint8),
],
)
def test_string_names_coerce_to_torch_dtype(self, name, expected):
out = self._coerce({"exp_avg_dtype": name})
assert out["exp_avg_dtype"] == expected
assert isinstance(out["exp_avg_dtype"], torch.dtype)

def test_fp8_maps_to_uint8(self):
out = self._coerce({"exp_avg_sq_dtype": "fp8"})
assert out["exp_avg_sq_dtype"] is torch.uint8

def test_case_and_whitespace_insensitive(self):
out = self._coerce({"exp_avg_dtype": " BF16 "})
assert out["exp_avg_dtype"] is torch.bfloat16

def test_already_torch_dtype_passes_through(self):
out = self._coerce({"exp_avg_dtype": torch.bfloat16})
assert out["exp_avg_dtype"] is torch.bfloat16

def test_main_params_dtype_accepts_fp32_and_fp16(self):
assert self._coerce({"main_params_dtype": "fp32"})["main_params_dtype"] is torch.float32
assert self._coerce({"main_params_dtype": "fp16"})["main_params_dtype"] is torch.float16

@pytest.mark.parametrize("bad", ["bf16", "fp8"])
def test_main_params_dtype_rejects_bf16_and_fp8(self, bad):
with pytest.raises(ValueError, match="main_params_dtype"):
self._coerce({"main_params_dtype": bad})

@pytest.mark.parametrize(
"name,expected", [("bf16", torch.bfloat16), ("fp16", torch.float16), ("fp32", torch.float32)]
)
def test_params_dtype_is_coerced_with_no_field_restriction(self, name, expected):
out = self._coerce({"params_dtype": name})
assert out["params_dtype"] is expected

def test_main_grads_dtype_coerced_but_not_field_validated(self):
out = self._coerce({"main_grads_dtype": "bf16"})
assert out["main_grads_dtype"] is torch.bfloat16

def test_unrecognized_dtype_name_raises(self):
with pytest.raises(ValueError, match="Unrecognized dtype name"):
self._coerce({"exp_avg_dtype": "bf17"})

def test_unrelated_kwargs_pass_through_untouched(self):
kwargs = {
"use_precision_aware_optimizer": True,
"optimizer_offload_fraction": 0.5,
"overlap_cpu_optimizer_d2h_h2d": False,
"exp_avg_dtype": "bf16",
}
out = self._coerce(kwargs)
assert out["use_precision_aware_optimizer"] is True
assert out["optimizer_offload_fraction"] == 0.5
assert out["overlap_cpu_optimizer_d2h_h2d"] is False
assert out["exp_avg_dtype"] is torch.bfloat16

def test_non_string_non_dtype_dtype_value_passes_through(self):
out = self._coerce({"main_grads_dtype": None})
assert out["main_grads_dtype"] is None

def test_none_kwargs_returns_empty_dict(self):
assert self._coerce(None) == {}

def test_input_not_mutated(self):
kwargs = {"exp_avg_dtype": "bf16"}
self._coerce(kwargs)
assert kwargs["exp_avg_dtype"] == "bf16"


@pytest.mark.skipif(not _has_megatron, reason="megatron-core not installed")
class TestInitMegatronOptimConfigDtypeCoercion:
def test_string_dtype_kwargs_reach_optimizer_config(self):
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

optim_config = SkyRLOptimizerConfig()
config = init_megatron_optim_config(
optim_config,
{
"use_precision_aware_optimizer": True,
"exp_avg_dtype": "bf16",
"exp_avg_sq_dtype": "fp8",
"main_params_dtype": "fp32",
},
)
assert config.exp_avg_dtype is torch.bfloat16
assert config.exp_avg_sq_dtype is torch.uint8
assert config.main_params_dtype is torch.float32

def test_params_dtype_string_override_reaches_optimizer_config(self):
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

config = init_megatron_optim_config(SkyRLOptimizerConfig(), {"params_dtype": "fp16"})
assert config.params_dtype is torch.float16

def test_default_kwargs_leave_dtypes_at_megatron_defaults(self):
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

config = init_megatron_optim_config(SkyRLOptimizerConfig(), {})
assert config.exp_avg_dtype is torch.float32
assert config.exp_avg_sq_dtype is torch.float32
assert config.main_params_dtype is torch.float32

def test_precision_aware_off_with_nonfp32_state_fast_fails_in_megatron(self):
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

with pytest.raises(AssertionError, match="exp_avg_dtype can only be fp32"):
init_megatron_optim_config(
SkyRLOptimizerConfig(),
{"use_precision_aware_optimizer": False, "exp_avg_dtype": "bf16"},
)
Loading