-
Notifications
You must be signed in to change notification settings - Fork 376
[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype) #1805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dyurk-lila
wants to merge
2
commits into
NovaSky-AI:main
Choose a base branch
from
dyurk-lila:feat/optimizer-state-dtype-coercion
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype) #1805
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| """Torch-only coercion for Megatron optimizer dtype kwargs.""" | ||
|
|
||
| from typing import Any, Dict, Set | ||
|
|
||
| import torch | ||
|
|
||
| # Megatron short names plus common YAML spellings. TE stores FP8 optimizer state | ||
| # as uint8, matching Megatron-LM's dtype map. | ||
| _DTYPE_NAME_TO_TORCH: Dict[str, torch.dtype] = { | ||
| "fp32": torch.float32, | ||
| "float32": torch.float32, | ||
| "float": torch.float32, | ||
| "bf16": torch.bfloat16, | ||
| "bfloat16": torch.bfloat16, | ||
| "fp16": torch.float16, | ||
| "float16": torch.float16, | ||
| "half": torch.float16, | ||
| "fp8": torch.uint8, | ||
| "float8": torch.uint8, | ||
| "uint8": torch.uint8, | ||
| } | ||
|
|
||
| # Only TE FusedAdam-backed fields get field-specific checks. ``main_grads_dtype`` | ||
| # is not forwarded at the pinned megatron-core rev, so it is coerced only and | ||
| # left to ``OptimizerConfig.__post_init__``. | ||
| _LEGAL_FIELD_DTYPES: Dict[str, Set[torch.dtype]] = { | ||
| "main_params_dtype": {torch.float32, torch.float16}, | ||
| "exp_avg_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8}, | ||
| "exp_avg_sq_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8}, | ||
| } | ||
|
|
||
|
|
||
| def coerce_optimizer_dtype_kwargs(optimizer_config_kwargs: Dict[str, Any] | None) -> Dict[str, Any]: | ||
| """Return kwargs with recognized ``*_dtype`` strings converted to ``torch.dtype``.""" | ||
| if optimizer_config_kwargs is None: | ||
| return {} | ||
|
|
||
| coerced: Dict[str, Any] = {} | ||
| for key, value in optimizer_config_kwargs.items(): | ||
| if not key.endswith("_dtype"): | ||
| coerced[key] = value | ||
| continue | ||
|
|
||
| if isinstance(value, torch.dtype): | ||
| dtype = value | ||
| elif isinstance(value, str): | ||
| name = value.strip().lower() | ||
| if name not in _DTYPE_NAME_TO_TORCH: | ||
| raise ValueError( | ||
| f"Unrecognized dtype name {value!r} for optimizer kwarg {key!r}. " | ||
| f"Expected one of {sorted(_DTYPE_NAME_TO_TORCH)} or a torch.dtype." | ||
| ) | ||
| dtype = _DTYPE_NAME_TO_TORCH[name] | ||
| else: | ||
| # Let Megatron validate non-string, non-dtype values. | ||
| coerced[key] = value | ||
| continue | ||
|
|
||
| legal = _LEGAL_FIELD_DTYPES.get(key) | ||
| if legal is not None and dtype not in legal: | ||
| legal_names = sorted({n for n, d in _DTYPE_NAME_TO_TORCH.items() if d in legal}) | ||
| raise ValueError(f"Illegal dtype {dtype} for optimizer kwarg {key!r}; legal values are {legal_names}.") | ||
| coerced[key] = dtype | ||
| return coerced | ||
154 changes: 154 additions & 0 deletions
154
tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| """Tests for Megatron optimizer dtype coercion.""" | ||
|
|
||
| import sys | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import ( | ||
| coerce_optimizer_dtype_kwargs, | ||
| ) | ||
|
|
||
| _has_megatron = "megatron" in sys.modules or __import__("importlib").util.find_spec("megatron") is not None | ||
|
|
||
|
|
||
| class TestCoerceOptimizerDtypeKwargs: | ||
| def _coerce(self, kwargs: dict | None) -> dict: | ||
| return coerce_optimizer_dtype_kwargs(kwargs) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "name,expected", | ||
| [ | ||
| ("bf16", torch.bfloat16), | ||
| ("bfloat16", torch.bfloat16), | ||
| ("fp16", torch.float16), | ||
| ("float16", torch.float16), | ||
| ("half", torch.float16), | ||
| ("fp32", torch.float32), | ||
| ("float32", torch.float32), | ||
| ("float", torch.float32), | ||
| ("fp8", torch.uint8), | ||
| ("float8", torch.uint8), | ||
| ("uint8", torch.uint8), | ||
| ], | ||
| ) | ||
| def test_string_names_coerce_to_torch_dtype(self, name, expected): | ||
| out = self._coerce({"exp_avg_dtype": name}) | ||
| assert out["exp_avg_dtype"] == expected | ||
| assert isinstance(out["exp_avg_dtype"], torch.dtype) | ||
|
|
||
| def test_fp8_maps_to_uint8(self): | ||
| out = self._coerce({"exp_avg_sq_dtype": "fp8"}) | ||
| assert out["exp_avg_sq_dtype"] is torch.uint8 | ||
|
|
||
| def test_case_and_whitespace_insensitive(self): | ||
| out = self._coerce({"exp_avg_dtype": " BF16 "}) | ||
| assert out["exp_avg_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_already_torch_dtype_passes_through(self): | ||
| out = self._coerce({"exp_avg_dtype": torch.bfloat16}) | ||
| assert out["exp_avg_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_main_params_dtype_accepts_fp32_and_fp16(self): | ||
| assert self._coerce({"main_params_dtype": "fp32"})["main_params_dtype"] is torch.float32 | ||
| assert self._coerce({"main_params_dtype": "fp16"})["main_params_dtype"] is torch.float16 | ||
|
|
||
| @pytest.mark.parametrize("bad", ["bf16", "fp8"]) | ||
| def test_main_params_dtype_rejects_bf16_and_fp8(self, bad): | ||
| with pytest.raises(ValueError, match="main_params_dtype"): | ||
| self._coerce({"main_params_dtype": bad}) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "name,expected", [("bf16", torch.bfloat16), ("fp16", torch.float16), ("fp32", torch.float32)] | ||
| ) | ||
| def test_params_dtype_is_coerced_with_no_field_restriction(self, name, expected): | ||
| out = self._coerce({"params_dtype": name}) | ||
| assert out["params_dtype"] is expected | ||
|
|
||
| def test_main_grads_dtype_coerced_but_not_field_validated(self): | ||
| out = self._coerce({"main_grads_dtype": "bf16"}) | ||
| assert out["main_grads_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_unrecognized_dtype_name_raises(self): | ||
| with pytest.raises(ValueError, match="Unrecognized dtype name"): | ||
| self._coerce({"exp_avg_dtype": "bf17"}) | ||
|
|
||
| def test_unrelated_kwargs_pass_through_untouched(self): | ||
| kwargs = { | ||
| "use_precision_aware_optimizer": True, | ||
| "optimizer_offload_fraction": 0.5, | ||
| "overlap_cpu_optimizer_d2h_h2d": False, | ||
| "exp_avg_dtype": "bf16", | ||
| } | ||
| out = self._coerce(kwargs) | ||
| assert out["use_precision_aware_optimizer"] is True | ||
| assert out["optimizer_offload_fraction"] == 0.5 | ||
| assert out["overlap_cpu_optimizer_d2h_h2d"] is False | ||
| assert out["exp_avg_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_non_string_non_dtype_dtype_value_passes_through(self): | ||
| out = self._coerce({"main_grads_dtype": None}) | ||
| assert out["main_grads_dtype"] is None | ||
|
|
||
| def test_none_kwargs_returns_empty_dict(self): | ||
| assert self._coerce(None) == {} | ||
|
|
||
| def test_input_not_mutated(self): | ||
| kwargs = {"exp_avg_dtype": "bf16"} | ||
| self._coerce(kwargs) | ||
| assert kwargs["exp_avg_dtype"] == "bf16" | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not _has_megatron, reason="megatron-core not installed") | ||
| class TestInitMegatronOptimConfigDtypeCoercion: | ||
| def test_string_dtype_kwargs_reach_optimizer_config(self): | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| optim_config = SkyRLOptimizerConfig() | ||
| config = init_megatron_optim_config( | ||
| optim_config, | ||
| { | ||
| "use_precision_aware_optimizer": True, | ||
| "exp_avg_dtype": "bf16", | ||
| "exp_avg_sq_dtype": "fp8", | ||
| "main_params_dtype": "fp32", | ||
| }, | ||
| ) | ||
| assert config.exp_avg_dtype is torch.bfloat16 | ||
| assert config.exp_avg_sq_dtype is torch.uint8 | ||
| assert config.main_params_dtype is torch.float32 | ||
|
|
||
| def test_params_dtype_string_override_reaches_optimizer_config(self): | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| config = init_megatron_optim_config(SkyRLOptimizerConfig(), {"params_dtype": "fp16"}) | ||
| assert config.params_dtype is torch.float16 | ||
|
|
||
| def test_default_kwargs_leave_dtypes_at_megatron_defaults(self): | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| config = init_megatron_optim_config(SkyRLOptimizerConfig(), {}) | ||
| assert config.exp_avg_dtype is torch.float32 | ||
| assert config.exp_avg_sq_dtype is torch.float32 | ||
| assert config.main_params_dtype is torch.float32 | ||
|
|
||
| def test_precision_aware_off_with_nonfp32_state_fast_fails_in_megatron(self): | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| with pytest.raises(AssertionError, match="exp_avg_dtype can only be fp32"): | ||
| init_megatron_optim_config( | ||
| SkyRLOptimizerConfig(), | ||
| {"use_precision_aware_optimizer": False, "exp_avg_dtype": "bf16"}, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
optimizer_config_kwargsisNone(e.g., if it is omitted or set tonullin the YAML configuration), calling.items()on it will raise anAttributeError. Adding a defensiveNonecheck at the beginning of the function ensures robustness and prevents runtime crashes.