Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions robosuite/controllers/composite/composite_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(self, sim: MjSim, robot_model: RobotModel, grippers: Dict[str, Grip
# task space actions (such as end effector poses) to joint actions (such as joint angles or joint torques)

self._whole_body_controller_action_split_indexes: OrderedDict = OrderedDict()
self.input_action_goal = None

def _init_controllers(self):
for part_name in self.part_controller_config.keys():
Expand Down Expand Up @@ -323,6 +324,7 @@ def setup_whole_body_controller_action_split_idx(self):
previous_idx = last_idx

def set_goal(self, all_action):
self.input_action_goal = all_action.copy()
target_qpos = self.joint_action_policy.solve(all_action[: self.joint_action_policy.control_dim])
# create new all_action vector with the IK solver's actions first
all_action = np.concatenate([target_qpos, all_action[self.joint_action_policy.control_dim :]])
Expand Down Expand Up @@ -433,6 +435,7 @@ def _validate_composite_controller_specific_config(self) -> None:

# Loop through ref_names and validate against mujoco model
original_ref_names = self.composite_controller_specific_config.get("ref_name", [])
original_ref_names = [name.format(idn=self.robot_model.idn) for name in original_ref_names]
for ref_name in original_ref_names:
if ref_name in self.sim.model.site_names: # Check if the site exists in the mujoco model
valid_ref_names.append(ref_name)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"type": "WHOLE_BODY_IK",
"composite_controller_specific_configs": {
"ref_name": ["gripper0_right_grip_site", "gripper0_left_grip_site"],
"ref_name": ["gripper{idn}_right_grip_site", "gripper{idn}_left_grip_site"],
"interpolation": null,
"actuation_part_names": ["torso", "head", "right", "left", "base", "legs"],
"max_dq": 4,
Expand Down
35 changes: 20 additions & 15 deletions robosuite/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def _reset_internal_state(self):
self.active_arm_indices = [0] * len(self.all_robot_arms)
self.active_robot = 0
self.base_modes = [False] * len(self.all_robot_arms)

self._prev_target = {arm: None for arm in self.all_robot_arms[self.active_robot]}
# need to keep track of all previous targets. If not when using absolute actions,
# the robot will execute the other robots pose when switching robots
self._prev_target = [{arm: None for arm in self.all_robot_arms[i]} for i in range(self.num_robots)]

@property
def active_arm(self):
Expand Down Expand Up @@ -177,7 +178,9 @@ def input2action(self, mirror_actions=False) -> Optional[Dict]:

return ac_dict

def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target", robot_idx=None):
if robot_idx is None:
robot_idx = self.active_robot
assert np.all(norm_delta <= 1.0) and np.all(norm_delta >= -1.0)

assert goal_update_mode in [
Expand All @@ -195,7 +198,7 @@ def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
"abs": abs_action,
}
elif robot.composite_controller_config["type"] in ["WHOLE_BODY_MINK_IK", "HYBRID_WHOLE_BODY_MINK_IK"]:
ref_frame = self.env.robots[0].composite_controller.composite_controller_specific_config.get(
ref_frame = robot.composite_controller.composite_controller_specific_config.get(
"ik_input_ref_frame", "world"
)

Expand All @@ -204,8 +207,9 @@ def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
delta_action[3:6] *= 0.15

# general case
if goal_update_mode == "achieved" or self._prev_target[arm] is None:
site_name = f"gripper0_{arm}_grip_site"
if goal_update_mode == "achieved" or self._prev_target[self.active_robot][arm] is None:
id = robot.robot_model.idn
site_name = f"gripper{id}_{arm}_grip_site"
# update next target based on current achieved pose
pos = self.env.sim.data.get_site_xpos(site_name).copy()
ori = self.env.sim.data.get_site_xmat(site_name).copy()
Expand All @@ -214,16 +218,16 @@ def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
pose_in_world = np.eye(4)
pose_in_world[:3, 3] = pos
pose_in_world[:3, :3] = ori
pose_in_base = self.env.robots[0].composite_controller.joint_action_policy.transform_pose(
pose_in_base = robot.composite_controller.joint_action_policy.transform_pose(
src_frame_pose=pose_in_world,
src_frame="world", # mocap pose is world coordinates
dst_frame=ref_frame,
)
pos, ori = pose_in_base[:3, 3], pose_in_base[:3, :3]
else:
# update next target based on previous target pose
pos = self._prev_target[arm][0:3].copy()
ori = T.quat2mat(T.axisangle2quat(self._prev_target[arm][3:6].copy()))
pos = self._prev_target[self.active_robot][arm][0:3].copy()
ori = T.quat2mat(T.axisangle2quat(self._prev_target[self.active_robot][arm][3:6].copy()))

# new positions computed in world frame coordinates
new_pos = pos + delta_action[0:3]
Expand All @@ -232,22 +236,23 @@ def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
new_axisangle = T.quat2axisangle(T.mat2quat(new_ori))

abs_action = np.concatenate([new_pos, new_axisangle])
self._prev_target[arm] = abs_action.copy()
self._prev_target[self.active_robot][arm] = abs_action.copy()

return {
"delta": delta_action,
"abs": abs_action,
}
elif robot.composite_controller_config["type"] in ["WHOLE_BODY_IK"]:
if goal_update_mode == "achieved" or self._prev_target[arm] is None:
site_name = f"gripper0_{arm}_grip_site"
if goal_update_mode == "achieved" or self._prev_target[self.active_robot][arm] is None:
id = robot.robot_model.idn
site_name = f"gripper{id}_{arm}_grip_site"
# update next target based on current achieved pose
pos = self.env.sim.data.get_site_xpos(site_name).copy()
ori = self.env.sim.data.get_site_xmat(site_name).copy()
else:
# update next target based on previous target pose
pos = self._prev_target[arm][0:3].copy()
ori = T.quat2mat(T.axisangle2quat(self._prev_target[arm][3:6].copy()))
pos = self._prev_target[self.active_robot][arm][0:3].copy()
ori = T.quat2mat(T.axisangle2quat(self._prev_target[self.active_robot][arm][3:6].copy()))

delta_action = norm_delta.copy()
delta_action[0:3] *= 0.05
Expand All @@ -260,7 +265,7 @@ def get_arm_action(self, robot, arm, norm_delta, goal_update_mode="target"):
new_axisangle = T.quat2axisangle(T.mat2quat(new_ori))

abs_action = np.concatenate([new_pos, new_axisangle])
self._prev_target[arm] = abs_action.copy()
self._prev_target[self.active_robot][arm] = abs_action.copy()

return {
"delta": delta_action,
Expand Down
115 changes: 114 additions & 1 deletion robosuite/scripts/collect_human_demonstrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import argparse
import copy
import datetime
import json
import os
Expand All @@ -15,11 +16,95 @@
import numpy as np

import robosuite as suite
import robosuite.utils.transform_utils as T
from robosuite.controllers import load_composite_controller_config
from robosuite.controllers.composite.composite_controller import WholeBody
from robosuite.controllers.parts.arm import InverseKinematicsController, OperationalSpaceController
from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER
from robosuite.wrappers import DataCollectionWrapper, VisualizationWrapper


def get_arm_ref(env, robot, arm):
"""
Extracts the reference pose of the specified arm of the robot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's reference pose? the world pose target? docs might be confusing to users

"""
if robot.composite_controller_config["type"] in ["WHOLE_BODY_MINK_IK", "HYBRID_WHOLE_BODY_MINK_IK"]:
ref_frame = robot.composite_controller.composite_controller_specific_config.get("ik_input_ref_frame", "world")

# general case
id = robot.robot_model.idn
site_name = f"gripper{id}_{arm}_grip_site"
# update next target based on current achieved pose
pos = env.sim.data.get_site_xpos(site_name).copy()
ori = env.sim.data.get_site_xmat(site_name).copy()
# convert target in world coordinate to
pose_in_world = np.eye(4)
pose_in_world[:3, 3] = pos
pose_in_world[:3, :3] = ori
pose_in_base = robot.composite_controller.joint_action_policy.transform_pose(
src_frame_pose=pose_in_world,
src_frame="world", # mocap pose is world coordinates
dst_frame=ref_frame,
)
pos, ori = pose_in_base[:3, 3], pose_in_base[:3, :3]
axisangle = T.quat2axisangle(T.mat2quat(ori))

abs_action = np.concatenate([pos, axisangle])

return abs_action
elif robot.composite_controller_config["type"] in ["WHOLE_BODY_IK"]:
id = robot.robot_model.idn
site_name = f"gripper{id}_{arm}_grip_site"
# update next target based on current achieved pose
pos = env.sim.data.get_site_xpos(site_name).copy()
ori = env.sim.data.get_site_xmat(site_name).copy()

axisangle = T.quat2axisangle(T.mat2quat(ori))

abs_action = np.concatenate([pos, axisangle])

return abs_action
else:
raise NotImplementedError


def get_abs_arm_static_target(robot, arm, env):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why's the fn called get_abs_arm_static_target? what's an absolute arm? does it mean arm's controller is absolute and doesn't work for relative action space? Can the function just be to get_arm_target() or get_arm_current_target?

"""
Extracts the internal goal absolute target for the specified arm of the robot. If the
internal goal is not set, it extracts the current achieved pose of the arm.
This is needed to maintain the arm position of a robot when the device is not controlling it.


Args:
robot (Robot): the robot to extract the arm target from
arm (str): the arm to extract the target for

Returns:
np.ndarray: the absolute target for the arm
"""
abs_action = np.zeros(6)
if isinstance(robot.composite_controller, WholeBody):
controller_input_type = robot.composite_controller.joint_action_policy.input_type
prev_action = robot.composite_controller.input_action_goal
if prev_action is None:
abs_action = get_arm_ref(env, robot, arm)
else:
start_idx, end_idx = robot.composite_controller._whole_body_controller_action_split_indexes[arm]
abs_action = prev_action[start_idx:end_idx]
elif isinstance(robot.part_controllers[arm], OperationalSpaceController):
controller_input_type = robot.part_controllers[arm].input_type
abs_action = robot.part_controllers[arm].delta_to_abs_action(np.zeros(6), "achieved")
else:
ROBOSUITE_DEFAULT_LOGGER.warning(
"Unable to extract absolute target for arm {} of robot {} returning zeros. This is only a problem is it is a TwoArmEnv".format(
arm, robot.robot_model.idn
)
)

assert controller_input_type == "absolute", f"Only calculating absolute targets"
return abs_action


def collect_human_trajectory(env, device, arm, max_fr):
"""
Use the device (keyboard or SpaceNav 3D mouse) to collect a demonstration.
Expand Down Expand Up @@ -82,9 +167,29 @@ def collect_human_trajectory(env, device, arm, max_fr):
action_dict[arm] = input_ac_dict[f"{arm}_abs"]
else:
raise ValueError
all_robot_action_dicts = []
# set inactive robot action dicts
for i, robot in enumerate(env.robots):
if i == device.active_robot:
all_robot_action_dicts.append(action_dict)
continue

inactive_robot_ac_dict = copy.deepcopy(all_prev_gripper_actions[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting inactive robot action as prev gripper action is confusing to me, as in active robot action seems more than just gripper. Can we rename things to be more intuitive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good points my bad, Ill refactor accordingly

if isinstance(robot.composite_controller, WholeBody):
controller_input_type = robot.composite_controller.joint_action_policy.input_type
else:
controller_input_type = robot.part_controllers[arm].input_type

# update action when controller input type is absolute so that the inactive robot does not move
if controller_input_type == "absolute":
for arm in robot.arms:
# set goal mode to acheived to avoid changing the target
inactive_robot_ac_dict[arm] = get_abs_arm_static_target(robot, arm, env)

all_robot_action_dicts.append(inactive_robot_ac_dict)

# Maintain gripper state for each robot but only update the active robot with action
env_action = [robot.create_action_vector(all_prev_gripper_actions[i]) for i, robot in enumerate(env.robots)]
env_action = [robot.create_action_vector(all_robot_action_dicts[i]) for i, robot in enumerate(env.robots)]
env_action[device.active_robot] = active_robot.create_action_vector(action_dict)
env_action = np.concatenate(env_action)
for gripper_ac in all_prev_gripper_actions[device.active_robot]:
Expand Down Expand Up @@ -301,6 +406,14 @@ def gather_demonstrations_as_hdf5(directory, out_dir, env_info):
# Check if we're using a multi-armed environment and use env_configuration argument if so
if "TwoArm" in args.environment:
config["env_configuration"] = args.config
# make fresh copies of the controller config for each robot since they will have different references
config["controller_configs"] = [
load_composite_controller_config(
controller=args.controller,
robot=robot,
)
for robot in args.robots
]

# Create environment
env = suite.make(
Expand Down
Loading