Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 104 additions & 50 deletions swarm_gpt/core/drone_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import asyncio
import logging
import multiprocessing as mp
import os
import threading
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand All @@ -13,11 +15,12 @@
from cflib2.toc_cache import FileTocCache

os.environ["SCIPY_ARRAY_API"] = "1"
from scipy.interpolate import interp1d
from scipy.spatial.transform import Rotation as R

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterable, Mapping
from concurrent.futures import Future
from multiprocessing.synchronize import Event

from drone_estimators.ros_nodes.ros2_connector import ROSConnector
from numpy.typing import NDArray as Array
Expand Down Expand Up @@ -68,6 +71,9 @@ def __init__(
self.toc_cache = FileTocCache("./cache")
self.ros_connector: ROSConnector | None = None
self._loop = asyncio.new_event_loop()
self._loop_thread: threading.Thread | None = None
self._estimator_stop_event: Event | None = None
self._estimator_future: Future[None] | None = None
self._closed = False

if not lighthouse:
Expand All @@ -81,13 +87,23 @@ def __init__(
self._run(self._connect())
if self.lighthouse:
self._run(self._check_lighthouse_decks())
else:
self._start_estimator_updater()
self.reset()
except BaseException:
if self._estimator_stop_event is not None:
self._estimator_stop_event.set()
if self._estimator_future is not None:
try:
self._estimator_future.result(timeout=max(1.0, 2 / self.update_freq))
except Exception as exc:
logger.warning(f"Stopping estimator updater failed: {exc}")
if self.cfs:
try:
self._run(self._disconnect())
except Exception as exc:
logger.error(f"Disconnecting after initialization failure failed: {exc}")
self._stop_loop_thread()
self._loop.close()
if self.ros_connector is not None:
self.ros_connector.close()
Expand All @@ -112,12 +128,11 @@ def takeoff(self, height: float = 1.5, duration: float = 3.0):

async def _takeoff(uri: str) -> None:
cf = self._cf(uri)
await self._send_external_pose(uri)
await self._change_commander_level(uri, "high")
await cf.high_level_commander().take_off(height, None, duration, None)
await self._update_external_pose_during(uri, duration)
await asyncio.sleep(duration)

self._run(self._parallel_by_uri("Taking off", self.uris, _takeoff)) # TODO add timeout
self._run(self._parallel_by_uri("Taking off", self.uris, _takeoff, timeout=duration + 1.0))

def land(self, height: float = 0.0, duration: float = 3.0):
"""Land the drones at a given height over a given duration."""
Expand All @@ -127,10 +142,10 @@ async def _land(uri: str) -> None:
await self._change_commander_level(uri, "high")
high_level_commander = cf.high_level_commander()
await high_level_commander.land(height, None, duration, None)
await self._update_external_pose_during(uri, duration)
await asyncio.sleep(duration)
await high_level_commander.stop(None)

self._run(self._parallel_by_uri("Landing", self.uris, _land)) # TODO add timeout
self._run(self._parallel_by_uri("Landing", self.uris, _land, timeout=duration + 1.0))

def goto(self, target: dict[str, list], duration: float = 3.0):
"""Execute a high-level goto command for all drones.
Expand All @@ -146,22 +161,19 @@ def goto(self, target: dict[str, list], duration: float = 3.0):

async def _goto(uri: str) -> None:
cf = self._cf(uri)
if not self.lighthouse:
await self._send_external_pose(uri)
await self._change_commander_level(uri, "high")
await cf.high_level_commander().go_to(
*target[uri], duration, relative=False, linear=True, group_mask=None
)
await self._update_external_pose_during(uri, duration)
await asyncio.sleep(duration)

self._run(self._parallel_by_uri("Goto", self.uris, _goto, timeout=duration + 1.0))

def setpoint(self, target: dict[str, list], duration: float = 3.0):
"""Stream a constant position+yaw setpoint to all drones.
def setpoint(self, target: dict[str, list]):
"""Send one position+yaw setpoint to all drones and return.

Args:
target: Position+Yaw references in the form {'uri1': [target], ...}.
duration: Duration of the setpoint stream in seconds.
"""
self._validate_required_uris("pos", target)
for uri, setpoint in target.items():
Expand All @@ -170,14 +182,9 @@ def setpoint(self, target: dict[str, list], duration: float = 3.0):

async def _setpoint(uri: str) -> None:
await self._change_commander_level(uri, "low")
ref = interp1d(
[0.0, duration],
[np.asarray(target[uri], dtype=float), np.asarray(target[uri], dtype=float)],
axis=0,
)
await self._stream_reference(uri, duration, lambda t: np.asarray(ref(t), dtype=float))
await self._cf(uri).commander().send_setpoint_position(*target[uri])

self._run(self._parallel_by_uri("Setpoint", self.uris, _setpoint)) # TODO add timeout
self._run(self._parallel_by_uri("Setpoint", self.uris, _setpoint, timeout=0.5))

def execute_choreography(
self,
Expand Down Expand Up @@ -211,7 +218,11 @@ async def _execute(uri: str) -> None:
(color_bot or {}).get(uri, {}),
)

self._run(self._parallel_by_uri("Choreography execution", self.uris, _execute)) # TODO add timeout
self._run(
self._parallel_by_uri(
"Choreography execution", self.uris, _execute, timeout=t_end + 1.0
)
)

def apply_colors(self, color_top: dict[str, Array] | None, color_bot: dict[str, Array] | None):
"""Apply colors to the drones.
Expand All @@ -233,7 +244,7 @@ async def _apply_colors(uri: str) -> None:
if uri in color_bot:
await self._apply_drone_color(uri, color_bot[uri], "bot")

self._run(self._parallel_by_uri("Applying colors", self.uris, _apply_colors)) # TODO add timeout
self._run(self._parallel_by_uri("Applying colors", self.uris, _apply_colors, timeout=0.5))

def set_param(self, param: str, value: float):
"""Set a Crazyflie parameter on all active drones.
Expand All @@ -246,7 +257,9 @@ def set_param(self, param: str, value: float):
async def _set_param(uri: str) -> None:
await self._cf(uri).param().set(param, value)

self._run(self._parallel_by_uri(f"Setting parameter {param}", self.uris, _set_param))
self._run(
self._parallel_by_uri(f"Setting parameter {param}", self.uris, _set_param, timeout=0.5)
)

def emergency_stop(self, uri: str | None = None):
"""Send an emergency stop signal to one URI or all drones (default)."""
Expand All @@ -255,7 +268,7 @@ def emergency_stop(self, uri: str | None = None):
else:
self._validate_known_uris("uri", {uri: None})
uris = [uri]
self._run(self._parallel_by_uri("Emergency stop", uris, self._emergency_stop)) # TODO add timeout
self._run(self._parallel_by_uri("Emergency stop", uris, self._emergency_stop, timeout=0.5))

def reset(self):
"""Reset all active drones."""
Expand Down Expand Up @@ -302,26 +315,63 @@ async def _close() -> None:
active_uris = [uri for uri in self.uris if uri in self.active_uris]
if active_uris:
try:
await self._parallel_by_uri("Emergency stop", active_uris, self._emergency_stop)
await self._parallel_by_uri(
"Emergency stop", active_uris, self._emergency_stop, timeout=0.5
)
await asyncio.sleep(0.1)
await self._parallel_by_uri("Shutdown LEDs", active_uris, _shutdown_leds)
await self._parallel_by_uri(
"Shutdown LEDs", active_uris, _shutdown_leds, timeout=0.5
)
await asyncio.sleep(0.2)
except RuntimeError as exc:
logger.warning(f"Shutdown failed: {exc}")

await self._disconnect()

try:
if self._estimator_stop_event is not None:
self._estimator_stop_event.set()
if self._estimator_future is not None:
try:
self._estimator_future.result(timeout=max(1.0, 2 / self.update_freq))
except Exception as exc:
logger.warning(f"Stopping estimator updater failed: {exc}")
self._run(_close())
finally:
self._estimator_future = None
self._estimator_stop_event = None
self._stop_loop_thread()
self._loop.close()
if self.ros_connector is not None:
self.ros_connector.close()

def _run(self, coroutine: Awaitable[Any]) -> Any:
"""Run a cflib2 coroutine on the swarm event loop."""
if self._loop_thread is not None:
return asyncio.run_coroutine_threadsafe(coroutine, self._loop).result()
return self._loop.run_until_complete(coroutine)

def _start_estimator_updater(self) -> None:
"""Start continuous mocap updates after the radio links are connected."""
ctx = mp.get_context("spawn")
self._estimator_stop_event = ctx.Event()
self._loop_thread = threading.Thread(
target=self._loop.run_forever, name="drone-swarm-event-loop", daemon=True
)
self._loop_thread.start()
self._estimator_future = asyncio.run_coroutine_threadsafe(
self._update_estimators(self._estimator_stop_event), self._loop
)

def _stop_loop_thread(self) -> None:
"""Stop the persistent event loop used by the estimator updater."""
if self._loop_thread is None:
return

self._loop.call_soon_threadsafe(self._loop.stop)
self._loop_thread.join()
self._loop_thread = None

def _cf(self, uri: str) -> Crazyflie:
if uri not in self.active_uris:
raise RuntimeError(f"Drone {uri} is not active.")
Expand Down Expand Up @@ -403,7 +453,11 @@ async def _check_lighthouse_decks(self) -> None:
)

async def _parallel_by_uri(
self, action_name: str, uris: Iterable[str], action: Callable[[str], Awaitable[None]], timeout: float | None = None
self,
action_name: str,
uris: Iterable[str],
action: Callable[[str], Awaitable[None]],
timeout: float | None = None,
) -> None:
target_uris = [uri for uri in uris if uri in self.active_uris and uri in self.cfs]

Expand Down Expand Up @@ -507,25 +561,30 @@ async def _apply_drone_color(
if deck == "bot" or deck == "both":
await param.set("colorLedBot.wrgb8888", color)

async def _send_external_pose(self, uri: str) -> None:
if self.lighthouse:
return
obs = await self._read_observation(uri)
await (
self._cf(uri)
.localization()
.external_pose()
.send_external_pose(
pos=np.asarray(obs["pos"], dtype=float).tolist(),
quat=np.asarray(obs["quat"], dtype=float).tolist(),
)
)

async def _update_external_pose_during(self, uri: str, duration: float) -> None:
end_time = asyncio.get_running_loop().time() + duration
while asyncio.get_running_loop().time() < end_time:
await self._send_external_pose(uri)
await asyncio.sleep(1 / self.update_freq)
async def _update_estimators(self, stop_event: Event) -> None:
"""Continuously send mocap poses to every active Crazyflie estimator."""
assert self.ros_connector is not None, "Estimator updates require lighthouse=False."
period = 1 / self.update_freq
next_tick = asyncio.get_running_loop().time()
while not stop_event.is_set():
# These properties already copy the synchronized ROS arrays into fresh NumPy snapshots.
positions = self.ros_connector.pos
quaternions = self.ros_connector.quat

async def _update_estimator(uri: str) -> None:
drone_name = f"cf{int(uri[-2:], 16):02d}"
await (
self._cf(uri)
.localization()
.external_pose()
.send_external_pose(
pos=positions[drone_name].tolist(), quat=quaternions[drone_name].tolist()
)
)

await self._parallel_by_uri("Updating estimators", self.uris, _update_estimator)
next_tick += period
await asyncio.sleep(max(0.0, next_tick - asyncio.get_running_loop().time()))

async def _stream_reference(
self,
Expand All @@ -536,7 +595,6 @@ async def _stream_reference(
color_bot: dict[float, Array] | None = None,
) -> None:
commander = self._cf(uri).commander()
t_est = -np.inf
top_cues = sorted((float(t), wrgb) for t, wrgb in (color_top or {}).items())
bot_cues = sorted((float(t), wrgb) for t, wrgb in (color_bot or {}).items())
i_next_top = 0
Expand All @@ -548,10 +606,6 @@ async def _stream_reference(
t_col = -np.inf

while (t_cur := asyncio.get_running_loop().time() - start_time) < duration:
if not self.lighthouse and t_cur - t_est >= 1 / self.update_freq:
await self._send_external_pose(uri)
t_est = t_cur

await commander.send_setpoint_position(*reference(t_cur))

if t_cur - t_col >= color_period:
Expand Down
Loading
Loading