From cb5632a137dfbe7d4ce7c3334048e3b7acfed5ca Mon Sep 17 00:00:00 2001 From: Haizhou Ge <1352674740@qq.com> Date: Thu, 22 Aug 2024 13:33:49 +0800 Subject: [PATCH] big refector and need to fix bugs --- envs/{common_env.py => common.py} | 22 +- envs/make_env.py | 49 ++-- habitats/README.md | 0 habitats/common/creator.py | 41 +++ habitats/common/utils.py | 6 + habitats/fake/common.py | 25 ++ .../ros/common/robot_config.py | 0 habitats/ros/ros1/moveit1.py | 0 .../ros/ros1/topicor.py | 13 +- habitats/ros/ros2/moveit2.py | 215 ++++++++++++++++ .../ros/ros2/topicor.py | 8 +- policies/onnx/onnx_policy.py | 136 ++++++++++ policy_evaluate.py | 2 +- robot_utils.py | 24 -- robots/airbots/airbot_kits/airbot_mmk.py | 6 + robots/airbots/airbot_kits/airbot_tok.py | 5 + robots/airbots/airbot_play.py | 76 ------ robots/airbots/airbot_play/airbot_play.py | 157 +++++++++++ .../airbot_play/airbot_play_moveit1.py | 0 .../airbot_play/airbot_play_moveit2.py | 0 .../airbots/airbot_play/airbot_play_ros1.py | 0 .../airbots/airbot_play/airbot_play_ros2.py | 0 robots/common.py | 18 ++ robots/common_robot.py | 243 ------------------ test_fake.ipynb | 170 ++++++++++++ utils.py | 4 + 26 files changed, 833 insertions(+), 387 deletions(-) rename envs/{common_env.py => common.py} (77%) create mode 100644 habitats/README.md create mode 100644 habitats/common/creator.py create mode 100644 habitats/common/utils.py create mode 100644 habitats/fake/common.py rename robots/ros_robots/ros_robot_config.py => habitats/ros/common/robot_config.py (100%) create mode 100644 habitats/ros/ros1/moveit1.py rename robots/ros_robots/ros1_robot.py => habitats/ros/ros1/topicor.py (92%) create mode 100644 habitats/ros/ros2/moveit2.py rename robots/ros_robots/ros2_robot.py => habitats/ros/ros2/topicor.py (96%) create mode 100644 policies/onnx/onnx_policy.py create mode 100644 robots/airbots/airbot_kits/airbot_mmk.py create mode 100644 robots/airbots/airbot_kits/airbot_tok.py delete mode 100644 robots/airbots/airbot_play.py create mode 100644 robots/airbots/airbot_play/airbot_play.py create mode 100644 robots/airbots/airbot_play/airbot_play_moveit1.py create mode 100644 robots/airbots/airbot_play/airbot_play_moveit2.py create mode 100644 robots/airbots/airbot_play/airbot_play_ros1.py create mode 100644 robots/airbots/airbot_play/airbot_play_ros2.py create mode 100644 robots/common.py delete mode 100644 robots/common_robot.py create mode 100644 test_fake.ipynb diff --git a/envs/common_env.py b/envs/common.py similarity index 77% rename from envs/common_env.py rename to envs/common.py index dff32a8..2e67a79 100644 --- a/envs/common_env.py +++ b/envs/common.py @@ -5,20 +5,29 @@ from typing import List from einops import rearrange import dm_env -from robots.common_robot import AssembledRobot +from robots.common import Robot +import logging + +logger = logging.getLogger(__name__) class CommonEnvConfig(object): - def __init__(self) -> None: - self.robots = [] + def __init__(self, robots:List[Robot]) -> None: + """For most real robot environments, the config only needs to specify the robot instances, + which have all sensors' configurations.""" + self.robots = robots + + def __post_init__(self): + assert self.robots, "There should be at least one robot in the environment." class CommonEnv: """ An environment is a combination of robots, scenes and objects. It should be able to reset and step. - The environment will return observations based on the state of the robot, the position of the sensors, and the current scene and object conditions. And for RL and data collection, it should also send rewards and done signals. + The environment will return observations based on the state of the robots, the position of the sensors, and the current scene and object conditions. And for RL and data collection, it should also send rewards and done signals to Gaia. """ - def __init__(self, *args, **kwargs) -> None: + + def __init__(self, config: CommonEnvConfig) -> None: raise NotImplementedError def reset(self) -> dm_env.TimeStep: @@ -27,9 +36,6 @@ def reset(self) -> dm_env.TimeStep: def step(self, action) -> dm_env.TimeStep: raise NotImplementedError - def get_reward(self): - raise NotImplementedError - def get_image(ts: dm_env.TimeStep, camera_names, mode=0): # TODO: remove this function diff --git a/envs/make_env.py b/envs/make_env.py index 8138752..0a3761f 100644 --- a/envs/make_env.py +++ b/envs/make_env.py @@ -1,5 +1,5 @@ from typing import List -from robots.common_robot import AssembledRobot +from robots.common import Robot def make_environment(env_config): @@ -22,47 +22,44 @@ def make_environment(env_config): ), "The length of start_joint should be equal to joint_num or joint_num*robot_num" print(f"Start joint: {start_joint}") - robot_instances: List[AssembledRobot] = [] + robot_instances: List[Robot] = [] if "airbot_play" in robot_name: - # set up can - # from utils import CAN_Tools - import airbot - + from robots.airbots.airbot_play.airbot_play import AIRBOTPlayPos, AIRBOTPlayConfig vel = 2.0 for i in range(robot_num): - # if 'v' not in can: - # if not CAN_Tools.check_can_status(can): - # success, error = CAN_Tools.activate_can_interface(can, 1000000) - # if not success: raise Exception(error) - airbot_player = airbot.create_agent( - "/usr/share/airbot_models/airbot_play_with_gripper.urdf", - "down", - can_buses[i], - vel, - eef_mode[i], - bigarm_type[i], - forearm_type[i], + airbot_player = AIRBOTPlayPos( + AIRBOTPlayConfig( + model_path="/usr/share/airbot_models/airbot_play_with_gripper.urdf", + gravity_mode="down", + can_bus=can_buses[i], + vel=vel, + eef_mode=eef_mode, + bigarm_type=bigarm_type, + forearm_type=forearm_type, + joint_vel=6.0, + dt=25, + ) ) robot_instances.append( - AssembledRobot( + Robot( airbot_player, 1 / fps, start_joint[joint_num * i : joint_num * (i + 1)], ) ) elif "fake" in robot_name or "none" in robot_name: - from robots.common_robot import AssembledFakeRobot + from robots.airbots.airbot_play.airbot_play import AIRBOTPlayPosFake if check_images: - AssembledFakeRobot.real_camera = True + AIRBOTPlayPosFake.real_camera = True for i in range(robot_num): robot_instances.append( - AssembledFakeRobot( + AIRBOTPlayPosFake( 1 / fps, start_joint[joint_num * i : joint_num * (i + 1)] ) ) elif "ros" in robot_name: - from robots.common_robot import AssembledRosRobot + from robots.airbots.airbot_play.airbot_play_ros1 import AIRBOTPlayPos import rospy rospy.init_node("replay_episodes") @@ -72,7 +69,7 @@ def make_environment(env_config): gripper_action_topic = f"{namespace}/gripper_group_position_controller/command" for i in range(robot_num): robot_instances.append( - AssembledRosRobot( + AIRBOTPlayPos( states_topic, arm_action_topic, gripper_action_topic, @@ -82,10 +79,10 @@ def make_environment(env_config): ) ) elif "mmk" in robot_name: - from robots.common_robot import AssembledMmkRobot + from robots.airbots.airbot_kits.airbot_mmk import AIRBOTMMK2, AIRBOTMMK2Config for i in range(robot_num): - robot_instances.append(AssembledMmkRobot()) + robot_instances.append(AIRBOTMMK2()) elif robot_name == "none": print("No direct robot is used") else: diff --git a/habitats/README.md b/habitats/README.md new file mode 100644 index 0000000..e69de29 diff --git a/habitats/common/creator.py b/habitats/common/creator.py new file mode 100644 index 0000000..4fc119b --- /dev/null +++ b/habitats/common/creator.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import List +from importlib import import_module + + +@dataclass +class CommonConfig(object): + module_path = None + class_name = None + instance_name = None + default_act = None + + def __post_init__(self): + assert self.default_act is not None, "default_act must be set" + assert self.instance_name is not None, "instance_name must be set" + assert self.module_path is not None, "class_path must be set" + assert self.class_name is not None, "class_name must be set" + + +class Configer(object): + @staticmethod + def config2dict(config): + return {key: value for key, value in config.__dict__.items()} + + @staticmethod + def config2tuple(config): + return tuple(config.__dict__.values()) + + +class Creator(object): + @staticmethod + def instancer(configs: List[CommonConfig]) -> list: + instances = [] + for config in configs: + module = import_module(config.module_path) + instances.append(getattr(module, config.class_name)(config)) + return instances + + @classmethod + def give_eyes(cls, robot, configs): + setattr(robot, "eyes", cls.instancer(configs)) diff --git a/habitats/common/utils.py b/habitats/common/utils.py new file mode 100644 index 0000000..146f74b --- /dev/null +++ b/habitats/common/utils.py @@ -0,0 +1,6 @@ +def get_values_by_names(sub_names: tuple, all_names: tuple, all_values: tuple) -> tuple: + """根据子名称列表获取所有值列表中对应的值列表,返回子值列表""" + sub_values = [0.0 for _ in range(len(sub_names))] + for i, name in enumerate(sub_names): + sub_values[i] = all_values[all_names.index(name)] + return tuple(sub_values) diff --git a/habitats/fake/common.py b/habitats/fake/common.py new file mode 100644 index 0000000..0a0e5f3 --- /dev/null +++ b/habitats/fake/common.py @@ -0,0 +1,25 @@ +from typing import Optional +import logging +from habitats.common.creator import CommonConfig + +logger = logging.getLogger(__name__) + + +class FakeCommon(object): + def __init__(self, config: CommonConfig) -> None: + self.reset(config) + + def reset(self, config: Optional[CommonConfig] = None): + if config is not None: + self.config = config + self._state = self.config.default_act + logger.debug(f"Reset {self.config.instance_name}: {self._state}") + + def step(self, action): + self._state = action + logger.debug(f"Step {self.config.instance_name}: {self._state}") + + @property + def state(self): + logger.debug(f"Get {self.config.instance_name} state: {self._state}") + return self._state diff --git a/robots/ros_robots/ros_robot_config.py b/habitats/ros/common/robot_config.py similarity index 100% rename from robots/ros_robots/ros_robot_config.py rename to habitats/ros/common/robot_config.py diff --git a/habitats/ros/ros1/moveit1.py b/habitats/ros/ros1/moveit1.py new file mode 100644 index 0000000..e69de29 diff --git a/robots/ros_robots/ros1_robot.py b/habitats/ros/ros1/topicor.py similarity index 92% rename from robots/ros_robots/ros1_robot.py rename to habitats/ros/ros1/topicor.py index e407ba1..ee22678 100644 --- a/robots/ros_robots/ros1_robot.py +++ b/habitats/ros/ros1/topicor.py @@ -4,13 +4,14 @@ from sensor_msgs.msg import JointState from geometry_msgs.msg import Pose, PoseStamped from ros_tools import Lister -from convert_all import flatten_dict -from ros_robot_config import EEF_POSE_POSITION, EEF_POSE_ORIENTATION, ACTIONS_TOPIC_CONFIG, OBSERVATIONS_TOPIC_CONFIG, EXAMPLE_CONFIG +from data_process.convert_all import flatten_dict +# TODO: add these to the config param, or configure the path +from habitats.ros.common.robot_config import EEF_POSE_POSITION, EEF_POSE_ORIENTATION, ACTIONS_TOPIC_CONFIG, OBSERVATIONS_TOPIC_CONFIG, EXAMPLE_CONFIG -class AssembledROS1Robot(object): +class ROS1Interface(object): """Use the keys and values in config as the keys to get all the configuration""" - + # TODO: make this a common class for all ROS interfaces include both low dim and high dim data, such as images, point clouds, etc. def __init__(self, config: Dict[str, dict] = None) -> None: self.params = config["param"] self.actions_dim = flatten_dict(self.params["actions_dim"]) @@ -86,6 +87,7 @@ def __init__(self, config: Dict[str, dict] = None) -> None: def _target_cmd_pub_thread(self): """Publish thread for all publishers""" + # TODO: change this to a timer rate = rospy.Rate(self.params["control_freq"]) while not rospy.is_shutdown(): for key, pub in self.action_pubs.items(): @@ -99,6 +101,7 @@ def _target_cmd_pub_thread(self): def _current_state_callback(self, data, args): """Callback function used for all subcribers""" + # TODO: for joint states, support the case where the joint names are not in a specific order key, lister, preprocess = args self.current_data[key] = preprocess(lister(data)) rospy.logdebug(f"Current data: {self.current_data[key]}") @@ -183,7 +186,7 @@ def reset(self) -> list: if __name__ == "__main__": rospy.init_node("test_mmk") - ros1_robot = AssembledROS1Robot(EXAMPLE_CONFIG) + ros1_robot = ROS1Interface(EXAMPLE_CONFIG) ros1_robot.wait_for_current_states() current = ros1_robot.get_current_states() print("Current states:", current) diff --git a/habitats/ros/ros2/moveit2.py b/habitats/ros/ros2/moveit2.py new file mode 100644 index 0000000..f6097c0 --- /dev/null +++ b/habitats/ros/ros2/moveit2.py @@ -0,0 +1,215 @@ +import logging, os, time +from typing import List + +# moveit usage +from rclpy.impl.rcutils_logger import RcutilsLogger +from ament_index_python.packages import get_package_share_directory +from moveit.planning import ( + MoveItPy, + PlanningComponent, + PlanningSceneMonitor, + TrajectoryExecutionManager, +) +from moveit_configs_utils import MoveItConfigsBuilder +from moveit_msgs.msg import MotionPlanResponse +from moveit_configs_utils import MoveItConfigsBuilder +from moveit.core.robot_state import RobotState +from moveit.core.robot_model import RobotModel, JointModelGroup, VariableBounds +from moveit.core.planning_scene import PlanningScene +from moveit.core.kinematic_constraints import construct_joint_constraint + +# from moveit.core.robot_trajectory import RobotTrajectory +from moveit.core.controller_manager import ExecutionStatus + +# from moveit.core.planning_interface import MotionPlanResponse +from geometry_msgs.msg import Pose, PoseStamped + + +class AirbotPlayMoveit(object): + + def __init__(self, robot_name: str, node_name: str) -> None: + # instantiate MoveItPy instance and get planning component + moveit_config = ( + MoveItConfigsBuilder(robot_name) + .robot_description(file_path="config/airbot.urdf.xacro") + .joint_limits() + .robot_description_kinematics() + .trajectory_execution(file_path="config/moveit_controllers.yaml") + .pilz_cartesian_limits() + .planning_pipelines( + pipelines=["ompl", "pilz_industrial_motion_planner", "stomp"] + ) + .planning_scene_monitor( + publish_planning_scene=True, + publish_geometry_updates=True, + publish_state_updates=True, + publish_transforms_updates=True, + publish_robot_description_semantic=True, + ) + .moveit_cpp( # must be set + file_path=os.path.join( + get_package_share_directory("airbot_moveit_config"), + "config", + "moveit_cpp.yaml", + ) + ) + .to_moveit_configs() + ).to_dict() + self.robot = MoveItPy(node_name=node_name, config_dict=moveit_config) + self.arm: PlanningComponent = self.robot.get_planning_component("arm") + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.DEBUG) + self.arm.planning_group_name + self.plan_scene_monitor: PlanningSceneMonitor = ( + self.robot.get_planning_scene_monitor() + ) + self.traj_manager: TrajectoryExecutionManager = ( + self.robot.get_trajactory_execution_manager() + ) + self.robot_model: RobotModel = self.robot.get_robot_model() + self.arm_joint_model_group: JointModelGroup = ( + self.robot_model.get_joint_model_group("arm") + ) + self.logger.info(f"robot_model.end_effectors: {self.robot_model.end_effectors}") + self.arm_links = self.arm_joint_model_group.link_model_names + self.arm_active_joints = self.arm_joint_model_group.active_joint_model_names + self.logger.info(f"all_arm_links: {self.arm_links}") + self.pose_target_link = self.arm_links[-1] + self.pose_reference_link = self.arm_links[0] + # self.plan_scene.wait_for_current_robot_state(rclpy.clock.Clock().now().to_msg(), 1.0) # TODO: error now + self.arm_active_joints_num = len(self.arm_active_joints) + self.logger.info("AIRBOT Play Moveit instance created") + + def go_home(self): + self.arm.set_start_state_to_current_state() + self.arm.set_goal_state(configuration_name="start") + self.move() + + def move(self, wait=False, sleep_time=0): + self.plan_and_execute( + self.robot, self.arm, self.logger, wait=wait, sleep_time=sleep_time + ) + + def go(self, target, wait=False, sleep_time=0, is_pose=False): + self.arm.set_start_state_to_current_state() + if is_pose: + pose = Pose() + pose.position.x = target[0] + pose.position.y = target[1] + pose.position.z = target[2] + pose.orientation.x = target[3] + pose.orientation.y = target[4] + pose.orientation.z = target[5] + pose.orientation.w = target[6] + pose_goal = PoseStamped() + pose_goal.pose = pose + pose_goal.header.frame_id = self.pose_reference_link + self.arm.set_goal_state( + pose_stamped_msg=pose_goal, pose_link=self.pose_target_link + ) + elif isinstance(target, str): + self.arm.set_goal_state(configuration_name=target) + else: + state = RobotState(self.robot_model) + if not isinstance(target, dict): + target = dict(zip(self.arm_active_joints, target)) + state.joint_positions = target + joint_constraint = construct_joint_constraint( + robot_state=state, + joint_model_group=self.arm_joint_model_group, + ) + self.arm.set_goal_state(motion_plan_constraints=[joint_constraint]) + self.move(wait, sleep_time) + + def get_current_pose(self): + current_pose = None + with self.plan_scene_monitor.read_only() as scene: + scene: PlanningScene + current_state: RobotState = scene.current_state + current_state.update() + pose: Pose = current_state.get_pose(self.pose_target_link) + current_pose = [ + pose.position.x, + pose.position.y, + pose.position.z, + pose.orientation.x, + pose.orientation.y, + pose.orientation.z, + pose.orientation.w, + ] + return current_pose + + def get_current_joint_positions(self): + current_joint_positions = None + with self.plan_scene_monitor.read_only() as scene: + scene: PlanningScene + current_state: RobotState = scene.current_state + current_state.update() + current_joint_positions: dict = current_state.joint_positions + return list(current_joint_positions.values()) + + def get_joint_limits(self): + bounds: List[VariableBounds] = ( + self.arm_joint_model_group.active_joint_model_bounds + ) + joint_bounds = {} + for index, joint_name in enumerate(self.arm_active_joints): + joint_bounds[joint_name] = {} + joint_bounds[joint_name]["position"] = bounds[index].position_bounded + joint_bounds[joint_name]["velocity"] = bounds[index].velocity_bounded + joint_bounds[joint_name]["acceleration"] = bounds[ + index + ].acceleration_bounded + joint_bounds[joint_name]["jerk"] = bounds[index].jerk_bounded + return joint_bounds + + def set_pose_target_link(self, link_name): + assert ( + link_name in self.arm_links + ), f"link_name: {link_name} not in all_arm_links" + self.pose_target_link = link_name + + def wait(self): + self.traj_manager.wait_for_execution() + + def stop(self): + self.traj_manager.stop_execution() + + def get_last_execution_status(self) -> ExecutionStatus: + return self.traj_manager.get_last_execution_status() + + @staticmethod + def plan_and_execute( + robot: MoveItPy, + planning_component: PlanningComponent, + logger: RcutilsLogger, + single_plan_parameters=None, + multi_plan_parameters=None, + wait=False, + sleep_time=0.0, + ): + """Helper function to plan and execute a motion.""" + # plan to goal + if multi_plan_parameters is not None: + plan_result = planning_component.plan( + multi_plan_parameters=multi_plan_parameters + ) + elif single_plan_parameters is not None: + plan_result = planning_component.plan( + single_plan_parameters=single_plan_parameters + ) + else: + plan_result = planning_component.plan() + + # execute the plan + if plan_result.error_code.val == 1: + plan_result: MotionPlanResponse + logger.info("Executing plan") + robot_trajectory = plan_result.trajectory + robot.execute(robot_trajectory, controllers=[]) + if wait: + robot.get_trajactory_execution_manager().wait_for_execution() + else: + logger.error("Planning failed") + + time.sleep(sleep_time) diff --git a/robots/ros_robots/ros2_robot.py b/habitats/ros/ros2/topicor.py similarity index 96% rename from robots/ros_robots/ros2_robot.py rename to habitats/ros/ros2/topicor.py index db444a2..b9dc00b 100644 --- a/robots/ros_robots/ros2_robot.py +++ b/habitats/ros/ros2/topicor.py @@ -6,11 +6,11 @@ from sensor_msgs.msg import JointState from geometry_msgs.msg import Pose, PoseStamped from ros_tools import Lister -from convert_all import flatten_dict -from ros_robot_config import EEF_POSE_POSITION, EEF_POSE_ORIENTATION, ACTIONS_TOPIC_CONFIG, OBSERVATIONS_TOPIC_CONFIG, EXAMPLE_CONFIG +from data_process.convert_all import flatten_dict +from habitats.ros.common.robot_config import EEF_POSE_POSITION, EEF_POSE_ORIENTATION, ACTIONS_TOPIC_CONFIG, OBSERVATIONS_TOPIC_CONFIG, EXAMPLE_CONFIG -class AssembledROS2Robot(object): +class ROS2Interface(object): """Use the keys and values in config as the keys to get all the configuration""" def __init__(self, config: Dict[str, dict] = None) -> None: @@ -184,7 +184,7 @@ def shutdown(self): def main(args=None): rclpy.init(args=args) - ros2_robot = AssembledROS2Robot(EXAMPLE_CONFIG) + ros2_robot = ROS2Interface(EXAMPLE_CONFIG) ros2_robot.wait_for_current_states() current = ros2_robot.get_current_states() ros2_robot.node.get_logger().info(f"Current states: {current}") diff --git a/policies/onnx/onnx_policy.py b/policies/onnx/onnx_policy.py new file mode 100644 index 0000000..69240b9 --- /dev/null +++ b/policies/onnx/onnx_policy.py @@ -0,0 +1,136 @@ +import onnx +import onnxruntime as ort +import logging +import torch + +logging.basicConfig(level=logging.DEBUG) + + +class ONNX(object): + # define the observation space and action space + # these are used to check the input and output data interaction + # between the policy and the environment + observation_space = {"state": 7, "image": [1, 3, 480, 640]} + action_space = 7 + def __init__(self, path) -> None: + self.path = path + # 加载和检查模型 + onnx_model = onnx.load(path) + onnx.checker.check_model(onnx_model) + # 获取输出层,包含层名称、维度信息 + # output = onnx_model.graph.output + # logging.info(output) + + def eval(self): + # 创建推理会话 + self.input_names = [] + self.output_names = [] + self.ort_session = ort.InferenceSession(self.path) + for i in self.ort_session.get_inputs(): + self.input_names.append(i.name) + for i in self.ort_session.get_outputs(): + self.output_names.append(i.name) + + def __call__(self, qpos, image, actions=None, is_pad=None) -> torch.Tensor: + # TODO: change the input data to just one dictionary + + # input_data = { + # self.input_names[0]: qpos, + # self.input_names[1]: image, + # self.input_names[2]: actions, + # } + # Convert PyTorch tensors to NumPy arrays + # qpos = qpos.cpu().numpy() + # qpos = torch.tensor(qpos, device="cuda").cpu().numpy().astype(np.float32) + qpos = qpos.astype("float32") + image = image.cpu().numpy().astype(np.float32) + input_feed = {} + # for name in self.input_names: + # input_feed[name] = input_data + # logging.debug(f"input_names: {self.input_names}") + input_feed = { + self.input_names[0]: qpos, + self.input_names[1]: image, + } + output = self.ort_session.run(self.output_names, input_feed) + output = torch.tensor(output[0], device="cuda") + # logging.debug(f"output shape: {output.shape}") + return output + + +if __name__ == "__main__": + + import numpy as np + + onnx_policy = ONNX( + "/home/ghz/Work/OpenGHz/Imitate-All/onnx_output/act_policy_4d.onnx" + ) + onnx_policy.eval() + + qpos = np.array( + [ + -0.000190738, + -0.766194, + 0.702869, + 1.53601, + -0.964942, + -1.57607, + 1.01381, + ] + ) + img = torch.randn([1, 3, 480, 640], device="cuda") + + qpos_mean = [ + 0.01466561, + -1.1554501, + 1.1064852, + 1.5773835, + -0.97277683, + -1.5242718, + 0.48118713, + ] + qpos_std = [ + 0.14543493, + 0.28088057, + 0.27077574, + 0.190825, + 0.20736837, + 0.22112855, + 0.36785737, + ] + + action_mean = np.array( + [ + 0.01609754, + -1.1573728, + 1.1107943, + 1.5749228, + -0.97393966, + -1.517563, + 0.45444244, + ] + ) + action_std = np.array( + [ + 0.14884493, + 0.28656977, + 0.27433002, + 0.19996108, + 0.22217035, + 0.22898215, + 0.4153349, + ] + ) + + pre_process = lambda s_qpos: (s_qpos - qpos_mean) / qpos_std + post_process = lambda a: a * action_std + action_mean + + logging.debug(f"raw qpos: {qpos}") + qpos = pre_process(qpos) + logging.debug(f"pre qpos: {qpos}") + out_put = onnx_policy(qpos.reshape(1, -1), img) + raw_action = out_put.cpu().numpy()[0][0] + + logging.debug(f"raw action: {raw_action}") + action = post_process(raw_action) # de-standardize action + logging.debug(f"post action: {action}") diff --git a/policy_evaluate.py b/policy_evaluate.py index 1dd5289..a0ffb0b 100644 --- a/policy_evaluate.py +++ b/policy_evaluate.py @@ -7,7 +7,7 @@ from visualize_episodes import save_videos from task_configs.config_tools.basic_configer import basic_parser, get_all_config from policies.common.maker import make_policy -from envs.common_env import get_image, CommonEnv +from envs.common import get_image, CommonEnv logging.basicConfig(level=logging.INFO) diff --git a/robot_utils.py b/robot_utils.py index c0cd5f7..a2c5093 100644 --- a/robot_utils.py +++ b/robot_utils.py @@ -233,30 +233,6 @@ def dt_helper(l): print() -def calibrate_linear_vel(base_action: np.ndarray, c=None): - if c is None: - c = 0.0 - v = base_action[..., 0] - w = base_action[..., 1] - base_action = base_action.copy() - base_action[..., 0] = v - c * w - return base_action - -def smooth_base_action(base_action): - return np.stack( - [ - np.convolve(base_action[:, i], np.ones(5) / 5, mode="same") - for i in range(base_action.shape[1]) - ], - axis=-1, - ).astype(np.float32) - -def postprocess_base_action(base_action): - linear_vel, angular_vel = base_action - angular_vel *= 0.9 - return np.array([linear_vel, angular_vel]) - - if __name__ == "__main__": show_images = False recorder = ImageRecorderVideo(cameras=[0], is_debug=False, show_images=show_images) diff --git a/robots/airbots/airbot_kits/airbot_mmk.py b/robots/airbots/airbot_kits/airbot_mmk.py new file mode 100644 index 0000000..f34a8fa --- /dev/null +++ b/robots/airbots/airbot_kits/airbot_mmk.py @@ -0,0 +1,6 @@ + +class AIRBOTMMK2Config(object): + pass + +class AIRBOTMMK2(object): + pass \ No newline at end of file diff --git a/robots/airbots/airbot_kits/airbot_tok.py b/robots/airbots/airbot_kits/airbot_tok.py new file mode 100644 index 0000000..992e897 --- /dev/null +++ b/robots/airbots/airbot_kits/airbot_tok.py @@ -0,0 +1,5 @@ +class AIRBOTTOK2Config(object): + pass + +class AIRBOTTOK2(object): + pass \ No newline at end of file diff --git a/robots/airbots/airbot_play.py b/robots/airbots/airbot_play.py deleted file mode 100644 index 58ee6ef..0000000 --- a/robots/airbots/airbot_play.py +++ /dev/null @@ -1,76 +0,0 @@ -import airbot -from robots.common_robot import Configer - - -class AIRBOTPlayConfig(object): - def __init__(self) -> None: - # init * 7 - self.model_path = "/usr/share/airbot_models/airbot_play_with_gripper.urdf" - self.gravity_mode = "down" - self.can_bus = "can0" - self.vel = 2.0 - self.eef_mode = "none" - self.bigarm_type = "OD" - self.forearm_type = "DM" - # other - self.joint_vel = 6.0 - - -class AIRBOTPlay(object): - def __init__(self, config: AIRBOTPlayConfig) -> None: - self.config = config - self.robot = airbot.create_agent(*Configer.config2tuple(config)[:7]) - self._arm_joints_num = 6 - self._joints_num = 7 - self.end_effector_open = 1 - self.end_effector_close = 0 - - def _set_eef(self, target, ctrl_type): - if len(target) == 1: - target = target[0] - if ctrl_type == "pos": - self.robot.set_target_end(target) - elif ctrl_type == "vel": - self.robot.set_target_end_v(target) - elif ctrl_type == "eff": - self.robot.set_target_end_t(target) - else: - raise ValueError(f"Invalid type: {ctrl_type}") - - def get_current_joint_positions(self): - joints = self.robot.get_current_joint_q() - if self.config.eef_mode in ["gripper"]: - joints += [self.robot.get_current_end()] - return joints - - def get_current_joint_velocities(self): - joints = self.robot.get_current_joint_v() - if self.config.eef_mode in ["gripper"]: - joints += [self.robot.get_current_end_v()] - return joints - - def get_current_joint_efforts(self): - joints = self.robot.get_current_joint_t() - if self.config.eef_mode in ["gripper"]: - joints += [self.robot.get_current_end_t()] - return joints - - def set_joint_position_target(self, qpos, qvel=None, blocking=False): - if qvel is None: - qvel = self.config.joint_vel - use_planning = blocking - self.robot.set_target_joint_q( - qpos[: self._arm_joints_num], use_planning, qvel[0], blocking - ) - if len(qpos) - self._arm_joints_num > 0: - self._set_eef(qpos[self._arm_joints_num :], "pos") - - def set_joint_velocity_target(self, qvel, blocking=False): - self.robot.set_target_joint_v(qvel[: self._arm_joints_num]) - if len(qvel) - self._arm_joints_num > 0: - self._set_eef(qvel[self._arm_joints_num :], "vel") - - def set_joint_effort_target(self, qeffort, blocking=False): - self.robot.set_target_joint_t(qeffort[: self._arm_joints_num]) - if len(qeffort) - self._arm_joints_num > 0: - self._set_eef(qeffort[self._arm_joints_num :], "eff") diff --git a/robots/airbots/airbot_play/airbot_play.py b/robots/airbots/airbot_play/airbot_play.py new file mode 100644 index 0000000..fa1af3a --- /dev/null +++ b/robots/airbots/airbot_play/airbot_play.py @@ -0,0 +1,157 @@ +from habitats.common.creator import Configer +from habitats.fake.common import FakeCommon +from dataclasses import dataclass +import logging + +logger = logging.getLogger(__name__) + +try: + import airbot +except ImportError: + airbot = None + logger.warning("Failed to import airbot python package, you can only use the fake robot.") + +""" +This file contains the basic class for AIRBOTPlay robot, which is implemented in airbot python package and a fake class for testing. +""" + +@dataclass +class AIRBOTPlayConfig(object): + # init * 7 + model_path = "/usr/share/airbot_models/airbot_play_with_gripper.urdf" + gravity_mode = "down" + can_bus = "can0" + vel = 2.0 + eef_mode = "none" + bigarm_type = "OD" + forearm_type = "DM" + # other + joint_vel = 6.0 + dt = 25 + # common + module_path = logger.name + class_name = "AIRBOTPlayPos" + instance_name = "airbot_player" + default_act = [0.0] * 7 + + +class AIRBOTPlay(object): + def __init__(self, config: AIRBOTPlayConfig) -> None: + self.config = config + self.robot = airbot.create_agent(*Configer.config2tuple(config)[:7]) + self._arm_joints_num = 6 + self._joints_num = 7 + + def _set_eef(self, target, ctrl_type): + if len(target) == 1: + target = target[0] + if ctrl_type == "pos": + self.robot.set_target_end(target) + elif ctrl_type == "vel": + self.robot.set_target_end_v(target) + elif ctrl_type == "eff": + self.robot.set_target_end_t(target) + else: + raise ValueError(f"Invalid type: {ctrl_type}") + + def get_current_joint_positions(self): + joints = self.robot.get_current_joint_q() + if self.config.eef_mode in ["gripper"]: + joints += [self.robot.get_current_end()] + return joints + + def get_current_joint_velocities(self): + joints = self.robot.get_current_joint_v() + if self.config.eef_mode in ["gripper"]: + joints += [self.robot.get_current_end_v()] + return joints + + def get_current_joint_efforts(self): + joints = self.robot.get_current_joint_t() + if self.config.eef_mode in ["gripper"]: + joints += [self.robot.get_current_end_t()] + return joints + + def set_joint_position_target(self, qpos, qvel=None, blocking=False): + if qvel is None: + qvel = self.config.joint_vel + use_planning = blocking + self.robot.set_target_joint_q( + qpos[: self._arm_joints_num], use_planning, qvel[0], blocking + ) + if len(qpos) - self._arm_joints_num > 0: + self._set_eef(qpos[self._arm_joints_num :], "pos") + + def set_joint_velocity_target(self, qvel, blocking=False): + self.robot.set_target_joint_v(qvel[: self._arm_joints_num]) + if len(qvel) - self._arm_joints_num > 0: + self._set_eef(qvel[self._arm_joints_num :], "vel") + + def set_joint_effort_target(self, qeffort, blocking=False): + self.robot.set_target_joint_t(qeffort[: self._arm_joints_num]) + if len(qeffort) - self._arm_joints_num > 0: + self._set_eef(qeffort[self._arm_joints_num :], "eff") + + +class AIRBOTPlayPos(AIRBOTPlay): + def __init__(self, config: AIRBOTPlayConfig) -> None: + super().__init__(config) + # TODO: 应该在创建policy的后处理中进行动作的拆分,在预处理中进行状态的拼接,不需强制要求定义state_dim和action_dim + + def reset(self): + self.set_joint_position_target(self.config.default_act) + + def act(self, action): + self.set_joint_position_target(action) + + @property + def state(self): + return self.get_current_joint_positions() + + +class AIRBOTPlayVel(AIRBOTPlay): + def __init__(self, config: AIRBOTPlayConfig) -> None: + super().__init__(config) + + def reset(self): + self.set_joint_velocity_target(self.config.default_act) + + def act(self, action): + self.set_joint_velocity_target(action) + + @property + def state(self): + return self.get_current_joint_velocities() + + +class AIRBOTPlayMIT(AIRBOTPlay): + def __init__(self, config: AIRBOTPlayConfig) -> None: + super().__init__(config) + + def reset(self): + self.act(self.config.default_act) + + def act(self, action): + self.set_joint_mit_target(action) + + @property + def state(self): + return self.get_current_joint_positions() + self.get_current_joint_velocities() + + +class AIRBOTPlayPosFake(FakeCommon): + """A fake robot for AIRBOTPlayPos.""" + + +def make_robot(robot_config, robot_type): + """A factory function to create a robot instance based on the robot config and type.""" + if robot_type == "pos": + return AIRBOTPlayPos(robot_config) + elif robot_type == "vel": + return AIRBOTPlayVel(robot_config) + elif robot_type == "mit": + return AIRBOTPlayMIT(robot_config) + elif robot_type == "fake": + return AIRBOTPlayPosFake(robot_config) + else: + raise ValueError(f"Invalid robot type: {robot_type}") diff --git a/robots/airbots/airbot_play/airbot_play_moveit1.py b/robots/airbots/airbot_play/airbot_play_moveit1.py new file mode 100644 index 0000000..e69de29 diff --git a/robots/airbots/airbot_play/airbot_play_moveit2.py b/robots/airbots/airbot_play/airbot_play_moveit2.py new file mode 100644 index 0000000..e69de29 diff --git a/robots/airbots/airbot_play/airbot_play_ros1.py b/robots/airbots/airbot_play/airbot_play_ros1.py new file mode 100644 index 0000000..e69de29 diff --git a/robots/airbots/airbot_play/airbot_play_ros2.py b/robots/airbots/airbot_play/airbot_play_ros2.py new file mode 100644 index 0000000..e69de29 diff --git a/robots/common.py b/robots/common.py new file mode 100644 index 0000000..3855c31 --- /dev/null +++ b/robots/common.py @@ -0,0 +1,18 @@ +"""A robot is a physical instance that has its proprioception state and can interact with the environment by subjective initiative actions. The robot's state is the information that can be obtained from the robot's body sensors while the actions are the control commands that can be sent to the robot's actuators. Vision, touch and the other external sensation (obtained by tactile sensors, cameras, lidar, radar, ultrasonic sensors, etc.) are not included in the robot's state, but in the environment. However, in addition to being related to the external environment, external observation also depends on the robot's state and the position and posture of the corresponding sensors. So the robot instance should have the full information and configurations of its external sensors to let the environment obtaining correct observations.""" + +class Robot(object): + """Assume the __init__ method of the robot class is the same as the reset method. + So you can inherit this class to save writing the initialization function.""" + + def __init__(self, config) -> None: + self.reset(config) + + def reset(self, config): + raise NotImplementedError + + def step(self, action): + raise NotImplementedError + + @property + def state(self): + raise NotImplementedError diff --git a/robots/common_robot.py b/robots/common_robot.py deleted file mode 100644 index 285180f..0000000 --- a/robots/common_robot.py +++ /dev/null @@ -1,243 +0,0 @@ -class Configer(object): - @staticmethod - def config2dict(config): - return {key: value for key, value in config.__dict__.items()} - - @staticmethod - def config2tuple(config): - return tuple(config.__dict__.values()) - - -class AssembledRobot(object): - def __init__(self, airbot_player, dt, default_joints): - self.robot = airbot_player - self._arm_joints_num = 6 - self.joints_num = 7 - self.dt = dt - self.default_joints = default_joints - self.default_velocities = [1.0] * self.joints_num - self.end_effector_open = 1 - self.end_effector_close = 0 - - def get_current_joint_positions(self): - return self.robot.get_current_joint_q() + [self.robot.get_current_end()] - - def get_current_joint_velocities(self): - return self.robot.get_current_joint_v() + [self.robot.get_current_end_v()] - - def get_current_joint_efforts(self): - return self.robot.get_current_joint_t() + [self.robot.get_current_end_t()] - - def set_joint_position_target( - self, qpos, qvel=None, blocking=False - ): # TODO: add blocking - if qvel is None: - qvel = self.default_velocities - use_planning = blocking - self.robot.set_target_joint_q( - qpos[: self._arm_joints_num], use_planning, qvel[0], blocking - ) - if len(qpos) == self.joints_num: - # 若不默认归一化,则需要对末端进行归一化操作 - self.robot.set_target_end(qpos[self._arm_joints_num]) - - def set_joint_velocity_target(self, qvel, blocking=False): - self.robot.set_target_joint_v(qvel[: self._arm_joints_num]) - if len(qvel) == self.joints_num: - self.robot.set_target_end_v(qvel[self._arm_joints_num]) - - def set_joint_effort_target(self, qeffort, blocking=False): - self.robot.set_target_joint_t(qeffort[: self._arm_joints_num]) - if len(qeffort) == self.joints_num: - self.robot.set_target_end_t(qeffort[self._arm_joints_num]) - - -class AssembledFakeRobot(object): - real_camera = False - - def __init__(self, dt, default_joints): - self.robot = "fake robot" - self.joints_num = 7 - self.dt = dt - self.default_joints = default_joints - self.end_effector_open = 1 - self.end_effector_close = 0 - assert len(default_joints) == self.joints_num - self._show = False - - def show(self): - self._show = True - - def get_current_joint_positions(self): - return self.default_joints - - def get_current_joint_velocities(self): - return self.default_joints - - def get_current_joint_efforts(self): - return self.default_joints - - def set_joint_position_target( - self, qpos, qvel=None, blocking=False - ): # TODO: add blocking - if self._show: - print(f"Setting joint position target to {qpos}") - - def set_joint_velocity_target(self, qvel, blocking=False): - if self._show: - print(f"Setting joint velocity target to {qvel}") - - def set_joint_effort_target(self, qeffort, blocking=False): - if self._show: - print(f"Setting joint effort target to {qeffort}") - - def set_end_effector_value(self, value): - if self._show: - print(f"Setting end effector value to {value}") - - def get_end_effector_value(self): - return [self.end_effector_open] - - -try: - import rospy - from sensor_msgs.msg import JointState - from std_msgs.msg import Float64MultiArray - import numpy as np - from threading import Thread - - from robot_tools.datar import get_values_by_names -except ImportError as e: - print(f"Error: {e}") - - -class AssembledRosRobot(object): - - def __init__( - self, - states_topic, - arm_action_topic, - gripper_action_topic, - states_num, - default_joints, - dt, - ) -> None: - if rospy.get_name() == "/unnamed": - rospy.init_node("ros_robot_node") - self.dt = dt - self.default_joints = default_joints - self.arm_joint_names = ( - "joint1", - "joint2", - "joint3", - "joint4", - "joint5", - "joint6", - ) - self.gripper_joint_names = ("endleft", "endright") - self.arm_joints_num = len(self.arm_joint_names) - self.all_joints_num = self.arm_joints_num + 1 - self.symmetry = 0.04 - self.end_effector_open = 0 - self.end_effector_close = 0 - - # subscribe to the states topics - assert len(default_joints) == self.all_joints_num - self.action_cmd = { - "arm": default_joints[:-1], - "gripper": self._eef_cmd_convert(default_joints[-1]), - } - self.body_current_data = { - "/observations/qpos": np.random.rand(states_num), - "/observations/qvel": np.random.rand(states_num), - "/observations/effort": np.random.rand(states_num), - "/action": np.random.rand(states_num), - } - self.states_suber = rospy.Subscriber( - states_topic, JointState, self.joint_states_callback - ) - self.arm_cmd_pub = rospy.Publisher( - arm_action_topic, Float64MultiArray, queue_size=10 - ) - self.gripper_cmd_pub = rospy.Publisher( - gripper_action_topic, Float64MultiArray, queue_size=10 - ) - Thread(target=self.publish_action, daemon=True).start() - - def _eef_cmd_convert(self, cmd): - value = cmd * self.symmetry - return [value, -value] - - def joint_states_callback(self, data): - arm_joints_pos = get_values_by_names( - self.arm_joint_names, data.name, data.position - ) - gripper_joints_pos = get_values_by_names( - self.gripper_joint_names, data.name, data.position - ) - gripper_joints_pos = [gripper_joints_pos[0] / self.symmetry] - self.body_current_data["/observations/qpos"] = list(arm_joints_pos) + list( - gripper_joints_pos - ) - arm_joints_vel = get_values_by_names( - self.arm_joint_names, data.name, data.velocity - ) - gripper_joints_vel = get_values_by_names( - self.gripper_joint_names, data.name, data.velocity - ) - gripper_joints_vel = [gripper_joints_vel[0]] - self.body_current_data["/observations/qvel"] = list(arm_joints_vel) + list( - gripper_joints_vel - ) - arm_joints_effort = get_values_by_names( - self.arm_joint_names, data.name, data.effort - ) - gripper_joints_effort = get_values_by_names( - self.gripper_joint_names, data.name, data.effort - ) - gripper_joints_effort = [gripper_joints_effort[0]] - self.body_current_data["/observations/effort"] = list(arm_joints_effort) + list( - gripper_joints_effort - ) - - def publish_action(self): - rate = rospy.Rate(200) - while not rospy.is_shutdown(): - self.arm_cmd_pub.publish(Float64MultiArray(data=self.action_cmd["arm"])) - self.gripper_cmd_pub.publish( - Float64MultiArray(data=self.action_cmd["gripper"]) - ) - rate.sleep() - - def get_current_joint_positions(self): - return self.body_current_data["/observations/qpos"] - - def get_current_joint_velocities(self): - return self.body_current_data["/observations/qvel"] - - def get_current_joint_efforts(self): - return self.body_current_data["/observations/effort"] - - def set_joint_position_target( - self, qpos, qvel=None, blocking=False - ): # TODO: add blocking - self.action_cmd["arm"] = qpos[: self.arm_joints_num] - if len(qpos) == self.all_joints_num: - self.action_cmd["gripper"] = self._eef_cmd_convert( - qpos[self.arm_joints_num] - ) - - def set_target_joint_q(self, qpos, qvel=None, blocking=False): - self.set_joint_position_target(qpos, qvel, blocking) - - def set_target_end(self, cmd): - self.action_cmd["gripper"] = self._eef_cmd_convert(cmd) - - def set_end_effector_value(self, value): - self.set_target_end(value) - - def set_joint_velocity_target(self, qvel, blocking=False): - print("Not implemented yet") - - def set_joint_effort_target(self, qeffort, blocking=False): - print("Not implemented yet") diff --git a/test_fake.ipynb b/test_fake.ipynb new file mode 100644 index 0000000..69471b5 --- /dev/null +++ b/test_fake.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:habitats.fake.common:Reset FakeRGBCam: [[[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " ...\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]]\n", + "DEBUG:habitats.fake.common:Get FakeRGBCam state: [[[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " ...\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]\n", + "\n", + " [[0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]\n", + " ...\n", + " [0 0 0]\n", + " [0 0 0]\n", + " [0 0 0]]]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "habitats.fake.common\n", + "(480, 640, 3)\n" + ] + } + ], + "source": [ + "from habitats.fake.cameras.rgb_cam import FakeRGBCam, FakeRGBCamConfig\n", + "import logging\n", + "\n", + "logging.basicConfig(level=logging.DEBUG)\n", + "config = FakeRGBCamConfig()\n", + "fake_rgb_cam = FakeRGBCam(config)\n", + "print(fake_rgb_cam.state.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from robots.fake_robots.arm import FakeArmPos, FakeArmConfig\n", + "import logging\n", + "\n", + "logging.basicConfig(level=logging.DEBUG)\n", + "\n", + "fake_arm = FakeArmPos(FakeArmConfig())\n", + "fake_arm.step((1, 2, 3, 4, 5, 6))\n", + "fake_arm.state\n", + "fake_arm.reset()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aloha", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/utils.py b/utils.py index ecf4b43..e5e84b1 100644 --- a/utils.py +++ b/utils.py @@ -487,3 +487,7 @@ def check_all_gpus_idle(utilization_threshold=10) -> Tuple[List[int], int]: print(f"Error occurred: {e}") return [] +def merge_custom(): + #TODO: merge custom folder + # subprocess.run(["cp", "-r", "custom/robots"], "robots") + pass \ No newline at end of file