diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 60b375777ba..c4a58ab73c9 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -1134,7 +1134,9 @@ def __new__( only_in_statuses=[TrialStatus.RUNNING], block_gen_if_met=True, block_transition_if_unmet=False, - transition_to=None, # Re-set in GS constructor. + # MaxParallelism transitions to self, + # this will be confirmed in GS init + transition_to=f"GenerationStep_{str(index)}", ) ) diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index f37724e9bc8..fd4e8581567 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -384,26 +384,22 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None: step._step_index = idx step._generation_strategy = self - # Set transition_to field for all but the last step, which transitions - # to itself. - # transition_to = step.name - try: - next_step = steps[idx + 1] - transition_to = GEN_STEP_NAME.format( + # Determine transition_to for steps, last step will transition to self + is_last_step = idx == len(steps) - 1 + next_step_name = ( + step.name + if is_last_step + else GEN_STEP_NAME.format( step_index=idx + 1, - generator_name=next_step.generator_name, + generator_name=steps[idx + 1].generator_name, ) - # for transition_criteria in step.transition_criteria: - # if ( - # transition_criteria.criterion_class - # != "MaxGenerationParallelism" - # ): - # transition_criteria._transition_to = transition_to - except IndexError: # Last step, steps[idx+1] was not possible - transition_to = step.name + ) for tc in step.transition_criteria: - if tc.criterion_class != "MaxGenerationParallelism": - tc._transition_to = transition_to + if tc.criterion_class == "MaxGenerationParallelism": + # MaxGenerationParallelism transitions to self (current step) + tc._transition_to = step.name + else: + tc._transition_to = next_step_name self._curr = steps[0] def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None: diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 011e3b1c849..24883ffe01d 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -1238,6 +1238,7 @@ def test_gs_setup_with_nodes(self) -> None: only_in_statuses=[TrialStatus.RUNNING], block_gen_if_met=True, block_transition_if_unmet=False, + transition_to="node_1", ), ] node_1 = GenerationNode( diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index fc858c9eaea..afd039b6d2f 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -27,7 +27,6 @@ IsSingleObjective, MaxGenerationParallelism, MinimumPreferenceOccurances, - MinimumTrialsInStatus, MinTrials, ) from ax.utils.common.logger import get_logger @@ -247,6 +246,7 @@ def test_default_step_criterion_setup(self) -> None: only_in_statuses=[TrialStatus.RUNNING], block_gen_if_met=True, block_transition_if_unmet=False, + transition_to="GenerationStep_1_BoTorch", ), ] step_2_expected_transition_criteria = [] @@ -525,16 +525,6 @@ def test_repr(self) -> None: + "'transition_to': None, 'block_gen_if_met': False, " "'block_transition_if_unmet': True})", ) - deprecated_min_trials_criterion = MinimumTrialsInStatus( - status=TrialStatus.COMPLETED, threshold=3 - ) - self.assertEqual( - str(deprecated_min_trials_criterion), - "MinimumTrialsInStatus({" - + "'status': .COMPLETED, " - + "'threshold': 3, " - + "'transition_to': None})", - ) max_parallelism = MaxGenerationParallelism( only_in_statuses=[TrialStatus.EARLY_STOPPED], threshold=3, @@ -545,14 +535,15 @@ def test_repr(self) -> None: ) self.assertEqual( str(max_parallelism), - "MaxGenerationParallelism({'threshold': 3, 'only_in_statuses': " + "MaxGenerationParallelism({'threshold': 3, " + + "'transition_to': 'GenerationStep_2', " + + "'only_in_statuses': " + "[.EARLY_STOPPED], " + "'not_in_statuses': [.FAILED], " - + "'transition_to': 'GenerationStep_2', " + "'block_transition_if_unmet': False, " + "'block_gen_if_met': True, " + "'use_all_trials_in_exp': False, " - + "'continue_trial_generation': True})", + + "'continue_trial_generation': False})", ) auto_transition = AutoTransitionAfterGen(transition_to="GenerationStep_2") self.assertEqual( diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index ffcc2f48961..aab6e1f20ee 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -21,7 +21,7 @@ from ax.generation_strategy.generation_node import GenerationNode from ax.utils.common.base import SortableBase -from ax.utils.common.serialization import SerializationMixin, serialize_init_args +from ax.utils.common.serialization import serialize_init_args from pyre_extensions import none_throws @@ -32,7 +32,7 @@ ) -class TransitionCriterion(SortableBase, SerializationMixin): +class TransitionCriterion(SortableBase): """ Simple class to describe a condition which must be met for this GenerationNode to take an action such as generation, transition, etc. @@ -424,13 +424,13 @@ class MaxGenerationParallelism(TrialBasedCriterion): def __init__( self, threshold: int, + transition_to: str, only_in_statuses: list[TrialStatus] | None = None, not_in_statuses: list[TrialStatus] | None = None, - transition_to: str | None = None, block_transition_if_unmet: bool | None = False, block_gen_if_met: bool | None = True, use_all_trials_in_exp: bool | None = False, - continue_trial_generation: bool | None = True, + continue_trial_generation: bool | None = False, ) -> None: super().__init__( threshold=threshold, @@ -732,35 +732,3 @@ def block_continued_generation_error( f"This criterion, {self.criterion_class} has been met but cannot " "continue generation from its associated GenerationNode." ) - - -# TODO: Deprecate once legacy usecase is updated -class MinimumTrialsInStatus(TransitionCriterion): - """ - Deprecated and replaced with more flexible MinTrials criterion. - """ - - def __init__( - self, - status: TrialStatus, - threshold: int, - transition_to: str | None = None, - ) -> None: - self.status = status - self.threshold = threshold - super().__init__(transition_to=transition_to) - - def is_met( - self, - experiment: Experiment, - curr_node: GenerationNode, - ) -> bool: - return len(experiment.trial_indices_by_status[self.status]) >= self.threshold - - def block_continued_generation_error( - self, - node_name: str | None, - experiment: Experiment | None, - trials_from_node: set[int], - ) -> None: - pass diff --git a/ax/orchestration/tests/test_orchestrator.py b/ax/orchestration/tests/test_orchestrator.py index d6f7bfd470f..a42d8ac11f4 100644 --- a/ax/orchestration/tests/test_orchestrator.py +++ b/ax/orchestration/tests/test_orchestrator.py @@ -163,7 +163,8 @@ class TestAxOrchestrator(TestCase): "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MaxGenerationParallelism(transition_to='None')])]), " + "transition_criteria=[MaxGenerationParallelism(" + "transition_to='GenerationStep_1_BoTorch')])]), " "options=OrchestratorOptions(max_pending_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " @@ -2857,7 +2858,8 @@ class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator): "GenerationNode(name='GenerationStep_1_BoTorch', " "generator_specs=[GeneratorSpec(generator_enum=BoTorch, " "generator_key_override=None)], " - "transition_criteria=[MaxGenerationParallelism(transition_to='None')])]), " + "transition_criteria=" + "[MaxGenerationParallelism(transition_to='GenerationStep_1_BoTorch')])]), " "options=OrchestratorOptions(max_pending_trials=10, " "trial_type=, batch_size=None, " "total_trials=0, tolerated_trial_failure_rate=0.2, " diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index b9685c0f6fb..e22af04e345 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -42,11 +42,7 @@ GenerationStrategy, ) from ax.generation_strategy.generator_spec import GeneratorSpec -from ax.generation_strategy.transition_criterion import ( - AuxiliaryExperimentCheck, - TransitionCriterion, - TrialBasedCriterion, -) +from ax.generation_strategy.transition_criterion import MinTrials, TransitionCriterion from ax.generators.torch.botorch_modular.generator import BoTorchGenerator from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.generators.torch.botorch_modular.utils import ModelConfig @@ -63,6 +59,7 @@ from ax.storage.utils import data_by_trial_to_data from ax.utils.common.logger import get_logger from ax.utils.common.serialization import ( + extract_init_args, SerializationMixin, TClassDecoderRegistry, TDecoderRegistry, @@ -280,16 +277,16 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) - elif isclass(_class) and ( - issubclass(_class, TrialBasedCriterion) - or issubclass(_class, AuxiliaryExperimentCheck) + elif ( + isclass(_class) + and issubclass(_class, TransitionCriterion) + and _class is not TransitionCriterion # TransitionCriterion is abstract ): - # TrialBasedCriterion contains a list of `TrialStatus` for args. - # AuxiliaryExperimentCheck contains AuxiliaryExperimentPurpose objects - # They need to be unpacked by hand to properly retain the types. - return unpack_transition_criteria_from_json( - class_=_class, - transition_criteria_json=object_json, + # TransitionCriterion may contain nested Ax objects (TrialStatus, etc.) + # that need recursive deserialization via object_from_json. + return transition_criterion_from_json( + transition_criterion_class=_class, + object_json=object_json, **vars(registry_kwargs), ) elif isclass(_class) and issubclass(_class, SerializationMixin): @@ -430,32 +427,51 @@ def generator_run_from_json( return generator_run -def unpack_transition_criteria_from_json( - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to - # avoid runtime subscripting errors. - class_: type, - transition_criteria_json: dict[str, Any], +def transition_criterion_from_json( + transition_criterion_class: type[TransitionCriterion], + object_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> TransitionCriterion | None: - """Load Ax transition criteria that depend on Trials from JSON. - - Since ``TrialBasedCriterion`` contain lists of ``TrialStatus``, - the json for these criterion needs to be carefully unpacked and - re-processed via ``object_from_json`` in order to maintain correct - typing. We pass in ``class_`` in order to correctly handle all classes - which inherit from ``TrialBasedCriterion`` (ex: ``MinTrials``). +) -> TransitionCriterion: + """Load TransitionCriterion from JSON. + + TransitionCriterion subclasses may contain nested Ax objects (like TrialStatus + enums and AuxiliaryExperimentPurpose) that need recursive deserialization via + object_from_json. We also use extract_init_args for backwards compatibility, + filtering to only valid constructor arguments. """ - new_dict = {} - for key, value in transition_criteria_json.items(): - new_val = object_from_json( + # Handle deprecated MinimumTrialsInStatus -> MinTrials conversion + if transition_criterion_class is MinTrials and "status" in object_json: + logger.warning( + "`MinimumTrialsInStatus` has been deprecated and removed. " + "Converting to `MinTrials` with equivalent functionality." + ) + status = object_from_json( + object_json=object_json.get("status"), + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + return MinTrials( + threshold=object_json.get("threshold"), + only_in_statuses=[status], + transition_to=object_json.get("transition_to"), + use_all_trials_in_exp=True, + ) + + decoded = { + key: object_from_json( object_json=value, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - new_dict[key] = new_val + for key, value in object_json.items() + } + + # filter to only valid constructor args (backwards compatibility) + init_args = extract_init_args(args=decoded, class_=transition_criterion_class) - return class_(**new_dict) + # pyre-ignore[45]: Class passed is always a concrete subclass. + return transition_criterion_class(**init_args) def search_space_from_json( @@ -806,6 +822,14 @@ def generation_node_from_json( # if needed during _validate_and_set_step_sequence. generation_node_json.pop("step_index", None) + # Backwards compatibility: For transition criteria with transition_to=None + # set transition_to to point to itself. + transition_criteria_json = generation_node_json.pop("transition_criteria") + if transition_criteria_json is not None: + for tc_json in transition_criteria_json: + if tc_json.get("transition_to") is None: + tc_json["transition_to"] = name + return GenerationNode( name=name, generator_specs=object_from_json( @@ -820,7 +844,7 @@ def generation_node_from_json( ), should_deduplicate=generation_node_json.pop("should_deduplicate", False), transition_criteria=object_from_json( - object_json=generation_node_json.pop("transition_criteria"), + object_json=transition_criteria_json, decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ), diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 837766d8b45..bfd6157f129 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -440,7 +440,7 @@ def generation_strategy_to_dict( def transition_criterion_to_dict(criterion: TransitionCriterion) -> dict[str, Any]: """Convert Ax TransitionCriterion to a dictionary.""" - properties = criterion.serialize_init_args(obj=criterion) + properties = serialize_init_args(obj=criterion) properties["__type"] = criterion.__class__.__name__ return properties diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index f7298b80832..7c16d97b7c8 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -78,7 +78,6 @@ IsSingleObjective, MaxGenerationParallelism, MinimumPreferenceOccurances, - MinimumTrialsInStatus, MinTrials, TransitionCriterion, ) @@ -219,7 +218,6 @@ MaxGenerationParallelism: transition_criterion_to_dict, Metric: metric_to_dict, MinTrials: transition_criterion_to_dict, - MinimumTrialsInStatus: transition_criterion_to_dict, MinimumPreferenceOccurances: transition_criterion_to_dict, AuxiliaryExperimentCheck: transition_criterion_to_dict, GeneratorSpec: generator_spec_to_dict, @@ -347,7 +345,8 @@ "MaxGenerationParallelism": MaxGenerationParallelism, "Metric": Metric, "MinTrials": MinTrials, - "MinimumTrialsInStatus": MinimumTrialsInStatus, + # DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials + "MinimumTrialsInStatus": MinTrials, "MinimumPreferenceOccurances": MinimumPreferenceOccurances, "GeneratorRegistryBase": GeneratorRegistryBase, "ModelRegistryBase": GeneratorRegistryBase, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 37464ebe0df..9bd36c4f097 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -503,6 +503,11 @@ def test_EncodeDecode(self) -> None: # _step_index is non-persistent state, unset on original to match # the decoded object which will have _step_index=None. original_object._step_index = None + # Transition_to is set during decode for backwards + # compatibility. Update original to match decoded. + for tc in original_object.transition_criteria: + if tc.transition_to is None: + tc._transition_to = original_object.name if isinstance(original_object, torch.nn.Module): self.assertIsInstance( converted_object, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index a449e4050ad..805aa04dec2 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -54,9 +54,18 @@ ) from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy +from ax.generation_strategy.transition_criterion import MaxGenerationParallelism from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.metrics.branin import BraninMetric from ax.runners.synthetic import SyntheticRunner +from ax.storage.json_store.decoder import ( + generation_node_from_json, + transition_criterion_from_json, +) +from ax.storage.json_store.registry import ( + CORE_CLASS_DECODER_REGISTRY, + CORE_DECODER_REGISTRY, +) from ax.storage.metric_registry import CORE_METRIC_REGISTRY, register_metrics from ax.storage.registry_bundle import RegistryBundle from ax.storage.runner_registry import CORE_RUNNER_REGISTRY, register_runner @@ -3311,3 +3320,74 @@ def test_load_candidate_source_auxiliary_experiments(self) -> None: self.assertIsNotNone(tl_metadata_0.overlap_parameters) # Should have 2 overlapping parameters (w and x) self.assertEqual(len(none_throws(tl_metadata_0.overlap_parameters)), 2) + + def test_transition_criterion_deserialize_with_extra_fields(self) -> None: + """Test that deserialization gracefully handles extra/unknown fields + ie this validates that backwards compatibility is maintained""" + # Simulate old serialized format with extra fields that no longer exist + old_format_json = { + "threshold": 5, + "only_in_statuses": [{"__type": "TrialStatus", "name": "RUNNING"}], + "not_in_statuses": None, + "transition_to": "test_node", + "block_gen_if_met": True, + "block_transition_if_unmet": False, + "use_all_trials_in_exp": False, + "continue_trial_generation": False, + "some_deprecated_field": "should_be_ignored", + } + + # Should not raise, extra field should be ignored + criterion = assert_is_instance( + transition_criterion_from_json( + transition_criterion_class=MaxGenerationParallelism, + object_json=old_format_json, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ), + MaxGenerationParallelism, + ) + self.assertEqual(criterion.threshold, 5) + self.assertEqual(criterion.transition_to, "test_node") + + def test_gen_node_deserialize_with_tc_transition_to_none( + self, + ) -> None: + """Test backwards compatibility when loading a MaxGenerationParallelism + that was stored with transition_to=None + """ + old_format_node_json = { + "__type": "GenerationNode", + "name": "test_node", + "generator_specs": [ + { + "__type": "GeneratorSpec", + "generator_enum": {"__type": "Generators", "name": "SOBOL"}, + "generator_kwargs": {}, + "generator_gen_kwargs": {}, + } + ], + "transition_criteria": [ + { + "__type": "MaxGenerationParallelism", + "threshold": 3, + "only_in_statuses": [{"__type": "TrialStatus", "name": "RUNNING"}], + "transition_to": None, # Old default + } + ], + } + + node = generation_node_from_json( + generation_node_json=old_format_node_json, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + self.assertEqual(node.name, "test_node") + self.assertEqual(len(node.transition_criteria), 1) + criterion = assert_is_instance( + node.transition_criteria[0], + MaxGenerationParallelism, + ) + self.assertEqual(criterion.threshold, 3) + # transition_to should now be set to the node name (pointing to itself) + self.assertEqual(criterion.transition_to, "test_node") diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index a531cce849e..5a6b2371c8b 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -87,9 +87,10 @@ ) from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( + AutoTransitionAfterGen, MaxGenerationParallelism, MinTrials, - TrialBasedCriterion, + TransitionCriterion, ) from ax.generators.torch.botorch_modular.acquisition import Acquisition from ax.generators.torch.botorch_modular.generator import BoTorchGenerator @@ -177,7 +178,7 @@ def get_experiment_with_map_data_type() -> Experiment: return experiment -def get_trial_based_criterion() -> list[TrialBasedCriterion]: +def get_trial_based_criterion() -> list[TransitionCriterion]: return [ MinTrials( threshold=3, @@ -190,6 +191,10 @@ def get_trial_based_criterion() -> list[TrialBasedCriterion]: not_in_statuses=[ TrialStatus.RUNNING, ], + transition_to="Sobol", + ), + AutoTransitionAfterGen( + transition_to="next_node", ), ] diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index a9fd730b931..a30ff3c54ee 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -228,10 +228,11 @@ def sobol_gpei_generation_node_gs( ), MaxGenerationParallelism( threshold=1000, - transition_to=None, + transition_to="MBM_node", block_gen_if_met=True, only_in_statuses=[TrialStatus.RUNNING], not_in_statuses=None, + continue_trial_generation=False, ), ] auto_mbm_criterion = [AutoTransitionAfterGen(transition_to="MBM_node")]