Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,9 @@ def __new__(
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
transition_to=None, # Re-set in GS constructor.
# MaxParallelism transitions to self,
# this will be confirmed in GS init
transition_to=f"GenerationStep_{str(index)}",
)
)

Expand Down
30 changes: 13 additions & 17 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,26 +384,22 @@ def _validate_and_set_step_sequence(self, steps: list[GenerationNode]) -> None:
step._step_index = idx
step._generation_strategy = self

# Set transition_to field for all but the last step, which transitions
# to itself.
# transition_to = step.name
try:
next_step = steps[idx + 1]
transition_to = GEN_STEP_NAME.format(
# Determine transition_to for steps, last step will transition to self
is_last_step = idx == len(steps) - 1
next_step_name = (
step.name
if is_last_step
else GEN_STEP_NAME.format(
step_index=idx + 1,
generator_name=next_step.generator_name,
generator_name=steps[idx + 1].generator_name,
)
# for transition_criteria in step.transition_criteria:
# if (
# transition_criteria.criterion_class
# != "MaxGenerationParallelism"
# ):
# transition_criteria._transition_to = transition_to
except IndexError: # Last step, steps[idx+1] was not possible
transition_to = step.name
)
for tc in step.transition_criteria:
if tc.criterion_class != "MaxGenerationParallelism":
tc._transition_to = transition_to
if tc.criterion_class == "MaxGenerationParallelism":
# MaxGenerationParallelism transitions to self (current step)
tc._transition_to = step.name
else:
tc._transition_to = next_step_name
self._curr = steps[0]

def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:
Expand Down
1 change: 1 addition & 0 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ def test_gs_setup_with_nodes(self) -> None:
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
transition_to="node_1",
),
]
node_1 = GenerationNode(
Expand Down
19 changes: 5 additions & 14 deletions ax/generation_strategy/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
IsSingleObjective,
MaxGenerationParallelism,
MinimumPreferenceOccurances,
MinimumTrialsInStatus,
MinTrials,
)
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -247,6 +246,7 @@ def test_default_step_criterion_setup(self) -> None:
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
transition_to="GenerationStep_1_BoTorch",
),
]
step_2_expected_transition_criteria = []
Expand Down Expand Up @@ -525,16 +525,6 @@ def test_repr(self) -> None:
+ "'transition_to': None, 'block_gen_if_met': False, "
"'block_transition_if_unmet': True})",
)
deprecated_min_trials_criterion = MinimumTrialsInStatus(
status=TrialStatus.COMPLETED, threshold=3
)
self.assertEqual(
str(deprecated_min_trials_criterion),
"MinimumTrialsInStatus({"
+ "'status': <enum 'TrialStatus'>.COMPLETED, "
+ "'threshold': 3, "
+ "'transition_to': None})",
)
max_parallelism = MaxGenerationParallelism(
only_in_statuses=[TrialStatus.EARLY_STOPPED],
threshold=3,
Expand All @@ -545,14 +535,15 @@ def test_repr(self) -> None:
)
self.assertEqual(
str(max_parallelism),
"MaxGenerationParallelism({'threshold': 3, 'only_in_statuses': "
"MaxGenerationParallelism({'threshold': 3, "
+ "'transition_to': 'GenerationStep_2', "
+ "'only_in_statuses': "
+ "[<enum 'TrialStatus'>.EARLY_STOPPED], "
+ "'not_in_statuses': [<enum 'TrialStatus'>.FAILED], "
+ "'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': True})",
+ "'continue_trial_generation': False})",
)
auto_transition = AutoTransitionAfterGen(transition_to="GenerationStep_2")
self.assertEqual(
Expand Down
40 changes: 4 additions & 36 deletions ax/generation_strategy/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ax.generation_strategy.generation_node import GenerationNode

from ax.utils.common.base import SortableBase
from ax.utils.common.serialization import SerializationMixin, serialize_init_args
from ax.utils.common.serialization import serialize_init_args
from pyre_extensions import none_throws


Expand All @@ -32,7 +32,7 @@
)


class TransitionCriterion(SortableBase, SerializationMixin):
class TransitionCriterion(SortableBase):
"""
Simple class to describe a condition which must be met for this GenerationNode to
take an action such as generation, transition, etc.
Expand Down Expand Up @@ -424,13 +424,13 @@ class MaxGenerationParallelism(TrialBasedCriterion):
def __init__(
self,
threshold: int,
transition_to: str,
only_in_statuses: list[TrialStatus] | None = None,
not_in_statuses: list[TrialStatus] | None = None,
transition_to: str | None = None,
block_transition_if_unmet: bool | None = False,
block_gen_if_met: bool | None = True,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = True,
continue_trial_generation: bool | None = False,
) -> None:
super().__init__(
threshold=threshold,
Expand Down Expand Up @@ -732,35 +732,3 @@ def block_continued_generation_error(
f"This criterion, {self.criterion_class} has been met but cannot "
"continue generation from its associated GenerationNode."
)


# TODO: Deprecate once legacy usecase is updated
class MinimumTrialsInStatus(TransitionCriterion):
"""
Deprecated and replaced with more flexible MinTrials criterion.
"""

def __init__(
self,
status: TrialStatus,
threshold: int,
transition_to: str | None = None,
) -> None:
self.status = status
self.threshold = threshold
super().__init__(transition_to=transition_to)

def is_met(
self,
experiment: Experiment,
curr_node: GenerationNode,
) -> bool:
return len(experiment.trial_indices_by_status[self.status]) >= self.threshold

def block_continued_generation_error(
self,
node_name: str | None,
experiment: Experiment | None,
trials_from_node: set[int],
) -> None:
pass
6 changes: 4 additions & 2 deletions ax/orchestration/tests/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ class TestAxOrchestrator(TestCase):
"GenerationNode(name='GenerationStep_1_BoTorch', "
"generator_specs=[GeneratorSpec(generator_enum=BoTorch, "
"generator_key_override=None)], "
"transition_criteria=[MaxGenerationParallelism(transition_to='None')])]), "
"transition_criteria=[MaxGenerationParallelism("
"transition_to='GenerationStep_1_BoTorch')])]), "
"options=OrchestratorOptions(max_pending_trials=10, "
"trial_type=<TrialType.TRIAL: 0>, batch_size=None, "
"total_trials=0, tolerated_trial_failure_rate=0.2, "
Expand Down Expand Up @@ -2857,7 +2858,8 @@ class TestAxOrchestratorMultiTypeExperiment(TestAxOrchestrator):
"GenerationNode(name='GenerationStep_1_BoTorch', "
"generator_specs=[GeneratorSpec(generator_enum=BoTorch, "
"generator_key_override=None)], "
"transition_criteria=[MaxGenerationParallelism(transition_to='None')])]), "
"transition_criteria="
"[MaxGenerationParallelism(transition_to='GenerationStep_1_BoTorch')])]), "
"options=OrchestratorOptions(max_pending_trials=10, "
"trial_type=<TrialType.TRIAL: 0>, batch_size=None, "
"total_trials=0, tolerated_trial_failure_rate=0.2, "
Expand Down
90 changes: 57 additions & 33 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@
GenerationStrategy,
)
from ax.generation_strategy.generator_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import (
AuxiliaryExperimentCheck,
TransitionCriterion,
TrialBasedCriterion,
)
from ax.generation_strategy.transition_criterion import MinTrials, TransitionCriterion
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
from ax.generators.torch.botorch_modular.utils import ModelConfig
Expand All @@ -63,6 +59,7 @@
from ax.storage.utils import data_by_trial_to_data
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import (
extract_init_args,
SerializationMixin,
TClassDecoderRegistry,
TDecoderRegistry,
Expand Down Expand Up @@ -280,16 +277,16 @@ def object_from_json(
object_json["outcome_transform_options"] = (
outcome_transform_options_json
)
elif isclass(_class) and (
issubclass(_class, TrialBasedCriterion)
or issubclass(_class, AuxiliaryExperimentCheck)
elif (
isclass(_class)
and issubclass(_class, TransitionCriterion)
and _class is not TransitionCriterion # TransitionCriterion is abstract
):
# TrialBasedCriterion contains a list of `TrialStatus` for args.
# AuxiliaryExperimentCheck contains AuxiliaryExperimentPurpose objects
# They need to be unpacked by hand to properly retain the types.
return unpack_transition_criteria_from_json(
class_=_class,
transition_criteria_json=object_json,
# TransitionCriterion may contain nested Ax objects (TrialStatus, etc.)
# that need recursive deserialization via object_from_json.
return transition_criterion_from_json(
transition_criterion_class=_class,
object_json=object_json,
**vars(registry_kwargs),
)
elif isclass(_class) and issubclass(_class, SerializationMixin):
Expand Down Expand Up @@ -430,32 +427,51 @@ def generator_run_from_json(
return generator_run


def unpack_transition_criteria_from_json(
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
class_: type,
transition_criteria_json: dict[str, Any],
def transition_criterion_from_json(
transition_criterion_class: type[TransitionCriterion],
object_json: dict[str, Any],
decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY,
class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY,
) -> TransitionCriterion | None:
"""Load Ax transition criteria that depend on Trials from JSON.

Since ``TrialBasedCriterion`` contain lists of ``TrialStatus``,
the json for these criterion needs to be carefully unpacked and
re-processed via ``object_from_json`` in order to maintain correct
typing. We pass in ``class_`` in order to correctly handle all classes
which inherit from ``TrialBasedCriterion`` (ex: ``MinTrials``).
) -> TransitionCriterion:
"""Load TransitionCriterion from JSON.

TransitionCriterion subclasses may contain nested Ax objects (like TrialStatus
enums and AuxiliaryExperimentPurpose) that need recursive deserialization via
object_from_json. We also use extract_init_args for backwards compatibility,
filtering to only valid constructor arguments.
"""
new_dict = {}
for key, value in transition_criteria_json.items():
new_val = object_from_json(
# Handle deprecated MinimumTrialsInStatus -> MinTrials conversion
if transition_criterion_class is MinTrials and "status" in object_json:
logger.warning(
"`MinimumTrialsInStatus` has been deprecated and removed. "
"Converting to `MinTrials` with equivalent functionality."
)
status = object_from_json(
object_json=object_json.get("status"),
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
return MinTrials(
threshold=object_json.get("threshold"),
only_in_statuses=[status],
transition_to=object_json.get("transition_to"),
use_all_trials_in_exp=True,
)

decoded = {
key: object_from_json(
object_json=value,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
new_dict[key] = new_val
for key, value in object_json.items()
}

# filter to only valid constructor args (backwards compatibility)
init_args = extract_init_args(args=decoded, class_=transition_criterion_class)

return class_(**new_dict)
# pyre-ignore[45]: Class passed is always a concrete subclass.
return transition_criterion_class(**init_args)


def search_space_from_json(
Expand Down Expand Up @@ -806,6 +822,14 @@ def generation_node_from_json(
# if needed during _validate_and_set_step_sequence.
generation_node_json.pop("step_index", None)

# Backwards compatibility: For transition criteria with transition_to=None
# set transition_to to point to itself.
transition_criteria_json = generation_node_json.pop("transition_criteria")
if transition_criteria_json is not None:
for tc_json in transition_criteria_json:
if tc_json.get("transition_to") is None:
tc_json["transition_to"] = name

return GenerationNode(
name=name,
generator_specs=object_from_json(
Expand All @@ -820,7 +844,7 @@ def generation_node_from_json(
),
should_deduplicate=generation_node_json.pop("should_deduplicate", False),
transition_criteria=object_from_json(
object_json=generation_node_json.pop("transition_criteria"),
object_json=transition_criteria_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
),
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/json_store/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def generation_strategy_to_dict(

def transition_criterion_to_dict(criterion: TransitionCriterion) -> dict[str, Any]:
"""Convert Ax TransitionCriterion to a dictionary."""
properties = criterion.serialize_init_args(obj=criterion)
properties = serialize_init_args(obj=criterion)
properties["__type"] = criterion.__class__.__name__
return properties

Expand Down
5 changes: 2 additions & 3 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
IsSingleObjective,
MaxGenerationParallelism,
MinimumPreferenceOccurances,
MinimumTrialsInStatus,
MinTrials,
TransitionCriterion,
)
Expand Down Expand Up @@ -219,7 +218,6 @@
MaxGenerationParallelism: transition_criterion_to_dict,
Metric: metric_to_dict,
MinTrials: transition_criterion_to_dict,
MinimumTrialsInStatus: transition_criterion_to_dict,
MinimumPreferenceOccurances: transition_criterion_to_dict,
AuxiliaryExperimentCheck: transition_criterion_to_dict,
GeneratorSpec: generator_spec_to_dict,
Expand Down Expand Up @@ -347,7 +345,8 @@
"MaxGenerationParallelism": MaxGenerationParallelism,
"Metric": Metric,
"MinTrials": MinTrials,
"MinimumTrialsInStatus": MinimumTrialsInStatus,
# DEPRECATED; backward compatibility for MinimumTrialsInStatus -> MinTrials
"MinimumTrialsInStatus": MinTrials,
"MinimumPreferenceOccurances": MinimumPreferenceOccurances,
"GeneratorRegistryBase": GeneratorRegistryBase,
"ModelRegistryBase": GeneratorRegistryBase,
Expand Down
5 changes: 5 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ def test_EncodeDecode(self) -> None:
# _step_index is non-persistent state, unset on original to match
# the decoded object which will have _step_index=None.
original_object._step_index = None
# Transition_to is set during decode for backwards
# compatibility. Update original to match decoded.
for tc in original_object.transition_criteria:
if tc.transition_to is None:
tc._transition_to = original_object.name
if isinstance(original_object, torch.nn.Module):
self.assertIsInstance(
converted_object,
Expand Down
Loading