diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index d03d770d6..db51a8cee 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -14,14 +14,13 @@ from collections import deque from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from contextlib import ExitStack from dataclasses import dataclass from functools import partial import inspect import os -from threading import Condition -from threading import Lock -from threading import RLock +import threading import time from types import TracebackType from typing import Any @@ -87,38 +86,110 @@ class _WorkTracker: """Track the amount of work that is in progress.""" def __init__(self) -> None: - # Number of tasks that are being executed - self._num_work_executing = 0 - self._work_condition = Condition() - - def __enter__(self) -> None: - """Increment the amount of executing work by 1.""" - with self._work_condition: - self._num_work_executing += 1 - - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], exctb: Optional[TracebackType]) -> None: - """Decrement the amount of work executing by 1.""" + self._work_condition = threading.Condition() + # Per-thread reentrant counter of in-flight work. A thread has + # an entry iff it is currently inside __enter__/__exit__; the + # value is the reentrance depth. The set of keys is the + # set of threads currently running a callback. + self._executing_thread_counts: Dict[threading.Thread, int] = {} + # Threads whose in-flight callback (if any) is committed to + # finishing rather than making further progress -- i.e., the + # thread is parked in wait() or has already returned from + # wait() and is finishing the surrounding shutdown. Their + # callback work should not block other waiters' drain checks; + # otherwise two callbacks on different worker threads both + # calling Executor.shutdown() would deadlock on each other. + # An entry is removed only when the owning callback ends + # (__exit__) or, for external callers with no in-flight + # callback, immediately when their wait() returns. + self._waiting_threads: Set[threading.Thread] = set() + + @contextmanager + def track_callback(self) -> Generator[None, None, None]: + """ + Track an in-flight callback for the duration of the context. + + The owning thread is captured at enter time and used to decrement + the per-thread count when the context exits -- even if the exit + runs on a different thread. That happens when a coroutine using + this context manager is suspended at an inner ``await`` and then + closed via GC (e.g. during executor teardown): ``coro.close()`` + raises ``GeneratorExit`` at the suspension point on whatever + thread the GC happened to run on, and the ``with`` block's + unwinding -- including this finally -- runs on that thread, not + the original worker thread. Using ``threading.current_thread()`` + in the finally would either lose the decrement (best case) or + ``KeyError`` (current case). + """ + owner = threading.current_thread() with self._work_condition: - self._num_work_executing -= 1 - self._work_condition.notify_all() + self._executing_thread_counts[owner] = ( + self._executing_thread_counts.get(owner, 0) + 1) + try: + yield + finally: + with self._work_condition: + count = self._executing_thread_counts[owner] - 1 + if count == 0: + del self._executing_thread_counts[owner] + # The thread's callback has ended, so it's no longer + # "committed to finishing" -- drop its waiter + # membership. + self._waiting_threads.discard(owner) + else: + self._executing_thread_counts[owner] = count + self._work_condition.notify_all() def wait(self, timeout_sec: Optional[float] = None) -> bool: """ Wait until all work completes. + Work being executed by the calling thread is excluded from the wait, + since that work is necessarily blocked on this call returning. Work + being executed by any other thread that has itself entered wait() is + also excluded, so concurrent shutdown() calls from inside callbacks + on different worker threads don't deadlock waiting for each other. + :param timeout_sec: Seconds to wait. Block forever if None or negative. Don't wait if 0 :type timeout_sec: float or None :rtype: bool True if all work completed """ if timeout_sec is not None and timeout_sec < 0.0: timeout_sec = None - # Wait for all work to complete + current = threading.current_thread() + + def other_work_drained() -> bool: + # True once every thread with in-flight work is itself in the + # waiting set, i.e., committed to finishing rather than making + # progress on its callback. + return self._executing_thread_counts.keys() <= self._waiting_threads + with self._work_condition: - if not self._work_condition.wait_for( - lambda: self._num_work_executing == 0, timeout_sec): - return False - return True + added_self = current not in self._waiting_threads + if added_self: + self._waiting_threads.add(current) + # A new waiter may have just satisfied an existing + # waiter's condition (its in-flight work is now excluded). + self._work_condition.notify_all() + drained = False + try: + drained = self._work_condition.wait_for(other_work_drained, timeout_sec) + return drained + finally: + # Keep the waiter membership while a callback is still + # in flight on this thread -- removing it now would let + # other concurrent waiters' predicates flip back to + # False and re-block until our callback ends, even + # though our callback is committed to finishing (we are + # past wait() and the rest of shutdown is cleanup). + # __exit__ will drop the membership when the callback + # ends. For external callers with no in-flight + # callback, no __exit__ will run, so discard here. + # However, if we timed out (not drained), we are NOT committed + # to finishing successfully, so we must discard the membership. + if added_self and (not drained or current not in self._executing_thread_counts): + self._waiting_threads.discard(current) + self._work_condition.notify_all() @overload @@ -212,12 +283,12 @@ def __init__(self, *, context: Optional[Context] = None) -> None: super().__init__() self._context = get_default_context() if context is None else context self._nodes: Set[Node] = set() - self._nodes_lock = RLock() + self._nodes_lock = threading.RLock() # all tasks that are not complete or canceled self._pending_tasks: Dict[Task[Any], TaskData] = {} # tasks that are ready to execute self._ready_tasks: Deque[Task[Any]] = deque() - self._tasks_lock = Lock() + self._tasks_lock = threading.Lock() # This is triggered when wait_for_ready_callbacks should rebuild the wait list self._guard: Optional[GuardCondition] = GuardCondition( callback=None, callback_group=None, context=self._context) @@ -225,7 +296,7 @@ def __init__(self, *, context: Optional[Context] = None) -> None: self._is_shutdown = False self._work_tracker = _WorkTracker() # Protect against shutdown() being called in parallel in two threads - self._shutdown_lock = Lock() + self._shutdown_lock = threading.Lock() # State for wait_for_ready_callbacks to reuse generator self._cb_iter: Optional[YieldedCallback] = None self._last_args: Optional[tuple[object, ...]] = None @@ -238,24 +309,28 @@ def __init__(self, *, context: Optional[Context] = None) -> None: # True when the executor is spinning self._is_spinning = False # Protects access to _is_spinning - self._is_spinning_lock = Lock() + self._is_spinning_cond = threading.Condition() + self._spinning_thread: Optional[threading.Thread] = None def _enter_spin(self) -> None: """Mark the executor as spinning and prevent concurrent spins.""" - with self._is_spinning_lock: + with self._is_spinning_cond: if self._is_spinning: raise RuntimeError('Executor is already spinning') self._is_spinning = True + self._spinning_thread = threading.current_thread() def _exit_spin(self) -> None: """Clear the spinning flag.""" - with self._is_spinning_lock: + with self._is_spinning_cond: self._is_spinning = False + self._spinning_thread = None + self._is_spinning_cond.notify_all() @property def is_spinning(self) -> bool: """Return whether the executor is currently spinning.""" - with self._is_spinning_lock: + with self._is_spinning_cond: return self._is_spinning @property @@ -315,19 +390,55 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool: timeout expires before all outstanding work is done. """ with self._shutdown_lock: - if not self._is_shutdown: + initiated_shutdown = not self._is_shutdown + if initiated_shutdown: self._is_shutdown = True # Tell executor it's been shut down if self._guard: self._guard.trigger() - if not self._is_shutdown: - if not self._work_tracker.wait(timeout_sec): - return False + # The timeout applies to the whole shutdown operation — both the + # callback drain and the spinner exit — not to each wait + # individually. Convert it into a deadline once; each wait below + # gets only the time remaining against that deadline. + if timeout_sec is None or timeout_sec < 0: + deadline: Optional[float] = None # block forever + else: + deadline = time.monotonic() + timeout_sec + + def remaining_timeout() -> Optional[float]: + if deadline is None: + return None + return max(0.0, deadline - time.monotonic()) + + # Wait for any in-flight callbacks on OTHER threads to drain. Done + # unconditionally (not just for the initiating call) so that: + # - concurrent shutdown() calls don't race past the wait and start + # destroying state while callbacks are still running, and + # - a caller who got False back from a timed-out shutdown() can + # simply call shutdown() again (with a longer or no timeout) and + # have the second call actually wait + finish cleanup. + # _work_tracker.wait excludes work being executed by the calling + # thread, so this is safe from inside a callback — it will not + # self-deadlock. + if not self._work_tracker.wait(remaining_timeout()): + return False # Clean up stuff that won't be used anymore with self._nodes_lock: self._nodes = set() + with self._is_spinning_cond: + if self._spinning_thread is not threading.current_thread(): + # Wait for the spin thread to acknowledge shutdown and + # exit before we destroy the guards (which the spinner + # may still be holding in its wait_set). If the wait + # times out, return False per the contract — don't + # destroy resources that the spinner might still touch. + if not self._is_spinning_cond.wait_for( + lambda: not self._is_spinning, + timeout=remaining_timeout()): + return False + with self._shutdown_lock: if self._guard: self._guard.destroy() @@ -668,7 +779,7 @@ async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool, entity._executor_event = False gc.trigger() return - with work_tracker: + with work_tracker.track_callback(): # The take_from_wait_list method here is expected to return either an async def # method or None if there is no work to do. call_coroutine = take_from_wait_list(entity) @@ -1087,7 +1198,7 @@ def __init__( 'Use the SingleThreadedExecutor instead.') self._futures: List[Future[Any]] = [] self._executor = ThreadPoolExecutor(num_threads) - self._futures_lock = Lock() + self._futures_lock = threading.Lock() def _spin_once_impl( self, @@ -1157,10 +1268,25 @@ def shutdown( :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. Don't wait if 0. :param wait_for_threads: If true, this function will block until all executor threads - have joined. + have joined. When shutdown() is called from inside a callback running on one of + this executor's worker threads, the *current* thread is necessarily excluded from + that join (Python cannot join a thread with itself) -- the rest of the callback + will finish after this returns and the worker will exit naturally. :return: ``True`` if all outstanding callbacks finished executing, or ``False`` if the timeout expires before all outstanding work is done. """ success: bool = super().shutdown(timeout_sec) - self._executor.shutdown(wait=wait_for_threads) + # Always tell the pool to shut down without waiting: if shutdown() + # was called from inside a callback running on one of these + # workers, letting ThreadPoolExecutor.shutdown(wait=True) join the + # current thread would raise RuntimeError. We do the joins below + # ourselves so we can skip the current thread. + self._executor.shutdown(wait=False) + if wait_for_threads: + current = threading.current_thread() + # Snapshot before iterating; _threads is stable post-shutdown + # (no new workers are spawned) but we copy defensively. + for t in list(self._executor._threads): + if t is not current: + t.join() return success diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 18472109d..635c007a9 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -17,6 +17,7 @@ import threading import time from typing import Generator +from typing import List from typing import Optional from typing import Protocol from typing import Set @@ -24,8 +25,10 @@ import warnings import rclpy +from rclpy.callback_groups import MutuallyExclusiveCallbackGroup from rclpy.callback_groups import ReentrantCallbackGroup from rclpy.context import Context +from rclpy.executors import _WorkTracker from rclpy.executors import Executor from rclpy.executors import MultiThreadedExecutor from rclpy.executors import ShutdownException @@ -80,10 +83,9 @@ def test_single_threaded_executor_executes(self) -> None: finally: executor.shutdown() - @unittest.skip('Flaky on CI - see issue #1648') def test_executor_immediate_shutdown(self) -> None: self.assertIsNotNone(self.node.handle) - for cls in [SingleThreadedExecutor, EventsExecutor]: + for cls in [SingleThreadedExecutor, MultiThreadedExecutor, EventsExecutor]: with self.subTest(cls=cls): executor = cls(context=self.context) try: @@ -767,7 +769,6 @@ def timer_callback() -> None: self.node.destroy_timer(tmr) def test_shutdown_executor_from_callback(self) -> None: - """https://github.com/ros2/rclpy/issues/944: allow for executor shutdown from callback.""" self.assertIsNotNone(self.node.handle) timer_period = 0.1 # TODO(bmartin427) This seems like an invalid test to me? executor.shutdown() is @@ -788,6 +789,251 @@ def timer_callback() -> None: self.assertTrue(shutdown_event.wait(120)) self.node.destroy_timer(tmr) + def test_shutdown_timeout_then_retry(self) -> None: + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor(context=self.context) + + callback_started = threading.Event() + callback_should_finish = threading.Event() + + def long_callback() -> None: + callback_started.set() + # Bound the wait so a broken test fails rather than hangs. + callback_should_finish.wait(timeout=10) + + tmr = self.node.create_timer(0.1, long_callback) + executor.add_node(self.node) + spin_thread = threading.Thread(target=executor.spin, daemon=True) + spin_thread.start() + + try: + # Wait for the callback to be running so the work_tracker count + # is guaranteed to be non-zero when shutdown's wait runs. + self.assertTrue(callback_started.wait(timeout=5)) + + # First shutdown: this is the one that flips _is_shutdown to + # True. Times out because the callback is still in flight. + self.assertFalse(executor.shutdown(timeout_sec=0.1)) + + # Second shutdown: _is_shutdown is already True. The callback + # is STILL in flight. The wait must run again (regardless of + # who initiated) and time out -- if it incorrectly skipped + # the wait, this would return True and race cleanup against + # the running callback. + self.assertFalse(executor.shutdown(timeout_sec=0.1)) + + # Release the callback. + callback_should_finish.set() + + # Final shutdown: callbacks have drained (or will momentarily). + # The wait should now succeed and cleanup should complete. + self.assertTrue(executor.shutdown(timeout_sec=5)) + finally: + # Guard rails in case an assertion above interrupts the flow. + callback_should_finish.set() + spin_thread.join(timeout=5) + self.node.destroy_timer(tmr) + + def test_work_tracker_coroutine_closed_on_different_thread(self) -> None: + class YieldOnce: + + def __await__(self) -> Generator[None, None, None]: + yield None + return None + + wt = _WorkTracker() + + async def callback_like() -> None: + with wt.track_callback(): + await YieldOnce() + + coro = callback_like() + + # Start the coroutine on thread A: runs through track_callback's + # __enter__ (incrementing the count for thread A) and suspends + # at ``await YieldOnce()``. + def thread_a() -> None: + coro.send(None) + + ta = threading.Thread(target=thread_a, name='WorkerA') + ta.start() + ta.join(timeout=5) + self.assertFalse(ta.is_alive()) + + # Close the coroutine from thread B (a different thread, mimicking + # the GC thread). Pre-fix this raises KeyError because the + # __exit__ looked up _executing_thread_counts[current_thread] + # and current_thread is thread B, not the thread that entered. + errors: List[BaseException] = [] + + def thread_b() -> None: + try: + coro.close() + except BaseException as e: + errors.append(e) + + tb = threading.Thread(target=thread_b, name='GCThread') + tb.start() + tb.join(timeout=5) + self.assertFalse(tb.is_alive()) + + self.assertFalse(errors, f'close() raised: {errors!r}') + self.assertFalse( + wt._executing_thread_counts, + f'work tracker not cleaned up: {wt._executing_thread_counts!r}') + + def test_shutdown_from_multithreaded_executor_callback(self) -> None: + self.assertIsNotNone(self.node.handle) + executor = MultiThreadedExecutor(num_threads=2, context=self.context) + + shutdown_returned = threading.Event() + shutdown_error: List[BaseException] = [] + + def timer_callback() -> None: + try: + # Default wait_for_threads=True is what triggers the bug. + executor.shutdown(timeout_sec=5) + except BaseException as e: + shutdown_error.append(e) + finally: + shutdown_returned.set() + + tmr = self.node.create_timer(0.1, timer_callback) + executor.add_node(self.node) + spin_thread = threading.Thread(target=executor.spin, daemon=True) + spin_thread.start() + + try: + self.assertTrue( + shutdown_returned.wait(timeout=15), + 'shutdown() never returned from inside the callback') + self.assertFalse( + shutdown_error, + f'shutdown() raised: {shutdown_error!r}') + finally: + spin_thread.join(timeout=5) + self.node.destroy_timer(tmr) + + def test_concurrent_shutdown_from_two_callbacks(self) -> None: + self.assertIsNotNone(self.node.handle) + executor = MultiThreadedExecutor(num_threads=2, context=self.context) + + # Distinct callback groups so the two timers can be dispatched to + # separate worker threads concurrently. (The default callback + # group is MutuallyExclusive at the node level.) + cb_group_a = MutuallyExclusiveCallbackGroup() + cb_group_b = MutuallyExclusiveCallbackGroup() + + # Use a barrier to make both callbacks reach shutdown() at the + # same time, so they are both inside _work_tracker.wait + # simultaneously -- the scenario the regression guards against. + barrier = threading.Barrier(2) + results: List[bool] = [] + results_lock = threading.Lock() + all_done = threading.Event() + + def shutdown_from_callback() -> None: + try: + barrier.wait(timeout=5) + except threading.BrokenBarrierError: + return + ok = executor.shutdown(timeout_sec=5, wait_for_threads=False) + with results_lock: + results.append(ok) + if len(results) == 2: + all_done.set() + + tmr_a = self.node.create_timer( + 0.05, shutdown_from_callback, callback_group=cb_group_a) + tmr_b = self.node.create_timer( + 0.05, shutdown_from_callback, callback_group=cb_group_b) + + executor.add_node(self.node) + spin_thread = threading.Thread(target=executor.spin, daemon=True) + spin_thread.start() + + try: + # Wait for both callbacks to finish their shutdown calls -- + # don't gate on the spinner exiting, since the spinner can + # exit before either callback has appended its result. + self.assertTrue( + all_done.wait(timeout=15), + f'only {len(results)}/2 shutdowns completed -- ' + 'concurrent shutdown deadlocked') + spin_thread.join(timeout=5) + self.assertFalse(spin_thread.is_alive(), 'spin thread did not exit') + with results_lock: + self.assertTrue( + all(results), + f'shutdown() returned False (timed out): {results}') + finally: + barrier.abort() + spin_thread.join(timeout=5) + self.node.destroy_timer(tmr_a) + self.node.destroy_timer(tmr_b) + + def test_work_tracker_waiter_leak_on_timeout(self) -> None: + wt = _WorkTracker() + + worker_a_running = threading.Event() + worker_a_should_exit = threading.Event() + worker_b_running = threading.Event() + + errors_a = [] + wait_returned_a = [] + + def worker_a_thread(): + try: + with wt.track_callback(): + worker_a_running.set() + # Call wait with a short timeout. This times out because + # Worker B is executing a callback and not waiting. + res = wt.wait(timeout_sec=0.1) + wait_returned_a.append(res) + # Stay alive inside the callback context + worker_a_should_exit.wait() + except Exception as e: + errors_a.append(e) + + def worker_b_thread(): + with wt.track_callback(): + worker_b_running.set() + # Run until Worker A's wait times out + time.sleep(0.5) + + tb = threading.Thread(target=worker_b_thread, name='WorkerB') + tb.start() + + ta = threading.Thread(target=worker_a_thread, name='WorkerA') + ta.start() + + self.assertTrue(worker_a_running.wait(timeout=2.0)) + self.assertTrue(worker_b_running.wait(timeout=2.0)) + + # Wait for Worker B to finish + tb.join(timeout=2.0) + + self.assertFalse(errors_a) + self.assertEqual(wait_returned_a, [False]) + + # MainThread calls wait(). Since Worker B finished, only Worker A is active. + # However, Worker A leaked into _waiting_threads from the timeout. + # MainThread's wait() will prematurely evaluate to True and exit instantly. + start_time = time.monotonic() + res_main = wt.wait(timeout_sec=0.2) + elapsed = time.monotonic() - start_time + + try: + # Under the bug, res_main is True and elapsed is ~0.0s. + # In the corrected code, it correctly blocks/returns False after 0.2s. + self.assertFalse( + res_main, + 'MainThread wait should have timed out because WorkerA is still running') + self.assertGreaterEqual(elapsed, 0.15) + finally: + worker_a_should_exit.set() + ta.join(timeout=2.0) + def test_context_manager(self) -> None: self.assertIsNotNone(self.node.handle)