Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ dist/
imgui.ini

# Ignore openpi debug outputs
results
results

output
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "ManiSkill2_real2sim"]
path = ManiSkill2_real2sim
url = https://github.com/allenzren/ManiSkill2_real2sim
url = https://github.com/allenzren/ManiSkill2_real2sim
[submodule "scripts/g3_lerobotpi0/lerobot"]
path = scripts/g3_lerobotpi0/lerobot
url = https://github.com/huggingface/lerobot
55 changes: 55 additions & 0 deletions scripts/g3_lerobotpi0/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
## 環境構築

```
g3_shell -c4 --gres gpu:1
git clone https://github.com/airoa-org/SimplerEnv.git
cd SimplerEnv
git checkout origin/bechmark/g3-fractal-lerobotpi0
git submodule update --init --recursive

conda create -n simpler_env_lerobotpi0 python=3.11
conda activate simpler_env_lerobotpi0

pip install numpy==1.24.4
cd ManiSkill2_real2sim/
pip install -e .
cd ..
pip install -e .

cd scripts/g3_lerobotpi0/lerobot
conda install -c conda-forge evdev
pip install "av>=12.0.5"
pip install -e ".[pi0]"
pip install numpy==1.25.2 # opencv-pythonとのconflictが起きるが無視
pip install flash-attn==2.8.1
pip install pytest

pip install matplotlib
conda install -c conda-forge ffmpeg
conda install -c conda-forge libvulkan-loader libvulkan-headers
pip install mediapy
```


## 実行
**インタラクティブモード**
```
g3_shell -c4 --gres gpu:1
cd SimplerEnv
conda activate simpler_env_lerobotpi0
python scripts/g3_lerobotpi0/evaluate_fractal.py \
--ckpt-path /home/group_25b505/group_3/members/user_00029_25b505/lerobot-pi0-fractal
```
* `--ckpt-path`で学習済みモデルのパスを指定

**Jobを投げる**
```
cd SimplerEnv
g3_sbatch \
--gpus-per-node=1 \
--output=output/%j.out \
--time=24:00:00 \
scripts/g3_lerobotpi0/job.sh \
/home/group_25b505/group_3/members/user_00029_25b505/lerobot-pi0-fractal
```
* $1で学習済みモデルのパスを指定
173 changes: 173 additions & 0 deletions scripts/g3_lerobotpi0/evaluate_fractal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import argparse
from typing import Dict
import time

import torch
import numpy as np
import cv2 as cv
from transformers import AutoTokenizer

from simpler_env.evaluation.adapter import AiroaToSimplerFractalAdapter
from simpler_env.evaluation.scores import run_comprehensive_evaluation
from simpler_env.policies.base import AiroaBasePolicy
from simpler_env.utils.geometry import euler2axangle
from simpler_env.utils.action.action_ensemble import ActionEnsembler

from lerobot.policies.pi0.modeling_pi0 import PI0Policy


class FractalLerobotPi0ToAiroaPolicy(AiroaBasePolicy):
def __init__(
self,
policy,
action_ensemble: bool,
action_ensemble_temp: float,
):
self.policy = policy
self.policy.eval()
self.action_ensemble = action_ensemble
self.action_ensemble_temp = action_ensemble_temp
self.image_size = (224, 224)

if self.action_ensemble:
self.action_ensembler = ActionEnsembler(
self.policy.config.n_action_steps, self.action_ensemble_temp
)
else:
self.action_ensembler = None

self.device = "cuda" if torch.cuda.is_available() else "cpu"

def step(self, obs: Dict) -> Dict:
image = self._resize_image(obs["image"])
prompt = obs["prompt"]
state = obs["state"]

obs_lerobotpi0 = {
"observation.state": torch.from_numpy(state).unsqueeze(0).float().to(self.device),
"observation.images.image": torch.from_numpy(image / 255).permute(2, 0, 1).unsqueeze(0).float().to(self.device),
"task": [prompt],
}

with torch.inference_mode():
actions = self.policy.select_action(obs_lerobotpi0)[0].cpu().numpy()

if self.action_ensemble:
action_chunk = [actions]
for _ in range(self.policy.config.n_action_steps-1):
actions = self.policy.select_action(obs_lerobotpi0)[0].cpu().numpy()
action_chunk.append(actions)
action_chunk = np.stack(action_chunk, axis=0)
actions = self.action_ensembler.ensemble_action(action_chunk)[None][0]

outputs = {
"actions": actions,
"terminate_episode": np.zeros(actions.shape[0]),
}

return outputs

def reset(self) -> None:
self.policy.reset()

def _resize_image(self, image: np.ndarray) -> np.ndarray:
image = cv.resize(image, tuple(self.image_size), interpolation=cv.INTER_AREA)
return image


class AiroaToSimplerFractalStickyActionAdapter(AiroaToSimplerFractalAdapter):
def __init__(self, policy):
super().__init__(policy)
self.sticky_gripper_num_repeat = 10 # same to lerobotpi0

def reset(self, task_description):
super().reset(task_description)
self.previous_gripper_action = None

def postprocess(self, outputs: Dict) -> Dict:
action = outputs["actions"]
roll, pitch, yaw = action[3:6]
action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw)

current_gripper_action = action[-1]

if self.previous_gripper_action is None:
relative_gripper_action = 0
self.previous_gripper_action = current_gripper_action
else:
relative_gripper_action = self.previous_gripper_action - current_gripper_action

# switch to sticky closing
if np.abs(relative_gripper_action) > 0.5 and (not self.sticky_action_is_on):
self.sticky_action_is_on = True
self.sticky_gripper_action = relative_gripper_action
self.previous_gripper_action = current_gripper_action

if self.sticky_action_is_on:
self.gripper_action_repeat += 1
relative_gripper_action = self.sticky_gripper_action

if self.gripper_action_repeat == self.sticky_gripper_num_repeat:
self.sticky_action_is_on = False
self.gripper_action_repeat = 0
self.sticky_gripper_action = 0.0

action = np.concatenate(
[
action[:3],
action_rotation_ax * action_rotation_angle,
[relative_gripper_action],
]
)

return {
"actions": action,
"terminate_episode": outputs["terminate_episode"],
}


def parse_args():
parser = argparse.ArgumentParser(description="Run Comprehensive ManiSkill2 Evaluation")
parser.add_argument(
"--ckpt-path",
type=str,
required=True,
help="Path to the checkpoint to evaluate.",
)
parser.add_argument(
"--action-ensemble",
type=bool,
default=True,
help="Whether to use action ensemble.",
)
parser.add_argument(
"--action-ensemble-temp",
type=float,
default=-0.8,
help="Temperature for action ensemble.",
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
ckpt_path = args.ckpt_path

policy = PI0Policy.from_pretrained(ckpt_path)
policy.language_tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
policy.model.paligemma_with_expert.paligemma.language_model = policy.model.paligemma_with_expert.paligemma.language_model.model
policy.model.paligemma_with_expert.gemma_expert.model = policy.model.paligemma_with_expert.gemma_expert.model.base_model
policy = FractalLerobotPi0ToAiroaPolicy(
policy=policy,
action_ensemble=args.action_ensemble,
action_ensemble_temp=args.action_ensemble_temp,
)

env_policy = AiroaToSimplerFractalStickyActionAdapter(policy=policy)

print("Policy initialized. Starting evaluation...")

final_scores = run_comprehensive_evaluation(env_policy=env_policy, ckpt_path=ckpt_path)

print("\nEvaluation finished.")
print(f"Final calculated scores: {final_scores}")
11 changes: 11 additions & 0 deletions scripts/g3_lerobotpi0/job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash


source ~/.bashrc
source $(conda info --base)/etc/profile.d/conda.sh
conda activate simpler_env_lerobotpi0

CKPT=$1

python scripts/g3_lerobotpi0/evaluate_fractal.py \
--ckpt-path $CKPT
1 change: 1 addition & 0 deletions scripts/g3_lerobotpi0/lerobot
Submodule lerobot added at 67196c
2 changes: 1 addition & 1 deletion simpler_env/evaluation/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def step(self, image: np.ndarray, eef_pos: np.ndarray, prompt: Optional[str] = N
}
return simpler_outputs, final_simpler_outputs

def _resize_image(image: np.ndarray) -> np.ndarray:
def _resize_image(self, image: np.ndarray) -> np.ndarray:
# Lanczos3相当の補間でリサイズ(OpenCVのINTER_LANCZOS4を使用)
resized = cv2.resize(image, (256, 256), interpolation=cv2.INTER_LANCZOS4)

Expand Down