diff --git a/pathwaysutils/elastic/elastic.py b/pathwaysutils/elastic/elastic.py new file mode 100644 index 0000000..c9d73b2 --- /dev/null +++ b/pathwaysutils/elastic/elastic.py @@ -0,0 +1,282 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Elasticity manager. + +This class provides a utility for elastic training. It provides a decorator that +retries a function in case of `jax.errors.JaxRuntimeError` caused by slice down +events. It also provides a utility for waiting for slices to become active. +""" + +import collections +from collections.abc import Mapping, Sequence +import logging +import time +import traceback + +import jax +import numpy as np +from pathwaysutils.debug import timing + + +_logger = logging.getLogger(__name__) + +_SIMPLE_EXECUTION_TEST_VALUE = 100 +_ELASTIC_DOWN_ERROR_TYPES = [ + "DATA_LOSS", +] +_ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [ + "DEADLINE_EXCEEDED", + "NOT_FOUND", + "INTERNAL", +] + + +def _plus_one(x: jax.Array) -> jax.Array: + """Adds one to each element in the array. + + Used to test if a slice is active. + + Args: + x: The array to add one to. + + Returns: + The array with one added to each element. + """ + return x + 1 + + +def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array: + """Simple execution to test if a slice is active. + + This function is used to test if a slice is active. It executes a simple + computation on the devices and returns the result. If any of the devices are + not active, the returned array will fail with a JaxRuntimeError used. + + Simply executing this function is not enough to determine if the slice is + active. We also need to check the value of the returned array. + + Args: + devices: The devices to execute on. + + Returns: + The result of the execution. + """ + if not devices: + raise ValueError("No devices") + + test_input = np.zeros(len(devices), dtype=float) + ( + _SIMPLE_EXECUTION_TEST_VALUE - 1 + ) + + return jax.pmap(_plus_one, devices=devices)(test_input) + + +def get_slice_to_devices( + devices: Sequence[jax.Device], +) -> dict[int, Sequence[jax.Device]]: + """Returns the mapping from slice index to devices.""" + slice_to_devices = collections.defaultdict(list) + for d in devices: + slice_to_devices[d.slice_index].append(d) + return dict(slice_to_devices) + + +@timing.timeit +def get_active_slice_indices( + slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, +) -> set[int]: + """Returns the set of active slices indices.""" + if slice_to_devices is None: + _logger.debug("slice_to_devices is None. Getting from jax.devices().") + slice_to_devices = get_slice_to_devices(tuple(jax.devices())) + + _logger.debug( + "Getting active slice indices for slices: %s", + sorted(list(slice_to_devices.keys())), + ) + + active_slice_indices = set() + + results = { + slice_index: _simple_execution(devices) + for slice_index, devices in slice_to_devices.items() + } + + for slice_index, x in results.items(): + _logger.debug("Checking slice_index=%s", slice_index) + expected = ( + np.zeros(len(slice_to_devices[slice_index]), dtype=float) + + _SIMPLE_EXECUTION_TEST_VALUE + ) + try: + with timing.Timer(f"Checking {slice_index=}"): + _logger.debug("Blocking until ready for slice_index=%s", slice_index) + jax.block_until_ready(x) + _logger.debug("Execution finished for slice_index=%s", slice_index) + if np.allclose(x, expected): + active_slice_indices.add(slice_index) + _logger.debug("slice_index=%s active", slice_index) + else: + _logger.error( + "Error with _simple_execution for slice_index=%s. " + "This should never happen. Expected: %s, Actual: %s", + slice_index, + expected, + x, + ) + raise ValueError( + f"Error with _simple_execution for slice_index={slice_index}." + ) + except jax.errors.JaxRuntimeError as error: + _logger.debug( + "Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error + ) + if not is_error_due_to_slice_down(error): + _logger.info("Re-raising error for slice_index=%s", slice_index) + raise + _logger.debug("slice_index=%s bad", slice_index) + + _logger.debug("active_slice_indices=%s", active_slice_indices) + + return active_slice_indices + + +def wait_for_slices( + slice_count: int, + poll_interval: float | int = 10, + timeout: float | int | None = None, + slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, +) -> set[int]: + """Waits until after at least `slice_count` slices become active. + + Args: + slice_count: The number of slices to wait for. + poll_interval: The minimum number of seconds to wait between availability + checks. If the check takes longer than this, the next check will start + immediately after the current check completes. Defaults to 10 seconds. + timeout: The maximum number of seconds to wait. If None, there is no + timeout. + slice_to_devices: A mapping from slice index to devices. If None, + `get_slice_to_devices(jax.devices())` is used. + + Returns: + The active slice indices + + Raises: + TimeoutError: If the timeout is reached before the slices become + active. + """ + if slice_to_devices is None: + _logger.debug("slice_to_devices is None. Getting from jax.devices().") + slice_to_devices = get_slice_to_devices(jax.devices()) + + _logger.info( + "Waiting for %s slices. Poll interval: %s, Timeout: %s", + slice_count, + poll_interval, + timeout, + ) + start_time = time.time() + + while True: + check_start_time = time.time() + + _logger.debug("Checking active slices...") + active_slice_indices = get_active_slice_indices(slice_to_devices) + if len(active_slice_indices) >= slice_count: + _logger.info( + "Sufficient slices active: %s >= %s. Active indices: %s", + len(active_slice_indices), + slice_count, + active_slice_indices, + ) + return active_slice_indices + + _logger.info( + "%s slices active. Wanting at least %s. Active indices: %s", + len(active_slice_indices), + slice_count, + active_slice_indices, + ) + + time_to_sleep = max(0, poll_interval - (time.time() - check_start_time)) + + if ( + timeout is not None + and (elapsed_time := time.time() - start_time) + time_to_sleep + >= timeout + ): + raise TimeoutError( + f"Timed out waiting for {slice_count} slices. Only" + f" {len(active_slice_indices)} active after" + f" {elapsed_time:.2f} seconds." + f" Next check would occur after the timeout of {timeout}" + " seconds." + ) + + if time_to_sleep > 0: + _logger.debug("Sleeping for %.2f seconds.", time_to_sleep) + + time.sleep(time_to_sleep) + + +def is_error_due_to_slice_down(error: Exception) -> bool: + """Returns True if the error is due to slice down. + + The error types that are considered due to slice down are + jax.errors.JaxRuntimeError with the following error kind in the message: + - DATA_LOSS + - DEADLINE_EXCEEDED + - NOT_FOUND + - INTERNAL + + Args: + error: The error to check. + """ + error_due_to_slice_down = False + traceback_logging_level = logging.DEBUG + + if isinstance(error, jax.errors.JaxRuntimeError): + _logger.debug("Checking if JaxRuntimeError is due to slice down: %s", error) + if any( + error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES + ): + _logger.debug( + "Caught an error due to slice down (matched" + " _ELASTIC_DOWN_ERROR_TYPES)" + ) + + error_due_to_slice_down = True + + elif any( + error_type in str(error) + for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES + ): + _logger.warning( + "Caught an error that may or may not be due to slice down (matched" + " _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES). This error will be treated" + " as due to slice down." + ) + traceback_logging_level = logging.WARNING + + error_due_to_slice_down = True + + if not error_due_to_slice_down: + _logger.debug("Caught an error not due to slice down") + + _logger.log( + traceback_logging_level, "\n".join(traceback.format_exception(error)) + ) + + return error_due_to_slice_down diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index ce57424..01721f0 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -18,59 +18,38 @@ events. It also provides a utility for waiting for slices to become active. """ -import collections -from collections.abc import Mapping, Sequence +import _thread +from collections.abc import Callable, Mapping, Sequence import functools -import itertools import logging -import time -import traceback -from typing import Any +import threading +from typing import Any, TypeVar import jax -import numpy as np -from pathwaysutils.debug import timing +from pathwaysutils.elastic import elastic _logger = logging.getLogger(__name__) -def _plus_one(x: jax.Array) -> jax.Array: - """Adds one to each element in the array. - - Used to test if a slice is active. +class ElasticRuntimeError(RuntimeError): + """Error raised when elasticity cannot continue.""" - Args: - x: The array to add one to. - Returns: - The array with one added to each element. - """ - return x + 1 +class NewSliceAvailableError(RuntimeError): + """Error raised when a new slice is available.""" -class ElasticRuntimeError(RuntimeError): - """Error raised when elasticity cannot continue.""" +_F = TypeVar("_F", bound=Callable[..., Any]) class Manager: """Utility class for elastic training.""" - _devices: Sequence[jax.Device] _total_slice_count: int | None = None slice_to_devices: Mapping[int, Sequence[jax.Device]] active_slice_indices: set[int] - _SIMPLE_EXECUTION_TEST_VALUE = 100 - _ELASTIC_DOWN_ERROR_TYPES = [ - "DATA_LOSS", - ] - _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES = [ - "DEADLINE_EXCEEDED", - "NOT_FOUND", - "INTERNAL", - ] - def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """Initializes the manager. @@ -79,24 +58,13 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """ if devices is None: devices = jax.devices() - self.devices = devices - - self.active_slice_indices = self.get_active_slice_indices() - - @property - def devices(self) -> Sequence[jax.Device]: - """Returns the devices.""" - return self._devices + self.slice_to_devices = elastic.get_slice_to_devices(devices) - @devices.setter - def devices(self, devices: Sequence[jax.Device]) -> None: - """Sets the devices.""" - self._devices = devices + self.all_slice_indices = set(self.slice_to_devices.keys()) - self.slice_to_devices = collections.defaultdict(list) - for d in self._devices: - self.slice_to_devices[d.slice_index].append(d) - self.slice_to_devices = dict(self.slice_to_devices) + self.active_slice_indices = elastic.get_active_slice_indices( + slice_to_devices=self.slice_to_devices + ) @property def total_slice_count(self) -> int: @@ -105,143 +73,6 @@ def total_slice_count(self) -> int: self._total_slice_count = len(self.slice_to_devices) return self._total_slice_count - def slice_device_count(self, slice_index: int) -> int: - """Returns the number of devices in a slice.""" - try: - return len(self.slice_to_devices[slice_index]) - except KeyError as error: - raise ValueError( - f"Slice {slice_index=} not found in {self.slice_to_devices=}" - ) from error - - def is_error_due_to_slice_down(self, error: Exception) -> bool: - """Returns True if the error is due to slice down. - - The error types that are considered due to slice down are - jax.errors.JaxRuntimeError with the following error kind in the message: - - DATA_LOSS - - DEADLINE_EXCEEDED - - NOT_FOUND - - INTERNAL - - Args: - error: The error to check. - """ - error_due_to_slice_down = False - traceback_logging_level = logging.DEBUG - - if isinstance(error, jax.errors.JaxRuntimeError): - if any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ERROR_TYPES - ): - _logger.info("Caught an error due to slice down") - - error_due_to_slice_down = True - - elif any( - error_type in str(error) - for error_type in self._ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES - ): - _logger.warning( - "Caught an error due that may or may not be due to slice down. This" - " error will be treated as due to slice down." - ) - traceback_logging_level = logging.WARNING - - error_due_to_slice_down = True - - if not error_due_to_slice_down: - _logger.info("Caught an error not due to slice down") - - _logger.log( - traceback_logging_level, "\n".join(traceback.format_exception(error)) - ) - - return error_due_to_slice_down - - def _simple_execution(self, devices: Sequence[jax.Device]) -> jax.Array: - """Simple execution to test if a slice is active. - - This function is used to test if a slice is active. It executes a simple - computation on the devices and returns the result. If any of the devices are - not active, the returned array will fail with a JaxRuntimeError used. - - Simply executing this function is not enough to determine if the slice is - active. We also need to check the value of the returned array. - - Args: - devices: The devices to execute on. - - Returns: - The result of the execution. - """ - if not devices: - raise ValueError("No devices") - - test_input = np.zeros(len(devices), dtype=float) + ( - self._SIMPLE_EXECUTION_TEST_VALUE - 1 - ) - - return jax.pmap(_plus_one, devices=devices)(test_input) - - @timing.timeit - def get_active_slice_indices(self) -> set[int]: - """Returns the set of active slices indices.""" - active_slice_indices = set() - - results = { - slice_index: self._simple_execution(devices) - for slice_index, devices in self.slice_to_devices.items() - } - - for slice_index, x in results.items(): - _logger.info("Checking slice_index=%s", slice_index) - expected = ( - np.zeros(self.slice_device_count(slice_index), dtype=float) - + self._SIMPLE_EXECUTION_TEST_VALUE - ) - try: - with timing.Timer(f"Checking {slice_index=}"): - jax.block_until_ready(x) - if np.allclose(x, expected): - active_slice_indices.add(slice_index) - _logger.info("slice_index=%s good", slice_index) - else: - _logger.error( - "Error with _simple_execution for slice_index=%s. " - "This should never happen. Expected: %s, Actual: %s", - slice_index, - expected, - x, - ) - raise ValueError( - f"Error with _simple_execution for slice_index={slice_index}." - ) - except jax.errors.JaxRuntimeError as error: - if not self.is_error_due_to_slice_down(error): - raise - _logger.info("slice_index=%s bad", slice_index) - - _logger.info("active_slice_indices=%s", active_slice_indices) - - return active_slice_indices - - @property - def active_slice_to_devices(self) -> dict[int, Sequence[jax.Device]]: - """The mapping from a active slice to its devices.""" - return { - slice_index: self.slice_to_devices[slice_index] - for slice_index in self.active_slice_indices - } - - @property - def active_devices(self) -> list[jax.Device]: - """Returns the active slice indices.""" - return list( - itertools.chain.from_iterable(self.active_slice_to_devices.values()) - ) - @property def default_device(self) -> jax.Device: """Returns the device that should be set to the default device. @@ -258,15 +89,20 @@ def active_slice_count(self) -> int: """Returns the number of slices.""" return len(self.active_slice_indices) + @property + def inactive_slice_indices(self) -> set[int]: + """Returns the set of inactive slice indices.""" + return self.all_slice_indices - self.active_slice_indices + def scale_by_active_slices(self, x: int | float) -> int | float: - """Scale x by the number of good slices.""" + """Scale x by the number of active slices.""" if isinstance(x, int): quotient, remainder = divmod( x * self.active_slice_count, self.total_slice_count ) if remainder: raise ValueError( - f"Cannot scale {x=} by good slices because it will result in a " + f"Cannot scale {x=} by active slices because it will result in a " f"remainder of {remainder=}." ) return quotient @@ -275,81 +111,58 @@ def scale_by_active_slices(self, x: int | float) -> int | float: else: raise ValueError(f"Unsupported type: {type(x)=}") - def wait_for_slices( + def _elasticity_retry_decorator( self, - slice_count: int | None = None, - poll_interval: float | int = 10, - timeout: float | int | None = None, - ) -> set[int]: - """Waits until after at least `slice_count` slices become active. - - Args: - slice_count: The number of slices to wait for. If None, waits for all - slices to become active. - poll_interval: The minimum number of seconds to wait between availability - checks. If the check takes longer than this, the next check will start - immediately after the current check completes. Defaults to 10 seconds. - timeout: The maximum number of seconds to wait. If None, there is no - timeout. - - Returns: - The good slice indices - - Raises: - TimeoutError: If the timeout is reached before the slices become - active. - """ - if slice_count is None: - slice_count = self.total_slice_count - - start_time = time.time() + max_retries: int, + pre_func: Callable[..., Any] | None = None, + ) -> Callable[[_F], _F]: + """Retries a function with elasticity fault tolerance.""" + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for retry_index in range(max_retries): + try: + _logger.info( + "Elastic attempt %d out of %d", retry_index + 1, max_retries + ) + if pre_func is not None: + pre_func() + + with jax.default_device(self.default_device): + return func(*args, **kwargs) + except (jax.errors.JaxRuntimeError, NewSliceAvailableError) as error: + if isinstance( + error, jax.errors.JaxRuntimeError + ) and not elastic.is_error_due_to_slice_down(error): + raise - while True: - check_start_time = time.time() + _logger.info("Slice down event detected. Retrying.") - active_slice_indices = self.get_active_slice_indices() - if len(active_slice_indices) >= slice_count: - _logger.info( - "%s/%s slices are active", - len(active_slice_indices), - self.total_slice_count, - ) - return active_slice_indices - - _logger.info( - "%s/%s slices active. Wanting at least %s/%s.", - len(active_slice_indices), - self.total_slice_count, - slice_count, - self.total_slice_count, - ) + try: + _logger.debug("Cleaning up any ongoing traces") + jax.profiler.stop_trace() + except (RuntimeError, ValueError) as e: + _logger.debug("No ongoing traces to clean up") + except Exception: + _logger.exception("Error cleaning up ongoing traces") + raise - time_to_sleep = max(0, poll_interval - (time.time() - check_start_time)) - - if ( - timeout is not None - and (elapsed_time := time.time() - start_time) + time_to_sleep - >= timeout - ): - raise TimeoutError( - f"Timed out waiting for {slice_count} slices. Only" - f" {len(active_slice_indices)} active after" - f" {elapsed_time:.2f} seconds." - f" Next check would occur after the timeout of {timeout}" - " seconds." + jax.clear_caches() + for array in jax.live_arrays(): + array.delete() + raise ElasticRuntimeError( + f"Elastic attempt {max_retries} out of {max_retries} failed." ) - if time_to_sleep > 0: - _logger.info("Sleeping for %.2f seconds.", time_to_sleep) - - time.sleep(time_to_sleep) + return wrapper + return decorator def pause_resume( self, max_retries: int, poll_interval: float | int = 10, timeout: float | None = None, - ) -> Any: + ) -> Callable[[_F], _F]: """Retries a function with pause/resume fault tolerance. This decorator wraps a function to automatically retry execution in case of @@ -379,37 +192,108 @@ def pause_resume( Exception: Any other exception raised by the wrapped function that is not due to a slice down event. """ + def pre_func(): + elastic.wait_for_slices( + slice_count=self.total_slice_count, + slice_to_devices=self.slice_to_devices, + poll_interval=poll_interval, + timeout=timeout, + ) + + return self._elasticity_retry_decorator( + max_retries=max_retries, pre_func=pre_func + ) + + def replica_resize( + self, + max_resizes: int, + poll_interval: float = 10, + ) -> Callable[[_F], _F]: + """Retries a function with replica/resize fault tolerance. + + Args: + max_resizes: The maximum number of times to retry the function after + resizing the replica count. + poll_interval: The number of seconds to wait between active slice checks. + Defaults to 10 seconds. + + Returns: + The result of the wrapped function. + + Raises: + ElasticRuntimeError: If all retry attempts fail. + Exception: Any other exception raised by the wrapped function that is not + due to a slice down event. + """ + + def pre_func(): + elastic.wait_for_slices( + slice_count=1, + slice_to_devices=self.slice_to_devices, + poll_interval=poll_interval, + ) + self.active_slice_indices = elastic.get_active_slice_indices( + self.slice_to_devices + ) + + retry_decorator = self._elasticity_retry_decorator( + max_retries=max_resizes, pre_func=pre_func + ) + def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - for retry_index in range(max_retries): - try: - _logger.info( - "Elastic attempt %d out of %d", retry_index + 1, max_retries - ) - - self.wait_for_slices(poll_interval=poll_interval, timeout=timeout) - - return func(*args, **kwargs) - except jax.errors.JaxRuntimeError as error: - if not self.is_error_due_to_slice_down(error): - raise + stop_event = threading.Event() + new_slice_event = threading.Event() + def monitor(): + while not stop_event.wait(poll_interval): try: - _logger.info("Cleaning up any ongoing traces") - jax.profiler.stop_trace() - except (RuntimeError, ValueError) as e: - _logger.info("No ongoing traces to clean up") - except Exception: - _logger.exception("Error cleaning up ongoing traces") - raise + if not self.inactive_slice_indices: + _logger.debug("No inactive slices to check.") + continue + + _logger.debug( + "Checking inactive slices: %s", self.inactive_slice_indices + ) + inactive_slice_to_devices = { + i: self.slice_to_devices[i] + for i in self.inactive_slice_indices + } + newly_active_indices = elastic.get_active_slice_indices( + inactive_slice_to_devices + ) + + if newly_active_indices: + _logger.info( + "New slices found: %s. Interrupting main thread.", + newly_active_indices, + ) + new_slice_event.set() + _thread.interrupt_main() + return + + _logger.debug("No new slices found.") + except Exception: # pylint: disable=broad-exception-caught + _logger.exception("Error in monitor thread") + + monitor_thread = threading.Thread(target=monitor, daemon=True) + monitor_thread.start() + try: + return func(*args, **kwargs) + except KeyboardInterrupt: + if new_slice_event.is_set(): + raise NewSliceAvailableError("New slice available") from None + raise + finally: + stop_event.set() + try: + monitor_thread.join() + except KeyboardInterrupt: + if new_slice_event.is_set(): + raise NewSliceAvailableError("New slice available") from None + raise - jax.clear_caches() - for array in jax.live_arrays(): - array.delete() - raise ElasticRuntimeError( - f"Elastic attempt {max_retries} out of {max_retries} failed." - ) + return retry_decorator(wrapper) - return wrapper return decorator