Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 140 additions & 133 deletions aic_example_policies/aic_example_policies/ros/RunACT.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,10 @@
# limitations under the License.
#

import os
"""ACT policy runner. Heavy imports run in __init__ to keep module discovery fast."""

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import time
import json
import torch
import numpy as np
import cv2
import draccus
from pathlib import Path
from typing import Callable, Dict, Any, List
from rclpy.node import Node
from geometry_msgs.msg import Twist, Vector3
from typing import Any, Dict

from aic_model.policy import (
GetObservationCallback,
Expand All @@ -37,37 +27,84 @@
)
from aic_model_interfaces.msg import Observation
from aic_task_interfaces.msg import Task
from rclpy.node import Node

from aic_control_interfaces.msg import (
MotionUpdate,
TrajectoryGenerationMode,
)
from geometry_msgs.msg import Wrench
class RunACT(Policy):
"""Loads ACT weights in ``__init__``; avoids top-level torch/lerobot imports."""

# LeRobot & Safetensors
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.act.configuration_act import ACTConfig
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
_SIM_TIMEOUT_SEC = 5.0

@staticmethod
def _is_act_policy_checkpoint_dir(d: Path) -> bool:
return (d / "config.json").is_file() and (d / "model.safetensors").is_file()

def _resolve_local_act_policy_dir(self) -> Path:
"""Resolve ACT checkpoint in dev/install/container layouts."""
here = Path(__file__).resolve()
for d in here.parents:
for candidate in (
d / "resource" / "aic_act_policy",
d / "aic_example_policies" / "resource" / "aic_act_policy",
):
if self._is_act_policy_checkpoint_dir(candidate):
self.get_logger().info(f"Using local ACT policy: {candidate}")
return candidate

raise FileNotFoundError(
"No local ACT policy checkpoint under any ancestor of "
f"{here} (expected .../resource/aic_act_policy or "
".../aic_example_policies/resource/aic_act_policy with config.json and model.safetensors)"
)

def _load_policy_processors(self, policy_path: Path) -> None:
"""Load LeRobot pre/post processors from checkpoint."""
import torch

pre_cfg = policy_path / "policy_preprocessor.json"
post_cfg = policy_path / "policy_postprocessor.json"
processor_overrides = {}
if not torch.cuda.is_available():
processor_overrides = {"device_processor": {"device": "cpu"}}

from lerobot.processor.pipeline import PolicyProcessorPipeline

self.preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=policy_path,
config_filename=pre_cfg.name,
overrides=processor_overrides,
)
self.postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=policy_path,
config_filename=post_cfg.name,
overrides=processor_overrides,
)
self.get_logger().info(
"Loaded policy pre/post processors from checkpoint configs."
)

class RunACT(Policy):
def __init__(self, parent_node: Node):
super().__init__(parent_node)

import json

import draccus
import torch
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from safetensors.torch import load_file

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------------------------------------------------
# 1. Configuration & Weights Loading
# -------------------------------------------------------------------------
repo_id = "grkw/aic_act_policy"
policy_path = self._resolve_local_act_policy_dir()

# Path to your checkpoint folder
policy_path = Path(
snapshot_download(
repo_id=repo_id,
allow_patterns=["config.json", "model.safetensors", "*.safetensors"],
)
)
# Torchvision ResNet etc. load via torch.hub; use vendored weights under resource/hub/.
hub_dir = policy_path.parent / "hub"
if hub_dir.is_dir():
torch.hub.set_dir(str(hub_dir.resolve()))
self.get_logger().info(f"torch.hub.set_dir({hub_dir}) (offline vision backbone)")

# Load Config Manually (Fixes 'Draccus' error by removing unknown 'type' field)
with open(policy_path / "config.json", "r") as f:
Expand All @@ -86,61 +123,24 @@ def __init__(self, parent_node: Node):

self.get_logger().info(f"ACT Policy loaded on {self.device} from {policy_path}")

# -------------------------------------------------------------------------
# 2. Normalization Stats Loading
# -------------------------------------------------------------------------
stats_path = (
policy_path / "policy_preprocessor_step_3_normalizer_processor.safetensors"
)
stats = load_file(stats_path)

# Helper to extract and shape stats for broadcasting
def get_stat(key, shape):
return stats[key].to(self.device).view(*shape)

# Image Stats (1, 3, 1, 1) for broadcasting against (Batch, Channel, Height, Width)
self.img_stats = {
"left": {
"mean": get_stat("observation.images.left_camera.mean", (1, 3, 1, 1)),
"std": get_stat("observation.images.left_camera.std", (1, 3, 1, 1)),
},
"center": {
"mean": get_stat("observation.images.center_camera.mean", (1, 3, 1, 1)),
"std": get_stat("observation.images.center_camera.std", (1, 3, 1, 1)),
},
"right": {
"mean": get_stat("observation.images.right_camera.mean", (1, 3, 1, 1)),
"std": get_stat("observation.images.right_camera.std", (1, 3, 1, 1)),
},
}
print(f"Image stats: {self.img_stats}")

# Robot State Stats (1, 26)
self.state_mean = get_stat("observation.state.mean", (1, -1))
self.state_std = get_stat("observation.state.std", (1, -1))
print(f"Robot state mean: {self.state_mean}")
print(f"Robot state std: {self.state_std}")

# Action Stats (1, 7) - Used for Un-normalization
self.action_mean = get_stat("action.mean", (1, -1))
self.action_std = get_stat("action.std", (1, -1))
print(f"Action mean: {self.action_mean}")
print(f"Action std: {self.action_std}")

# Config
self.image_scaling = 0.25 # Must match AICRobotAICControllerConfig

self.get_logger().info("Normalization statistics loaded successfully.")
self._load_policy_processors(policy_path)
self.get_logger().info("Using policy processors for normalization.")

@staticmethod
def _img_to_tensor(
self,
raw_img,
device: torch.device,
scale: float,
mean: torch.Tensor,
std: torch.Tensor,
) -> torch.Tensor:
"""Converts ROS Image -> Resized -> Permuted -> Normalized Tensor."""
):
"""Converts ROS Image -> Resized -> Permuted -> Float Tensor."""
import cv2
import numpy as np
import torch

device = self.device
scale = self.image_scaling

# 1. Bytes to Numpy (H, W, C)
img_np = np.frombuffer(raw_img.data, dtype=np.uint8).reshape(
raw_img.height, raw_img.width, 3
Expand All @@ -162,35 +162,22 @@ def _img_to_tensor(
.to(device)
)

# 4. Normalize (Apply Mean/Std)
# Formula: (x - mean) / std
return (tensor - mean) / std
return tensor

def prepare_observations(self, obs_msg: Observation) -> Dict[str, torch.Tensor]:
def prepare_observations(self, obs_msg: Observation) -> Dict[str, Any]:
"""Convert ROS Observation message into dictionary of normalized tensors."""
import torch

# --- Process Cameras ---
obs = {
"observation.images.left_camera": self._img_to_tensor(
obs_msg.left_image,
self.device,
self.image_scaling,
self.img_stats["left"]["mean"],
self.img_stats["left"]["std"],
),
"observation.images.center_camera": self._img_to_tensor(
obs_msg.center_image,
self.device,
self.image_scaling,
self.img_stats["center"]["mean"],
self.img_stats["center"]["std"],
),
"observation.images.right_camera": self._img_to_tensor(
obs_msg.right_image,
self.device,
self.image_scaling,
self.img_stats["right"]["mean"],
self.img_stats["right"]["std"],
),
}

Expand All @@ -199,38 +186,40 @@ def prepare_observations(self, obs_msg: Observation) -> Dict[str, torch.Tensor]:
tcp_pose = obs_msg.controller_state.tcp_pose
tcp_vel = obs_msg.controller_state.tcp_velocity

state_np = np.array(
[
# TCP Position (3)
tcp_pose.position.x,
tcp_pose.position.y,
tcp_pose.position.z,
# TCP Orientation (4)
tcp_pose.orientation.x,
tcp_pose.orientation.y,
tcp_pose.orientation.z,
tcp_pose.orientation.w,
# TCP Linear Vel (3)
tcp_vel.linear.x,
tcp_vel.linear.y,
tcp_vel.linear.z,
# TCP Angular Vel (3)
tcp_vel.angular.x,
tcp_vel.angular.y,
tcp_vel.angular.z,
# TCP Error (6)
*obs_msg.controller_state.tcp_error,
# Joint Positions (7)
*obs_msg.joint_states.position[:7],
],
dtype=np.float32,
state_values = [
# TCP Position (3)
tcp_pose.position.x,
tcp_pose.position.y,
tcp_pose.position.z,
# TCP Orientation (4)
tcp_pose.orientation.x,
tcp_pose.orientation.y,
tcp_pose.orientation.z,
tcp_pose.orientation.w,
# TCP Linear Vel (3)
tcp_vel.linear.x,
tcp_vel.linear.y,
tcp_vel.linear.z,
# TCP Angular Vel (3)
tcp_vel.angular.x,
tcp_vel.angular.y,
tcp_vel.angular.z,
# TCP Error (6)
*obs_msg.controller_state.tcp_error,
# Joint Positions (7)
*obs_msg.joint_states.position[:7],
]

obs["observation.state"] = (
torch.tensor(
state_values,
dtype=torch.float32,
device=self.device,
)
.unsqueeze(0)
)

# Normalize State
raw_state_tensor = (
torch.from_numpy(state_np).float().unsqueeze(0).to(self.device)
)
obs["observation.state"] = (raw_state_tensor - self.state_mean) / self.state_std
obs = self.preprocessor.process_observation(obs)

return obs

Expand All @@ -242,13 +231,20 @@ def insert_cable(
send_feedback: SendFeedbackCallback,
**kwargs,
):
import time

import torch
from geometry_msgs.msg import Twist, Vector3
from rclpy.duration import Duration

self.policy.reset()
self.get_logger().info(f"RunACT.insert_cable() enter. Task: {task}")

start_time = time.time()
clock = self.get_clock()
start = clock.now()
deadline = start + Duration(seconds=self._SIM_TIMEOUT_SEC)

# Run inference for 30 seconds
while time.time() - start_time < 30.0:
while clock.now() < deadline:
loop_start = time.time()

# 1. Get & Process Observation
Expand All @@ -266,11 +262,15 @@ def insert_cable(
normalized_action = self.policy.select_action(obs_tensors)

# 3. Un-normalize Action
# Formula: (norm * std) + mean
raw_action_tensor = (normalized_action * self.action_std) + self.action_mean
raw_action_tensor = self.postprocessor.process_action(normalized_action)

# 4. Extract and Command
# raw_action_tensor is [1, 7], taking [0] gives vector of 7
if isinstance(raw_action_tensor, dict):
if "action" in raw_action_tensor:
raw_action_tensor = raw_action_tensor["action"]
else:
raw_action_tensor = next(iter(raw_action_tensor.values()))
action = raw_action_tensor[0].cpu().numpy()

self.get_logger().info(f"Action: {action}")
Expand All @@ -291,10 +291,17 @@ def insert_cable(
elapsed = time.time() - loop_start
time.sleep(max(0, 0.25 - elapsed))

self.get_logger().info("RunACT.insert_cable() exiting...")
self.get_logger().info(
f"RunACT.insert_cable(): sim timeout after {self._SIM_TIMEOUT_SEC}s"
)
return True

def set_cartesian_twist_target(self, twist: Twist, frame_id: str = "base_link"):
def set_cartesian_twist_target(self, twist, frame_id: str = "base_link"):
import numpy as np

from aic_control_interfaces.msg import MotionUpdate, TrajectoryGenerationMode
from geometry_msgs.msg import Vector3, Wrench

motion_update_msg = MotionUpdate()
motion_update_msg.velocity = twist
motion_update_msg.header.frame_id = frame_id
Expand Down
1 change: 1 addition & 0 deletions aic_example_policies/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<depend>rclpy</depend>
<depend>tf2_ros</depend>
<depend>sensor_msgs</depend>
<depend>std_msgs</depend>
<depend>std_srvs</depend>
<depend>trajectory_msgs</depend>

Expand Down
5 changes: 4 additions & 1 deletion aic_model/aic_model/aic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import importlib
import inspect
import traceback
import numpy as np
import rclpy
import threading
Expand Down Expand Up @@ -121,7 +122,9 @@ def on_configure(self, state: LifecycleState) -> TransitionCallbackReturn:
try:
self._policy = self._policy_class(self)
except Exception as e:
self.get_logger().error(f"Error instantiating policy: {e}")
self.get_logger().error(
f"Error instantiating policy: {e}\n{traceback.format_exc()}"
)
return TransitionCallbackReturn.ERROR
return TransitionCallbackReturn.SUCCESS

Expand Down
Loading