Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions envs/common_env.py → envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,29 @@
from typing import List
from einops import rearrange
import dm_env
from robots.common_robot import AssembledRobot
from robots.common import Robot
import logging

logger = logging.getLogger(__name__)


class CommonEnvConfig(object):
def __init__(self) -> None:
self.robots = []
def __init__(self, robots:List[Robot]) -> None:
"""For most real robot environments, the config only needs to specify the robot instances,
which have all sensors' configurations."""
self.robots = robots

def __post_init__(self):
assert self.robots, "There should be at least one robot in the environment."


class CommonEnv:
"""
An environment is a combination of robots, scenes and objects. It should be able to reset and step.
The environment will return observations based on the state of the robot, the position of the sensors, and the current scene and object conditions. And for RL and data collection, it should also send rewards and done signals.
The environment will return observations based on the state of the robots, the position of the sensors, and the current scene and object conditions. And for RL and data collection, it should also send rewards and done signals to Gaia.
"""
def __init__(self, *args, **kwargs) -> None:

def __init__(self, config: CommonEnvConfig) -> None:
raise NotImplementedError

def reset(self) -> dm_env.TimeStep:
Expand All @@ -27,9 +36,6 @@ def reset(self) -> dm_env.TimeStep:
def step(self, action) -> dm_env.TimeStep:
raise NotImplementedError

def get_reward(self):
raise NotImplementedError


def get_image(ts: dm_env.TimeStep, camera_names, mode=0):
# TODO: remove this function
Expand Down
49 changes: 23 additions & 26 deletions envs/make_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List
from robots.common_robot import AssembledRobot
from robots.common import Robot


def make_environment(env_config):
Expand All @@ -22,47 +22,44 @@ def make_environment(env_config):
), "The length of start_joint should be equal to joint_num or joint_num*robot_num"
print(f"Start joint: {start_joint}")

robot_instances: List[AssembledRobot] = []
robot_instances: List[Robot] = []
if "airbot_play" in robot_name:
# set up can
# from utils import CAN_Tools
import airbot

from robots.airbots.airbot_play.airbot_play import AIRBOTPlayPos, AIRBOTPlayConfig
vel = 2.0
for i in range(robot_num):
# if 'v' not in can:
# if not CAN_Tools.check_can_status(can):
# success, error = CAN_Tools.activate_can_interface(can, 1000000)
# if not success: raise Exception(error)
airbot_player = airbot.create_agent(
"/usr/share/airbot_models/airbot_play_with_gripper.urdf",
"down",
can_buses[i],
vel,
eef_mode[i],
bigarm_type[i],
forearm_type[i],
airbot_player = AIRBOTPlayPos(
AIRBOTPlayConfig(
model_path="/usr/share/airbot_models/airbot_play_with_gripper.urdf",
gravity_mode="down",
can_bus=can_buses[i],
vel=vel,
eef_mode=eef_mode,
bigarm_type=bigarm_type,
forearm_type=forearm_type,
joint_vel=6.0,
dt=25,
)
)
robot_instances.append(
AssembledRobot(
Robot(
airbot_player,
1 / fps,
start_joint[joint_num * i : joint_num * (i + 1)],
)
)
elif "fake" in robot_name or "none" in robot_name:
from robots.common_robot import AssembledFakeRobot
from robots.airbots.airbot_play.airbot_play import AIRBOTPlayPosFake

if check_images:
AssembledFakeRobot.real_camera = True
AIRBOTPlayPosFake.real_camera = True
for i in range(robot_num):
robot_instances.append(
AssembledFakeRobot(
AIRBOTPlayPosFake(
1 / fps, start_joint[joint_num * i : joint_num * (i + 1)]
)
)
elif "ros" in robot_name:
from robots.common_robot import AssembledRosRobot
from robots.airbots.airbot_play.airbot_play_ros1 import AIRBOTPlayPos
import rospy

rospy.init_node("replay_episodes")
Expand All @@ -72,7 +69,7 @@ def make_environment(env_config):
gripper_action_topic = f"{namespace}/gripper_group_position_controller/command"
for i in range(robot_num):
robot_instances.append(
AssembledRosRobot(
AIRBOTPlayPos(
states_topic,
arm_action_topic,
gripper_action_topic,
Expand All @@ -82,10 +79,10 @@ def make_environment(env_config):
)
)
elif "mmk" in robot_name:
from robots.common_robot import AssembledMmkRobot
from robots.airbots.airbot_kits.airbot_mmk import AIRBOTMMK2, AIRBOTMMK2Config

for i in range(robot_num):
robot_instances.append(AssembledMmkRobot())
robot_instances.append(AIRBOTMMK2())
elif robot_name == "none":
print("No direct robot is used")
else:
Expand Down
Empty file added habitats/README.md
Empty file.
41 changes: 41 additions & 0 deletions habitats/common/creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import dataclass
from typing import List
from importlib import import_module


@dataclass
class CommonConfig(object):
module_path = None
class_name = None
instance_name = None
default_act = None

def __post_init__(self):
assert self.default_act is not None, "default_act must be set"
assert self.instance_name is not None, "instance_name must be set"
assert self.module_path is not None, "class_path must be set"
assert self.class_name is not None, "class_name must be set"


class Configer(object):
@staticmethod
def config2dict(config):
return {key: value for key, value in config.__dict__.items()}

@staticmethod
def config2tuple(config):
return tuple(config.__dict__.values())


class Creator(object):
@staticmethod
def instancer(configs: List[CommonConfig]) -> list:
instances = []
for config in configs:
module = import_module(config.module_path)
instances.append(getattr(module, config.class_name)(config))
return instances

@classmethod
def give_eyes(cls, robot, configs):
setattr(robot, "eyes", cls.instancer(configs))
6 changes: 6 additions & 0 deletions habitats/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def get_values_by_names(sub_names: tuple, all_names: tuple, all_values: tuple) -> tuple:
"""根据子名称列表获取所有值列表中对应的值列表,返回子值列表"""
sub_values = [0.0 for _ in range(len(sub_names))]
for i, name in enumerate(sub_names):
sub_values[i] = all_values[all_names.index(name)]
return tuple(sub_values)
25 changes: 25 additions & 0 deletions habitats/fake/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Optional
import logging
from habitats.common.creator import CommonConfig

logger = logging.getLogger(__name__)


class FakeCommon(object):
def __init__(self, config: CommonConfig) -> None:
self.reset(config)

def reset(self, config: Optional[CommonConfig] = None):
if config is not None:
self.config = config
self._state = self.config.default_act
logger.debug(f"Reset {self.config.instance_name}: {self._state}")

def step(self, action):
self._state = action
logger.debug(f"Step {self.config.instance_name}: {self._state}")

@property
def state(self):
logger.debug(f"Get {self.config.instance_name} state: {self._state}")
return self._state
Empty file added habitats/ros/ros1/moveit1.py
Empty file.
13 changes: 8 additions & 5 deletions robots/ros_robots/ros1_robot.py → habitats/ros/ros1/topicor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from sensor_msgs.msg import JointState
from geometry_msgs.msg import Pose, PoseStamped
from ros_tools import Lister
from convert_all import flatten_dict
from ros_robot_config import EEF_POSE_POSITION, EEF_POSE_ORIENTATION, ACTIONS_TOPIC_CONFIG, OBSERVATIONS_TOPIC_CONFIG, EXAMPLE_CONFIG
from data_process.convert_all import flatten_dict
# TODO: add these to the config param, or configure the path
from habitats.ros.common.robot_config import EEF_POSE_POSITION, EEF_POSE_ORIENTATION, ACTIONS_TOPIC_CONFIG, OBSERVATIONS_TOPIC_CONFIG, EXAMPLE_CONFIG


class AssembledROS1Robot(object):
class ROS1Interface(object):
"""Use the keys and values in config as the keys to get all the configuration"""

# TODO: make this a common class for all ROS interfaces include both low dim and high dim data, such as images, point clouds, etc.
def __init__(self, config: Dict[str, dict] = None) -> None:
self.params = config["param"]
self.actions_dim = flatten_dict(self.params["actions_dim"])
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(self, config: Dict[str, dict] = None) -> None:

def _target_cmd_pub_thread(self):
"""Publish thread for all publishers"""
# TODO: change this to a timer
rate = rospy.Rate(self.params["control_freq"])
while not rospy.is_shutdown():
for key, pub in self.action_pubs.items():
Expand All @@ -99,6 +101,7 @@ def _target_cmd_pub_thread(self):

def _current_state_callback(self, data, args):
"""Callback function used for all subcribers"""
# TODO: for joint states, support the case where the joint names are not in a specific order
key, lister, preprocess = args
self.current_data[key] = preprocess(lister(data))
rospy.logdebug(f"Current data: {self.current_data[key]}")
Expand Down Expand Up @@ -183,7 +186,7 @@ def reset(self) -> list:
if __name__ == "__main__":

rospy.init_node("test_mmk")
ros1_robot = AssembledROS1Robot(EXAMPLE_CONFIG)
ros1_robot = ROS1Interface(EXAMPLE_CONFIG)
ros1_robot.wait_for_current_states()
current = ros1_robot.get_current_states()
print("Current states:", current)
Expand Down
Loading