From 38af3c55e235d96719029cc865a4114a4258398d Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Sun, 18 May 2025 15:48:53 +1200 Subject: [PATCH 1/7] Implement CustomSerializer support --- labtech/serialization.py | 121 ++++++++++++++++++++++++++++++++------- 1 file changed, 101 insertions(+), 20 deletions(-) diff --git a/labtech/serialization.py b/labtech/serialization.py index c907f05..55a4639 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -1,8 +1,9 @@ """Serialization/deserialization of tasks to/from JSON.""" +from abc import ABC, abstractmethod from dataclasses import fields from enum import Enum -from typing import Optional, Type, Union, cast +from typing import Any, Optional, Sequence, Type, Union, cast from frozendict import frozendict @@ -16,7 +17,95 @@ dict[str, 'jsonable'], list['jsonable']] +class CustomSerializer(ABC): + """Base class for custom serializers that can convert complex + objects into JSON-compatible representations.""" + + @abstractmethod + def handles(self, value: Any) -> bool: + """Returns True if value should be serialized by this + serializer.""" + + @abstractmethod + def serialize(self, serializer: 'Serializer', value: Any) -> jsonable: + """Convert value into a JSON-compatible representation + composed only of dictionaries, lists, strings, numbers and + `None`. + + Also receives the full Serializer, which can be used to call + `serializer.serialize_value()` to serialize nested elements + within value.""" + + @abstractmethod + def deserialize(self, serializer: 'Serializer', serialized: jsonable) -> Any: + """Convert the serialized representation returned by + serialize() back into the original value. + + Also receives the full Serializer, which can be used to call + `serializer.deserialize_value()` to deserialize nested elements + within serialized.""" + + class Serializer: + """Serializer for producing serialized JSON representations of + Task objects, and deserializing JSON back into Task objects.""" + + def __init__(self, custom_serializer_classes: Optional[Sequence[Type[CustomSerializer]]] = None): + """ + Args: + custom_serializer_classes: A list of classes that inherit from + [CustomSerializer][labtech.serialization.CustomSerializer] that + extend the types of task parameters that can be serialized. When + a value is serialized, the `handle()` method of an instance of + each custom_serializer_class is called in the provided order to + determine whether it should be used to serialize that value. + Custom serializers are applied before default serialization is + applied. + + """ + self.custom_serializers = [ + custom_serializer_class() for custom_serializer_class + in ([] if custom_serializer_classes is None else custom_serializer_classes) + ] + + def _is_serialized_custom(self, serialized: jsonable) -> bool: + return isinstance(serialized, dict) and bool(serialized.get('_is_custom', False)) + + def _serialize_custom(self, custom_serializer: CustomSerializer, value: Any) -> dict[str, jsonable]: + return { + '_is_custom': True, + '__class__': self.serialize_class(custom_serializer.__class__), + 'value': custom_serializer.serialize( + serializer=self, + value=value, + ), + } + + def _deserialize_custom(self, serialized: dict[str, jsonable]) -> Any: + if not self._is_serialized_custom(serialized): + raise SerializationError(("deserialize_custom() must be called with a " + f"serialized custom value, received: '{serialized}'")) + + custom_serializer = self.deserialize_class(serialized['__class__'])() + return custom_serializer.deserialize_value( + serializer=self, + serialized=serialized['value'], + ) + + def _is_serialized_enum(self, serialized: jsonable) -> bool: + return isinstance(serialized, dict) and bool(serialized.get('_is_enum', False)) + + def _serialize_enum(self, value: Enum) -> jsonable: + return { + '_is_enum': True, + '__class__': self.serialize_class(value.__class__), + 'name': value.name, + } + + def _deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum: + enum_cls = self.deserialize_class(serialized['__class__']) + return enum_cls[serialized['name']] + def is_serialized_task(self, serialized: jsonable) -> bool: return isinstance(serialized, dict) and bool(serialized.get('_is_task', False)) @@ -69,7 +158,11 @@ def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Opti task._set_result_meta(result_meta) return task - def serialize_value(self, value) -> jsonable: + def serialize_value(self, value: Any) -> jsonable: + for custom_serializer in self.custom_serializers: + if custom_serializer.handles(value): + return self._serialize_custom(custom_serializer, value) + if is_task(value): return self.serialize_task(value) elif isinstance(value, tuple): @@ -80,7 +173,7 @@ def serialize_value(self, value) -> jsonable: for key, value in value.items() } elif isinstance(value, Enum): - return self.serialize_enum(value) + return self._serialize_enum(value) elif ((value is None) or isinstance(value, str) or isinstance(value, bool) @@ -91,7 +184,9 @@ def serialize_value(self, value) -> jsonable: "that your task's parameters only use supported types.")) def deserialize_value(self, value: jsonable): - if self.is_serialized_task(value): + if self._is_serialized_custom(value): + return self._deserialize_custom(cast(dict[str, jsonable], value)) + elif self.is_serialized_task(value): return self.deserialize_task(cast(dict[str, jsonable], value), result_meta=None) elif isinstance(value, list): return tuple([self.deserialize_value(item) for item in value]) @@ -100,24 +195,10 @@ def deserialize_value(self, value: jsonable): ensure_dict_key_str(k, exception_type=SerializationError): self.deserialize_value(v) for k, v in value.items() }) - elif self.is_serialized_enum(value): - return self.deserialize_enum(cast(dict[str, jsonable], value)) + elif self._is_serialized_enum(value): + return self._deserialize_enum(cast(dict[str, jsonable], value)) return value - def is_serialized_enum(self, serialized: jsonable) -> bool: - return isinstance(serialized, dict) and bool(serialized.get('_is_enum', False)) - - def serialize_enum(self, value: Enum) -> jsonable: - return { - '_is_enum': True, - '__class__': self.serialize_class(value.__class__), - 'name': value.name, - } - - def deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum: - enum_cls = self.deserialize_class(serialized['__class__']) - return enum_cls[serialized['name']] - def serialize_class(self, cls: Type) -> jsonable: return f'{cls.__module__}.{cls.__qualname__}' From 2eebaa83f17552015b3afe2d7e750ba68263539f Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Sun, 18 May 2025 17:54:36 +1200 Subject: [PATCH 2/7] Implement support for full custom parameter handlers. --- labtech/__init__.py | 2 + labtech/cache.py | 6 +-- labtech/exceptions.py | 4 ++ labtech/params.py | 69 ++++++++++++++++++++++++++++++ labtech/serialization.py | 92 +++++++++------------------------------- labtech/tasks.py | 33 +++++++++++++- labtech/types.py | 78 ++++++++++++++++++++++++++++++++++ 7 files changed, 207 insertions(+), 77 deletions(-) create mode 100644 labtech/params.py diff --git a/labtech/__init__.py b/labtech/__init__.py index 68e04dd..acd5655 100644 --- a/labtech/__init__.py +++ b/labtech/__init__.py @@ -34,6 +34,7 @@ def run(self): __version__ = '0.7.1' from .lab import Lab +from .params import param_handler from .tasks import task from .types import is_task, is_task_type from .utils import logger @@ -41,6 +42,7 @@ def run(self): __all__ = [ 'is_task_type', 'is_task', + 'param_handler', 'task', 'Lab', 'logger', diff --git a/labtech/cache.py b/labtech/cache.py index b685ca5..9fe264e 100644 --- a/labtech/cache.py +++ b/labtech/cache.py @@ -10,8 +10,8 @@ from . import __version__ as labtech_version from .exceptions import CacheError, TaskNotFound -from .serialization import Serializer -from .types import Cache, ResultMeta, ResultT, Storage, Task, TaskResult, TaskT +from .serialization import DefaultSerializer +from .types import Cache, ResultMeta, ResultT, Serializer, Storage, Task, TaskResult, TaskT class NullCache(Cache): @@ -50,7 +50,7 @@ class BaseCache(Cache): METADATA_FILENAME = 'metadata.json' def __init__(self, *, serializer: Optional[Serializer] = None): - self.serializer = serializer or Serializer() + self.serializer = serializer or DefaultSerializer() def cache_key(self, task: Task) -> str: serialized_str = json.dumps(self.serializer.serialize_task(task)).encode('utf-8') diff --git a/labtech/exceptions.py b/labtech/exceptions.py index b67be65..73dc674 100644 --- a/labtech/exceptions.py +++ b/labtech/exceptions.py @@ -24,6 +24,10 @@ class TaskError(LabtechError): """Raised for failures when handling Task objects.""" +class ParamHandlerError(LabtechError): + """Raised for failures in custom parameter handlers.""" + + class StorageError(LabtechError): """Raised for failures when interacting with Storage objects.""" diff --git a/labtech/params.py b/labtech/params.py new file mode 100644 index 0000000..0812dc8 --- /dev/null +++ b/labtech/params.py @@ -0,0 +1,69 @@ +from inspect import isclass +from typing import TypedDict + +from .exceptions import ParamHandlerError +from .types import ParamHandler + + +class ParamHandlerEntry(TypedDict): + handler: ParamHandler + priority: int + + +CUSTOM_PARAM_HANDLER_ENTRIES: dict[str, ParamHandlerEntry] = {} +CUSTOM_PARAM_HANDLERS: list[ParamHandler] = [] + + +def param_handler(*args, priority: int = 1000): + """Class decorator for declaring custom parameter handlers that + can define how Labtech should handle the processing, + serialization, and deserialization of additional parameter types. + + Defining a custom parameter handler is an advanced feature of + Labtech, and you are responsible for ensuring: + + * The decorated class implements all methods of the + [`ParamHandler`][labtech.types.ParamHandler] protocol. + * To ensure tasks are reproducible, you should only define + handlers for customer parameter types that are **immutable**. + * Because tasks are hashable representations of their parameters, + you should only define handlers for customer parameter types that + are **hashable**. + * Because serialized parameters will reference the module path and + class name of the custom parameter handler that was used to + serialize them, you should avoid moving or renaming custom + parameter handlers once they are in use. + + Args: + priority: Determines the order in which custom parameter handlers are + applied when processing a parameter value. Lower priority values + are applied first. + + """ + + def decorator(cls): + global CUSTOM_PARAM_HANDLERS + + if not isinstance(cls, ParamHandler): + raise ParamHandlerError( + (f"Cannot register '{cls.__qualname__}' as a custom parameter handler, " + "as it does not implement all methods of the 'ParamHandler' protocol.") + ) + + CUSTOM_PARAM_HANDLER_ENTRIES[f'{cls.__module__}.{cls.__qualname__}'] = ParamHandlerEntry( + handler=cls(), + priority=priority, + ) + CUSTOM_PARAM_HANDLERS = [ + entry['handler'] for entry in + # Sort param handlers by priority, keeping insertion order + # where priorities are equal. + sorted(CUSTOM_PARAM_HANDLERS.values(), key=lambda entry: entry['priority']) + ] + + return cls + + if len(args) > 0 and isclass(args[0]): + return decorator(args[0], *args[1:]) + else: + return decorator diff --git a/labtech/serialization.py b/labtech/serialization.py index 55a4639..26bf17a 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -1,83 +1,30 @@ """Serialization/deserialization of tasks to/from JSON.""" -from abc import ABC, abstractmethod from dataclasses import fields from enum import Enum -from typing import Any, Optional, Sequence, Type, Union, cast +from typing import Any, Optional, Type, cast from frozendict import frozendict from .exceptions import SerializationError -from .types import ResultMeta, Task, is_task +from .params import CUSTOM_PARAM_HANDLERS +from .types import ParamHandler, ResultMeta, Serializer, Task, is_task, jsonable from .utils import ensure_dict_key_str -# Type to represent any value that can be handled by Python's default -# json encoder and decoder. -jsonable = Union[None, str, bool, float, int, - dict[str, 'jsonable'], list['jsonable']] - - -class CustomSerializer(ABC): - """Base class for custom serializers that can convert complex - objects into JSON-compatible representations.""" - - @abstractmethod - def handles(self, value: Any) -> bool: - """Returns True if value should be serialized by this - serializer.""" - - @abstractmethod - def serialize(self, serializer: 'Serializer', value: Any) -> jsonable: - """Convert value into a JSON-compatible representation - composed only of dictionaries, lists, strings, numbers and - `None`. - - Also receives the full Serializer, which can be used to call - `serializer.serialize_value()` to serialize nested elements - within value.""" - - @abstractmethod - def deserialize(self, serializer: 'Serializer', serialized: jsonable) -> Any: - """Convert the serialized representation returned by - serialize() back into the original value. - - Also receives the full Serializer, which can be used to call - `serializer.deserialize_value()` to deserialize nested elements - within serialized.""" - - -class Serializer: - """Serializer for producing serialized JSON representations of - Task objects, and deserializing JSON back into Task objects.""" - - def __init__(self, custom_serializer_classes: Optional[Sequence[Type[CustomSerializer]]] = None): - """ - Args: - custom_serializer_classes: A list of classes that inherit from - [CustomSerializer][labtech.serialization.CustomSerializer] that - extend the types of task parameters that can be serialized. When - a value is serialized, the `handle()` method of an instance of - each custom_serializer_class is called in the provided order to - determine whether it should be used to serialize that value. - Custom serializers are applied before default serialization is - applied. - - """ - self.custom_serializers = [ - custom_serializer_class() for custom_serializer_class - in ([] if custom_serializer_classes is None else custom_serializer_classes) - ] + +class DefaultSerializer(Serializer): + """Default Serializer implementation.""" def _is_serialized_custom(self, serialized: jsonable) -> bool: return isinstance(serialized, dict) and bool(serialized.get('_is_custom', False)) - def _serialize_custom(self, custom_serializer: CustomSerializer, value: Any) -> dict[str, jsonable]: + def _serialize_custom(self, custom_param_handler: ParamHandler, value: Any) -> dict[str, jsonable]: return { '_is_custom': True, - '__class__': self.serialize_class(custom_serializer.__class__), - 'value': custom_serializer.serialize( - serializer=self, + '__class__': self.serialize_class(custom_param_handler.__class__), + 'value': custom_param_handler.serialize( value=value, + serializer=self, ), } @@ -86,10 +33,10 @@ def _deserialize_custom(self, serialized: dict[str, jsonable]) -> Any: raise SerializationError(("deserialize_custom() must be called with a " f"serialized custom value, received: '{serialized}'")) - custom_serializer = self.deserialize_class(serialized['__class__'])() - return custom_serializer.deserialize_value( - serializer=self, + custom_param_handler = self.deserialize_class(serialized['__class__'])() + return custom_param_handler.deserialize_value( serialized=serialized['value'], + serializer=self, ) def _is_serialized_enum(self, serialized: jsonable) -> bool: @@ -106,8 +53,7 @@ def _deserialize_enum(self, serialized: dict[str, jsonable]) -> Enum: enum_cls = self.deserialize_class(serialized['__class__']) return enum_cls[serialized['name']] - - def is_serialized_task(self, serialized: jsonable) -> bool: + def _is_serialized_task(self, serialized: jsonable) -> bool: return isinstance(serialized, dict) and bool(serialized.get('_is_task', False)) def serialize_task(self, task: Task) -> dict[str, jsonable]: @@ -133,7 +79,7 @@ def serialize_task(self, task: Task) -> dict[str, jsonable]: return serialized def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Optional[ResultMeta]) -> Task: - if not self.is_serialized_task(serialized): + if not self._is_serialized_task(serialized): raise SerializationError(("deserialize_task() must be called with a " f"serialized Task, received: '{serialized}'")) @@ -159,9 +105,9 @@ def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Opti return task def serialize_value(self, value: Any) -> jsonable: - for custom_serializer in self.custom_serializers: - if custom_serializer.handles(value): - return self._serialize_custom(custom_serializer, value) + for custom_param_handler in CUSTOM_PARAM_HANDLERS: + if custom_param_handler.handles(value): + return self._serialize_custom(custom_param_handler, value) if is_task(value): return self.serialize_task(value) @@ -186,7 +132,7 @@ def serialize_value(self, value: Any) -> jsonable: def deserialize_value(self, value: jsonable): if self._is_serialized_custom(value): return self._deserialize_custom(cast(dict[str, jsonable], value)) - elif self.is_serialized_task(value): + elif self._is_serialized_task(value): return self.deserialize_task(cast(dict[str, jsonable], value), result_meta=None) elif isinstance(value, list): return tuple([self.deserialize_value(item) for item in value]) diff --git a/labtech/tasks.py b/labtech/tasks.py index cc648d4..e05a8e8 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -1,7 +1,9 @@ """Utilities for defining tasks.""" +from collections.abc import Hashable from dataclasses import dataclass, fields from enum import Enum +from functools import partial from inspect import isclass from types import UnionType from typing import Any, Optional, Sequence, TypeAlias, Union, cast @@ -10,6 +12,7 @@ from .cache import NullCache, PickleCache from .exceptions import TaskError +from .params import CUSTOM_PARAM_HANDLERS from .types import Cache, LabContext, ResultMeta, ResultsMap, ResultT, Task, TaskInfo, is_task, is_task_type from .utils import ensure_dict_key_str @@ -32,7 +35,20 @@ class CacheDefault: def immutable_param_value(key: str, value: Any) -> Any: - """Converts a parameter value to an immutable equivalent that is hashable.""" + """Converts a parameter value to an immutable equivalent that is + hashable (so that the task itself is hashable to be stored in + sets).""" + # Any value handled by custom_param_handlers is expected to be + # immutable and hashable. + for custom_param_handler in CUSTOM_PARAM_HANDLERS: + if custom_param_handler.handles(value): + if not isinstance(value, Hashable): + raise TaskError( + (f"Type '{type(value).__qualname__}' in parameter value '{key}' is handled " + f"by '{type(custom_param_handler).__qualname__}', but is not hashable.") + ) + return value + if isinstance(value, list) or isinstance(value, tuple): return tuple(immutable_param_value(f'{key}[{i}]', item) for i, item in enumerate(value)) if isinstance(value, dict) or isinstance(value, frozendict): @@ -152,6 +168,9 @@ def task(*args, * Note: Mutable `list` and `dict` collections will be converted to immutable `tuple` and [`frozendict`](https://pypi.org/project/frozendict/) collections. + * Immutable and hashable values for which a + [custom parameter handler][labtech.params.param_handler] has been + registered. The task type is expected to define a `run()` method that takes no arguments (other than `self`). The `run()` method should execute @@ -207,6 +226,10 @@ def run(self): documentation of each runner backend for supported options. The implementation may make use of the task's parameter values. + Because serialized tasks will reference the module path and class + name of the task type, you should avoid moving or renaming task + types once they are in use. + Args: cache: The Cache that controls how task results are formatted for caching. Can be set to an instance of any @@ -289,6 +312,14 @@ def find_tasks_in_param(param_value: Any, searched_coll_ids: Optional[set[int]] if id(param_value) in searched_coll_ids: return [] + for custom_param_handler in CUSTOM_PARAM_HANDLERS: + if custom_param_handler.handles(param_value): + searched_coll_ids = searched_coll_ids | {id(param_value)} + return custom_param_handler.find_tasks( + value=param_value, + find_tasks_in_param=partial(find_tasks_in_param, searched_coll_ids=searched_coll_ids), + ) + if is_task(param_value): return [param_value] elif isinstance(param_value, list) or isinstance(param_value, tuple): diff --git a/labtech/types.py b/labtech/types.py index a2cfbcd..2726598 100644 --- a/labtech/types.py +++ b/labtech/types.py @@ -16,8 +16,15 @@ Sequence, Type, TypeVar, + Union, + runtime_checkable, ) +# Type to represent any value that can be handled by Python's default +# json encoder and decoder. +jsonable = Union[None, str, bool, float, int, + dict[str, 'jsonable'], list['jsonable']] + @dataclass(frozen=True) class TaskInfo: @@ -222,6 +229,77 @@ def delete(self, storage: Storage, task: Task) -> None: `storage`.""" +class Serializer(ABC): + """Serializer for producing serialized JSON representations of + Task objects, and deserializing JSON back into Task objects.""" + + @abstractmethod + def serialize_task(self, task: Task) -> dict[str, jsonable]: + """Convert the given task into a JSON-compatible + representation composed only of dictionaries, lists, strings, + numbers and `None`.""" + + @abstractmethod + def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Optional[ResultMeta]) -> Task: + """Convert the given serialized representation returned by + serialize_task() back into the original task.""" + + @abstractmethod + def serialize_value(self, value: Any) -> jsonable: + """Convert the given value into a JSON-compatible + representation composed only of dictionaries, lists, strings, + numbers and `None`.""" + + @abstractmethod + def deserialize_value(self, value: jsonable): + """Convert the given serialized representation returned by + serialize_value() back into the original value.""" + + @abstractmethod + def serialize_class(self, cls: Type) -> jsonable: + """Convert the given class into a string representation.""" + + @abstractmethod + def deserialize_class(self, serialized_class: jsonable) -> Type: + """Load the class named in the given serialized representation + returned by serialize_class().""" + + +@runtime_checkable +class ParamHandler(Protocol): + """Protocol for custom parameter handlers that can define how + Labtech should handle the processing, serialization, and + deserialization of additional parameter types.""" + + def handles(self, value: Any) -> bool: + """Returns True if the given parameter value should be handled + by this class.""" + + def find_tasks(self, value: Any, *, find_tasks_in_param: Callable[[Any], Sequence[Task]]) -> Sequence[Task]: + """Given a parameter value, return all tasks within it (not + including tasks within those tasks). + + The provided `find_tasks_in_param` should be called to find + tasks in anynested elements within the value.""" + + def serialize(self, value: Any, *, serializer: Serializer) -> jsonable: + """Convert the given parameter value into a JSON-compatible + representation composed only of dictionaries, lists, strings, + numbers and `None`. + + Also receives the full Serializer, which can be used to call + `serializer.serialize_value()` to serialize nested elements + within the value.""" + + def deserialize(self, serialized: jsonable, *, serializer: Serializer) -> Any: + """Convert the given serialized representation returned by + serialize() back into the original parameter value. + + Also receives the full Serializer, which can be used to call + `serializer.deserialize_value()` to deserialize nested elements + within the serialized representation.""" + + TaskMonitorInfoValue = datetime | str | int | float TaskMonitorInfoItem = TaskMonitorInfoValue | tuple[TaskMonitorInfoValue, str] TaskMonitorInfo = dict[str, TaskMonitorInfoItem] From 35610911010e2788b5e726c03f6ecaa493b1f0f1 Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Sun, 18 May 2025 21:02:00 +1200 Subject: [PATCH 3/7] Share custom parameter handlers to remote ray processes. --- labtech/exceptions.py | 5 ++++ labtech/params.py | 63 +++++++++++++++++++++++++++++++--------- labtech/runners/ray.py | 8 ++++- labtech/serialization.py | 17 ++++++----- labtech/tasks.py | 6 ++-- labtech/types.py | 9 ++++-- labtech/utils.py | 6 ++++ 7 files changed, 87 insertions(+), 27 deletions(-) diff --git a/labtech/exceptions.py b/labtech/exceptions.py index 73dc674..987a69a 100644 --- a/labtech/exceptions.py +++ b/labtech/exceptions.py @@ -28,6 +28,11 @@ class ParamHandlerError(LabtechError): """Raised for failures in custom parameter handlers.""" +class UnregisteredParamHandlerError(LabtechError): + """Raised when attempting to lookup a custom parameter handler + that is not registered.""" + + class StorageError(LabtechError): """Raised for failures when interacting with Storage objects.""" diff --git a/labtech/params.py b/labtech/params.py index 0812dc8..7cb873c 100644 --- a/labtech/params.py +++ b/labtech/params.py @@ -1,8 +1,9 @@ from inspect import isclass from typing import TypedDict -from .exceptions import ParamHandlerError +from .exceptions import ParamHandlerError, UnregisteredParamHandlerError from .types import ParamHandler +from .utils import fully_qualified_class_name class ParamHandlerEntry(TypedDict): @@ -10,8 +11,18 @@ class ParamHandlerEntry(TypedDict): priority: int -CUSTOM_PARAM_HANDLER_ENTRIES: dict[str, ParamHandlerEntry] = {} -CUSTOM_PARAM_HANDLERS: list[ParamHandler] = [] +_CUSTOM_PARAM_HANDLER_ENTRIES: dict[str, ParamHandlerEntry] = {} +_CUSTOM_PARAM_HANDLERS = [] + + +def _update_custom_param_handlers() -> None: + global _CUSTOM_PARAM_HANDLERS + _CUSTOM_PARAM_HANDLERS = [ + entry['handler'] for entry in + # Sort param handlers by priority, keeping insertion order + # where priorities are equal. + sorted(_CUSTOM_PARAM_HANDLER_ENTRIES.values(), key=lambda entry: entry['priority']) + ] def param_handler(*args, priority: int = 1000): @@ -25,10 +36,11 @@ def param_handler(*args, priority: int = 1000): * The decorated class implements all methods of the [`ParamHandler`][labtech.types.ParamHandler] protocol. * To ensure tasks are reproducible, you should only define - handlers for customer parameter types that are **immutable**. + handlers for custom parameter types that are **immutable and + composed only of immutable elements**. * Because tasks are hashable representations of their parameters, - you should only define handlers for customer parameter types that - are **hashable**. + you should only define handlers for custom parameter types that + are **hashable and composed only of hashable elements**. * Because serialized parameters will reference the module path and class name of the custom parameter handler that was used to serialize them, you should avoid moving or renaming custom @@ -42,7 +54,7 @@ class name of the custom parameter handler that was used to """ def decorator(cls): - global CUSTOM_PARAM_HANDLERS + global _CUSTOM_PARAM_HANDLERS if not isinstance(cls, ParamHandler): raise ParamHandlerError( @@ -50,16 +62,11 @@ def decorator(cls): "as it does not implement all methods of the 'ParamHandler' protocol.") ) - CUSTOM_PARAM_HANDLER_ENTRIES[f'{cls.__module__}.{cls.__qualname__}'] = ParamHandlerEntry( + _CUSTOM_PARAM_HANDLER_ENTRIES[fully_qualified_class_name(cls)] = ParamHandlerEntry( handler=cls(), priority=priority, ) - CUSTOM_PARAM_HANDLERS = [ - entry['handler'] for entry in - # Sort param handlers by priority, keeping insertion order - # where priorities are equal. - sorted(CUSTOM_PARAM_HANDLERS.values(), key=lambda entry: entry['priority']) - ] + _update_custom_param_handlers() return cls @@ -67,3 +74,31 @@ def decorator(cls): return decorator(args[0], *args[1:]) else: return decorator + + +def get_custom_param_handler_entries() -> dict[str, ParamHandlerEntry]: + return _CUSTOM_PARAM_HANDLER_ENTRIES + + +def set_custom_param_handler_entries(custom_param_handler_entries: dict[str, ParamHandlerEntry]) -> None: + global _CUSTOM_PARAM_HANDLER_ENTRIES + _CUSTOM_PARAM_HANDLER_ENTRIES = custom_param_handler_entries + _update_custom_param_handlers() + + +def get_custom_param_handlers() -> list[ParamHandler]: + return _CUSTOM_PARAM_HANDLERS + + +def lookup_custom_param_handler(fq_class_name: str) -> ParamHandler: + try: + entry = _CUSTOM_PARAM_HANDLER_ENTRIES[fq_class_name] + except KeyError: + raise UnregisteredParamHandlerError(fully_qualified_class_name) + return entry['handler'] + + +def clear_custom_param_handlers() -> None: + global _CUSTOM_PARAM_HANDLER_ENTRIES + _CUSTOM_PARAM_HANDLER_ENTRIES = {} + _update_custom_param_handlers() diff --git a/labtech/runners/ray.py b/labtech/runners/ray.py index 4cf8170..9e3c5d6 100644 --- a/labtech/runners/ray.py +++ b/labtech/runners/ray.py @@ -6,6 +6,7 @@ from typing import Iterator, Optional, Sequence from labtech.exceptions import RunnerError +from labtech.params import ParamHandlerEntry, get_custom_param_handler_entries, set_custom_param_handler_entries from labtech.tasks import get_direct_dependencies from labtech.types import LabContext, ResultMeta, ResultT, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult, is_task from labtech.utils import logger @@ -31,7 +32,8 @@ class TaskDetail: # arguments, even though they work. @ray.remote(num_returns=2) # type: ignore[arg-type] def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: bool, - context: LabContext, storage: Storage) -> tuple[ResultMeta, ResultT]: + context: LabContext, storage: Storage, + custom_param_handler_entries: dict[str, ParamHandlerEntry]) -> tuple[ResultMeta, ResultT]: # task_refs_args is expected to be a flattened list of (task, # result_meta, result_value) triples - passed this way to ensure # refs are top-level to trigger locality-aware scheduling: @@ -52,6 +54,8 @@ def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: b value=result_value, ) + set_custom_param_handler_entries(custom_param_handler_entries) + for dependency_task in get_direct_dependencies(task, all_identities=True): dependency_task._set_results_map(results_map) @@ -83,6 +87,7 @@ def __init__(self, *, context: LabContext, storage: Storage, logger.debug('Uploading context and storage objects to ray object store') self.context_ref = ray.put(context) self.storage_ref = ray.put(storage) + self.custom_param_handler_entries_ref = ray.put(get_custom_param_handler_entries()) logger.debug('Uploaded context and storage objects to ray object store') self.cancelled = False @@ -137,6 +142,7 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: use_cache=use_cache, context=self.context_ref, storage=self.storage_ref, + custom_param_handler_entries=self.custom_param_handler_entries_ref, ) ) result_meta_ref, result_value_ref = result_refs diff --git a/labtech/serialization.py b/labtech/serialization.py index 26bf17a..776f246 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -6,10 +6,10 @@ from frozendict import frozendict -from .exceptions import SerializationError -from .params import CUSTOM_PARAM_HANDLERS +from .exceptions import SerializationError, UnregisteredParamHandlerError +from .params import get_custom_param_handlers, lookup_custom_param_handler from .types import ParamHandler, ResultMeta, Serializer, Task, is_task, jsonable -from .utils import ensure_dict_key_str +from .utils import ensure_dict_key_str, fully_qualified_class_name class DefaultSerializer(Serializer): @@ -33,8 +33,11 @@ def _deserialize_custom(self, serialized: dict[str, jsonable]) -> Any: raise SerializationError(("deserialize_custom() must be called with a " f"serialized custom value, received: '{serialized}'")) - custom_param_handler = self.deserialize_class(serialized['__class__'])() - return custom_param_handler.deserialize_value( + try: + custom_param_handler = lookup_custom_param_handler(cast(str, serialized['__class__'])) + except UnregisteredParamHandlerError: + custom_param_handler = self.deserialize_class(serialized['__class__'])() + return custom_param_handler.deserialize( serialized=serialized['value'], serializer=self, ) @@ -105,7 +108,7 @@ def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Opti return task def serialize_value(self, value: Any) -> jsonable: - for custom_param_handler in CUSTOM_PARAM_HANDLERS: + for custom_param_handler in get_custom_param_handlers(): if custom_param_handler.handles(value): return self._serialize_custom(custom_param_handler, value) @@ -146,7 +149,7 @@ def deserialize_value(self, value: jsonable): return value def serialize_class(self, cls: Type) -> jsonable: - return f'{cls.__module__}.{cls.__qualname__}' + return fully_qualified_class_name(cls) def deserialize_class(self, serialized_class: jsonable) -> Type: cls_module, cls_name = cast(str, serialized_class).rsplit('.', 1) diff --git a/labtech/tasks.py b/labtech/tasks.py index e05a8e8..aac8298 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -12,7 +12,7 @@ from .cache import NullCache, PickleCache from .exceptions import TaskError -from .params import CUSTOM_PARAM_HANDLERS +from .params import get_custom_param_handlers from .types import Cache, LabContext, ResultMeta, ResultsMap, ResultT, Task, TaskInfo, is_task, is_task_type from .utils import ensure_dict_key_str @@ -40,7 +40,7 @@ def immutable_param_value(key: str, value: Any) -> Any: sets).""" # Any value handled by custom_param_handlers is expected to be # immutable and hashable. - for custom_param_handler in CUSTOM_PARAM_HANDLERS: + for custom_param_handler in get_custom_param_handlers(): if custom_param_handler.handles(value): if not isinstance(value, Hashable): raise TaskError( @@ -312,7 +312,7 @@ def find_tasks_in_param(param_value: Any, searched_coll_ids: Optional[set[int]] if id(param_value) in searched_coll_ids: return [] - for custom_param_handler in CUSTOM_PARAM_HANDLERS: + for custom_param_handler in get_custom_param_handlers(): if custom_param_handler.handles(param_value): searched_coll_ids = searched_coll_ids | {id(param_value)} return custom_param_handler.find_tasks( diff --git a/labtech/types.py b/labtech/types.py index 2726598..7ef9f2d 100644 --- a/labtech/types.py +++ b/labtech/types.py @@ -320,6 +320,11 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: effectively calling: ``` + # custom_param_handler_entries should be saved from + # labtech.params.get_custom_param_handler_entries() on the main process + # and set in any remote processes that don't inherit from the main process: + labtech.params.set_custom_param_handler_entries(custom_param_handler_entries) + for dependency_task in get_direct_dependencies(task, all_identities=True): # Where results_map is expected to contain the TaskResult for # each dependency_task. @@ -335,8 +340,8 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: return labtech.runners.base.run_or_load_task( task=task, use_cache=use_cache, - filtered_context=task.filter_context(self.context), - storage=self.storage, + filtered_context=task.filter_context(context), + storage=storage, ) finally: current_process.name = orig_process_name diff --git a/labtech/utils.py b/labtech/utils.py index 4203908..6b65bd1 100644 --- a/labtech/utils.py +++ b/labtech/utils.py @@ -117,6 +117,10 @@ def ensure_dict_key_str(value, *, exception_type: Type[Exception]) -> str: return cast(str, value) +def fully_qualified_class_name(cls: Type) -> str: + return f'{cls.__module__}.{cls.__qualname__}' + + def is_ipython() -> bool: return hasattr(builtins, '__IPYTHON__') @@ -132,10 +136,12 @@ class tqdm_notebook(base_tqdm_notebook): __all__ = [ + 'make_logger_handler', 'logger', 'OrderedSet', 'LoggerFileProxy', 'ensure_dict_key_str', + 'fully_qualified_class_name', 'is_ipython', 'tqdm', 'tqdm_notebook', From 39b45212f8e73c042c38795db6467d822a166839 Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Sun, 18 May 2025 21:02:14 +1200 Subject: [PATCH 4/7] Add unit tests for custom parameter handlers. --- tests/integration/test_e2e.py | 39 +++++++++++++-- tests/labtech/test_params.py | 90 +++++++++++++++++++++++++++++++++ tests/labtech/test_tasks.py | 94 +++++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+), 5 deletions(-) create mode 100644 tests/labtech/test_params.py diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index f302337..dc42296 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -1,6 +1,7 @@ """Test a set of tasks packed with usage of features end-to-end. Loosely based on tasks from the tutorial.""" +from datetime import datetime from tempfile import TemporaryDirectory from typing import Any, Protocol, TypedDict @@ -8,10 +9,33 @@ import ray import labtech +from labtech.params import clear_custom_param_handlers from labtech.runners.ray import RayRunnerBackend from labtech.types import Task +@pytest.fixture(autouse=True) +def datetime_param_handler(): + + @labtech.param_handler + class DatetimeParamHandler: + + def handles(self, value): + return isinstance(value, datetime) + + def find_tasks(self, value, *, find_tasks_in_param): + return [] + + def serialize(self, value, *, serializer): + return value.timestamp() + + def deserialize(self, serialized, *, serializer): + return datetime.fromtimestamp(serialized) + + yield + clear_custom_param_handlers() + + @labtech.task(cache=None) class ClassifierTask: n_estimators: int @@ -53,6 +77,7 @@ def run(self) -> dict: @labtech.task class WrappingExperiment(ExperimentTask): experiment: ExperimentTask + dt: datetime @property def dataset_key(self): @@ -60,7 +85,8 @@ def dataset_key(self): def run(self) -> dict: return { - 'inner_experiment': self.experiment.result + 'inner_experiment': self.experiment.result, + 'dt': self.dt, } @@ -92,6 +118,7 @@ class Evaluation(TypedDict): def basic_evaluation(context: dict[str, Any]) -> Evaluation: """Evaluation of a standard setup of multiple levels of dependency.""" + now = datetime.now() classifier_tasks = [ ClassifierTask( n_estimators=n_estimators, @@ -109,6 +136,7 @@ def basic_evaluation(context: dict[str, Any]) -> Evaluation: wrapping_experiments = [ WrappingExperiment( experiment=classifier_experiment, + dt=now, ) for classifier_experiment in classifier_experiments ] @@ -125,10 +153,10 @@ def basic_evaluation(context: dict[str, Any]) -> Evaluation: '2': {'dataset': 'aaa', 'classifier': {'n_estimators': 2}}, '3': {'dataset': 'bbb', 'classifier': {'n_estimators': 1}}, '4': {'dataset': 'bbb', 'classifier': {'n_estimators': 2}}, - '5': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 1}}}, - '6': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 2}}}, - '7': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 1}}}, - '8': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 2}}}, + '5': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 1}}, 'dt': now}, + '6': {'inner_experiment': {'dataset': 'aaa', 'classifier': {'n_estimators': 2}}, 'dt': now}, + '7': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 1}}, 'dt': now}, + '8': {'inner_experiment': {'dataset': 'bbb', 'classifier': {'n_estimators': 2}}, 'dt': now}, }, ) @@ -196,6 +224,7 @@ def test_e2e(self, max_workers: int, runner_backend: str, evaluation_key: str, c cached_result = lab.run_task(cached_tasks[0]) assert cached_result == evaluation['expected_result'] + class TestE2ERay: def setup_method(self, method): diff --git a/tests/labtech/test_params.py b/tests/labtech/test_params.py new file mode 100644 index 0000000..78b4732 --- /dev/null +++ b/tests/labtech/test_params.py @@ -0,0 +1,90 @@ +import pytest + +import labtech +from labtech.exceptions import ParamHandlerError +from labtech.params import clear_custom_param_handlers, get_custom_param_handlers + + +class TestParamHandler: + + def teardown_method(self, method): + clear_custom_param_handlers() + + def test_register(self): + + @labtech.param_handler + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return list(sorted(value, key=hash)) + + def deserialize(self, value, *, serializer): + return frozenset(value) + + assert [type(handler) for handler in get_custom_param_handlers()] == [ + FrozensetParamHandler, + ] + + def test_register_priority(self): + + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return list(sorted(value, key=hash)) + + def deserialize(self, value, *, serializer): + return frozenset(value) + + @labtech.param_handler(priority=2000) + class FrozensetParamHandlerOne(FrozensetParamHandler): + pass + + @labtech.param_handler + class FrozensetParamHandlerTwo(FrozensetParamHandler): + pass + + @labtech.param_handler + class FrozensetParamHandlerThree(FrozensetParamHandler): + pass + + @labtech.param_handler(priority=100) + class FrozensetParamHandlerFour(FrozensetParamHandler): + pass + + assert [type(handler) for handler in get_custom_param_handlers()] == [ + FrozensetParamHandlerFour, + FrozensetParamHandlerTwo, + FrozensetParamHandlerThree, + FrozensetParamHandlerOne, + ] + + def test_register_noncompliant(self): + with pytest.raises( + ParamHandlerError, match=( + "Cannot register 'TestParamHandler.test_register_noncompliant..CustomParamHandler' " + "as a custom parameter handler, as it does not implement all methods of the 'ParamHandler' protocol." + ), + ): + @labtech.param_handler + class CustomParamHandler: + pass diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index 763fdfc..622345f 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -9,6 +9,7 @@ import labtech.tasks from labtech.cache import BaseCache, NullCache, PickleCache from labtech.exceptions import TaskError +from labtech.params import clear_custom_param_handlers from labtech.tasks import _RESERVED_ATTRS, ParamScalar, find_tasks_in_param, immutable_param_value from labtech.types import ResultT, Storage, Task, TaskInfo @@ -262,6 +263,52 @@ def run(self) -> None: class TestImmutableParamValue: + + def setup_method(self, method): + + @labtech.param_handler + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return list(sorted(value, key=hash)) + + def deserialize(self, value, *, serializer): + return frozenset(value) + + @labtech.param_handler + class SetParamHandler: + """This is not a valid param handler, because sets are not + hashable.""" + + def handles(self, value): + return isinstance(value, set) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return list(sorted(value, key=hash)) + + def deserialize(self, value, *, serializer): + return set(value) + + def teardown_method(self, method): + clear_custom_param_handlers() + def test_empty_list(self) -> None: assert immutable_param_value("hello", []) == () @@ -306,6 +353,20 @@ def test_nested_list_dict(self) -> None: def test_scalar(self, scalar: ParamScalar) -> None: assert immutable_param_value("hello", scalar) is scalar + def test_custom_param(self) -> None: + example_frozenset = frozenset(['one', 2, frozenset([3, 'four'])]) + assert immutable_param_value("hello", example_frozenset) is example_frozenset + + def test_custom_param_unhashable(self, scalar: ParamScalar) -> None: + example_set = set(['one', 2, frozenset([3, 'four'])]) + with pytest.raises( + TaskError, match=( + "Type 'set' in parameter value 'hello' is handled by " + "'TestImmutableParamValue.setup_method..SetParamHandler', but is not hashable." + ) + ): + immutable_param_value("hello", example_set) + def test_unhandled(self) -> None: with pytest.raises( TaskError, match="Unsupported type '_BadObject' in parameter value 'hello'." @@ -321,6 +382,31 @@ def test_multiple_nested_error(self) -> None: class TestFindTasksInParam: + + def setup_method(self, method): + + @labtech.param_handler + class FrozensetParamHandler: + + def handles(self, value): + return isinstance(value, frozenset) + + def find_tasks(self, value, *, find_tasks_in_param): + return [ + task + for item in sorted(value, key=hash) + for task in find_tasks_in_param(item) + ] + + def serialize(self, value, *, serializer): + return list(sorted(value, key=hash)) + + def deserialize(self, value, *, serializer): + return frozenset(value) + + def teardown_method(self, method): + clear_custom_param_handlers() + def test_scalar(self, scalar: ParamScalar) -> None: assert find_tasks_in_param(scalar) == [] @@ -363,6 +449,14 @@ def test_searched_coll_ids(self) -> None: task2 ] + def test_custom_param_handler(self) -> None: + task1 = ExampleTask(1) + task2 = ExampleTask(2) + assert find_tasks_in_param(frozenset([1, task1, frozenset([task2, 2])])) == [ + task1, + task2, + ] + def test_unhandled(self) -> None: match = re.escape( "Unexpected type _BadObject encountered in task parameter value." From 2d57bcba1a24d5932cfae0ca59c0cd8d7e5a2040 Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Sun, 18 May 2025 21:36:56 +1200 Subject: [PATCH 5/7] Simplify custom parameter handlers to a single global variable --- labtech/params.py | 108 +++++++++++++++++----------------- labtech/runners/ray.py | 10 ++-- labtech/serialization.py | 6 +- labtech/tasks.py | 6 +- labtech/types.py | 8 +-- tests/integration/test_e2e.py | 4 +- tests/labtech/test_params.py | 8 +-- tests/labtech/test_tasks.py | 6 +- 8 files changed, 79 insertions(+), 77 deletions(-) diff --git a/labtech/params.py b/labtech/params.py index 7cb873c..5bdc314 100644 --- a/labtech/params.py +++ b/labtech/params.py @@ -1,5 +1,6 @@ +#from functools import cached_property from inspect import isclass -from typing import TypedDict +from typing import Optional, Type, TypedDict from .exceptions import ParamHandlerError, UnregisteredParamHandlerError from .types import ParamHandler @@ -11,18 +12,60 @@ class ParamHandlerEntry(TypedDict): priority: int -_CUSTOM_PARAM_HANDLER_ENTRIES: dict[str, ParamHandlerEntry] = {} -_CUSTOM_PARAM_HANDLERS = [] +class ParamHandlerManager: + def __init__(self) -> None: + self._entries: dict[str, ParamHandlerEntry] = {} + self._prioritised_handlers: Optional[list[ParamHandler]] = None -def _update_custom_param_handlers() -> None: - global _CUSTOM_PARAM_HANDLERS - _CUSTOM_PARAM_HANDLERS = [ - entry['handler'] for entry in - # Sort param handlers by priority, keeping insertion order - # where priorities are equal. - sorted(_CUSTOM_PARAM_HANDLER_ENTRIES.values(), key=lambda entry: entry['priority']) - ] + def register(self, cls: Type[ParamHandler], *, priority: int) -> None: + if not isinstance(cls, ParamHandler): + raise ParamHandlerError( + (f"Cannot register '{cls.__qualname__}' as a custom parameter handler, " + "as it does not implement all methods of the 'ParamHandler' protocol.") + ) + + self._entries[fully_qualified_class_name(cls)] = ParamHandlerEntry( + handler=cls(), + priority=priority, + ) + # Clear cache + self._prioritised_handlers = None + + def lookup(self, fq_class_name: str) -> ParamHandler: + try: + entry = self._entries[fq_class_name] + except KeyError: + raise UnregisteredParamHandlerError(fully_qualified_class_name) + return entry['handler'] + + def clear(self) -> None: + self._entries = {} + # Clear cache + self._prioritised_handlers = None + + @property + def prioritised_handlers(self) -> list[ParamHandler]: + if self._prioritised_handlers is None: + self._prioritised_handlers = [ + entry['handler'] for entry in + # Sort param handlers by priority, keeping insertion order + # where priorities are equal. + sorted(self._entries.values(), key=lambda entry: entry['priority']) + ] + return self._prioritised_handlers + + +_PARAM_HANDLER_MANAGER = ParamHandlerManager() + + +def get_param_handler_manager() -> ParamHandlerManager: + return _PARAM_HANDLER_MANAGER + + +def set_param_handler_manager(param_handler_manager: ParamHandlerManager) -> None: + global _PARAM_HANDLER_MANAGER + _PARAM_HANDLER_MANAGER = param_handler_manager def param_handler(*args, priority: int = 1000): @@ -54,51 +97,10 @@ class name of the custom parameter handler that was used to """ def decorator(cls): - global _CUSTOM_PARAM_HANDLERS - - if not isinstance(cls, ParamHandler): - raise ParamHandlerError( - (f"Cannot register '{cls.__qualname__}' as a custom parameter handler, " - "as it does not implement all methods of the 'ParamHandler' protocol.") - ) - - _CUSTOM_PARAM_HANDLER_ENTRIES[fully_qualified_class_name(cls)] = ParamHandlerEntry( - handler=cls(), - priority=priority, - ) - _update_custom_param_handlers() - + get_param_handler_manager().register(cls, priority=priority) return cls if len(args) > 0 and isclass(args[0]): return decorator(args[0], *args[1:]) else: return decorator - - -def get_custom_param_handler_entries() -> dict[str, ParamHandlerEntry]: - return _CUSTOM_PARAM_HANDLER_ENTRIES - - -def set_custom_param_handler_entries(custom_param_handler_entries: dict[str, ParamHandlerEntry]) -> None: - global _CUSTOM_PARAM_HANDLER_ENTRIES - _CUSTOM_PARAM_HANDLER_ENTRIES = custom_param_handler_entries - _update_custom_param_handlers() - - -def get_custom_param_handlers() -> list[ParamHandler]: - return _CUSTOM_PARAM_HANDLERS - - -def lookup_custom_param_handler(fq_class_name: str) -> ParamHandler: - try: - entry = _CUSTOM_PARAM_HANDLER_ENTRIES[fq_class_name] - except KeyError: - raise UnregisteredParamHandlerError(fully_qualified_class_name) - return entry['handler'] - - -def clear_custom_param_handlers() -> None: - global _CUSTOM_PARAM_HANDLER_ENTRIES - _CUSTOM_PARAM_HANDLER_ENTRIES = {} - _update_custom_param_handlers() diff --git a/labtech/runners/ray.py b/labtech/runners/ray.py index 9e3c5d6..a1351e0 100644 --- a/labtech/runners/ray.py +++ b/labtech/runners/ray.py @@ -6,7 +6,7 @@ from typing import Iterator, Optional, Sequence from labtech.exceptions import RunnerError -from labtech.params import ParamHandlerEntry, get_custom_param_handler_entries, set_custom_param_handler_entries +from labtech.params import ParamHandlerManager, get_param_handler_manager, set_param_handler_manager from labtech.tasks import get_direct_dependencies from labtech.types import LabContext, ResultMeta, ResultT, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult, is_task from labtech.utils import logger @@ -33,7 +33,7 @@ class TaskDetail: @ray.remote(num_returns=2) # type: ignore[arg-type] def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: bool, context: LabContext, storage: Storage, - custom_param_handler_entries: dict[str, ParamHandlerEntry]) -> tuple[ResultMeta, ResultT]: + param_handler_manager: ParamHandlerManager) -> tuple[ResultMeta, ResultT]: # task_refs_args is expected to be a flattened list of (task, # result_meta, result_value) triples - passed this way to ensure # refs are top-level to trigger locality-aware scheduling: @@ -54,7 +54,7 @@ def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: b value=result_value, ) - set_custom_param_handler_entries(custom_param_handler_entries) + set_param_handler_manager(param_handler_manager) for dependency_task in get_direct_dependencies(task, all_identities=True): dependency_task._set_results_map(results_map) @@ -87,7 +87,7 @@ def __init__(self, *, context: LabContext, storage: Storage, logger.debug('Uploading context and storage objects to ray object store') self.context_ref = ray.put(context) self.storage_ref = ray.put(storage) - self.custom_param_handler_entries_ref = ray.put(get_custom_param_handler_entries()) + self.param_handler_manager_ref = ray.put(get_param_handler_manager()) logger.debug('Uploaded context and storage objects to ray object store') self.cancelled = False @@ -142,7 +142,7 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: use_cache=use_cache, context=self.context_ref, storage=self.storage_ref, - custom_param_handler_entries=self.custom_param_handler_entries_ref, + param_handler_manager=self.param_handler_manager_ref, ) ) result_meta_ref, result_value_ref = result_refs diff --git a/labtech/serialization.py b/labtech/serialization.py index 776f246..a46b46d 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -7,7 +7,7 @@ from frozendict import frozendict from .exceptions import SerializationError, UnregisteredParamHandlerError -from .params import get_custom_param_handlers, lookup_custom_param_handler +from .params import get_param_handler_manager from .types import ParamHandler, ResultMeta, Serializer, Task, is_task, jsonable from .utils import ensure_dict_key_str, fully_qualified_class_name @@ -34,7 +34,7 @@ def _deserialize_custom(self, serialized: dict[str, jsonable]) -> Any: f"serialized custom value, received: '{serialized}'")) try: - custom_param_handler = lookup_custom_param_handler(cast(str, serialized['__class__'])) + custom_param_handler = get_param_handler_manager().lookup(cast(str, serialized['__class__'])) except UnregisteredParamHandlerError: custom_param_handler = self.deserialize_class(serialized['__class__'])() return custom_param_handler.deserialize( @@ -108,7 +108,7 @@ def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Opti return task def serialize_value(self, value: Any) -> jsonable: - for custom_param_handler in get_custom_param_handlers(): + for custom_param_handler in get_param_handler_manager().prioritised_handlers: if custom_param_handler.handles(value): return self._serialize_custom(custom_param_handler, value) diff --git a/labtech/tasks.py b/labtech/tasks.py index aac8298..7062d83 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -12,7 +12,7 @@ from .cache import NullCache, PickleCache from .exceptions import TaskError -from .params import get_custom_param_handlers +from .params import get_param_handler_manager from .types import Cache, LabContext, ResultMeta, ResultsMap, ResultT, Task, TaskInfo, is_task, is_task_type from .utils import ensure_dict_key_str @@ -40,7 +40,7 @@ def immutable_param_value(key: str, value: Any) -> Any: sets).""" # Any value handled by custom_param_handlers is expected to be # immutable and hashable. - for custom_param_handler in get_custom_param_handlers(): + for custom_param_handler in get_param_handler_manager().prioritised_handlers: if custom_param_handler.handles(value): if not isinstance(value, Hashable): raise TaskError( @@ -312,7 +312,7 @@ def find_tasks_in_param(param_value: Any, searched_coll_ids: Optional[set[int]] if id(param_value) in searched_coll_ids: return [] - for custom_param_handler in get_custom_param_handlers(): + for custom_param_handler in get_param_handler_manager().prioritised_handlers: if custom_param_handler.handles(param_value): searched_coll_ids = searched_coll_ids | {id(param_value)} return custom_param_handler.find_tasks( diff --git a/labtech/types.py b/labtech/types.py index 7ef9f2d..a11584b 100644 --- a/labtech/types.py +++ b/labtech/types.py @@ -320,10 +320,10 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: effectively calling: ``` - # custom_param_handler_entries should be saved from - # labtech.params.get_custom_param_handler_entries() on the main process - # and set in any remote processes that don't inherit from the main process: - labtech.params.set_custom_param_handler_entries(custom_param_handler_entries) + # param_handler_manager should be saved from + # labtech.params.get_param_handler_manager() on the main process + # and set in remote processes that don't inherit from the main process: + labtech.params.set_param_handler_manager(param_handler_manager) for dependency_task in get_direct_dependencies(task, all_identities=True): # Where results_map is expected to contain the TaskResult for diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index dc42296..14d4731 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -9,7 +9,7 @@ import ray import labtech -from labtech.params import clear_custom_param_handlers +from labtech.params import get_param_handler_manager from labtech.runners.ray import RayRunnerBackend from labtech.types import Task @@ -33,7 +33,7 @@ def deserialize(self, serialized, *, serializer): return datetime.fromtimestamp(serialized) yield - clear_custom_param_handlers() + get_param_handler_manager().clear() @labtech.task(cache=None) diff --git a/tests/labtech/test_params.py b/tests/labtech/test_params.py index 78b4732..f8ad192 100644 --- a/tests/labtech/test_params.py +++ b/tests/labtech/test_params.py @@ -2,13 +2,13 @@ import labtech from labtech.exceptions import ParamHandlerError -from labtech.params import clear_custom_param_handlers, get_custom_param_handlers +from labtech.params import get_param_handler_manager class TestParamHandler: def teardown_method(self, method): - clear_custom_param_handlers() + get_param_handler_manager().clear() def test_register(self): @@ -31,7 +31,7 @@ def serialize(self, value, *, serializer): def deserialize(self, value, *, serializer): return frozenset(value) - assert [type(handler) for handler in get_custom_param_handlers()] == [ + assert [type(handler) for handler in get_param_handler_manager().prioritised_handlers] == [ FrozensetParamHandler, ] @@ -71,7 +71,7 @@ class FrozensetParamHandlerThree(FrozensetParamHandler): class FrozensetParamHandlerFour(FrozensetParamHandler): pass - assert [type(handler) for handler in get_custom_param_handlers()] == [ + assert [type(handler) for handler in get_param_handler_manager().prioritised_handlers] == [ FrozensetParamHandlerFour, FrozensetParamHandlerTwo, FrozensetParamHandlerThree, diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index 622345f..a3ee7af 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -9,7 +9,7 @@ import labtech.tasks from labtech.cache import BaseCache, NullCache, PickleCache from labtech.exceptions import TaskError -from labtech.params import clear_custom_param_handlers +from labtech.params import get_param_handler_manager from labtech.tasks import _RESERVED_ATTRS, ParamScalar, find_tasks_in_param, immutable_param_value from labtech.types import ResultT, Storage, Task, TaskInfo @@ -307,7 +307,7 @@ def deserialize(self, value, *, serializer): return set(value) def teardown_method(self, method): - clear_custom_param_handlers() + get_param_handler_manager().clear() def test_empty_list(self) -> None: assert immutable_param_value("hello", []) == () @@ -405,7 +405,7 @@ def deserialize(self, value, *, serializer): return frozenset(value) def teardown_method(self, method): - clear_custom_param_handlers() + get_param_handler_manager().clear() def test_scalar(self, scalar: ParamScalar) -> None: assert find_tasks_in_param(scalar) == [] From 319f3a0634da01a17f7ed3a43b39890d4a24e83b Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Mon, 19 May 2025 00:22:40 +1200 Subject: [PATCH 6/7] Pass param_handler_manager to runner factory, and simplify interface with get/instantiate methods. Move Runner and RunnerBackend to avoid circular imports. --- docs/runners.md | 6 +- labtech/lab.py | 10 +-- labtech/monitor.py | 10 +-- labtech/params.py | 17 +++-- labtech/runners/__init__.py | 3 + labtech/runners/base.py | 129 +++++++++++++++++++++++++++++++++- labtech/runners/process.py | 11 +-- labtech/runners/ray.py | 18 ++--- labtech/runners/serial.py | 8 ++- labtech/runners/thread.py | 8 ++- labtech/serialization.py | 10 +-- labtech/tasks.py | 6 +- labtech/types.py | 123 +------------------------------- tests/integration/test_e2e.py | 28 ++++---- tests/labtech/test_params.py | 20 +++--- tests/labtech/test_tasks.py | 18 ++--- 16 files changed, 221 insertions(+), 204 deletions(-) diff --git a/docs/runners.md b/docs/runners.md index 68af181..a7610e9 100644 --- a/docs/runners.md +++ b/docs/runners.md @@ -36,12 +36,12 @@ See: [Multi-Machine Clusters](./distributed.md) You can define your own Runner Backend to execute tasks with a different form of parallelism or distributed computing platform by defining an implementation of the -[`RunnerBackend`][labtech.types.RunnerBackend] abstract base class: +[`RunnerBackend`][labtech.runners.RunnerBackend] abstract base class: -::: labtech.types.RunnerBackend +::: labtech.runners.RunnerBackend options: heading_level: 4 -::: labtech.types.Runner +::: labtech.runners.Runner options: heading_level: 4 diff --git a/labtech/lab.py b/labtech/lab.py index 477252d..813119e 100644 --- a/labtech/lab.py +++ b/labtech/lab.py @@ -11,10 +11,11 @@ from .exceptions import LabError, TaskNotFound from .monitor import TaskMonitor -from .runners import ForkRunnerBackend, SerialRunnerBackend, SpawnRunnerBackend, ThreadRunnerBackend +from .params import ParamHandlerManager +from .runners import ForkRunnerBackend, RunnerBackend, SerialRunnerBackend, SpawnRunnerBackend, ThreadRunnerBackend from .storage import LocalStorage, NullStorage from .tasks import get_direct_dependencies -from .types import LabContext, ResultMeta, ResultT, RunnerBackend, Storage, Task, TaskT, is_task, is_task_type +from .types import LabContext, ResultMeta, ResultT, Storage, Task, TaskT, is_task, is_task_type from .utils import OrderedSet, base_tqdm, is_ipython, logger, tqdm, tqdm_notebook @@ -201,13 +202,14 @@ def run(self, tasks: Sequence[Task]) -> dict[Task, Any]: runner = self.lab.runner_backend.build_runner( context=self.lab.context, max_workers=self.lab.max_workers, + param_handler_manager=ParamHandlerManager.get(), storage=self.lab._storage, ) task_monitor = None if not self.disable_top: task_monitor = TaskMonitor( - runner=runner, + get_task_infos=runner.get_task_infos, top_format=self.top_format, top_sort=self.top_sort, top_n=self.top_n, @@ -359,7 +361,7 @@ def __init__(self, *, useful when troubleshooting issues running tasks on different threads and processes. * Any instance of a - [`RunnerBackend`][labtech.types.RunnerBackend], + [`RunnerBackend`][labtech.runners.RunnerBackend], allowing for custom task management implementations. For details on the differences between `'fork'` and diff --git a/labtech/monitor.py b/labtech/monitor.py index 6f52e16..3c66136 100644 --- a/labtech/monitor.py +++ b/labtech/monitor.py @@ -2,12 +2,12 @@ from datetime import datetime from itertools import zip_longest from string import Template -from typing import Optional, Sequence, cast +from typing import Callable, Optional, Sequence, cast import psutil from .exceptions import LabError -from .types import Runner, TaskMonitorInfo, TaskMonitorInfoItem, TaskMonitorInfoValue +from .types import TaskMonitorInfo, TaskMonitorInfoItem, TaskMonitorInfoValue from .utils import tqdm @@ -82,9 +82,9 @@ def show(self) -> None: class TaskMonitor: - def __init__(self, *, runner: Runner, notebook: bool, + def __init__(self, *, get_task_infos: Callable[[], Sequence[TaskMonitorInfo]], notebook: bool, top_format: str, top_sort: str, top_n: int): - self.runner = runner + self.get_task_infos = get_task_infos self.top_template = Template(top_format) self.top_sort = top_sort self.top_sort_key = top_sort @@ -103,7 +103,7 @@ def _top_task_lines(self) -> tuple[int, list[str]]: # Make (shallow) copies of dictionaries to avoid mutating # original dictionaries provided by runner. info.copy() - for info in self.runner.get_task_infos() + for info in self.get_task_infos() ] total_task_count = len(task_infos) diff --git a/labtech/params.py b/labtech/params.py index 5bdc314..4ef2657 100644 --- a/labtech/params.py +++ b/labtech/params.py @@ -55,17 +55,16 @@ def prioritised_handlers(self) -> list[ParamHandler]: ] return self._prioritised_handlers + def instantiate(self) -> None: + global _PARAM_HANDLER_MANAGER + _PARAM_HANDLER_MANAGER = self -_PARAM_HANDLER_MANAGER = ParamHandlerManager() - + @staticmethod + def get() -> 'ParamHandlerManager': + return _PARAM_HANDLER_MANAGER -def get_param_handler_manager() -> ParamHandlerManager: - return _PARAM_HANDLER_MANAGER - -def set_param_handler_manager(param_handler_manager: ParamHandlerManager) -> None: - global _PARAM_HANDLER_MANAGER - _PARAM_HANDLER_MANAGER = param_handler_manager +_PARAM_HANDLER_MANAGER = ParamHandlerManager() def param_handler(*args, priority: int = 1000): @@ -97,7 +96,7 @@ class name of the custom parameter handler that was used to """ def decorator(cls): - get_param_handler_manager().register(cls, priority=priority) + ParamHandlerManager.get().register(cls, priority=priority) return cls if len(args) > 0 and isclass(args[0]): diff --git a/labtech/runners/__init__.py b/labtech/runners/__init__.py index 8b7fa77..955898a 100644 --- a/labtech/runners/__init__.py +++ b/labtech/runners/__init__.py @@ -1,8 +1,11 @@ +from .base import Runner, RunnerBackend from .process import ForkRunnerBackend, SpawnRunnerBackend from .serial import SerialRunnerBackend from .thread import ThreadRunnerBackend __all__ = [ + 'Runner', + 'RunnerBackend', 'ForkRunnerBackend', 'SpawnRunnerBackend', 'SerialRunnerBackend', diff --git a/labtech/runners/base.py b/labtech/runners/base.py index ba2e257..08722c1 100644 --- a/labtech/runners/base.py +++ b/labtech/runners/base.py @@ -1,17 +1,142 @@ +from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import fields from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Iterator, Optional, Sequence from frozendict import frozendict from labtech.exceptions import LabError +from labtech.params import ParamHandlerManager from labtech.tasks import is_task -from labtech.types import LabContext, ResultMeta, Storage, Task, TaskResult +from labtech.types import LabContext, ResultMeta, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import logger +class Runner(ABC): + """Manages the execution of [Tasks][labtech.types.Task], typically + by delegating to a parallel processing framework.""" + + @abstractmethod + def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: + """Submit the given task object to be run and have its result cached. + + It is up to the Runner to decide when to start running the + task (i.e. when resources become available). + + The implementation of this method should run the task by + effectively calling: + + ``` + # param_handler_manager needs to be instantiated in remote processes + # that don't inherit from the main process: + param_handler_manager.instantiate() + + for dependency_task in get_direct_dependencies(task, all_identities=True): + # Where results_map is expected to contain the TaskResult for + # each dependency_task. + dependency_task._set_results_map(results_map) + + current_process = multiprocessing.current_process() + orig_process_name = current_process.name + try: + # If the thread name or similar is set instead of the process + # name, then the Runner should update the handler of the global + # labtech.utils.logger to include that instead of the process name. + current_process.name = task_name + return labtech.runners.base.run_or_load_task( + task=task, + use_cache=use_cache, + filtered_context=task.filter_context(context), + storage=storage, + ) + finally: + current_process.name = orig_process_name + ``` + + Args: + task: The task to execute. + task_name: Name to use when referring to the task in logs. + use_cache: If True, the task's result should be fetched from the + cache if it is available (fetching should still be done in a + delegated process). + + """ + + @abstractmethod + def wait(self, *, timeout_seconds: Optional[float]) -> Iterator[tuple[Task, ResultMeta | BaseException]]: + """Wait up to timeout_seconds or until at least one of the + submitted tasks is done, then return an iterator of tasks in a + done state and a list of tasks in all other states. + + Each task is returned as a pair where the first value is the + task itself, and the second value is either: + + * For a successfully completed task: Metadata of the result. + * For a task that fails with any BaseException descendant: The exception + that was raised. + + Cancelled tasks are never returned. + + """ + + @abstractmethod + def cancel(self) -> None: + """Cancel all submitted tasks that have not yet been started.""" + + @abstractmethod + def stop(self) -> None: + """Stop all currently running tasks.""" + + @abstractmethod + def close(self) -> None: + """Clean up any resources used by the Runner after all tasks + are finished, cancelled, or stopped.""" + + @abstractmethod + def pending_task_count(self) -> int: + """Returns the number of tasks that have been submitted but + not yet cancelled or returned from a call to wait().""" + + @abstractmethod + def get_result(self, task: Task) -> TaskResult: + """Returns the in-memory result for a task that was + successfully run by this Runner. Raises a KeyError for a + result with no in-memory result.""" + + @abstractmethod + def remove_results(self, tasks: Sequence[Task]) -> None: + """Removes the in-memory results for tasks that were + sucessfully run by this Runner. Ignores tasks that have no + in-memory result.""" + + @abstractmethod + def get_task_infos(self) -> list[TaskMonitorInfo]: + """Returns a snapshot of monitoring information about each + task that is currently running.""" + + +class RunnerBackend(ABC): + """Factory class to construct [Runner][labtech.runners.Runner] objects.""" + + @abstractmethod + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, + max_workers: Optional[int]) -> Runner: + """Return a Runner prepared with the given configuration. + + Args: + context: Additional variables made available to tasks that aren't + considered when saving to/loading from the cache. + storage: Where task results should be cached to. + param_handler_manager: Custom parameter handling configuration + to be instantiated on remote processes. + max_workers: The maximum number of parallel worker processes for + running tasks. + """ + + @contextmanager def optional_mlflow(task: Task): """Context manager to set mlflow "run" configuration for a task if diff --git a/labtech/runners/process.py b/labtech/runners/process.py index 0be4f7b..53ff3f6 100644 --- a/labtech/runners/process.py +++ b/labtech/runners/process.py @@ -18,11 +18,12 @@ from labtech.exceptions import RunnerError, TaskDiedError from labtech.monitor import get_process_info +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, ResultsMap, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult +from labtech.types import LabContext, ResultMeta, ResultsMap, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import LoggerFileProxy, logger -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task class FutureStateError(Exception): @@ -479,7 +480,8 @@ class SpawnRunnerBackend(RunnerBackend): """ - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> SpawnProcessRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> SpawnProcessRunner: return SpawnProcessRunner( context=context, storage=storage, @@ -557,7 +559,8 @@ class ForkRunnerBackend(RunnerBackend): """ - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> ForkProcessRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> ForkProcessRunner: return ForkProcessRunner( context=context, storage=storage, diff --git a/labtech/runners/ray.py b/labtech/runners/ray.py index a1351e0..cd53089 100644 --- a/labtech/runners/ray.py +++ b/labtech/runners/ray.py @@ -6,12 +6,12 @@ from typing import Iterator, Optional, Sequence from labtech.exceptions import RunnerError -from labtech.params import ParamHandlerManager, get_param_handler_manager, set_param_handler_manager +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, ResultT, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult, is_task +from labtech.types import LabContext, ResultMeta, ResultT, Storage, Task, TaskMonitorInfo, TaskResult, is_task from labtech.utils import logger -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task try: import ray @@ -34,6 +34,8 @@ class TaskDetail: def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: bool, context: LabContext, storage: Storage, param_handler_manager: ParamHandlerManager) -> tuple[ResultMeta, ResultT]: + param_handler_manager.instantiate() + # task_refs_args is expected to be a flattened list of (task, # result_meta, result_value) triples - passed this way to ensure # refs are top-level to trigger locality-aware scheduling: @@ -54,8 +56,6 @@ def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: b value=result_value, ) - set_param_handler_manager(param_handler_manager) - for dependency_task in get_direct_dependencies(task, all_identities=True): dependency_task._set_results_map(results_map) @@ -76,7 +76,7 @@ def _ray_func(*task_refs_args, task: Task[ResultT], task_name: str, use_cache: b class RayRunner(Runner): - def __init__(self, *, context: LabContext, storage: Storage, + def __init__(self, *, context: LabContext, storage: Storage, param_handler_manager: ParamHandlerManager, monitor_interval_seconds: float, monitor_timeout_seconds: int) -> None: self.monitor_interval_seconds = monitor_interval_seconds self.monitor_timeout_seconds = monitor_timeout_seconds @@ -87,7 +87,7 @@ def __init__(self, *, context: LabContext, storage: Storage, logger.debug('Uploading context and storage objects to ray object store') self.context_ref = ray.put(context) self.storage_ref = ray.put(storage) - self.param_handler_manager_ref = ray.put(get_param_handler_manager()) + self.param_handler_manager_ref = ray.put(param_handler_manager) logger.debug('Uploaded context and storage objects to ray object store') self.cancelled = False @@ -319,7 +319,8 @@ def __init__(self, monitor_interval_seconds: float = 1, monitor_timeout_seconds: self.monitor_interval_seconds = monitor_interval_seconds self.monitor_timeout_seconds = monitor_timeout_seconds - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> Runner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> Runner: if max_workers is not None: raise RunnerError(( 'Remove max_workers from your Lab configuration, as RayRunnerBackend only supports max_workers=None. ' @@ -330,6 +331,7 @@ def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Op return RayRunner( context=context, storage=storage, + param_handler_manager=param_handler_manager, monitor_interval_seconds=self.monitor_interval_seconds, monitor_timeout_seconds=self.monitor_timeout_seconds, ) diff --git a/labtech/runners/serial.py b/labtech/runners/serial.py index c6c2f6f..b9515f2 100644 --- a/labtech/runners/serial.py +++ b/labtech/runners/serial.py @@ -6,11 +6,12 @@ import psutil from labtech.monitor import get_process_info +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult +from labtech.types import LabContext, ResultMeta, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import logger -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task @dataclass(frozen=True) @@ -114,7 +115,8 @@ class SerialRunnerBackend(RunnerBackend): """Runner Backend that runs each task serially in the main process and thread.""" - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> SerialRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> SerialRunner: return SerialRunner( context=context, storage=storage, diff --git a/labtech/runners/thread.py b/labtech/runners/thread.py index ccf5eba..699fd19 100644 --- a/labtech/runners/thread.py +++ b/labtech/runners/thread.py @@ -9,11 +9,12 @@ from labtech.exceptions import RunnerError from labtech.monitor import get_process_info +from labtech.params import ParamHandlerManager from labtech.tasks import get_direct_dependencies -from labtech.types import LabContext, ResultMeta, Runner, RunnerBackend, Storage, Task, TaskMonitorInfo, TaskResult +from labtech.types import LabContext, ResultMeta, Storage, Task, TaskMonitorInfo, TaskResult from labtech.utils import OrderedSet, logger, make_logger_handler -from .base import run_or_load_task +from .base import Runner, RunnerBackend, run_or_load_task class KillThread(Exception): @@ -171,7 +172,8 @@ class ThreadRunnerBackend(RunnerBackend): """ - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> ThreadRunner: + def build_runner(self, *, context: LabContext, storage: Storage, + param_handler_manager: ParamHandlerManager, max_workers: Optional[int]) -> ThreadRunner: return ThreadRunner( context=context, storage=storage, diff --git a/labtech/serialization.py b/labtech/serialization.py index a46b46d..66ea59d 100644 --- a/labtech/serialization.py +++ b/labtech/serialization.py @@ -7,7 +7,7 @@ from frozendict import frozendict from .exceptions import SerializationError, UnregisteredParamHandlerError -from .params import get_param_handler_manager +from .params import ParamHandlerManager from .types import ParamHandler, ResultMeta, Serializer, Task, is_task, jsonable from .utils import ensure_dict_key_str, fully_qualified_class_name @@ -34,7 +34,7 @@ def _deserialize_custom(self, serialized: dict[str, jsonable]) -> Any: f"serialized custom value, received: '{serialized}'")) try: - custom_param_handler = get_param_handler_manager().lookup(cast(str, serialized['__class__'])) + custom_param_handler = ParamHandlerManager.get().lookup(cast(str, serialized['__class__'])) except UnregisteredParamHandlerError: custom_param_handler = self.deserialize_class(serialized['__class__'])() return custom_param_handler.deserialize( @@ -108,7 +108,7 @@ def deserialize_task(self, serialized: dict[str, jsonable], *, result_meta: Opti return task def serialize_value(self, value: Any) -> jsonable: - for custom_param_handler in get_param_handler_manager().prioritised_handlers: + for custom_param_handler in ParamHandlerManager.get().prioritised_handlers: if custom_param_handler.handles(value): return self._serialize_custom(custom_param_handler, value) @@ -118,8 +118,8 @@ def serialize_value(self, value: Any) -> jsonable: return [self.serialize_value(item) for item in value] elif isinstance(value, frozendict): return { - ensure_dict_key_str(key, exception_type=SerializationError): self.serialize_value(value) - for key, value in value.items() + ensure_dict_key_str(k, exception_type=SerializationError): self.serialize_value(v) + for k, v in value.items() } elif isinstance(value, Enum): return self._serialize_enum(value) diff --git a/labtech/tasks.py b/labtech/tasks.py index 7062d83..f52b9c5 100644 --- a/labtech/tasks.py +++ b/labtech/tasks.py @@ -12,7 +12,7 @@ from .cache import NullCache, PickleCache from .exceptions import TaskError -from .params import get_param_handler_manager +from .params import ParamHandlerManager from .types import Cache, LabContext, ResultMeta, ResultsMap, ResultT, Task, TaskInfo, is_task, is_task_type from .utils import ensure_dict_key_str @@ -40,7 +40,7 @@ def immutable_param_value(key: str, value: Any) -> Any: sets).""" # Any value handled by custom_param_handlers is expected to be # immutable and hashable. - for custom_param_handler in get_param_handler_manager().prioritised_handlers: + for custom_param_handler in ParamHandlerManager.get().prioritised_handlers: if custom_param_handler.handles(value): if not isinstance(value, Hashable): raise TaskError( @@ -312,7 +312,7 @@ def find_tasks_in_param(param_value: Any, searched_coll_ids: Optional[set[int]] if id(param_value) in searched_coll_ids: return [] - for custom_param_handler in get_param_handler_manager().prioritised_handlers: + for custom_param_handler in ParamHandlerManager.get().prioritised_handlers: if custom_param_handler.handles(param_value): searched_coll_ids = searched_coll_ids | {id(param_value)} return custom_param_handler.find_tasks( diff --git a/labtech/types.py b/labtech/types.py index a11584b..19434dc 100644 --- a/labtech/types.py +++ b/labtech/types.py @@ -9,7 +9,6 @@ Any, Callable, Generic, - Iterator, Literal, Optional, Protocol, @@ -275,7 +274,7 @@ def handles(self, value: Any) -> bool: """Returns True if the given parameter value should be handled by this class.""" - def find_tasks(self, value: Any, *, find_tasks_in_param: Callable[[Any], Sequence[Task]]) -> Sequence[Task]: + def find_tasks(self, value: Any, *, find_tasks_in_param: Callable[[Any], Sequence[Task]]) -> list[Task]: """Given a parameter value, return all tasks within it (not including tasks within those tasks). @@ -303,123 +302,3 @@ def deserialize(self, serialized: jsonable, *, serializer: Serializer) -> Any: TaskMonitorInfoValue = datetime | str | int | float TaskMonitorInfoItem = TaskMonitorInfoValue | tuple[TaskMonitorInfoValue, str] TaskMonitorInfo = dict[str, TaskMonitorInfoItem] - - -class Runner(ABC): - """Manages the execution of [Tasks][labtech.types.Task], typically - by delegating to a parallel processing framework.""" - - @abstractmethod - def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: - """Submit the given task object to be run and have its result cached. - - It is up to the Runner to decide when to start running the - task (i.e. when resources become available). - - The implementation of this method should run the task by - effectively calling: - - ``` - # param_handler_manager should be saved from - # labtech.params.get_param_handler_manager() on the main process - # and set in remote processes that don't inherit from the main process: - labtech.params.set_param_handler_manager(param_handler_manager) - - for dependency_task in get_direct_dependencies(task, all_identities=True): - # Where results_map is expected to contain the TaskResult for - # each dependency_task. - dependency_task._set_results_map(results_map) - - current_process = multiprocessing.current_process() - orig_process_name = current_process.name - try: - # If the thread name or similar is set instead of the process - # name, then the Runner should update the handler of the global - # labtech.utils.logger to include that instead of the process name. - current_process.name = task_name - return labtech.runners.base.run_or_load_task( - task=task, - use_cache=use_cache, - filtered_context=task.filter_context(context), - storage=storage, - ) - finally: - current_process.name = orig_process_name - ``` - - Args: - task: The task to execute. - task_name: Name to use when referring to the task in logs. - use_cache: If True, the task's result should be fetched from the - cache if it is available (fetching should still be done in a - delegated process). - - """ - - @abstractmethod - def wait(self, *, timeout_seconds: Optional[float]) -> Iterator[tuple[Task, ResultMeta | BaseException]]: - """Wait up to timeout_seconds or until at least one of the - submitted tasks is done, then return an iterator of tasks in a - done state and a list of tasks in all other states. - - Each task is returned as a pair where the first value is the - task itself, and the second value is either: - - * For a successfully completed task: Metadata of the result. - * For a task that fails with any BaseException descendant: The exception - that was raised. - - Cancelled tasks are never returned. - - """ - - @abstractmethod - def cancel(self) -> None: - """Cancel all submitted tasks that have not yet been started.""" - - @abstractmethod - def stop(self) -> None: - """Stop all currently running tasks.""" - - @abstractmethod - def close(self) -> None: - """Clean up any resources used by the Runner after all tasks - are finished, cancelled, or stopped.""" - - @abstractmethod - def pending_task_count(self) -> int: - """Returns the number of tasks that have been submitted but - not yet cancelled or returned from a call to wait().""" - - @abstractmethod - def get_result(self, task: Task) -> TaskResult: - """Returns the in-memory result for a task that was - successfully run by this Runner. Raises a KeyError for a - result with no in-memory result.""" - - @abstractmethod - def remove_results(self, tasks: Sequence[Task]) -> None: - """Removes the in-memory results for tasks that were - sucessfully run by this Runner. Ignores tasks that have no - in-memory result.""" - - @abstractmethod - def get_task_infos(self) -> list[TaskMonitorInfo]: - """Returns a snapshot of monitoring information about each - task that is currently running.""" - - -class RunnerBackend(ABC): - """Factory class to construct [Runner][labtech.types.Runner] objects.""" - - @abstractmethod - def build_runner(self, *, context: LabContext, storage: Storage, max_workers: Optional[int]) -> Runner: - """Return a Runner prepared with the given configuration. - - Args: - context: Additional variables made available to tasks that aren't - considered when saving to/loading from the cache. - storage: Where task results should be cached to. - max_workers: The maximum number of parallel worker processes for - running tasks. - """ diff --git a/tests/integration/test_e2e.py b/tests/integration/test_e2e.py index 14d4731..10de836 100644 --- a/tests/integration/test_e2e.py +++ b/tests/integration/test_e2e.py @@ -9,31 +9,31 @@ import ray import labtech -from labtech.params import get_param_handler_manager +from labtech.params import ParamHandlerManager from labtech.runners.ray import RayRunnerBackend from labtech.types import Task -@pytest.fixture(autouse=True) -def datetime_param_handler(): +class DatetimeParamHandler: - @labtech.param_handler - class DatetimeParamHandler: + def handles(self, value): + return isinstance(value, datetime) - def handles(self, value): - return isinstance(value, datetime) + def find_tasks(self, value, *, find_tasks_in_param): + return [] - def find_tasks(self, value, *, find_tasks_in_param): - return [] + def serialize(self, value, *, serializer): + return value.timestamp() - def serialize(self, value, *, serializer): - return value.timestamp() + def deserialize(self, serialized, *, serializer): + return datetime.fromtimestamp(serialized) - def deserialize(self, serialized, *, serializer): - return datetime.fromtimestamp(serialized) +@pytest.fixture(autouse=True) +def datetime_param_handler(): + labtech.param_handler(DatetimeParamHandler) yield - get_param_handler_manager().clear() + ParamHandlerManager.get().clear() @labtech.task(cache=None) diff --git a/tests/labtech/test_params.py b/tests/labtech/test_params.py index f8ad192..3dc4b9f 100644 --- a/tests/labtech/test_params.py +++ b/tests/labtech/test_params.py @@ -2,13 +2,13 @@ import labtech from labtech.exceptions import ParamHandlerError -from labtech.params import get_param_handler_manager +from labtech.params import ParamHandlerManager class TestParamHandler: def teardown_method(self, method): - get_param_handler_manager().clear() + ParamHandlerManager.get().clear() def test_register(self): @@ -26,12 +26,12 @@ def find_tasks(self, value, *, find_tasks_in_param): ] def serialize(self, value, *, serializer): - return list(sorted(value, key=hash)) + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] - def deserialize(self, value, *, serializer): - return frozenset(value) + def deserialize(self, serialized, *, serializer): + return frozenset([serializer.deserialize_value(item) for item in serialized]) - assert [type(handler) for handler in get_param_handler_manager().prioritised_handlers] == [ + assert [type(handler) for handler in ParamHandlerManager.get().prioritised_handlers] == [ FrozensetParamHandler, ] @@ -50,10 +50,10 @@ def find_tasks(self, value, *, find_tasks_in_param): ] def serialize(self, value, *, serializer): - return list(sorted(value, key=hash)) + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] - def deserialize(self, value, *, serializer): - return frozenset(value) + def deserialize(self, serialized, *, serializer): + return frozenset([serializer.deserialize_value(item) for item in serialized]) @labtech.param_handler(priority=2000) class FrozensetParamHandlerOne(FrozensetParamHandler): @@ -71,7 +71,7 @@ class FrozensetParamHandlerThree(FrozensetParamHandler): class FrozensetParamHandlerFour(FrozensetParamHandler): pass - assert [type(handler) for handler in get_param_handler_manager().prioritised_handlers] == [ + assert [type(handler) for handler in ParamHandlerManager.get().prioritised_handlers] == [ FrozensetParamHandlerFour, FrozensetParamHandlerTwo, FrozensetParamHandlerThree, diff --git a/tests/labtech/test_tasks.py b/tests/labtech/test_tasks.py index a3ee7af..8d6dfb4 100644 --- a/tests/labtech/test_tasks.py +++ b/tests/labtech/test_tasks.py @@ -9,7 +9,7 @@ import labtech.tasks from labtech.cache import BaseCache, NullCache, PickleCache from labtech.exceptions import TaskError -from labtech.params import get_param_handler_manager +from labtech.params import ParamHandlerManager from labtech.tasks import _RESERVED_ATTRS, ParamScalar, find_tasks_in_param, immutable_param_value from labtech.types import ResultT, Storage, Task, TaskInfo @@ -280,10 +280,10 @@ def find_tasks(self, value, *, find_tasks_in_param): ] def serialize(self, value, *, serializer): - return list(sorted(value, key=hash)) + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] - def deserialize(self, value, *, serializer): - return frozenset(value) + def deserialize(self, serialized, *, serializer): + return frozenset([serializer.deserialize_value(item) for item in serialized]) @labtech.param_handler class SetParamHandler: @@ -301,13 +301,13 @@ def find_tasks(self, value, *, find_tasks_in_param): ] def serialize(self, value, *, serializer): - return list(sorted(value, key=hash)) + return [serializer.serialize_value(item) for item in sorted(value, key=hash)] - def deserialize(self, value, *, serializer): - return set(value) + def deserialize(self, serialized, *, serializer): + return set([serializer.deserialize_value(item) for item in serialized]) def teardown_method(self, method): - get_param_handler_manager().clear() + ParamHandlerManager.get().clear() def test_empty_list(self) -> None: assert immutable_param_value("hello", []) == () @@ -405,7 +405,7 @@ def deserialize(self, value, *, serializer): return frozenset(value) def teardown_method(self, method): - get_param_handler_manager().clear() + ParamHandlerManager.get().clear() def test_scalar(self, scalar: ParamScalar) -> None: assert find_tasks_in_param(scalar) == [] From 8d901ab0000227696268bfa4042efbea8543b09f Mon Sep 17 00:00:00 2001 From: Ben Denham Date: Mon, 19 May 2025 16:01:00 +1200 Subject: [PATCH 7/7] Add docs for custom parameter handling --- README.md | 1 + docs/cookbook.md | 73 ++++++++++++++++++++++++++++++++++++++++++++++++ docs/params.md | 11 ++++++++ mkdocs.yml | 1 + 4 files changed, 86 insertions(+) create mode 100644 docs/params.md diff --git a/README.md b/README.md index 730d63d..7287b17 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ To learn more, dive into the following resources: * [Option for customising task execution backends](https://ben-denham.github.io/labtech/runners) * [Diagramming tools](https://ben-denham.github.io/labtech/diagram) * [Distributing across multiple machines](https://ben-denham.github.io/labtech/distributed) +* [Extensible parameter types](https://ben-denham.github.io/labtech/params) * [More examples](https://github.com/ben-denham/labtech/tree/main/examples) diff --git a/docs/cookbook.md b/docs/cookbook.md index 20688cd..7d421e4 100644 --- a/docs/cookbook.md +++ b/docs/cookbook.md @@ -76,6 +76,7 @@ object as a parameter to a task: * Constructing the object in a dependent task * Passing the object in an `Enum` parameter * Passing the object in the lab context +* Defining a custom parameter handler #### Constructing objects in dependent tasks @@ -285,6 +286,78 @@ lab = labtech.Lab( results = lab.run_tasks(experiments) ``` +#### Defining a custom parameter handler + +Advanced users may want to extend Labtech to support additional types +of parameters. To do so, you can declare a custom parameter type +handler class with the [`@param_handler`](/params) decorator. + +The example below demonstrates defining a handler for [Scipy +probability distributions](https://docs.scipy.org/doc/scipy/reference/stats.html): + +``` {.python .code} +import scipy.stats +from scipy.stats.distributions import rv_frozen + + +@labtech.param_handler +class DistributionParamHandler: + """ + There are two important limitations to this implementation: + + 1. Distributions with complex arguments are not supported + (e.g. rv_histogram, which takes arrays as arguments). + 2. Equivalent distributions expressed with different arguments + (e.g. positional vs keyword arguments) will be treated as + different parameter values for the purposes of caching. + + """ + + def handles(self, value): + return isinstance(value, rv_frozen) + + def find_tasks(self, value, *, find_tasks_in_param): + return [] + + def serialize(self, value, *, serializer): + return { + 'name': value.dist.name, + 'args': [serializer.serialize_value(arg) for arg in value.args], + 'kwds': { + key: serializer.serialize_value(kwd) for key, kwd in + sorted(value.kwds.items(), key=lambda pair: pair[0]) + }, + } + + def deserialize(self, serialized, *, serializer): + dist_cls = getattr(scipy.stats, serialized['name']) + args = [serializer.deserialize_value(arg) for arg in serialized['args']] + kwds = { + key: serializer.deserialize_value(kwd) + for key, kwd in serialized['kwds'] + } + return dist_cls(*args, **kwds) + + +@labtech.task +class Experiment: + distribution: rv_frozen + + def run(self): + return self.distribution.mean() + + +experiments = [ + Experiment(distribution=distribution) + for distribution in [ + scipy.stats.norm(loc=42), + scipy.stats.expon(loc=2), + ] +] +lab = labtech.Lab(storage=None) +results = lab.run_tasks(experiments) +``` + ### How can I control multi-processing myself within a task? diff --git a/docs/params.md b/docs/params.md new file mode 100644 index 0000000..b2cd3ac --- /dev/null +++ b/docs/params.md @@ -0,0 +1,11 @@ +To extend the types of data that can be used as Labtech parameters, +you can define a class that implements the +[`ParamHandler`][labtech.types.ParamHandler] protocol and decorate it +with the [`@param_handler`][labtech.param_handler] decorator. A full +example is given [in the cookbook](/cookbook#defining-a-custom-parameter-handler). + +::: labtech.param_handler + +::: labtech.types.ParamHandler + +::: labtech.types.Serializer diff --git a/mkdocs.yml b/mkdocs.yml index 0d47de8..bbbdc07 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,6 +10,7 @@ nav: - Multi-Machine Clusters: 'distributed.md' - Diagramming: 'diagram.md' - Caches and Storage: 'caching.md' + - Custom Parameter Handling: 'params.md' plugins: - search - mkdocstrings: