diff --git a/swarm_gpt/core/drone_swarm.py b/swarm_gpt/core/drone_swarm.py index a7d702e..4394a2c 100644 --- a/swarm_gpt/core/drone_swarm.py +++ b/swarm_gpt/core/drone_swarm.py @@ -4,7 +4,9 @@ import asyncio import logging +import multiprocessing as mp import os +import threading from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -13,11 +15,12 @@ from cflib2.toc_cache import FileTocCache os.environ["SCIPY_ARRAY_API"] = "1" -from scipy.interpolate import interp1d from scipy.spatial.transform import Rotation as R if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable, Mapping + from concurrent.futures import Future + from multiprocessing.synchronize import Event from drone_estimators.ros_nodes.ros2_connector import ROSConnector from numpy.typing import NDArray as Array @@ -68,6 +71,9 @@ def __init__( self.toc_cache = FileTocCache("./cache") self.ros_connector: ROSConnector | None = None self._loop = asyncio.new_event_loop() + self._loop_thread: threading.Thread | None = None + self._estimator_stop_event: Event | None = None + self._estimator_future: Future[None] | None = None self._closed = False if not lighthouse: @@ -81,13 +87,23 @@ def __init__( self._run(self._connect()) if self.lighthouse: self._run(self._check_lighthouse_decks()) + else: + self._start_estimator_updater() self.reset() except BaseException: + if self._estimator_stop_event is not None: + self._estimator_stop_event.set() + if self._estimator_future is not None: + try: + self._estimator_future.result(timeout=max(1.0, 2 / self.update_freq)) + except Exception as exc: + logger.warning(f"Stopping estimator updater failed: {exc}") if self.cfs: try: self._run(self._disconnect()) except Exception as exc: logger.error(f"Disconnecting after initialization failure failed: {exc}") + self._stop_loop_thread() self._loop.close() if self.ros_connector is not None: self.ros_connector.close() @@ -112,12 +128,11 @@ def takeoff(self, height: float = 1.5, duration: float = 3.0): async def _takeoff(uri: str) -> None: cf = self._cf(uri) - await self._send_external_pose(uri) await self._change_commander_level(uri, "high") await cf.high_level_commander().take_off(height, None, duration, None) - await self._update_external_pose_during(uri, duration) + await asyncio.sleep(duration) - self._run(self._parallel_by_uri("Taking off", self.uris, _takeoff)) # TODO add timeout + self._run(self._parallel_by_uri("Taking off", self.uris, _takeoff, timeout=duration + 1.0)) def land(self, height: float = 0.0, duration: float = 3.0): """Land the drones at a given height over a given duration.""" @@ -127,10 +142,10 @@ async def _land(uri: str) -> None: await self._change_commander_level(uri, "high") high_level_commander = cf.high_level_commander() await high_level_commander.land(height, None, duration, None) - await self._update_external_pose_during(uri, duration) + await asyncio.sleep(duration) await high_level_commander.stop(None) - self._run(self._parallel_by_uri("Landing", self.uris, _land)) # TODO add timeout + self._run(self._parallel_by_uri("Landing", self.uris, _land, timeout=duration + 1.0)) def goto(self, target: dict[str, list], duration: float = 3.0): """Execute a high-level goto command for all drones. @@ -146,22 +161,19 @@ def goto(self, target: dict[str, list], duration: float = 3.0): async def _goto(uri: str) -> None: cf = self._cf(uri) - if not self.lighthouse: - await self._send_external_pose(uri) await self._change_commander_level(uri, "high") await cf.high_level_commander().go_to( *target[uri], duration, relative=False, linear=True, group_mask=None ) - await self._update_external_pose_during(uri, duration) + await asyncio.sleep(duration) self._run(self._parallel_by_uri("Goto", self.uris, _goto, timeout=duration + 1.0)) - def setpoint(self, target: dict[str, list], duration: float = 3.0): - """Stream a constant position+yaw setpoint to all drones. + def setpoint(self, target: dict[str, list]): + """Send one position+yaw setpoint to all drones and return. Args: target: Position+Yaw references in the form {'uri1': [target], ...}. - duration: Duration of the setpoint stream in seconds. """ self._validate_required_uris("pos", target) for uri, setpoint in target.items(): @@ -170,14 +182,9 @@ def setpoint(self, target: dict[str, list], duration: float = 3.0): async def _setpoint(uri: str) -> None: await self._change_commander_level(uri, "low") - ref = interp1d( - [0.0, duration], - [np.asarray(target[uri], dtype=float), np.asarray(target[uri], dtype=float)], - axis=0, - ) - await self._stream_reference(uri, duration, lambda t: np.asarray(ref(t), dtype=float)) + await self._cf(uri).commander().send_setpoint_position(*target[uri]) - self._run(self._parallel_by_uri("Setpoint", self.uris, _setpoint)) # TODO add timeout + self._run(self._parallel_by_uri("Setpoint", self.uris, _setpoint, timeout=0.5)) def execute_choreography( self, @@ -211,7 +218,11 @@ async def _execute(uri: str) -> None: (color_bot or {}).get(uri, {}), ) - self._run(self._parallel_by_uri("Choreography execution", self.uris, _execute)) # TODO add timeout + self._run( + self._parallel_by_uri( + "Choreography execution", self.uris, _execute, timeout=t_end + 1.0 + ) + ) def apply_colors(self, color_top: dict[str, Array] | None, color_bot: dict[str, Array] | None): """Apply colors to the drones. @@ -233,7 +244,7 @@ async def _apply_colors(uri: str) -> None: if uri in color_bot: await self._apply_drone_color(uri, color_bot[uri], "bot") - self._run(self._parallel_by_uri("Applying colors", self.uris, _apply_colors)) # TODO add timeout + self._run(self._parallel_by_uri("Applying colors", self.uris, _apply_colors, timeout=0.5)) def set_param(self, param: str, value: float): """Set a Crazyflie parameter on all active drones. @@ -246,7 +257,9 @@ def set_param(self, param: str, value: float): async def _set_param(uri: str) -> None: await self._cf(uri).param().set(param, value) - self._run(self._parallel_by_uri(f"Setting parameter {param}", self.uris, _set_param)) + self._run( + self._parallel_by_uri(f"Setting parameter {param}", self.uris, _set_param, timeout=0.5) + ) def emergency_stop(self, uri: str | None = None): """Send an emergency stop signal to one URI or all drones (default).""" @@ -255,7 +268,7 @@ def emergency_stop(self, uri: str | None = None): else: self._validate_known_uris("uri", {uri: None}) uris = [uri] - self._run(self._parallel_by_uri("Emergency stop", uris, self._emergency_stop)) # TODO add timeout + self._run(self._parallel_by_uri("Emergency stop", uris, self._emergency_stop, timeout=0.5)) def reset(self): """Reset all active drones.""" @@ -302,9 +315,13 @@ async def _close() -> None: active_uris = [uri for uri in self.uris if uri in self.active_uris] if active_uris: try: - await self._parallel_by_uri("Emergency stop", active_uris, self._emergency_stop) + await self._parallel_by_uri( + "Emergency stop", active_uris, self._emergency_stop, timeout=0.5 + ) await asyncio.sleep(0.1) - await self._parallel_by_uri("Shutdown LEDs", active_uris, _shutdown_leds) + await self._parallel_by_uri( + "Shutdown LEDs", active_uris, _shutdown_leds, timeout=0.5 + ) await asyncio.sleep(0.2) except RuntimeError as exc: logger.warning(f"Shutdown failed: {exc}") @@ -312,16 +329,49 @@ async def _close() -> None: await self._disconnect() try: + if self._estimator_stop_event is not None: + self._estimator_stop_event.set() + if self._estimator_future is not None: + try: + self._estimator_future.result(timeout=max(1.0, 2 / self.update_freq)) + except Exception as exc: + logger.warning(f"Stopping estimator updater failed: {exc}") self._run(_close()) finally: + self._estimator_future = None + self._estimator_stop_event = None + self._stop_loop_thread() self._loop.close() if self.ros_connector is not None: self.ros_connector.close() def _run(self, coroutine: Awaitable[Any]) -> Any: """Run a cflib2 coroutine on the swarm event loop.""" + if self._loop_thread is not None: + return asyncio.run_coroutine_threadsafe(coroutine, self._loop).result() return self._loop.run_until_complete(coroutine) + def _start_estimator_updater(self) -> None: + """Start continuous mocap updates after the radio links are connected.""" + ctx = mp.get_context("spawn") + self._estimator_stop_event = ctx.Event() + self._loop_thread = threading.Thread( + target=self._loop.run_forever, name="drone-swarm-event-loop", daemon=True + ) + self._loop_thread.start() + self._estimator_future = asyncio.run_coroutine_threadsafe( + self._update_estimators(self._estimator_stop_event), self._loop + ) + + def _stop_loop_thread(self) -> None: + """Stop the persistent event loop used by the estimator updater.""" + if self._loop_thread is None: + return + + self._loop.call_soon_threadsafe(self._loop.stop) + self._loop_thread.join() + self._loop_thread = None + def _cf(self, uri: str) -> Crazyflie: if uri not in self.active_uris: raise RuntimeError(f"Drone {uri} is not active.") @@ -403,7 +453,11 @@ async def _check_lighthouse_decks(self) -> None: ) async def _parallel_by_uri( - self, action_name: str, uris: Iterable[str], action: Callable[[str], Awaitable[None]], timeout: float | None = None + self, + action_name: str, + uris: Iterable[str], + action: Callable[[str], Awaitable[None]], + timeout: float | None = None, ) -> None: target_uris = [uri for uri in uris if uri in self.active_uris and uri in self.cfs] @@ -507,25 +561,30 @@ async def _apply_drone_color( if deck == "bot" or deck == "both": await param.set("colorLedBot.wrgb8888", color) - async def _send_external_pose(self, uri: str) -> None: - if self.lighthouse: - return - obs = await self._read_observation(uri) - await ( - self._cf(uri) - .localization() - .external_pose() - .send_external_pose( - pos=np.asarray(obs["pos"], dtype=float).tolist(), - quat=np.asarray(obs["quat"], dtype=float).tolist(), - ) - ) - - async def _update_external_pose_during(self, uri: str, duration: float) -> None: - end_time = asyncio.get_running_loop().time() + duration - while asyncio.get_running_loop().time() < end_time: - await self._send_external_pose(uri) - await asyncio.sleep(1 / self.update_freq) + async def _update_estimators(self, stop_event: Event) -> None: + """Continuously send mocap poses to every active Crazyflie estimator.""" + assert self.ros_connector is not None, "Estimator updates require lighthouse=False." + period = 1 / self.update_freq + next_tick = asyncio.get_running_loop().time() + while not stop_event.is_set(): + # These properties already copy the synchronized ROS arrays into fresh NumPy snapshots. + positions = self.ros_connector.pos + quaternions = self.ros_connector.quat + + async def _update_estimator(uri: str) -> None: + drone_name = f"cf{int(uri[-2:], 16):02d}" + await ( + self._cf(uri) + .localization() + .external_pose() + .send_external_pose( + pos=positions[drone_name].tolist(), quat=quaternions[drone_name].tolist() + ) + ) + + await self._parallel_by_uri("Updating estimators", self.uris, _update_estimator) + next_tick += period + await asyncio.sleep(max(0.0, next_tick - asyncio.get_running_loop().time())) async def _stream_reference( self, @@ -536,7 +595,6 @@ async def _stream_reference( color_bot: dict[float, Array] | None = None, ) -> None: commander = self._cf(uri).commander() - t_est = -np.inf top_cues = sorted((float(t), wrgb) for t, wrgb in (color_top or {}).items()) bot_cues = sorted((float(t), wrgb) for t, wrgb in (color_bot or {}).items()) i_next_top = 0 @@ -548,10 +606,6 @@ async def _stream_reference( t_col = -np.inf while (t_cur := asyncio.get_running_loop().time() - start_time) < duration: - if not self.lighthouse and t_cur - t_est >= 1 / self.update_freq: - await self._send_external_pose(uri) - t_est = t_cur - await commander.send_setpoint_position(*reference(t_cur)) if t_cur - t_col >= color_period: diff --git a/tests/unit/test_drone_swarm.py b/tests/unit/test_drone_swarm.py new file mode 100644 index 0000000..d96394a --- /dev/null +++ b/tests/unit/test_drone_swarm.py @@ -0,0 +1,194 @@ +import asyncio +import threading +from typing import Any + +import numpy as np +import pytest +from cflib2.error import DisconnectedError + +from swarm_gpt.core.drone_swarm import DroneSwarm + + +class FakeParam: + def __init__(self) -> None: + self.values: list[tuple[str, Any]] = [] + + async def set(self, name: str, value: Any) -> None: + self.values.append((name, value)) + + +class FakeCommander: + def __init__(self) -> None: + self.setpoints: list[tuple[float, ...]] = [] + + async def send_setpoint_position(self, *setpoint: float) -> None: + self.setpoints.append(setpoint) + + +class FakeExternalPose: + def __init__(self) -> None: + self.sent: list[tuple[list[float], list[float]]] = [] + + async def send_external_pose(self, *, pos: list[float], quat: list[float]) -> None: + self.sent.append((pos, quat)) + + +class FakeLocalization: + def __init__(self) -> None: + self.fake_external_pose = FakeExternalPose() + + def external_pose(self) -> FakeExternalPose: + return self.fake_external_pose + + +class FakeCrazyflie: + def __init__(self) -> None: + self.fake_param = FakeParam() + self.fake_commander = FakeCommander() + self.fake_localization = FakeLocalization() + + def param(self) -> FakeParam: + return self.fake_param + + def commander(self) -> FakeCommander: + return self.fake_commander + + def localization(self) -> FakeLocalization: + return self.fake_localization + + +class FakeROSConnector: + def __init__( + self, + positions: dict[str, list[float]], + quaternions: dict[str, list[float]], + stop_event: threading.Event, + ) -> None: + self.positions = positions + self.quaternions = quaternions + self.stop_event = stop_event + self.pos_reads = 0 + self.quat_reads = 0 + + @property + def pos(self) -> dict[str, np.ndarray]: + self.pos_reads += 1 + return {name: np.asarray(pos) for name, pos in self.positions.items()} + + @property + def quat(self) -> dict[str, np.ndarray]: + self.quat_reads += 1 + self.stop_event.set() + return {name: np.asarray(quat) for name, quat in self.quaternions.items()} + + +def make_swarm(uris: list[str]) -> DroneSwarm: + swarm = object.__new__(DroneSwarm) + swarm.uris = uris + swarm.cfs = {uri: FakeCrazyflie() for uri in uris} + swarm.active_uris = set(uris) + swarm._commander_levels = dict.fromkeys(uris) + swarm._loop = asyncio.new_event_loop() + swarm._loop_thread = None + return swarm + + +def test_setpoint_sends_once_and_returns(): + uris = ["radio://0/80/2M/E7E7E7E701", "radio://0/80/2M/E7E7E7E702"] + swarm = make_swarm(uris) + target = {uris[0]: [1.0, 2.0, 3.0, 4.0], uris[1]: [5.0, 6.0, 7.0, 8.0]} + + try: + swarm.setpoint(target) + finally: + swarm._loop.close() + + for uri in uris: + cf = swarm.cfs[uri] + assert cf.fake_commander.setpoints == [tuple(target[uri])] + assert cf.fake_param.values == [("commander.enHighLevel", 0)] + + +def test_estimator_updater_copies_batch_and_skips_inactive_drones(): + uris = [ + "radio://0/80/2M/E7E7E7E701", + "radio://0/80/2M/E7E7E7E702", + "radio://0/80/2M/E7E7E7E703", + ] + swarm = make_swarm(uris) + swarm.active_uris.remove(uris[2]) + swarm.lighthouse = False + swarm.update_freq = 1_000 + stop_event = threading.Event() + connector = FakeROSConnector( + positions={"cf01": [1.0, 2.0, 3.0], "cf02": [4.0, 5.0, 6.0], "cf03": [7.0, 8.0, 9.0]}, + quaternions={ + "cf01": [0.0, 0.0, 0.0, 1.0], + "cf02": [0.1, 0.2, 0.3, 0.9], + "cf03": [0.4, 0.5, 0.6, 0.7], + }, + stop_event=stop_event, + ) + swarm.ros_connector = connector + try: + asyncio.run(swarm._update_estimators(stop_event)) + finally: + swarm._loop.close() + + assert swarm.cfs[uris[0]].fake_localization.fake_external_pose.sent == [ + (connector.positions["cf01"], connector.quaternions["cf01"]) + ] + assert swarm.cfs[uris[1]].fake_localization.fake_external_pose.sent == [ + (connector.positions["cf02"], connector.quaternions["cf02"]) + ] + assert swarm.cfs[uris[2]].fake_localization.fake_external_pose.sent == [] + assert connector.pos_reads == 1 + assert connector.quat_reads == 1 + + +def test_disconnected_drone_is_warned_and_deactivated(capsys: pytest.CaptureFixture[str]): + uri = "radio://0/80/2M/E7E7E7E701" + swarm = make_swarm([uri]) + swarm._commander_levels[uri] = "low" + + async def fail_update(_uri: str) -> None: + raise DisconnectedError("link lost") + + try: + asyncio.run(swarm._parallel_by_uri("Updating estimators", [uri], fail_update)) + finally: + swarm._loop.close() + + assert uri not in swarm.active_uris + assert swarm._commander_levels[uri] is None + assert f"{uri} disconnected or unreachable" in capsys.readouterr().err + + +def test_estimator_updater_lifecycle(): + swarm = make_swarm([]) + swarm.lighthouse = False + swarm.update_freq = 1_000 + swarm._estimator_stop_event = None + swarm._estimator_future = None + swarm._closed = False + swarm.ros_connector = None + + async def update_until_stopped(stop_event: threading.Event) -> None: + while not stop_event.is_set(): + await asyncio.sleep(0.001) + + swarm._update_estimators = update_until_stopped + swarm._start_estimator_updater() + + try: + assert swarm._loop_thread is not None + assert swarm._loop_thread.is_alive() + swarm.close() + assert swarm._estimator_future is None + assert swarm._estimator_stop_event is None + finally: + if not swarm._loop.is_closed(): + swarm._stop_loop_thread() + swarm._loop.close() + + assert swarm._loop_thread is None diff --git a/tools/test_drone_swarm.py b/tools/test_drone_swarm.py index 2a78fe1..e4a8762 100644 --- a/tools/test_drone_swarm.py +++ b/tools/test_drone_swarm.py @@ -7,6 +7,7 @@ from __future__ import annotations import argparse +import time import tomllib from pathlib import Path from typing import Any @@ -35,8 +36,11 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--limit", type=int, default=None, help="Use only the first N drones.") parser.add_argument("--height", type=float, default=0.6, help="Takeoff/test height in meters.") - parser.add_argument("--min-height", type=float, default=0.2, help="Minimum accepted z after takeoff.") + parser.add_argument( + "--min-height", type=float, default=0.2, help="Minimum accepted z after takeoff." + ) parser.add_argument("--takeoff-duration", type=float, default=3.0) + parser.add_argument("--setpoint-duration", type=float, default=2.0) parser.add_argument("--goto-duration", type=float, default=3.0) parser.add_argument("--choreo-duration", type=float, default=3.0) parser.add_argument("--land-duration", type=float, default=3.0) @@ -87,21 +91,19 @@ def check_observations(swarm: DroneSwarm) -> None: rpy = np.asarray(obs["rpy"], dtype=float) if pos.shape != (3,) or quat.shape != (4,) or rpy.shape != (3,): raise RuntimeError(f"Invalid observation shape for {uri}: {obs}") - if not np.all(np.isfinite(pos)) or not np.all(np.isfinite(quat)) or not np.all(np.isfinite(rpy)): + if ( + not np.all(np.isfinite(pos)) + or not np.all(np.isfinite(quat)) + or not np.all(np.isfinite(rpy)) + ): raise RuntimeError(f"Non-finite observation for {uri}: {obs}") if np.linalg.norm(quat) < 0.5: raise RuntimeError(f"Suspicious quaternion for {uri}: {quat}") - print( - f" {uri}: pos = {pos.round(3).tolist()}, " - f"rpy = {np.degrees(rpy).round(1).tolist()}" - ) + print(f" {uri}: pos = {pos.round(3).tolist()}, rpy = {np.degrees(rpy).round(1).tolist()}") def pose_for_uri( - swarm: DroneSwarm, - drones_by_uri: dict[str, dict[str, Any]], - uri: str, - height: float, + swarm: DroneSwarm, drones_by_uri: dict[str, dict[str, Any]], uri: str, height: float ) -> np.ndarray: """Return a [x, y, z, yaw_deg] pose for command references.""" if not swarm.is_active(uri): @@ -190,6 +192,15 @@ def run_smoke_test(args: argparse.Namespace) -> None: uri: pose_for_uri(swarm, drones_by_uri, uri, args.height) for uri in swarm.uris } + print("\nHolding position with manually streamed setpoints") + period = 1 / args.ctrl_freq + deadline = time.monotonic() + args.setpoint_duration + next_tick = time.monotonic() + while time.monotonic() < deadline: + swarm.setpoint(takeoff_poses) + next_tick += period + time.sleep(max(0.0, next_tick - time.monotonic())) + print("\nGoing to a small offset") goto_refs = {} goto_targets = {} @@ -198,7 +209,7 @@ def run_smoke_test(args: argparse.Namespace) -> None: target[0] += args.goto_dx target[1] += args.goto_dy goto_targets[uri] = target - goto_refs[uri] = [target] + goto_refs[uri] = target swarm.goto(goto_refs, duration=args.goto_duration) print("\nExecuting simple choreography") @@ -220,10 +231,7 @@ def run_smoke_test(args: argparse.Namespace) -> None: args.choreo_duration / 2.0: np.array([0.0, 60.0, 60.0, 0.0]), } swarm.execute_choreography( - choreography, - args.choreo_duration, - color_top=cue_color_top, - color_bot=cue_color_bot, + choreography, args.choreo_duration, color_top=cue_color_top, color_bot=cue_color_bot ) print("\nLanding")