diff --git a/examples/box_picking_drq/drq_policy.py b/examples/box_picking_drq/drq_policy.py new file mode 100644 index 00000000..7b9a7219 --- /dev/null +++ b/examples/box_picking_drq/drq_policy.py @@ -0,0 +1,660 @@ +#!/usr/bin/env python3 +import copy +import time +from functools import partial +import jax +import jax.numpy as jnp +import numpy as np +import pynput +import threading +import tqdm +from absl import app, flags +from flax.training import checkpoints +from datetime import datetime + +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics + +from serl_launcher.agents.continuous.drq import DrQAgent +from serl_launcher.common.evaluation import evaluate +from serl_launcher.utils.timer_utils import Timer +from serl_launcher.wrappers.chunking import ChunkingWrapper +from serl_launcher.utils.sampling_utils import TemporalActionEnsemble +from serl_launcher.utils.train_utils import ( + print_agent_params, + parameter_overview, +) + +from agentlace.trainer import TrainerServer, TrainerClient +from agentlace.data.data_store import QueuedDataStore + +from serl_launcher.utils.launcher import ( + make_voxel_drq_agent, + make_trainer_config, + make_wandb_logger, +) +from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore +from serl_launcher.wrappers.serl_obs_wrappers import ( + SERLObsWrapper, + ScaleObservationWrapper, +) +from serl_launcher.wrappers.observation_statistics_wrapper import ( + ObservationStatisticsWrapper, +) +from ur_env.envs.relative_env import RelativeFrame +from ur_env.envs.wrappers import ( + SpacemouseIntervention, + Quat2MrpWrapper, + ObservationRotationWrapper, +) + +import ur_env + +# used to debug nan errors (also in jit-ed functions) +# jax.config.update("jax_debug_nans", True) + +devices = jax.local_devices() +num_devices = len(devices) +sharding = jax.sharding.PositionalSharding(devices) + +FLAGS = flags.FLAGS + +flags.DEFINE_string("env", "box_picking_camera_env", "Name of environment.") +flags.DEFINE_string("agent", "drq", "Name of agent.") +flags.DEFINE_string( + "exp_name", "box picking drq", "Name of the experiment for wandb logging." +) +flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") +flags.DEFINE_string( + "camera_mode", "rgb", "Camera mode, one of (rgb, depth, both, pointcloud)" +) + +flags.DEFINE_integer("seed", 1, "Random seed.") +flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") +flags.DEFINE_integer("utd_ratio", 4, "UTD ratio.") + +flags.DEFINE_string( + "state_mask", + "all", + "if all the states should be considered, see serl_launcher/common/encoding for more info", +) +flags.DEFINE_string("encoder_type", "voxnet-pretrained", "Encoder type.") +flags.DEFINE_integer( + "encoder_bottleneck_dim", 128, "bottleneck dimension of the encoder" +) +flags.DEFINE_multi_string( + "encoder_kwargs", None, "Encoder kwargs in the form ['dict key', 'dict value']" +) +flags.DEFINE_bool( + "enable_obs_rotation_wrapper", + False, + "Whether to enable observation rotation wrapper (train in one quaternion)", +) +flags.DEFINE_bool( + "enable_temporal_ensemble_sampling", + False, + "Whether to enable sampling the action from a temporal ensemble: action = 0.5*a0 + 0.3*a-1 + 0.2*a-2 + 0.1*a-3", +) + +flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") +flags.DEFINE_integer( + "replay_buffer_capacity", 10000, "Replay buffer capacity." +) # quite low to forget demo trajectories + +flags.DEFINE_integer("random_steps", 0, "Sample random actions for this many steps.") +flags.DEFINE_integer("training_starts", 0, "Training starts after this step.") +flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.") + +flags.DEFINE_integer("log_period", 10, "Logging period.") +flags.DEFINE_integer("eval_period", 1000, "Evaluation period in seconds") +flags.DEFINE_integer("eval_n_trajs", 10, "Number of trajectories for evaluation.") + +# flag to indicate if this is a leaner or a actor +flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") +flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") +flags.DEFINE_string("ip", "localhost", "IP address of the learner.") +flags.DEFINE_string("demo_path", None, "Path to the demo data.") +flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") +flags.DEFINE_string( + "checkpoint_path", + "/home/nico/real-world-rl/serl/examples/box_picking_drq/checkpoints", + "Path to save checkpoints.", +) + +flags.DEFINE_integer( + "eval_checkpoint_step", 0, "evaluate the policy from ckpt at this step" +) +flags.DEFINE_string( + "log_rlds_path", + "/home/nico/real-world-rl/serl/examples/box_picking_drq/rlds", + "Path to save RLDS logs.", +) +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") + +flags.DEFINE_boolean( + "debug", False, "Debug mode." +) # debug mode will disable wandb logging + + +def print_green(x): + return print("\033[92m {}\033[00m".format(x)) + + +PAUSE_EVENT_FLAG = threading.Event() +PAUSE_EVENT_FLAG.clear() # clear() to continue the actor/learner loop, set() to pause + + +def pause_callback(key): + """Callback for when a key is pressed""" + global PAUSE_EVENT_FLAG + try: + # chosen a rarely used key to avoid conflicts. this listener is always on, even when the program is not in focus + if not PAUSE_EVENT_FLAG.is_set() and key == pynput.keyboard.Key.pause: + print("Requested pause training") + # set the PAUSE FLAG to pause the actor/learner loop + PAUSE_EVENT_FLAG.set() + except AttributeError: + # print(f'{key} pressed') + pass + + +listener = pynput.keyboard.Listener( + on_press=pause_callback +) # to enable keyboard based pause +listener.start() + + +############################################################################## + + +def actor(agent: DrQAgent, data_store, env, sampling_rng): + """ + This is the actor loop, which runs when "--actor" is set to True. + """ + global PAUSE_EVENT_FLAG + + if FLAGS.eval_checkpoint_step: + wandb_logger = make_wandb_logger( + project="paper_evaluation_unseen" + if "eval" in FLAGS.env + else "paper_evaluation", + description=FLAGS.exp_name or FLAGS.env, + debug=FLAGS.debug, + ) + success_counter = 0 + + ckpt = checkpoints.restore_checkpoint( + FLAGS.checkpoint_path, + agent.state, + step=FLAGS.eval_checkpoint_step, + ) + agent = agent.replace(state=ckpt) + action_ensemble = TemporalActionEnsemble( + activated=FLAGS.enable_temporal_ensemble_sampling + ) + + trajectories = [] + traj_infos = [] + for episode in range(FLAGS.eval_n_trajs): + trajectory = [] + obs, _ = env.reset() + done = False + action_ensemble.reset() + start_time = time.time() + + while not done: + actions = agent.sample_actions( + observations=jax.device_put(obs), + argmax=True, + ) + actions = np.asarray(jax.device_get(actions)) + + ensembled_action = action_ensemble.sample( + actions + ) # will return actions if not activated + next_obs, reward, done, truncated, info = env.step(ensembled_action) + transition = dict( + observations=obs[ + "state" + ].copy(), # do not save voxel grid or images + actions=ensembled_action, + next_observations=next_obs["state"].copy(), + rewards=reward, + masks=1.0 - done, + dones=done, + ) + trajectory.append(transition) + obs = next_obs + + if done or truncated: + success_counter += reward > 50.0 + dt = time.time() - start_time + running_reward = np.sum( + np.asarray([t["rewards"] for t in trajectory]) + ) + running_reward = max(running_reward, -100.0) # -100 min value + + print(f"{success_counter}/{episode + 1} ", end=" ") + print(f"time: {dt:.3f}s running_rew: {running_reward:.2f}") + + trajectories.append( + {"traj": trajectory, "time": dt, "success": (reward > 50.0)} + ) + infos = { + "running_reward": running_reward, + "time": dt, + "success_rate": float(reward > 50.0), + "action_cost": np.linalg.norm( + np.asarray([t["actions"] for t in trajectory]), + axis=1, + ord=2, + ).mean(), + } + traj_infos.append(infos) + wandb_logger.log(infos, step=episode) + + # if pause event is requested, pause the actor + if PAUSE_EVENT_FLAG.is_set(): + print("Actor eval loop interrupted") + response = input("Do you want to continue (c), or exit (e)? ") + if response == "c": + # update PAUSE FLAG to continue training + PAUSE_EVENT_FLAG.clear() + print("Continuing") + else: + print("Stopping actor eval") + break + + traj_infos = { + k: [d[k] for d in traj_infos] for k in traj_infos[0] + } # list of dicts to dict of lists + mean_infos = {"mean_" + key: np.mean(val) for key, val in traj_infos.items()} + wandb_logger.log(mean_infos) + for key, value in mean_infos.items(): + print(f"{key}: {value:.3f}") + + filename = f"trajectories {'temp_ens' if action_ensemble.is_activated() else ''} {datetime.now().strftime('%m-%d %H%M')}.pkl" + with open(filename, "wb") as f: + import pickle + + pickle.dump(trajectories, f) + return # after done eval, return and exit + + client = TrainerClient( + "actor_env", + FLAGS.ip, + make_trainer_config(), + data_store, + wait_for_server=True, + ) + + # Function to update the agent with new params + def update_params(params): + nonlocal agent + agent = agent.replace(state=agent.state.replace(params=params)) + + client.recv_network_callback(update_params) + + obs, _ = env.reset() + + # training loop + timer = Timer() + running_return = 0.0 + + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): + timer.tick("total") + + with timer.context("sample_actions"): + if step < FLAGS.random_steps: + actions = env.action_space.sample() + else: + sampling_rng, key = jax.random.split(sampling_rng) + actions = agent.sample_actions( + observations=jax.device_put(obs), + seed=key, + deterministic=False, + ) + actions = np.asarray(jax.device_get(actions)) + + # Step environment + with timer.context("step_env"): + next_obs, reward, done, truncated, info = env.step(actions) + + # override the action with the intervention action + if "intervene_action" in info: + actions = info.pop("intervene_action") + + reward = np.asarray(reward, dtype=np.float32) + info = np.asarray(info) + running_return = running_return * 0.99 + reward + transition = dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done, + ) + data_store.insert(transition) + + obs = next_obs + if done or truncated: + stats = {"train": info} # send stats to the learner to log + client.request("send-stats", stats) + print(f"running return: {running_return}") + running_return = 0.0 + obs, _ = env.reset() + + if step % FLAGS.steps_per_update == 0: + client.update() + + timer.tock("total") + + if FLAGS.eval_period and step % FLAGS.eval_period == 0 and step: + with timer.context("eval"): + evaluate_info = evaluate( + policy_fn=partial(agent.sample_actions, argmax=True), + env=env, + num_episodes=FLAGS.eval_n_trajs, + ) + stats = {"eval": evaluate_info} + client.request("send-stats", stats) + + if step % FLAGS.log_period == 0: + stats = {"timer": timer.get_average_times()} + client.request("send-stats", stats) + + if PAUSE_EVENT_FLAG.is_set(): + print_green("Actor loop interrupted") + response = input( + "Do you want to continue (c), save replay buffer and exit (s) or simply exit (e)? " + ) + if response == "c": + print("Continuing") + PAUSE_EVENT_FLAG.clear() + else: + if response == "s": + print("Saving replay buffer") + data_store.save( + "replay_buffer_actor.npz" + ) # not yet supported for QueuedDataStore + else: + print("Replay buffer not saved") + print("Stopping actor client") + client.stop() + break + + +############################################################################## + + +def learner(rng, agent: DrQAgent, replay_buffer, wandb_logger=None): + """ + The learner loop, which runs when "--learner" is set to True. + """ + # To track the step in the training loop + update_steps = 0 + global PAUSE_EVENT_FLAG + + def stats_callback(type: str, payload: dict) -> dict: + """Callback for when server receives stats request.""" + assert type == "send-stats", f"Invalid request type: {type}" + if wandb_logger is not None: + wandb_logger.log(payload, step=update_steps) + return {} # not expecting a response + + # Create server + server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server.register_data_store("actor_env", replay_buffer) + server.start(threaded=True) + + # Loop to wait until replay_buffer is filled + pbar = tqdm.tqdm( + total=FLAGS.training_starts, + initial=len(replay_buffer), + desc="Filling up replay buffer", + position=0, + leave=True, + ) + while len(replay_buffer) < FLAGS.training_starts: + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + time.sleep(1) + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + pbar.close() + + # send the initial network to the actor + server.publish_network(agent.state.params) + print_green("sent initial network to actor") + + replay_iterator = replay_buffer.get_iterator( + sample_args={ + "batch_size": FLAGS.batch_size, + "pack_obs_and_next_obs": True, + }, + device=sharding.replicate(), + ) + + # wait till the replay buffer is filled with enough data + timer = Timer() + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"): + timer.tick("learner_total") + + # run n-1 critic updates and 1 critic + actor update. + # This makes training on GPU faster by reducing the large batch transfer time from CPU to GPU + for critic_step in range(FLAGS.utd_ratio - 1): + with timer.context("sample_replay_buffer"): + batch = next(replay_iterator) + + with timer.context("train_critics"): + agent, critics_info = agent.update_critics( + batch, + ) + + with timer.context("train"): + batch = next(replay_iterator) + agent, update_info = agent.update_high_utd(batch, utd_ratio=1) + + timer.tock("learner_total") + + # publish the updated network + if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) + server.publish_network(agent.state.params) + + if update_steps % FLAGS.log_period == 0 and wandb_logger: + wandb_logger.log(update_info, step=update_steps) + wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps) + wandb_logger.log({"replay_buffer_size": len(replay_buffer)}) + + update_steps += 1 + + if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0: + assert FLAGS.checkpoint_path is not None + checkpoints.save_checkpoint( + FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100 + ) + + if PAUSE_EVENT_FLAG.is_set(): + print("Learner loop interrupted") + response = input( + "Do you want to continue (c), save training state and exit (s) or simply exit (e)? " + ) + if "c" in response: + print("Continuing") + PAUSE_EVENT_FLAG.clear() + else: + if response == "s": + print("Saving learner state") + agent_ckpt = checkpoints.save_checkpoint( + FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100 + ) + replay_buffer.save( + "replay_buffer_learner.npz" + ) # not yet supported for QueuedDataStore + # TODO: save other parts of training state + else: + print("Training state not saved") + print("Stopping learner client") + break + + server.stop() + parameter_overview(agent) # print end state + + +############################################################################## + + +def main(_): + assert FLAGS.batch_size % num_devices == 0 + if FLAGS.checkpoint_path.split("/")[-1] == "checkpoints": + FLAGS.checkpoint_path = ( + FLAGS.checkpoint_path + + " " + + FLAGS.exp_name + + " " + + datetime.now().strftime("%m%d-%H:%M") + ) + + # seed + rng = jax.random.PRNGKey(FLAGS.seed) + + # create env and load dataset + env = gym.make( + FLAGS.env, + camera_mode=FLAGS.camera_mode, + fake_env=FLAGS.learner, + max_episode_length=FLAGS.max_traj_length, + ) + # if FLAGS.actor: + # env = SpacemouseIntervention(env) + env = RelativeFrame(env) + env = Quat2MrpWrapper(env) + env = ScaleObservationWrapper( + env + ) # scale obs space (after quat2mrp, but before serlobs) + env = ObservationStatisticsWrapper(env) + if FLAGS.enable_obs_rotation_wrapper: + env = ObservationRotationWrapper(env) + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + env = RecordEpisodeStatistics(env) + + image_keys = [key for key in env.observation_space.keys() if key != "state"] + print(f"image keys: {image_keys}") + + rng, sampling_rng = jax.random.split(rng) + + # assert FLAGS.encoder_kwargs is None or len(FLAGS.encoder_kwargs) % 2 == 0 + encoder_kwargs = { + "bottleneck_dim": FLAGS.encoder_bottleneck_dim, + **( + dict(zip(*[iter(FLAGS.encoder_kwargs)] * 2)) if FLAGS.encoder_kwargs else {} + ), + } + encoder_kwargs = { + k: (int(v) if str(v).isdigit() else v) for k, v in encoder_kwargs.items() + } + + agent: DrQAgent = make_voxel_drq_agent( + seed=FLAGS.seed, + sample_obs=env.observation_space.sample(), + sample_action=env.action_space.sample(), + image_keys=image_keys, + encoder_type=FLAGS.encoder_type, + state_mask=FLAGS.state_mask, + encoder_kwargs=encoder_kwargs, + ) + + # replicate agent across devices + # need the jnp.array to avoid a bug where device_put doesn't recognize primitives + agent: DrQAgent = jax.device_put( + jax.tree_map(jnp.array, agent), sharding.replicate() + ) + + # print useful info + print_agent_params(agent, image_keys) + parameter_overview(agent) + + if FLAGS.enable_obs_rotation_augmentation: + print("Batch Observation Rotation enabled!") + assert ( + not FLAGS.enable_obs_rotation_augmentation + or not FLAGS.enable_obs_rotation_wrapper + ) # both is pointless + + def create_replay_buffer_and_wandb_logger(): + replay_buffer = MemoryEfficientReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=FLAGS.replay_buffer_capacity, + image_keys=image_keys, + ) + # set up wandb and logging + wandb_logger = make_wandb_logger( + project="paper_experiments", + description=FLAGS.exp_name or FLAGS.env, + debug=FLAGS.debug, + ) + return replay_buffer, wandb_logger + + if FLAGS.learner: + sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) + replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger() + + import pickle as pkl + + with open(FLAGS.demo_path, "rb") as f: + trajs = pkl.load(f) + + # check which observations can be ignored for this run + to_pop = [] + for obs_name in [i for i in trajs[0]["observations"].keys()]: + if obs_name not in env.observation_space.spaces: + to_pop.append(obs_name) + print(f"ignored {to_pop} observation in the demo trajectories") + + for traj in trajs: + for obs_name in to_pop: + traj["observations"].pop(obs_name) + traj["next_observations"].pop(obs_name) + + replay_buffer.insert(traj) + print(f"replay buffer size: {len(replay_buffer)}") + + # learner loop + print_green("starting learner loop") + try: + learner( + sampling_rng, + agent, + replay_buffer=replay_buffer, + wandb_logger=wandb_logger, + ) + except KeyboardInterrupt: + print_green("leraner loop interrupted") + finally: + # Wrap up the learner loop + env.close() + print("Learner loop finished") + + elif FLAGS.actor: + sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) + data_store = QueuedDataStore(50000) # the queue size on the actor + + # actor loop + print_green("starting actor loop") + try: + actor(agent, data_store, env, sampling_rng) + print_green("actor loop finished") + except (KeyboardInterrupt, RuntimeError) as e: + print_green("actor loop interrupted: " + str(e)) + finally: + env.close() + + else: + raise NotImplementedError("Must be either a learner or an actor") + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/box_picking_drq/record_demo.py b/examples/box_picking_drq/record_demo.py new file mode 100644 index 00000000..e9d8afac --- /dev/null +++ b/examples/box_picking_drq/record_demo.py @@ -0,0 +1,131 @@ +import gym +from tqdm import tqdm +import numpy as np +import copy +import pickle as pkl +import datetime +import os +import threading +from pynput import keyboard + +from ur_env.envs.relative_env import RelativeFrame +from ur_env.envs.wrappers import ( + SpacemouseIntervention, + Quat2MrpWrapper, + ObservationRotationWrapper, +) + +from serl_launcher.wrappers.serl_obs_wrappers import ( + SERLObsWrapper, + ScaleObservationWrapper, +) +from serl_launcher.wrappers.chunking import ChunkingWrapper + +import ur_env + +exit_program = threading.Event() + + +def on_space(key, info_dict): + if key == keyboard.Key.space: + for key, item in info_dict.items(): + print(f"{key}: {item}", end=" ") + print() + + +def on_esc(key): + if key == keyboard.Key.esc: + exit_program.set() + + +if __name__ == "__main__": + env = gym.make( + "box_picking_camera_env", + camera_mode="pointcloud", + max_episode_length=100, + ) + env = SpacemouseIntervention(env) + env = RelativeFrame(env) + env = Quat2MrpWrapper(env) + env = ScaleObservationWrapper(env) + # env = ObservationRotationWrapper(env) # if it should be enabled + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + obs, _ = env.reset() + + transitions = [] + success_count = 0 + success_needed = 20 + total_count = 0 + pbar = tqdm(total=success_needed) + + info_dict = { + "state": env.unwrapped.curr_pos, + "gripper_state": env.unwrapped.gripper_state, + "force": env.unwrapped.curr_force, + "reset_pose": env.unwrapped.curr_reset_pose, + } + listener_1 = keyboard.Listener( + daemon=True, on_press=lambda event: on_space(event, info_dict=info_dict) + ) + listener_1.start() + + listener_2 = keyboard.Listener(on_press=on_esc, daemon=True) + listener_2.start() + + uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + file_name = f"box_picking_{success_needed}_demos_{uuid}.pkl" + file_dir = os.path.dirname(os.path.realpath(__file__)) # same dir as this script + file_path = os.path.join(file_dir, file_name) + + if not os.access(file_dir, os.W_OK): + raise PermissionError(f"No permission to write to {file_dir}") + + try: + running_reward = 0.0 + while success_count < success_needed: + if exit_program.is_set(): + raise KeyboardInterrupt # stop program, but clean up before + + next_obs, rew, done, truncated, info = env.step(action=np.zeros((7,))) + actions = info["intervene_action"] + + transition = copy.deepcopy( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=rew, + masks=1.0 - done, + dones=done, + ) + ) + transitions.append(transition) + + obs = next_obs + running_reward += rew + + if done or truncated: + success_count += int(rew > 0.99) + total_count += 1 + print( + f"{rew}\tGot {success_count} successes of {total_count} trials. {success_needed} successes needed." + ) + pbar.update(int(rew > 0.99)) + obs, _ = env.reset() + print("Reward total:", running_reward) + running_reward = 0.0 + + with open(file_path, "wb") as f: + pkl.dump(transitions, f) + print(f"saved {success_needed} demos to {file_path}") + + except KeyboardInterrupt as e: + print(f"\nProgram was interrupted, cleaning up... ", e.__str__()) + + finally: + pbar.close() + env.close() + listener_1.stop() + listener_2.stop() diff --git a/examples/box_picking_drq/run_actor.sh b/examples/box_picking_drq/run_actor.sh new file mode 100755 index 00000000..be65622f --- /dev/null +++ b/examples/box_picking_drq/run_actor.sh @@ -0,0 +1,20 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +python drq_policy.py "$@" \ + --actor \ + --env box_picking_camera_env \ + --max_traj_length 100 \ + --exp_name=box_picking \ + --camera_mode pointcloud \ + --seed 1 \ + --max_steps 20000 \ + --random_steps 0 \ + --training_starts 500 \ + --utd_ratio 8 \ + --batch_size 128 \ + --eval_period 1000 \ + --encoder_type voxnet-pretrained \ + --state_mask all \ + --encoder_bottleneck_dim 128 \ +# --enable_obs_rotation_wrapper \ +# --debug diff --git a/examples/box_picking_drq/run_evaluation.sh b/examples/box_picking_drq/run_evaluation.sh new file mode 100644 index 00000000..929e9618 --- /dev/null +++ b/examples/box_picking_drq/run_evaluation.sh @@ -0,0 +1,18 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +python drq_policy.py "$@" \ + --actor \ + --env box_picking_camera_env \ + --exp_name=drq_evaluation \ + --camera_mode pointcloud \ + --batch_size 128 \ + --max_traj_length 100 \ + --checkpoint_path "checkpoint folder path here"\ + --eval_checkpoint_step 10000 \ + --eval_n_trajs 20 \ + \ + --encoder_type voxnet-pretrained \ + --state_mask all \ + --encoder_bottleneck_dim 128 \ +# --enable_obs_rotation_wrapper \ +# --debug diff --git a/examples/box_picking_drq/run_learner.sh b/examples/box_picking_drq/run_learner.sh new file mode 100755 index 00000000..21f4d65f --- /dev/null +++ b/examples/box_picking_drq/run_learner.sh @@ -0,0 +1,22 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +python drq_policy.py "$@" \ + --learner \ + --env box_picking_camera_env \ + --exp_name=ox_picking \ + --camera_mode pointcloud \ + --max_traj_length 100 \ + --seed 1 \ + --max_steps 25000 \ + --random_steps 0 \ + --training_starts 500 \ + --utd_ratio 8 \ + --batch_size 128 \ + --eval_period 20000 \ + --checkpoint_period 1000 \ + --encoder_type voxnet-pretrained \ + --state_mask all \ + --encoder_bottleneck_dim 128 \ + --demo_path "demo path here *.pkl" \ +# --enable_obs_rotation_wrapper \ +# --debug diff --git a/serl_launcher/serl_launcher/agents/continuous/drq.py b/serl_launcher/serl_launcher/agents/continuous/drq.py index 9b2b690b..518f17ed 100644 --- a/serl_launcher/serl_launcher/agents/continuous/drq.py +++ b/serl_launcher/serl_launcher/agents/continuous/drq.py @@ -10,14 +10,22 @@ from serl_launcher.agents.continuous.sac import SACAgent from serl_launcher.common.common import JaxRLTrainState, ModuleDict, nonpytree_field -from serl_launcher.common.encoding import EncodingWrapper +from serl_launcher.common.encoding import ( + EncodingWrapper, + MaskedEncodingWrapper, + create_state_mask, +) from serl_launcher.common.optimizers import make_optimizer from serl_launcher.common.typing import Batch, Data, Params, PRNGKey from serl_launcher.networks.actor_critic_nets import Critic, Policy, ensemblize from serl_launcher.networks.lagrange import GeqLagrangeMultiplier from serl_launcher.networks.mlp import MLP +from serl_launcher.vision.voxel_grid_encoders import VoxNet from serl_launcher.utils.train_utils import _unpack, concat_batches -from serl_launcher.vision.data_augmentations import batched_random_crop +from serl_launcher.vision.data_augmentations import ( + batched_random_crop, + batched_random_shift_voxel, +) class DrQAgent(SACAgent): @@ -86,7 +94,7 @@ def create( # Config assert not entropy_per_dim, "Not implemented" if target_entropy is None: - target_entropy = -actions.shape[-1] / 2 + target_entropy = -actions.shape[-1] return cls( state=state, @@ -241,15 +249,203 @@ def create_drq( return agent + @classmethod + def create_voxel_drq( + cls, + rng: PRNGKey, + observations: Data, + actions: jnp.ndarray, + # Model architecture + encoder_type: str = "resnet", + use_proprio: bool = False, + state_mask: str = "all", + critic_network_kwargs: dict = { + "hidden_dims": [256, 256], + }, + policy_network_kwargs: dict = { + "hidden_dims": [256, 256], + }, + policy_kwargs: dict = { + "tanh_squash_distribution": True, + "std_parameterization": "uniform", + }, + encoder_kwargs: dict = {}, + critic_ensemble_size: int = 2, + critic_subsample_size: Optional[int] = None, + temperature_init: float = 1.0, + image_keys: Iterable[str] = ("image",), + **kwargs, + ): + """ + Create a new voxel-based agent. + """ + + policy_network_kwargs["activate_final"] = True + critic_network_kwargs["activate_final"] = True + + if encoder_type == "small": + from serl_launcher.vision.small_encoders import SmallEncoder + + small_encoder = SmallEncoder( + features=(64, 64, 32, 32), + kernel_sizes=(3, 3, 3, 3), + strides=(2, 2, 1, 1), + padding="VALID", + pool_method="spatial_learned_embeddings", + bottleneck_dim=128, + spatial_block_size=8, + name=f"small_encoder", + ) + encoders = {image_key: small_encoder for image_key in image_keys} + elif encoder_type == "resnet": + from serl_launcher.vision.resnet_v1 import resnetv1_configs + + encoders = { + image_key: resnetv1_configs["resnetv1-10"]( + name=f"encoder_{image_key}", **encoder_kwargs + ) + for image_key in image_keys + } + elif encoder_type == "resnet-pretrained": + from serl_launcher.vision.resnet_v1 import ( + PreTrainedResNetEncoder, + resnetv1_configs, + ) + + pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"]( + pre_pooling=True, + name="pretrained_encoder", + ) + + encoders = { + image_key: PreTrainedResNetEncoder( + rng=rng, + pretrained_encoder=pretrained_encoder, + name=f"encoder_{image_key}", + **encoder_kwargs, + ) + for image_key in image_keys + } + elif encoder_type == "resnet-pretrained-18": + # pretrained ResNet18 from pytorch + from serl_launcher.vision.resnet_v1_18 import resnetv1_18_configs + from serl_launcher.vision.resnet_v1 import PreTrainedResNetEncoder + + pretrained_encoder = resnetv1_18_configs["resnetv1-18-frozen"]( + name="pretrained_encoder", + ) + + encoders = { + image_key: PreTrainedResNetEncoder( + rng=rng, + pretrained_encoder=pretrained_encoder, + name=f"encoder_{image_key}", + **encoder_kwargs, + ) + for image_key in image_keys + } + elif encoder_type == "voxnet" or encoder_type == "voxnet-pretrained": + encoders = { + image_key: VoxNet( + bottleneck_dim=encoder_kwargs["bottleneck_dim"], + use_conv_bias=True, + final_activation=nn.tanh, + pretrained=encoder_type == "voxnet-pretrained", + ) + for image_key in image_keys + } + elif encoder_type.lower() == "none": + encoders = None + else: + raise NotImplementedError(f"Unknown encoder type: {encoder_type}") + + state_mask_arr = create_state_mask(state_mask) + print(f"state_mask: {state_mask} {state_mask_arr.astype(jnp.int32)}") + encoder_def = MaskedEncodingWrapper( + encoder=encoders, + use_proprio=use_proprio, + enable_stacking=True, + image_keys=image_keys, + state_mask=state_mask_arr, + ) + + encoders = { + "critic": encoder_def, + "actor": encoder_def, + } + + # Define networks + critic_backbone = partial(MLP, **critic_network_kwargs) + critic_backbone = ensemblize(critic_backbone, critic_ensemble_size)( + name="critic_ensemble" + ) + critic_def = partial( + Critic, encoder=encoders["critic"], network=critic_backbone + )(name="critic") + + policy_def = Policy( + encoder=encoders["actor"], + network=MLP(**policy_network_kwargs), + action_dim=actions.shape[-1], + **policy_kwargs, + name="actor", + ) + + temperature_def = GeqLagrangeMultiplier( + init_value=temperature_init, + constraint_shape=(), + constraint_type="geq", + name="temperature", + ) + + agent = cls.create( + rng, + observations, + actions, + actor_def=policy_def, + critic_def=critic_def, + temperature_def=temperature_def, + critic_ensemble_size=critic_ensemble_size, + critic_subsample_size=critic_subsample_size, + image_keys=image_keys, + **kwargs, + ) + + if encoder_type == "resnet-pretrained": # load pretrained weights for ResNet-10 + from serl_launcher.utils.train_utils import load_resnet10_params + + agent = load_resnet10_params(agent, image_keys) + + if encoder_type == "voxnet-pretrained": + from serl_launcher.utils.train_utils import load_pretrained_VoxNet_params + + agent = load_pretrained_VoxNet_params(agent, image_keys) + + return agent + def data_augmentation_fn(self, rng, observations): + # TODO make it configurable: see https://github.com/rail-berkeley/serl/pull/67 + for pixel_key in self.config["image_keys"]: - observations = observations.copy( - add_or_replace={ - pixel_key: batched_random_crop( - observations[pixel_key], rng, padding=4, num_batch_dims=2 - ) - } - ) + # pointcloud augmentation + if "pointcloud" in pixel_key: + observations = observations.copy( + add_or_replace={ + pixel_key: batched_random_shift_voxel( + observations[pixel_key], rng, padding=3, num_batch_dims=2 + ) + } + ) + + # image augmentation + else: + observations = observations.copy( + add_or_replace={ + pixel_key: batched_random_crop( + observations[pixel_key], rng, padding=4, num_batch_dims=2 + ) + } + ) return observations @partial(jax.jit, static_argnames=("utd_ratio", "pmap_axis")) diff --git a/serl_launcher/serl_launcher/agents/continuous/sac.py b/serl_launcher/serl_launcher/agents/continuous/sac.py index ca933db4..e04733bb 100644 --- a/serl_launcher/serl_launcher/agents/continuous/sac.py +++ b/serl_launcher/serl_launcher/agents/continuous/sac.py @@ -167,9 +167,16 @@ def critic_loss_fn(self, batch, params: Params, rng: PRNGKey): ) chex.assert_shape(target_q, (batch_size,)) - if self.config["backup_entropy"]: + if self.config["backup_entropy"]: # not the same as in original jaxrl_m SAC implementation: https://github.com/dibyaghosh/jaxrl_m/blob/main/examples/mujoco/sac.py temperature = self.forward_temperature() - target_q = target_q - temperature * next_actions_log_probs + # target_q = target_q - temperature * next_actions_log_probs # serl original + target_q = ( + target_q + - self.config["discount"] + * batch["masks"] + * next_actions_log_probs + * temperature + ) # as in jaxrl_m predicted_qs = self.forward_critic( batch["observations"], batch["actions"], rng=rng, grad_params=params @@ -385,7 +392,7 @@ def create( # Config assert not entropy_per_dim, "Not implemented" if target_entropy is None: - target_entropy = -actions.shape[-1] / 2 + target_entropy = -actions.shape[-1] return cls( state=state, diff --git a/serl_launcher/serl_launcher/common/encoding.py b/serl_launcher/serl_launcher/common/encoding.py index 823782d5..c54915b7 100644 --- a/serl_launcher/serl_launcher/common/encoding.py +++ b/serl_launcher/serl_launcher/common/encoding.py @@ -7,6 +7,27 @@ from einops import rearrange, repeat +def create_state_mask(mask_str: str) -> jnp.ndarray: + all = jnp.ones((27,), dtype=jnp.bool) + none = jnp.zeros_like(all) + no_action = all.at[:7].set(False) + gripper = none.at[0 + 7 : 2 + 7].set(True) + no_ForceTorque = all.at[7 + 2 : 7 + 5].set(False).at[7 + 11 : 7 + 14].set(False) + action_only = none.at[:7].set(True) + masks = dict( + all=all, + none=jnp.zeros_like(all), + gripper=gripper, + position_gripper=gripper.at[5:11].set(True), + no_ForceTorque=no_ForceTorque, + no_ForceTorqueAction=jnp.bitwise_and(no_ForceTorque, no_action), + gripper_Zinfo=gripper.at[7 + 7].set(True), + action_only=action_only, + ) + assert mask_str in masks + return masks[mask_str] + + class EncodingWrapper(nn.Module): """ Encodes observations into a single flat encoding, adding additional @@ -72,6 +93,82 @@ def __call__( return encoded +class MaskedEncodingWrapper(nn.Module): + """ + Encodes observations into a single flat encoding, adding additional + functionality for adding proprioception and stopping the gradient. + + Args: + encoder: The encoder network. + use_proprio: Whether to concatenate proprioception (after encoding). + state_mask: Which proprioceptive states to propagate, and which to ignore + """ + + encoder: nn.Module + use_proprio: bool + state_mask: jnp.ndarray + enable_stacking: bool = False + image_keys: Iterable[str] = ("image",) + + @nn.compact + def __call__( + self, + observations: Dict[str, jnp.ndarray], + train=False, + stop_gradient=False, + is_encoded=False, + ) -> jnp.ndarray: + # encode images with encoder + if self.encoder is None: + # project state to embeddings as well + state = observations["state"] + if self.enable_stacking: + # Combine stacking and channels into a single dimension + if len(state.shape) == 2: + state = rearrange(state, "T C -> (T C)") + if len(state.shape) == 3: + state = rearrange(state, "B T C -> B (T C)") + # do not use proprio latent dim + return state + + encoded = [] + for image_key in self.image_keys: + image = observations[image_key] + if not is_encoded: + if self.enable_stacking: + # Combine stacking and channels into a single dimension + if len(image.shape) == 4: + image = rearrange(image, "T H W C -> H W (T C)") + if len(image.shape) == 5: + image = rearrange(image, "B T H W C -> B H W (T C)") + + image = self.encoder[image_key](image, train=train, encode=not is_encoded) + + if stop_gradient: + image = jax.lax.stop_gradient(image) + + encoded.append(image) + + encoded = jnp.concatenate(encoded, axis=-1) + + if self.use_proprio: + # project state to embeddings as well + state = observations["state"] + state = state[..., self.state_mask] # only propagate non-zero mask entries + if state.shape[-1] != 0: + if self.enable_stacking: + # Combine stacking and channels into a single dimension + if len(state.shape) == 2: + state = rearrange(state, "T C -> (T C)") + encoded = encoded.reshape(-1) + if len(state.shape) == 3: + state = rearrange(state, "B T C -> B (T C)") + # do not use proprio latent di + encoded = jnp.concatenate([encoded, state], axis=-1) + + return encoded + + class GCEncodingWrapper(nn.Module): """ Encodes observations and goals into a single flat encoding. Handles all the diff --git a/serl_launcher/serl_launcher/networks/actor_critic_nets.py b/serl_launcher/serl_launcher/networks/actor_critic_nets.py index 189eef82..3746ae82 100644 --- a/serl_launcher/serl_launcher/networks/actor_critic_nets.py +++ b/serl_launcher/serl_launcher/networks/actor_critic_nets.py @@ -62,7 +62,8 @@ def __call__( obs_enc = self.encoder(observations) inputs = jnp.concatenate([obs_enc, actions], -1) - outputs = self.network(inputs, train=train) + outputs = self.network(inputs, train) + # train=train throws: "RuntimeWarning: kwargs are not supported in vmap, so "train" is(are) ignored" if self.init_final is not None: value = nn.Dense( 1, @@ -157,7 +158,7 @@ def ensemblize(cls, num_qs, out_axes=0): return nn.vmap( cls, variable_axes={"params": 0}, - split_rngs={"params": True}, + split_rngs={"params": True, "dropout": True}, in_axes=None, out_axes=out_axes, axis_size=num_qs, diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 782221eb..c37a2620 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -116,6 +116,66 @@ def make_drq_agent( return agent +def make_voxel_drq_agent( + seed, + sample_obs, + sample_action, + image_keys=("image",), + encoder_type="voxnet", + state_mask="all", + encoder_kwargs=None, +): + if encoder_kwargs is None: + encoder_kwargs = dict(bottleneck_dim=128) + + agent = DrQAgent.create_voxel_drq( + jax.random.PRNGKey(seed), + sample_obs, + sample_action, + encoder_type=encoder_type, + use_proprio=True, + state_mask=state_mask, + image_keys=image_keys, + policy_kwargs=dict( + tanh_squash_distribution=True, + std_parameterization="exp", + std_min=1e-5, + std_max=5, + ), + critic_network_kwargs=dict( + activations=nn.tanh, + use_layer_norm=True, + hidden_dims=[256, 256], + dropout_rate=0.1, + ), + policy_network_kwargs=dict( + activations=nn.tanh, + use_layer_norm=True, + hidden_dims=[256, 256], + dropout_rate=0.1, + ), + temperature_init=1e-2, + discount=0.99, # 0.99 + backup_entropy=True, + critic_ensemble_size=10, + critic_subsample_size=2, + encoder_kwargs=encoder_kwargs, + # dict( + # # pooling_method="spatial_learned_embeddings", + # bottleneck_dim=128, + # # num_spatial_blocks=8, + # # num_kp=64, + # ), + actor_optimizer_kwargs={ + "learning_rate": 3e-3, + }, + critic_optimizer_kwargs={ + "learning_rate": 3e-3, + }, + ) + return agent + + def make_vice_agent( seed, sample_obs, diff --git a/serl_launcher/serl_launcher/utils/sampling_utils.py b/serl_launcher/serl_launcher/utils/sampling_utils.py new file mode 100644 index 00000000..9827b2de --- /dev/null +++ b/serl_launcher/serl_launcher/utils/sampling_utils.py @@ -0,0 +1,30 @@ +import numpy as np + + +class TemporalActionEnsemble: + def __init__(self, activated=True, action_shape=(7,), ensemble=None): + if ensemble is None: + ensemble = [0.5, 0.3, 0.2, 0.1] + self.activated = activated + self.ensemble = np.asarray(ensemble) + self.buffer = np.zeros((len(ensemble), action_shape[0])) + + if activated: + print(f"Temporal Action Ensemble enabled: {self.ensemble}") + + def reset(self): + self.buffer[...] = 0.0 + + def sample(self, curr_action: np.ndarray): + if not self.activated: + return curr_action + + curr_action = curr_action.reshape(-1) + assert curr_action.shape[0] == self.buffer.shape[1] + + self.buffer = np.roll(self.buffer, axis=0, shift=1) + self.buffer[0, :] = curr_action + return np.dot(self.ensemble, self.buffer) + + def is_activated(self): + return self.activated diff --git a/serl_launcher/serl_launcher/utils/train_utils.py b/serl_launcher/serl_launcher/utils/train_utils.py index 31037317..200e1798 100644 --- a/serl_launcher/serl_launcher/utils/train_utils.py +++ b/serl_launcher/serl_launcher/utils/train_utils.py @@ -128,3 +128,110 @@ def load_resnet10_params(agent, image_keys=("image",), public=True): agent = agent.replace(state=agent.state.replace(params=new_params)) return agent + + +def load_pretrained_VoxNet_params(agent, image_keys=("pointcloud",)): + ckpt = jnp.load("/home/nico/Downloads/c-11.npz") + + new_params = agent.state.params + + for image_key in image_keys: + new_encoder_params = new_params["modules_actor"]["encoder"][ + f"encoder_{image_key}" + ] + to_replace = { + "conv_5x5x5": "voxnet/conv1/conv3d/", + "conv_3x3x3": "voxnet/conv2/conv3d/", + "conv_2x2x2": "voxnet/conv3/conv3d/", + } + replaced = [] + for key, weights in to_replace.items(): + if key in new_encoder_params: + shape = new_encoder_params[key]["kernel"].shape + new_encoder_params[key]["kernel"] = ( + new_encoder_params[key]["kernel"] + .at[:] + .set(ckpt[weights + "kernel:0"][..., : shape[-1]]) + ) + new_encoder_params[key]["bias"] = ( + new_encoder_params[key]["bias"] + .at[:] + .set(ckpt[weights + "bias:0"][: shape[-1]]) + ) + replaced.append(f"{key}:{shape}") + + print(f"replaced {replaced} in {image_key}") + + # replace LayerNorm params with pretrained BN ones + new_encoder_params["LayerNorm_0"]["bias"] = ( + new_encoder_params["LayerNorm_0"]["bias"] + .at[:] + .set(ckpt["voxnet/conv1/batch_normalization/beta:0"]) + ) + new_encoder_params["LayerNorm_0"]["scale"] = ( + new_encoder_params["LayerNorm_0"]["scale"] + .at[:] + .set(ckpt["voxnet/conv1/batch_normalization/gamma:0"]) + ) + + new_encoder_params["LayerNorm_1"]["bias"] = ( + new_encoder_params["LayerNorm_0"]["bias"] + .at[:] + .set(ckpt["voxnet/conv2/batch_normalization/beta:0"]) + ) + new_encoder_params["LayerNorm_1"]["scale"] = ( + new_encoder_params["LayerNorm_0"]["scale"] + .at[:] + .set(ckpt["voxnet/conv2/batch_normalization/gamma:0"]) + ) + + agent = agent.replace(state=agent.state.replace(params=new_params)) + return agent + + +def print_agent_params(agent, image_keys=("image",)): + """ + helper function to print the parameter count of the actor and critic networks + """ + + def get_size(params): + return sum(x.size for x in jax.tree.leaves(params)) + + total_param_count = get_size(agent.state.params) + actor, critic = ( + agent.state.params["modules_actor"], + agent.state.params["modules_critic"], + ) + + # calculate encoder params + try: + pretrained_encoder_count = get_size( + actor["encoder"][f"encoder_{image_keys[0]}"]["pretrained_encoder"] + ) + except Exception as e: + pretrained_encoder_count = 0 + + try: + encoder_count = get_size(actor["encoder"]) + except Exception as e: + encoder_count = 0 + + actor_count = get_size(actor) + critic_count = get_size(critic) + + print(f"\ntotal params: {total_param_count / 1e6:.3f}M") + print( + f"encoder params: {(encoder_count - pretrained_encoder_count) / 1e6:.3f}M pretrained encoder params: {pretrained_encoder_count / 1e6:.3f}M" + ) + print( + f"actor params: {(actor_count - encoder_count) / 1e6:.3f}M critic_params: {critic_count / 1e6:.3f}M" + ) + print( + f"total parameters to train: {(total_param_count - pretrained_encoder_count) / 1e6:.3f}M\n" + ) + + +def parameter_overview(agent): + from clu import parameter_overview + + print(parameter_overview.get_parameter_overview(agent.state.params)) diff --git a/serl_launcher/serl_launcher/vision/data_augmentations.py b/serl_launcher/serl_launcher/vision/data_augmentations.py index 2c2440fa..3f449cd5 100644 --- a/serl_launcher/serl_launcher/vision/data_augmentations.py +++ b/serl_launcher/serl_launcher/vision/data_augmentations.py @@ -4,6 +4,7 @@ import jax.numpy as jnp +@partial(jax.jit, static_argnames="padding") def random_crop(img, rng, *, padding): crop_from = jax.random.randint(rng, (2,), 0, 2 * padding + 1) crop_from = jnp.concatenate([crop_from, jnp.zeros((1,), dtype=jnp.int32)]) @@ -36,6 +37,36 @@ def batched_random_crop(img, rng, *, padding, num_batch_dims: int = 1): return img +def random_shift_3d(img, rng, *, padding): + crop_from = jax.random.randint(rng, (3,), 0, 2 * padding + 1) + padded_img = jnp.pad( + img, + ( + (padding, padding), + (padding, padding), + (padding, padding), + ), + mode="constant" + ) + return jax.lax.dynamic_slice(padded_img, crop_from, img.shape) + + +@partial(jax.jit, static_argnames=("padding", "num_batch_dims")) +def batched_random_shift_voxel(img, rng, *, padding, num_batch_dims: int = 1): + original_shape = img.shape + img = jnp.reshape(img, (-1, *img.shape[num_batch_dims:])) + # shape (B, B2, X, Y, Z) + + rngs = jax.random.split(rng, img.shape[0]) + img = jax.vmap( + lambda i, r: random_shift_3d(i, r, padding=padding), in_axes=(0, 0), out_axes=0 + )(img, rngs) + + # Restore batch dims + img = jnp.reshape(img, original_shape) + return img + + def _maybe_apply(apply_fn, inputs, rng, apply_prob): should_apply = jax.random.uniform(rng, shape=()) <= apply_prob return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x) diff --git a/serl_launcher/serl_launcher/vision/resnet_v1.py b/serl_launcher/serl_launcher/vision/resnet_v1.py index e18769b3..e89a6f42 100644 --- a/serl_launcher/serl_launcher/vision/resnet_v1.py +++ b/serl_launcher/serl_launcher/vision/resnet_v1.py @@ -8,6 +8,7 @@ import numpy as np from serl_launcher.vision.film_conditioning_layer import FilmConditioning +from serl_launcher.common.typing import PRNGKey ModuleDef = Any @@ -198,7 +199,6 @@ class ResNetEncoder(nn.Module): norm: str = "group" add_spatial_coordinates: bool = False pooling_method: str = "avg" - use_spatial_softmax: bool = False softmax_temperature: float = 1.0 use_multiplicative_cond: bool = False num_spatial_blocks: int = 8 @@ -322,10 +322,11 @@ def __call__( class PreTrainedResNetEncoder(nn.Module): + rng: PRNGKey = None pooling_method: str = "avg" - use_spatial_softmax: bool = False softmax_temperature: float = 1.0 num_spatial_blocks: int = 8 + num_kp: Optional[int] = None # for Spatial Softmax bottleneck_dim: Optional[int] = None pretrained_encoder: nn.module = None @@ -348,8 +349,22 @@ def __call__( channel=channel, num_features=self.num_spatial_blocks, )(x) - x = nn.Dropout(0.1, deterministic=not train)(x) + x = nn.Dropout(0.1, deterministic=not train)(x, rng=self.rng) elif self.pooling_method == "spatial_softmax": + if self.num_kp is not None: + """ + implemented as in https://github.com/huggingface/lerobot/blob/ff8f6aa6cde2957f08547eb081aac12ca4669b6a/lerobot/common/policies/diffusion/modeling_diffusion.py#L316 + In this case it would result in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. + """ + x = nn.Conv( + features=self.num_kp, + kernel_size=1, + use_bias=False, + dtype=jnp.float32, + kernel_init=nn.initializers.kaiming_normal(), + name="spatial_softmax_conv", + )(x) height, width, channel = x.shape[-3:] pos_x, pos_y = jnp.meshgrid( jnp.linspace(-1.0, 1.0, height), jnp.linspace(-1.0, 1.0, width) diff --git a/serl_launcher/serl_launcher/vision/resnet_v1_18.py b/serl_launcher/serl_launcher/vision/resnet_v1_18.py new file mode 100644 index 00000000..cc8bbe0d --- /dev/null +++ b/serl_launcher/serl_launcher/vision/resnet_v1_18.py @@ -0,0 +1,492 @@ +import jax.lax +import jax.numpy as jnp +import flax.linen as nn +import functools +from typing import Any, Callable, Iterable, Optional, Tuple, Union +import h5py +import warnings + +from flax.linen.module import compact, merge_param +from jax.nn import initializers +from jax import lax + +from serl_launcher.vision.resnet_v1 import SpatialLearnedEmbeddings, SpatialSoftmax + +PRNGKey = Any +Array = Any +Shape = Tuple[int] +Dtype = Any + + +# ---------------------------------------------------------------# +# Normalization +# ---------------------------------------------------------------# +def batch_norm(x, train, epsilon=1e-05, momentum=0.99, params=None, dtype="float32"): + # we do not use running average in the implementation (set to False) + if params is None: + x = BatchNorm( + epsilon=epsilon, + momentum=momentum, + use_running_average=False, # was not train + dtype=dtype, + )(x) + else: + x = BatchNorm( + epsilon=epsilon, + momentum=momentum, + bias_init=lambda *_: jnp.array(params["bias"]), + scale_init=lambda *_: jnp.array(params["scale"]), + mean_init=lambda *_: jnp.array(params["mean"]), + var_init=lambda *_: jnp.array(params["var"]), + use_running_average=False, # was not train + dtype=dtype, + )(x) + return x + + +def _absolute_dims(rank, dims): + return tuple([rank + dim if dim < 0 else dim for dim in dims]) + + +class BatchNorm(nn.Module): + """BatchNorm Module. + + Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py + + Attributes: + use_running_average: if True, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of the batch statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + momentum: float = 0.99 + epsilon: float = 1e-5 + dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32) + var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32) + axis_name: Optional[str] = None + axis_index_groups: Any = None + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when parameters are mutable) the running average + of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + Returns: + Normalized inputs (the same shape as inputs). + """ + use_running_average = merge_param( + "use_running_average", self.use_running_average, use_running_average + ) + x = jnp.asarray(x, jnp.float32) + axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) + axis = _absolute_dims(x.ndim, axis) + feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) + reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) + reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) + + # see NOTE above on initialization behavior + initializing = self.is_mutable_collection("params") + + if use_running_average: + ra_mean = self.variable( + "batch_stats", "mean", self.mean_init, reduced_feature_shape + ) + ra_var = self.variable( + "batch_stats", "var", self.var_init, reduced_feature_shape + ) + mean, var = ra_mean.value, ra_var.value + else: + mean = jnp.mean(x, axis=reduction_axis, keepdims=False) + mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) + if self.axis_name is not None and not initializing: + concatenated_mean = jnp.concatenate([mean, mean2]) + mean, mean2 = jnp.split( + lax.pmean( + concatenated_mean, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups, + ), + 2, + ) + var = mean2 - lax.square(mean) + + y = x - mean.reshape(feature_shape) + mul = lax.rsqrt(var + self.epsilon) + if self.use_scale: + scale = self.param("scale", self.scale_init, reduced_feature_shape).reshape( + feature_shape + ) + mul = mul * scale + y = y * mul + if self.use_bias: + bias = self.param("bias", self.bias_init, reduced_feature_shape).reshape( + feature_shape + ) + y = y + bias + return jnp.asarray(y, self.dtype) + + +LAYERS = {"resnet18": [2, 2, 2, 2]} + + +class BasicBlock(nn.Module): + """ + Basic Block. + + Attributes: + features (int): Number of output channels. + kernel_size (Tuple): Kernel size. + downsample (bool): If True, downsample spatial resolution. + stride (bool): If True, use strides (2, 2). Not used in this module. + The attribute is only here for compatibility with Bottleneck. + param_dict (h5py.Group): Parameter dict with pretrained parameters. + kernel_init (functools.partial): Kernel initializer. + bias_init (functools.partial): Bias initializer. + block_name (str): Name of block. + dtype (str): Data type. + """ + + features: int + kernel_size: Union[int, Iterable[int]] = (3, 3) + downsample: bool = False + stride: bool = True + param_dict: h5py.Group = None + kernel_init: functools.partial = nn.initializers.lecun_normal() + bias_init: functools.partial = nn.initializers.zeros + block_name: str = None + dtype: str = "float32" + + @nn.compact + def __call__(self, x, act, train=True): + """ + Run Basic Block. + + Args: + x (tensor): Input tensor of shape [N, H, W, C]. + act (dict): Dictionary containing activations. + train (bool): Training mode. + + Returns: + (tensor): Output shape of shape [N, H', W', features]. + """ + residual = x + + x = nn.Conv( + features=self.features, + kernel_size=self.kernel_size, + strides=(2, 2) if self.downsample else (1, 1), + padding=((1, 1), (1, 1)), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["conv1"]["weight"]), + use_bias=False, + dtype=self.dtype, + )(x) + + x = batch_norm( + x, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None if self.param_dict is None else self.param_dict["bn1"], + dtype=self.dtype, + ) + x = nn.relu(x) + + x = nn.Conv( + features=self.features, + kernel_size=self.kernel_size, + strides=(1, 1), + padding=((1, 1), (1, 1)), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["conv2"]["weight"]), + use_bias=False, + dtype=self.dtype, + )(x) + + x = batch_norm( + x, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None if self.param_dict is None else self.param_dict["bn2"], + dtype=self.dtype, + ) + + if self.downsample: + residual = nn.Conv( + features=self.features, + kernel_size=(1, 1), + strides=(2, 2), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array( + self.param_dict["downsample"]["conv"]["weight"] + ), + use_bias=False, + dtype=self.dtype, + )(residual) + + residual = batch_norm( + residual, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None + if self.param_dict is None + else self.param_dict["downsample"]["bn"], + dtype=self.dtype, + ) + + x += residual + x = nn.relu(x) + act[self.block_name] = x + return x + + +class ResNet(nn.Module): + """ + ResNet. + + Attributes: + output (str): + Output of the module. Available options are: + - 'softmax': Output is a softmax tensor of shape [N, 1000] + - 'log_softmax': Output is a softmax tensor of shape [N, 1000] + - 'logits': Output is a tensor of shape [N, 1000] + - 'activations': Output is a dictionary containing the ResNet activations + pretrained (str): + Indicates if and what type of weights to load. Options are: + - 'imagenet': Loads the network parameters trained on ImageNet + - None: Parameters of the module are initialized randomly + normalize (bool): + If True, the input will be normalized with the ImageNet statistics. + architecture (str): + Which ResNet model to use: + - 'resnet18' + num_classes (int): + Number of classes. + block (nn.Module): + Type of residual block: + - BasicBlock + kernel_init (function): + A function that takes in a shape and returns a tensor. + bias_init (function): + A function that takes in a shape and returns a tensor. + ckpt_dir (str): + The directory to which the pretrained weights are downloaded. + Only relevant if a pretrained model is used. + If this argument is None, the weights will be saved to a temp directory. + dtype (str): Data type. + """ + + output: str = "softmax" + pretrained: str = "imagenet" + normalize: bool = True + architecture: str = "resnet18" + num_classes: int = 1000 + block: nn.Module = BasicBlock + kernel_init: functools.partial = nn.initializers.lecun_normal() + bias_init: functools.partial = nn.initializers.zeros + ckpt_dir: str = None + dtype: str = "float32" + pre_pooling: bool = True # skip pooling + + def setup(self): + # self.param_dict = None + if self.pretrained == "imagenet": + # ckpt_file = utils.download(self.ckpt_dir, URLS[self.architecture]) + self.param_dict = h5py.File(self.ckpt_dir, "r") + # print(f"loaded pretrained weights from {self.ckpt_dir}") + + @nn.compact + def __call__(self, observations, train=False): + """ + Args: + x (tensor): Input tensor of shape [N, H, W, 3]. Images must be in range [0, 1]. + train (bool): Training mode. + + Returns: + (tensor): Out + if pre_pooling is True: features of shape (B, 7, 7, 512) + """ + # assert observations.shape[-3:] == (224, 224, 3) + + if self.normalize: + mean = jnp.array([0.485, 0.456, 0.406]).reshape(1, 1, 1, -1) + std = jnp.array([0.229, 0.224, 0.225]).reshape(1, 1, 1, -1) + x = (observations.astype(jnp.float32) / 255.0 - mean) / std + + if self.pretrained == "imagenet": + if self.num_classes != 1000: + warnings.warn( + f"The user specified parameter 'num_classes' was set to {self.num_classes} " + "but will be overwritten with 1000 to match the specified pretrained checkpoint 'imagenet', if ", + UserWarning, + ) + num_classes = 1000 + else: + num_classes = self.num_classes + + act = {} + + x = nn.Conv( + features=64, + kernel_size=(7, 7), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["conv1"]["weight"]), + strides=(2, 2), + padding=((3, 3), (3, 3)), + use_bias=False, + dtype=self.dtype, + )(x) + act["conv1"] = x + + x = batch_norm( + x, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None if self.param_dict is None else self.param_dict["bn1"], + dtype=self.dtype, + ) + x = nn.relu(x) + x = nn.max_pool( + x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)) + ) + + # Layer 1 + down = self.block.__name__ == "Bottleneck" + for i in range(LAYERS[self.architecture][0]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer1"][f"block{i}"] + ) + x = self.block( + features=64, + kernel_size=(3, 3), + downsample=i == 0 and down, + stride=i != 0, + param_dict=params, + block_name=f"block1_{i}", + dtype=self.dtype, + )(x, act, train) + + # Layer 2 + for i in range(LAYERS[self.architecture][1]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer2"][f"block{i}"] + ) + x = self.block( + features=128, + kernel_size=(3, 3), + downsample=i == 0, + param_dict=params, + block_name=f"block2_{i}", + dtype=self.dtype, + )(x, act, train) + + # Layer 3 + for i in range(LAYERS[self.architecture][2]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer3"][f"block{i}"] + ) + x = self.block( + features=256, + kernel_size=(3, 3), + downsample=i == 0, + param_dict=params, + block_name=f"block3_{i}", + dtype=self.dtype, + )(x, act, train) + + # Layer 4 + for i in range(LAYERS[self.architecture][3]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer4"][f"block{i}"] + ) + x = self.block( + features=512, + kernel_size=(3, 3), + downsample=i == 0, + param_dict=params, + block_name=f"block4_{i}", + dtype=self.dtype, + )(x, act, train) + + # if we want the pre_pooling output, return here + if self.pre_pooling: + return jax.lax.stop_gradient(x) # shape (b, 7, 7, 512) + + # Classifier + x = jnp.mean(x, axis=(1, 2)) + x = nn.Dense( + features=num_classes, + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["fc"]["weight"]), + bias_init=self.bias_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["fc"]["bias"]), + dtype=self.dtype, + )(x) + act["fc"] = x + + if self.output == "softmax": + return nn.softmax(x) + if self.output == "log_softmax": + return nn.log_softmax(x) + if self.output == "activations": + return act + return x + + +resnetv1_18_configs = { + "resnetv1-18-frozen": functools.partial( + ResNet, + architecture="resnet18", + ckpt_dir="/examples/box_picking_drq/resnet18_weights.h5", # download from #TODO + pre_pooling=True, + ) +} diff --git a/serl_launcher/serl_launcher/vision/small_encoders.py b/serl_launcher/serl_launcher/vision/small_encoders.py index 630c7b9a..d0e37969 100644 --- a/serl_launcher/serl_launcher/vision/small_encoders.py +++ b/serl_launcher/serl_launcher/vision/small_encoders.py @@ -14,9 +14,12 @@ class SmallEncoder(nn.Module): pool_method: str = "spatial_learned_embeddings" bottleneck_dim: Optional[int] = None spatial_block_size: Optional[int] = 8 + num_kp: Optional[int] = 32 @nn.compact - def __call__(self, observations: jnp.ndarray, train=False) -> jnp.ndarray: + def __call__( + self, observations: jnp.ndarray, train=False, encode=True + ) -> jnp.ndarray: assert len(self.features) == len(self.strides) x = observations.astype(jnp.float32) / 255.0 @@ -44,6 +47,13 @@ def __call__(self, observations: jnp.ndarray, train=False) -> jnp.ndarray: raise ValueError( "spatial_block_size must be set when using spatial_learned_embeddings" ) + x = nn.Conv( # 512 to num_kp features (less complexity) + features=self.num_kp, + kernel_size=1, + use_bias=False, + dtype=jnp.float32, + kernel_init=nn.initializers.kaiming_normal(), + )(x) x = SpatialLearnedEmbeddings(*(x.shape[-3:]), self.spatial_block_size)(x) x = nn.Dropout(0.1, deterministic=not train)(x) diff --git a/serl_launcher/serl_launcher/vision/voxel_grid_encoders.py b/serl_launcher/serl_launcher/vision/voxel_grid_encoders.py new file mode 100644 index 00000000..45fabc33 --- /dev/null +++ b/serl_launcher/serl_launcher/vision/voxel_grid_encoders.py @@ -0,0 +1,148 @@ +from functools import partial +from typing import Any, Callable, Optional, Sequence, Tuple + +import flax.linen as nn +import jax.numpy as jnp +import jax.lax as lax + +import jax + + +class SpatialSoftArgmax3D(nn.Module): + """ + 3D Implementation of Spatial Soft Argmax + why arg-max and not max: see https://github.com/tensorflow/tensorflow/issues/6271#issuecomment-266893850 + """ + + x_len: int + y_len: int + z_len: int + channel: int + temperature: float = 1.0 + + def setup(self): + pos_x, pos_y, pos_z = jnp.meshgrid( + jnp.linspace(-1.0, 1.0, self.x_len), + jnp.linspace(-1.0, 1.0, self.y_len), + jnp.linspace(-1.0, 1.0, self.z_len), + indexing="ij", + ) + self.pos_x = pos_x.reshape(-1) # shape (x*y*z) + self.pos_y = pos_y.reshape(-1) + self.pos_z = pos_z.reshape(-1) + + @nn.compact + def __call__(self, features): + # add batch dim if missing + no_batch_dim = len(features.shape) < 5 + if no_batch_dim: + features = features[None] + + assert len(features.shape) == 5 + batch_size, num_featuremaps = features.shape[0], features.shape[-1] + features = features.transpose(0, 4, 1, 2, 3).reshape( + batch_size, num_featuremaps, self.x_len * self.y_len * self.z_len + ) + + softmax_attention = nn.softmax(features / self.temperature, axis=-1) + expected_x = jnp.sum(self.pos_x * softmax_attention, axis=-1) + expected_y = jnp.sum(self.pos_y * softmax_attention, axis=-1) + expected_z = jnp.sum(self.pos_z * softmax_attention, axis=-1) + expected_xyz = jnp.concatenate([expected_x, expected_y, expected_z], axis=-1) + + expected_xy = jnp.reshape(expected_xyz, (batch_size, 3, num_featuremaps)) + + if no_batch_dim: + expected_xy = expected_xy[0] + return expected_xy + + +class VoxNet(nn.Module): + """ + VoxNet-like implementation: https://github.com/AutoDeep/VoxNet/blob/master/src/nets/voxNet.py + """ + + use_conv_bias: bool = False + bottleneck_dim: Optional[int] = None + final_activation: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.tanh + pretrained: bool = False + scale_factor: float = 1.0 + + @nn.compact + def __call__( + self, + observations: jnp.ndarray, + encode: bool = True, + train: bool = True, + ): + # observations has shape (B, X, Y, Z) + no_batch_dim = len(observations.shape) < 4 + if no_batch_dim: + observations = observations[None] + + observations = ( + observations.astype(jnp.float32)[..., None] / self.scale_factor + ) # add conv channel + + conv3d = partial( + nn.Conv, + kernel_init=nn.initializers.xavier_normal(), + use_bias=self.use_conv_bias, + padding="valid", + bias_init=nn.zeros_init(), + ) + l_relu = partial(nn.leaky_relu, negative_slope=0.1) + max_pool = partial(nn.max_pool, window_shape=(2, 2, 2), strides=(2, 2, 2)) + + if self.pretrained: + feature_dimensions = (64, 64, 32) + else: + feature_dimensions = (32, 16, 8) + + x = observations + x = conv3d( + features=feature_dimensions[0], + kernel_size=(5, 5, 5), + strides=(2, 2, 2), + name="conv_5x5x5", + )(x) + x = nn.LayerNorm()(x) + x = l_relu(x) + + x = conv3d( + features=feature_dimensions[1], + kernel_size=(3, 3, 3), + strides=(1, 1, 1), + name="conv_3x3x3", + )(x) + x = max_pool(x) + + if self.pretrained: + x = jax.lax.stop_gradient( + x + ) # unfortunately also cuts gradients of the LayerNorm above + + x = nn.LayerNorm()(x) + x = l_relu(x) + + x = conv3d( + features=feature_dimensions[ + 2 + ], # if pretrained, uses [..] out of 128 pretrained params as initial weights + kernel_size=(2, 2, 2), + strides=(2, 2, 2), + name="conv_2x2x2", + )(x) + x = nn.LayerNorm()(x) + x = l_relu(x) + + # x = SpatialSoftArgmax3D(10, 10, 8, 64)(x) # not used for now + + # reshape and dense (preserve batch dim) + x = jnp.reshape(x, (1 if no_batch_dim else x.shape[0], -1)) + if self.bottleneck_dim is not None: + x = nn.Dense(self.bottleneck_dim)(x) + x = nn.LayerNorm()(x) + x = self.final_activation(x) + + return x[0] if no_batch_dim else x diff --git a/serl_launcher/serl_launcher/wrappers/observation_statistics_wrapper.py b/serl_launcher/serl_launcher/wrappers/observation_statistics_wrapper.py new file mode 100644 index 00000000..90c4836a --- /dev/null +++ b/serl_launcher/serl_launcher/wrappers/observation_statistics_wrapper.py @@ -0,0 +1,93 @@ +import numpy as np +from collections import deque +import gym + + +class ObservationStatisticsWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): + """ + This wrapper will keep track of the observation statistics. + + At the end of an episode, the statistics of the episode will be added to ``info`` + using the key ``obsStat``. + """ + + def __init__(self, env: gym.Env, deque_size: int = 100): + gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size) + gym.Wrapper.__init__(self, env) + + self.buffer = {} + + # make buffer + for name, space in self.env.observation_space["state"].items(): + self.buffer[name] = np.zeros( + shape=(self.max_episode_length, space.shape[0]) + ) + + # may not be used + self.num_envs = getattr(env, "num_envs", 1) + self.episode_count = 0 + self.episode_start_times: np.ndarray = None + self.episode_returns = None + self.episode_lengths = None + self.return_queue = deque(maxlen=deque_size) + self.length_queue = deque(maxlen=deque_size) + self.is_vector_env = getattr(env, "is_vector_env", False) + + def step(self, action): + """Steps through the environment, recording the episode statistics.""" + ( + observations, + rewards, + terminations, + truncations, + infos, + ) = self.env.step(action) + assert isinstance( + infos, dict + ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." + + for name, obs in observations["state"].items(): + self.buffer[name][self.curr_path_length - 1, :] = obs + + dones = np.logical_or(terminations, truncations) + num_dones = np.sum(dones) + if num_dones: + calc_buffs = {} + calc_buffs.update( + { + name + "_mean": np.mean(obs[: self.curr_path_length], axis=0) + for name, obs in self.buffer.items() + } + ) + calc_buffs.update( + { + name + "_std": np.std(obs[: self.curr_path_length], axis=0) + for name, obs in self.buffer.items() + } + ) + buff = {} + for name, value in calc_buffs.items(): + for i in range(value.shape[0]): + buff[ + name + f"_{['x', 'y', 'z', 'rx', 'ry', 'rz', 'grip'][i]}" + ] = value[i] + infos["obsStat"] = buff + # print(buff) + + return ( + observations, + rewards, + terminations, + truncations, + infos, + ) + + def reset(self, **kwargs): + """Resets the environment using kwargs and resets the episode returns and lengths.""" + obs, info = super().reset(**kwargs) + + # reset buffer to zero + for name, value in self.buffer.items(): + value[...] = 0 + + return obs, info diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index 41c169f9..4bfbb721 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -23,3 +23,41 @@ def observation(self, obs): **(obs["images"]), } return obs + + +class ScaleObservationWrapper(gym.ObservationWrapper): + """ + This observation wrapper scales the observations with the provided hyperparams + (to somewhat normalize the observations space) + """ + + def __init__( + self, + env, + translation_scale=100.0, + rotation_scale=10.0, + force_scale=1.0, + torque_scale=10.0, + ): + super().__init__(env) + self.translation_scale = translation_scale + self.rotation_scale = rotation_scale + self.force_scale = force_scale + self.torque_scale = torque_scale + + def scale_wrapper_get_scales(self): + return dict( + translation_scale=self.translation_scale, + rotation_scale=self.rotation_scale, + force_scale=self.force_scale, + torque_scale=self.torque_scale, + ) + + def observation(self, obs): + obs["state"]["tcp_pose"][:3] *= self.translation_scale + obs["state"]["tcp_pose"][3:] *= self.rotation_scale + obs["state"]["tcp_vel"][:3] *= self.translation_scale + obs["state"]["tcp_vel"][3:] *= self.rotation_scale + obs["state"]["tcp_force"] *= self.force_scale + obs["state"]["tcp_torque"] *= self.torque_scale + return obs diff --git a/serl_robot_infra/franka_env/utils/transformations.py b/serl_robot_infra/franka_env/utils/transformations.py index 52a55527..3687237b 100644 --- a/serl_robot_infra/franka_env/utils/transformations.py +++ b/serl_robot_infra/franka_env/utils/transformations.py @@ -23,6 +23,14 @@ def construct_adjoint_matrix(tcp_pose): return adjoint_matrix +def construct_rotation_matrix(tcp_pose): + """ + Construct the adjoint matrix for a spatial velocity vector + :args: tcp_pose: (x, y, z, qx, qy, qz, qw) + """ + return R.from_quat(tcp_pose[3:]).as_matrix() + + def construct_homogeneous_matrix(tcp_pose): """ Construct the homogeneous transformation matrix from given pose. diff --git a/serl_robot_infra/robot_controllers/__init__.py b/serl_robot_infra/robot_controllers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/serl_robot_infra/robot_controllers/ur5_controller.py b/serl_robot_infra/robot_controllers/ur5_controller.py new file mode 100644 index 00000000..b0f04a7d --- /dev/null +++ b/serl_robot_infra/robot_controllers/ur5_controller.py @@ -0,0 +1,448 @@ +import datetime +import time +import threading +import asyncio +import numpy as np +from scipy.spatial.transform import Rotation as R +from rtde_control import RTDEControlInterface +from rtde_receive import RTDEReceiveInterface + +from ur_env.utils.vacuum_gripper import VacuumGripper +from ur_env.utils.rotations import rotvec_2_quat, quat_2_rotvec, pose2rotvec, pose2quat + +np.set_printoptions(precision=4, suppress=True) + + +def pos_difference(quat_pose_1: np.ndarray, quat_pose_2: np.ndarray): + assert quat_pose_1.shape == (7,) + assert quat_pose_2.shape == (7,) + p_diff = np.sum(np.abs(quat_pose_1[:3] - quat_pose_2[:3])) + + r_diff = ( + R.from_quat(quat_pose_1[3:]) * R.from_quat(quat_pose_2[3:]).inv() + ).magnitude() + return p_diff + r_diff + + +class UrImpedanceController(threading.Thread): + def __init__( + self, + robot_ip, + frequency=100, + kp=10000, + kd=2200, + config=None, + verbose=False, + *args, + **kwargs, + ): + super(UrImpedanceController, self).__init__(*args, **kwargs) + self._stop = threading.Event() + self._reset = threading.Event() + self._is_ready = threading.Event() + self._is_truncated = threading.Event() + self.lock = threading.Lock() + + self.robot_ip = robot_ip + self.frequency = frequency + self.kp = kp + self.kd = kd + self.gripper_timeout = { + "timeout": config.GRIPPER_TIMEOUT, + "last_grip": time.monotonic() - 1e6, + } + self.verbose = verbose + + self.target_pos = np.zeros( + (7,), dtype=np.float32 + ) # new as quat to avoid +- problems with axis angle repr. + self.target_grip = np.zeros((1,), dtype=np.float32) + self.curr_pos = np.zeros((7,), dtype=np.float32) + self.curr_vel = np.zeros((6,), dtype=np.float32) + self.gripper_state = np.zeros((2,), dtype=np.float32) + self.curr_Q = np.zeros((6,), dtype=np.float32) + self.curr_Qd = np.zeros((6,), dtype=np.float32) + self.curr_force_lowpass = np.zeros((6,), dtype=np.float32) # force of tool tip + self.curr_force = np.zeros((6,), dtype=np.float32) + + self.reset_Q = np.array( + [np.pi / 2.0, -np.pi / 2.0, np.pi / 2.0, -np.pi / 2.0, -np.pi / 2.0, 0.0], + dtype=np.float32, + ) # reset state in Joint Space + self.reset_Pose = np.zeros_like(self.reset_Q) + self.reset_height = np.array([0.1], dtype=np.float32) # TODO make customizable + + self.delta = config.ERROR_DELTA + self.fm_damping = config.FORCEMODE_DAMPING + self.fm_task_frame = config.FORCEMODE_TASK_FRAME + self.fm_selection_vector = config.FORCEMODE_SELECTION_VECTOR + self.fm_limits = config.FORCEMODE_LIMITS + + self.ur_control: RTDEControlInterface = None + self.ur_receive: RTDEReceiveInterface = None + self.robotiq_gripper: VacuumGripper = None + + # only temporary to test + self.hist_data = [[], []] + self.horizon = [0, 500] + self.err = 0 + self.noerr = 0 + + # log to file (reset every new run) + with open("/tmp/console2.txt", "w") as f: + f.write("reset\n") + self.second_console = open("/tmp/console2.txt", "a") + + def start(self): + super().start() + if self.verbose: + print(f"[RIC] Controller process spawned at {self.native_id}") + + def print(self, msg, both=False): + self.second_console.write(f"{datetime.datetime.now()} --> {msg}\n") + if both: + print(msg) + + async def start_ur_interfaces(self, gripper=True): + self.ur_control = RTDEControlInterface(self.robot_ip) + self.ur_receive = RTDEReceiveInterface(self.robot_ip) + if gripper: + self.robotiq_gripper = VacuumGripper(self.robot_ip) + await self.robotiq_gripper.connect() + await self.robotiq_gripper.activate() + if self.verbose: + gr_string = "(with gripper) " if gripper else "" + print(f"[RIC] Controller connected to robot {gr_string}at: {self.robot_ip}") + + async def restart_ur_interface(self): + self._is_truncated.set() + self.print("[RIC] forcemode failed, is now truncated!") + + # disconnect and reconnect, otherwise the controller won't take any commands + self.ur_control.disconnect() + try: + print(f"[RTDE] trying to reconnect") + self.ur_control.reconnect() + except RuntimeError: + self.ur_receive.disconnect() + for _ in range(10): + try: + self.ur_control.disconnect() + self.ur_receive.disconnect() + await self.start_ur_interfaces(gripper=False) + return + except Exception as e: + print(e) + time.sleep(0.2) + + def stop(self): + self._stop.set() + + def stopped(self): + return self._stop.is_set() + + def is_moving(self): + return np.linalg.norm(self.get_state()["vel"], 2) > 0.01 + + def set_target_pos(self, target_pos: np.ndarray): + if target_pos.shape == (7,): + target_orientation = target_pos[3:] + elif target_pos.shape == (6,): + target_orientation = rotvec_2_quat(target_pos[3:]) + else: + raise ValueError(f"[RIC] target pos has shape {target_pos.shape}") + + with self.lock: + self.target_pos[:3] = target_pos[:3] + self.target_pos[3:] = target_orientation + + self.print(f"target: {self.target_pos}") + + def set_reset_Q(self, reset_Q: np.ndarray): + with self.lock: + self.reset_Q[:] = reset_Q + self._reset.set() + + def set_reset_pose(self, reset_pose: np.ndarray): + with self.lock: + self.reset_Pose[:] = reset_pose + self._reset.set() + + def set_gripper_pos(self, target_grip: np.ndarray): + with self.lock: + self.target_grip[:] = target_grip + + def get_target_pos(self, copy=True): + with self.lock: + if copy: + return self.target_pos.copy() + else: + return self.target_pos + + async def _update_robot_state(self): + pos = self.ur_receive.getActualTCPPose() + vel = self.ur_receive.getActualTCPSpeed() + Q = self.ur_receive.getActualQ() + Qd = self.ur_receive.getActualQd() + force = self.ur_receive.getActualTCPForce() + pressure = await self.robotiq_gripper.get_current_pressure() + obj_status = await self.robotiq_gripper.get_object_status() + + # 3-> no object detected, 0-> sucking empty, [1, 2] obj detected + grip_status = [-1.0, 1.0, 1.0, 0.0][obj_status.value] + + pressure = ( + pressure if pressure < 99 else 0 + ) # 100 no obj, 99 sucking empty, so they are ignored + # grip status, 0->neutral, -1->bad (sucking but no obj), 1-> good (sucking and obj) + grip_status = 1.0 if pressure > 0 else grip_status + pressure /= 98.0 # pressure between [0, 1] + with self.lock: + self.curr_pos[:] = pose2quat(pos) + self.curr_vel[:] = vel + self.curr_Q[:] = Q + self.curr_Qd[:] = Qd + self.curr_force[:] = np.array(force) + # use moving average (5), since the force fluctuates heavily + self.curr_force_lowpass[:] = ( + 0.1 * np.array(force) + 0.9 * self.curr_force_lowpass[:] + ) + self.gripper_state[:] = [pressure, grip_status] + + def get_state(self): + with self.lock: + state = { + "pos": self.curr_pos, + "vel": self.curr_vel, + "Q": self.curr_Q, + "Qd": self.curr_Qd, + "force": self.curr_force_lowpass[:3], + "torque": self.curr_force_lowpass[3:], + "gripper": self.gripper_state, + } + return state + + def is_ready(self): + return self._is_ready.is_set() + + def is_reset(self): + return not self._reset.is_set() + + def _calculate_force(self): + target_pos = self.get_target_pos(copy=True) + with self.lock: + curr_pos = self.curr_pos + curr_vel = self.curr_vel + + # calc position for + kp, kd = self.kp, self.kd + diff_p = np.clip( + target_pos[:3] - curr_pos[:3], a_min=-self.delta, a_max=self.delta + ) + vel_delta = 2 * self.delta * self.frequency + diff_d = np.clip(-curr_vel[:3], a_min=-vel_delta, a_max=vel_delta) + force_pos = kp * diff_p + kd * diff_d + + # calc torque + rot_diff = R.from_quat(target_pos[3:]) * R.from_quat(curr_pos[3:]).inv() + vel_rot_diff = R.from_rotvec(curr_vel[3:]).inv() + torque = ( + rot_diff.as_rotvec() * 100 + vel_rot_diff.as_rotvec() * 22 + ) # TODO make customizable + + # check for big downward tcp force and adapt accordingly + if self.curr_force[2] > 3.5 and force_pos[2] < 0.0: + force_pos[2] = ( + max((1.5 - self.curr_force_lowpass[2]), 0.0) * force_pos[2] + + min(self.curr_force_lowpass[2] - 0.5, 1.0) * 20.0 + ) + + return np.concatenate((force_pos, torque)) + + async def send_gripper_command(self, force_release=False): + if force_release: + await self.robotiq_gripper.automatic_release() + self.target_grip[0] = 0.0 + return + + timeout_exceeded = ( + time.monotonic() - self.gripper_timeout["last_grip"] + ) * 1000 > self.gripper_timeout["timeout"] + # target grip above threshold and timeout exceeded and not gripping something already + if ( + self.target_grip[0] > 0.5 + and timeout_exceeded + and self.gripper_state[1] < 0.5 + ): + await self.robotiq_gripper.automatic_grip() + self.target_grip[0] = 0.0 + self.gripper_timeout["last_grip"] = time.monotonic() + # print("grip") + + # release if below neg threshold and gripper activated (grip_status not zero) + elif self.target_grip[0] < -0.5 and abs(self.gripper_state[1]) > 0.5: + await self.robotiq_gripper.automatic_release() + self.target_grip[0] = 0.0 + # print("release") + + def _truncate_check(self): + downward_force = self.curr_force_lowpass[2] > 20.0 + if downward_force: # TODO add better criteria + self._is_truncated.set() + else: + self._is_truncated.clear() + + def is_truncated(self): + return self._is_truncated.is_set() + + def run(self): + try: + asyncio.run( + self.run_async() + ) # gripper has to be awaited, both init and commands + finally: + self.stop() + + async def _go_to_reset_pose(self): + self.ur_control.forceModeStop() + + # first disable vaccum gripper + if self.robotiq_gripper: + await self.send_gripper_command(force_release=True) + time.sleep(0.01) + + # then move up (so no boxes are moved) + success = True + while self.curr_pos[2] < self.reset_height: + if ( + self.curr_Q[2] < 0.5 + ): # if the shoulder joint is near 180deg --> do not move into singularity + success = success and self.ur_control.speedJ( + [0.0, -1.0, 1.0, 0.0, 0.0, 0.0], acceleration=0.8 + ) + else: + success = success and self.ur_control.speedL( + [0.0, 0.0, 0.25, 0.0, 0.0, 0.0], acceleration=0.8 + ) + await self._update_robot_state() + time.sleep(0.01) + self.ur_control.speedStop(a=1.0) + + if self.reset_Pose.std() > 0.001: + success = success and self.ur_control.moveL( + self.reset_Pose, speed=0.5, acceleration=0.3 + ) + self.print( + f"[RIC] moving to {self.reset_Pose} with moveL (task space)", + both=self.verbose, + ) + self.reset_Pose[:] = 0.0 + else: + # then move to desired Jointspace position + success = success and self.ur_control.moveJ( + self.reset_Q, speed=1.0, acceleration=0.8 + ) + self.print( + f"[RIC] moving to {self.reset_Q} with moveJ (joint space)", + both=self.verbose, + ) + + time.sleep(0.1) # wait for 100ms + await self._update_robot_state() + with self.lock: + self.target_pos = self.curr_pos.copy() + + self.ur_control.forceModeSetDamping(self.fm_damping) # less damping = Faster + self.ur_control.zeroFtSensor() + + if not success: # restart if not successful + await self.restart_ur_interface() + else: + self._reset.clear() + + async def run_async(self): + await self.start_ur_interfaces(gripper=True) + + self.ur_control.forceModeSetDamping(self.fm_damping) # less damping = Faster + + try: + dt = 1.0 / self.frequency + self.ur_control.zeroFtSensor() + await self._update_robot_state() + self.target_pos = self.curr_pos.copy() + print(f"[RIC] target position set to curr pos: {self.target_pos}") + + self._is_ready.set() + + while not self.stopped(): + if self._reset.is_set(): + await self._update_robot_state() + await self._go_to_reset_pose() + + t_now = time.monotonic() + + # update robot state and check for truncation + await self._update_robot_state() + self._truncate_check() + + # calculate force + force = self._calculate_force() + # print(self.target_pos, self.curr_pos, force) + self.print( + f" p:{self.curr_pos} f:{self.curr_force_lowpass} gr:{self.gripper_state}" + ) # log to file + + # send command to robot + t_start = self.ur_control.initPeriod() + fm_successful = self.ur_control.forceMode( + self.fm_task_frame, + self.fm_selection_vector, + force, + 2, + self.fm_limits, + ) + if not fm_successful: # truncate if the robot ends up in a singularity + await self.restart_ur_interface() + await self._go_to_reset_pose() + + if self.robotiq_gripper: + await self.send_gripper_command() + + self.ur_control.waitPeriod(t_start) + + a = dt - (time.monotonic() - t_now) + time.sleep(max(0.0, a)) + self.err, self.noerr = self.err + int(a < 0.0), self.noerr + int( + a >= 0.0 + ) # some logging + if a < -0.04: # log if delay more than 50ms + self.print( + f"Controller Thread stopped for {(time.monotonic() - t_now)*1e3:.1f} ms" + ) + + finally: + if self.verbose: + print( + f"[RTDEPositionalController] >dt: {self.err}
5.13, auto_exposure False & auto_white_balance True, below only auto_exposure True + for sensor in self.profile.get_device().query_sensors(): + sensor.set_option(rs.option.enable_auto_exposure, True) + # sensor.set_option(rs.option.enable_auto_white_balance, True) + + # Create an align object + # rs.align allows us to perform alignment of depth frames to others frames + # The "align_to" is the stream type to which we plan to align depth frames. + align_to = rs.stream.color + self.align = rs.align(align_to) + + def read(self): + t = time.time() + frames = self.pipe.wait_for_frames() + tdiff = time.time() - t + if tdiff > 0.5: + print(f"wait for frames took {tdiff:.3f} seconds") + image, depth, pointcloud = None, None, None + + if self.rgb: + aligned_frames = self.align.process(frames) + color_frame = aligned_frames.get_color_frame() + + if color_frame.is_video_frame(): + image = np.asarray(color_frame.get_data()) + + if self.depth: + aligned_frames = self.align.process(frames) + depth_frame = aligned_frames.get_depth_frame() + + if depth_frame.is_depth_frame(): + depth = np.asanyarray(depth_frame.get_data()) + # clip max + depth = np.where( + (depth > self.max_clipping_distance), + 0.0, + self.max_clipping_distance - depth, + ) + + depth = (depth * (256.0 / self.max_clipping_distance)).astype(np.uint8) + depth = depth[..., None] + + if self.pointcloud: + depth_frame = self.decimation_filter.process(frames.get_depth_frame()) + depth_frame = self.threshold_filter.process(depth_frame) + depth_frame = self.temporal_filter.process(depth_frame) + if depth_frame.is_depth_frame(): + points = self.pc.calculate(depth_frame) + pointcloud = ( + np.asanyarray(points.get_vertices()).view(np.float32).reshape(-1, 3) + ) + + if isinstance(image, np.ndarray) and isinstance(depth, np.ndarray): + return True, np.concatenate((image, depth), axis=-1) + elif isinstance(image, np.ndarray): + return True, image + elif isinstance(depth, np.ndarray): + return True, depth + elif isinstance(pointcloud, np.ndarray): + return True, pointcloud + else: + return False, None + + def close(self): + self.pipe.stop() + self.cfg.disable_all_streams() diff --git a/serl_robot_infra/ur_env/camera/utils.py b/serl_robot_infra/ur_env/camera/utils.py new file mode 100644 index 00000000..aa6c6a75 --- /dev/null +++ b/serl_robot_infra/ur_env/camera/utils.py @@ -0,0 +1,307 @@ +import numpy as np +import open3d as o3d +from scipy.spatial.transform import Rotation as R +import threading +from typing import Any + + +def finetune_pointcloud_fusion(pc1: np.ndarray, pc2: np.ndarray): + pcd1, pcd2 = o3d.geometry.PointCloud(), o3d.geometry.PointCloud() + pcd1.points = o3d.utility.Vector3dVector(pc1) + pcd2.points = o3d.utility.Vector3dVector(pc2) + pcd1.estimate_normals() + pcd2.estimate_normals() + + def pairwise_registration(source, target, max_correspondence_distance): + # see https://www.open3d.org/docs/latest/tutorial/Advanced/multiway_registration.html + icp = o3d.pipelines.registration.registration_icp( + source, + target, + max_correspondence_distance, + np.eye(4), + o3d.pipelines.registration.TransformationEstimationPointToPlane(), + ) + transformation_icp = icp.transformation + information_icp = ( + o3d.pipelines.registration.get_information_matrix_from_point_clouds( + source, target, max_correspondence_distance, icp.transformation + ) + ) + return transformation_icp, information_icp + + with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm: + transformation, info = pairwise_registration( + pcd1, pcd2, max_correspondence_distance=1e-3 + ) + + r = R.from_matrix(transformation[:3, :3].copy()).as_euler("xyz") + t = transformation[:3, 3].copy().flatten() + print(f"fusion result--> r: {r} t: {t}") + return transformation + + +def pointcloud_to_voxel_grid( + points: np.ndarray, + voxel_size: float, + min_bounds: np.ndarray, + max_bounds: np.ndarray, +): + points_filtered = crop_pointcloud( + points, min_bounds=min_bounds, max_bounds=max_bounds + ) + dimensions = np.ceil((max_bounds - min_bounds) / voxel_size).astype(int) + voxel_indices = ((points_filtered - min_bounds) / voxel_size).astype(int) + + voxel_grid = np.zeros(dimensions, dtype=np.bool_) + valid_indices = np.all((voxel_indices >= 0) & (voxel_indices < dimensions), axis=1) + voxel_grid[ + voxel_indices[valid_indices, 0], + voxel_indices[valid_indices, 1], + voxel_indices[valid_indices, 2], + ] = True + return voxel_grid, voxel_indices[valid_indices, :].astype(np.uint8) + + +def crop_pointcloud(points: np.ndarray, min_bounds: np.ndarray, max_bounds: np.ndarray): + within_bounds = np.all((points >= min_bounds) & (points <= max_bounds), axis=1) + return points[within_bounds] + + +def transform_point_cloud(points, transform_matrix): + if points.shape[1] == 3: + points = np.hstack([points, np.ones((points.shape[0], 1))]) + + transformed_points = np.dot(points, transform_matrix.T) + + if transformed_points.shape[1] == 4: + transformed_points = transformed_points[:, :3] + + return transformed_points + + +class PointCloudFusion: + def __init__( + self, + angle=30.0, + x_distance=0.195, + y_distance=-0.0, + voxel_grid_shape=(100, 100, 80), + ): + self.pcd1, self.pcd2 = None, None + + # 10cm width and 8cm height for the box + self.min_bounds = np.array([-0.05, -0.05, 0.075]) + self.max_bounds = np.array([0.05, 0.05, 0.155]) + + vox_size = (self.max_bounds - self.min_bounds) / voxel_grid_shape + assert np.all(np.isclose(vox_size, vox_size[0])) + self.voxel_size: float = float(vox_size[0]) + + self.original_pcds = [] + self._is_transformed = False + self.fine_transformed = False + + t1 = np.eye(4) + t1[:3, :3] = R.from_euler("xyz", [angle, 0.0, 0.0], degrees=True).as_matrix() + t1[1, 3] = x_distance / 2.0 + t1[0, 3] = y_distance / 2.0 + self.t1 = t1 + + t2 = np.eye(4) + t2[:3, :3] = R.from_euler("xyz", [-angle, 0.0, 0.0], degrees=True).as_matrix() + t2[1, 3] = -x_distance / 2.0 + t2[0, 3] = -y_distance / 2.0 + self.t2 = t2 + + def save_finetuned(self): + assert self.fine_transformed + t_finetuned = np.zeros((2, *self.t1.shape)) + t_finetuned[0, ...] = self.t1 + t_finetuned[1, ...] = self.t2 + with open("PointCloudFusionFinetuned.npy", "wb") as f: + np.save(f, t_finetuned) + + def get_voxelgrid_shape(self): + return np.ceil((self.max_bounds - self.min_bounds) / self.voxel_size).astype( + int + ) + + def load_finetuned(self): + from os.path import exists + + if not exists( + "/home/nico/real-world-rl/spacemouse_tests/PointCloudFusionFinetuned.npy" + ): + return False + with open( + "/home/nico/real-world-rl/spacemouse_tests/PointCloudFusionFinetuned.npy", + "rb", + ) as f: + t_finetuned = np.load(f) + self.t1 = t_finetuned[0, ...] + self.t2 = t_finetuned[1, ...] + self.fine_transformed = True + print(f"loaded finetuned Point Cloud fusion parameters!") + return True + + def append(self, pcd: np.ndarray): + if self.pcd1 is None: + self.original_pcds.append(pcd) + self.pcd1 = pcd + elif self.pcd2 is None: + self.original_pcds.append(pcd) + self.pcd2 = pcd + else: + raise NotImplementedError("3 pointclouds not supported") + + def calibrate_fusion(self): + assert self.is_complete() + # rough transform + if not self._is_transformed: + self._transform() + + # then calibrate + t = finetune_pointcloud_fusion(pc1=self.pcd1, pc2=self.pcd2) + return t + + def set_fine_tuned_transformation(self, transformation): + assert not self.fine_transformed + + t = transformation.copy()[:3, 3] / 2.0 # half the translation + rot = np.zeros((2, 3, 3)) + rot[0, ...] = transformation[:3, :3] + rot[1, ...] = np.eye(3) + r = R.from_matrix(rot).mean() # half the rotation + + t1_fine = np.eye(4) + t1_fine[:3, :3] = r.as_matrix() + t1_fine[:3, 3] = t + self.t1 = np.dot(self.t1, t1_fine) + + t2_fine = np.eye(4) + t2_fine[:3, :3] = r.inv().as_matrix() + t2_fine[:3, 3] = -t + self.t2 = np.dot(self.t2, t2_fine) + + self.fine_transformed = True + + def clear(self): + self.pcd1, self.pcd2 = None, None + self._is_transformed = False + self.original_pcds = [] + + def _transform(self): + assert not self.is_empty() + self.pcd1 = transform_point_cloud(points=self.pcd1, transform_matrix=self.t1) + if self.pcd2 is not None: + self.pcd2 = transform_point_cloud( + points=self.pcd2, transform_matrix=self.t2 + ) + self._is_transformed = True + + def voxelize(self, points: np.ndarray): + grid, indices = pointcloud_to_voxel_grid( + points, + voxel_size=self.voxel_size, + min_bounds=self.min_bounds, + max_bounds=self.max_bounds, + ) + return grid, indices + + def crop(self, points: np.ndarray): + return crop_pointcloud( + points=points, min_bounds=self.min_bounds, max_bounds=self.max_bounds + ) + + def get_pointcloud_representation(self, voxelize=True): + if self.is_complete(): + return self.fuse_pointclouds(voxelize=voxelize) + elif not self.is_empty(): + return self.get_first(voxelize=voxelize) + + def fuse_pointclouds(self, voxelize=True, cropped=True): + if not self._is_transformed: + self._transform() + swap = lambda x: np.moveaxis(x, 0, 1) + fused = swap(np.hstack([swap(self.pcd1), swap(self.pcd2)])) + return ( + self.voxelize(fused) + if voxelize + else (self.crop(fused) if cropped else fused) + ) + + def get_first(self, voxelize=True): + if not self._is_transformed: + self.pcd1 = transform_point_cloud(self.pcd1, transform_matrix=self.t1) + return self.voxelize(self.pcd1) if voxelize else self.crop(self.pcd1) + + def get_original_pcds(self): + if len(self.original_pcds) == 1: + return self.original_pcds[0] + else: + return self.original_pcds + + def is_complete(self): + return self.pcd1 is not None and self.pcd2 is not None + + def is_empty(self): + return self.pcd1 is None and self.pcd2 is None + + +class CalibrationTread(threading.Thread): + def __init__( + self, + pc_fusion: PointCloudFusion, + num_samples=20, + verbose=False, + *args, + **kwargs, + ): + super(CalibrationTread, self).__init__(*args, **kwargs) + self.pc_fusion = pc_fusion + self.samples = np.zeros((num_samples, 4, 4)) # transformation matrix samples + self.pc_backlog = [] + self.verbose = verbose + + def start(self): + super().start() + if self.verbose: + print(f"Calibration Thread started at {self.native_id}") + + def append_backlog(self, pc1, pc2): + self.pc_backlog.append([pc1, pc2]) + assert self.samples.shape[0] >= len(self.pc_backlog) + + def calibrate(self, visualize=False): + print(f"calibrating for {len(self.pc_backlog)} samples...") + for i, (pc1, pc2) in enumerate(self.pc_backlog): + self.pc_fusion.clear() + self.pc_fusion.append(pc1) + self.pc_fusion.append(pc2) + + self.samples[i, ...] = self.pc_fusion.calibrate_fusion() + + if visualize: + # visualize for testing + pc = self.pc_fusion.pcd1.copy() + pc2 = self.pc_fusion.pcd2.copy() + pc = transform_point_cloud( + points=pc, transform_matrix=self.samples[i] + ) # transform + + swap = lambda x: np.moveaxis(x, 0, 1) + fused = swap(np.hstack([swap(pc), swap(pc2)])) + + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(fused) + o3d.visualization.draw_geometries([pc]) + + rotations = R.from_matrix(self.samples[:, :3, :3]) + mean_rot = rotations.mean().as_matrix() + translation = np.mean(self.samples[:, :3, 3], axis=0) + + final = np.eye(4) + final[:3, :3] = mean_rot + final[:3, 3] = translation + print(f"calibration result: {final}") + self.pc_fusion.set_fine_tuned_transformation(final) diff --git a/serl_robot_infra/ur_env/camera/video_capture.py b/serl_robot_infra/ur_env/camera/video_capture.py new file mode 100644 index 00000000..842010ff --- /dev/null +++ b/serl_robot_infra/ur_env/camera/video_capture.py @@ -0,0 +1,39 @@ +import queue +import threading +import time + + +class VideoCapture: + def __init__(self, cap, name=None): + if name is None: + name = cap.name + self.name = name + self.q = queue.Queue() + self.cap = cap + self.t = threading.Thread(target=self._reader) + self.t.daemon = True + self.enable = True + self.t.start() + + # read frames as soon as they are available, keeping only most recent one + + def _reader(self): + while self.enable: + time.sleep(0.01) + ret, frame = self.cap.read() + if not ret: + break + if not self.q.empty(): + try: + self.q.get_nowait() # discard previous (unprocessed) frame + except queue.Empty: + pass + self.q.put(frame) + + def read(self): + return self.q.get(timeout=5) + + def close(self): + self.enable = False + self.t.join() + self.cap.close() diff --git a/serl_robot_infra/ur_env/envs/__init__.py b/serl_robot_infra/ur_env/envs/__init__.py new file mode 100644 index 00000000..b2b274f2 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/__init__.py @@ -0,0 +1 @@ +from ur_env.envs.ur5_env import UR5Env, DefaultEnvConfig diff --git a/serl_robot_infra/ur_env/envs/basic_env/__init__.py b/serl_robot_infra/ur_env/envs/basic_env/__init__.py new file mode 100644 index 00000000..7de2b85c --- /dev/null +++ b/serl_robot_infra/ur_env/envs/basic_env/__init__.py @@ -0,0 +1 @@ +from ur_env.envs.basic_env.box_picking_basic_env import BoxPickingBasicEnv diff --git a/serl_robot_infra/ur_env/envs/basic_env/box_picking_basic_env.py b/serl_robot_infra/ur_env/envs/basic_env/box_picking_basic_env.py new file mode 100644 index 00000000..484eee06 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/basic_env/box_picking_basic_env.py @@ -0,0 +1,30 @@ +import numpy as np +from typing import Tuple + +from ur_env.envs.ur5_env import UR5Env +from ur_env.envs.basic_env.config import UR5BasicConfig + + +class BoxPickingBasicEnv(UR5Env): + def __init__(self, **kwargs): + super().__init__(**kwargs, config=UR5BasicConfig) + + def compute_reward(self, obs, action) -> float: + # huge action gives negative reward (like in mountain car) + action_cost = 0.1 * np.sum(np.power(action, 2)) + step_cost = 0.01 + + gripper_state = obs["state"]["gripper_state"] + suction_cost = 0.1 * float(np.isclose(gripper_state[0], 0.99, atol=1e-4)) + + if self.reached_goal_state(obs): + return 10.0 - action_cost - step_cost - suction_cost + else: + return 0.0 - action_cost - step_cost - suction_cost + + def reached_goal_state(self, obs) -> bool: + # obs[0] == gripper pressure, obs[4] == force in Z-axis + state = obs["state"] + return ( + 0.1 < state["gripper_state"][0] < 0.85 and state["tcp_pose"][2] > 0.15 + ) # new min height with box diff --git a/serl_robot_infra/ur_env/envs/basic_env/config.py b/serl_robot_infra/ur_env/envs/basic_env/config.py new file mode 100644 index 00000000..b997ab6f --- /dev/null +++ b/serl_robot_infra/ur_env/envs/basic_env/config.py @@ -0,0 +1,29 @@ +from ur_env.envs.ur5_env import DefaultEnvConfig +import numpy as np + + +class UR5BasicConfig(DefaultEnvConfig): + """Set the configuration for UR5Env.""" + + RESET_Q = np.array( + [ # reset poses in joint space (multiple if preferred) + [2.6331, -1.5022, 2.1151, -2.183, -1.5664, -0.4762], + [1.983, -1.2533, 1.9069, -2.2314, -1.5495, 0.4462], + ] + ) + RANDOM_RESET = True + RANDOM_XY_RANGE = (0.0,) + RANDOM_ROT_RANGE = (0.04,) + ABS_POSE_LIMIT_HIGH = np.array([0.6, 0.1, 0.25, 0.05, 0.05, 0.2]) + ABS_POSE_LIMIT_LOW = np.array([-0.7, -0.85, -0.006, -0.05, -0.05, -0.2]) + ABS_POSE_RANGE_LIMITS = np.array([0.36, 0.83]) + ACTION_SCALE = np.array([0.02, 0.1, 1.0], dtype=np.float32) + + ROBOT_IP: str = "172.22.22.2" + CONTROLLER_HZ = 100 + GRIPPER_TIMEOUT = 2000 # in milliseconds + ERROR_DELTA: float = 0.05 + FORCEMODE_DAMPING: float = 0.02 + FORCEMODE_TASK_FRAME = np.zeros(6) + FORCEMODE_SELECTION_VECTOR = np.ones(6, dtype=np.int8) + FORCEMODE_LIMITS = np.array([0.5, 0.5, 0.5, 1.0, 1.0, 1.0]) diff --git a/serl_robot_infra/ur_env/envs/camera_env/__init__.py b/serl_robot_infra/ur_env/envs/camera_env/__init__.py new file mode 100644 index 00000000..1893ec25 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/camera_env/__init__.py @@ -0,0 +1 @@ +from ur_env.envs.camera_env.box_picking_camera_env import BoxPickingCameraEnv diff --git a/serl_robot_infra/ur_env/envs/camera_env/box_picking_camera_env.py b/serl_robot_infra/ur_env/envs/camera_env/box_picking_camera_env.py new file mode 100644 index 00000000..7add9909 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/camera_env/box_picking_camera_env.py @@ -0,0 +1,94 @@ +import numpy as np +from typing import Tuple + +from ur_env.envs.ur5_env import UR5Env +from ur_env.envs.camera_env.config import UR5CameraConfig + + +class BoxPickingCameraEnv(UR5Env): + def __init__(self, load_config=True, **kwargs): + if load_config: + super().__init__(**kwargs, config=UR5CameraConfig) + else: + super().__init__(**kwargs) + + def compute_reward(self, obs, action) -> float: + action_cost = 0.1 * np.sum(np.power(action, 2)) + action_diff_cost = 0.1 * np.sum( + np.power(obs["state"]["action"] - self.last_action, 2) + ) + self.last_action[:] = action + step_cost = 0.1 + + suction_reward = 0.3 * float(obs["state"]["gripper_state"][1] > 0.5) + suction_cost = 3.0 * float(obs["state"]["gripper_state"][1] < -0.5) + + orientation_cost = ( + 1.0 - sum(obs["state"]["tcp_pose"][3:] * self.curr_reset_pose[3:]) ** 2 + ) + orientation_cost = max(orientation_cost - 0.005, 0.0) * 25.0 + + max_pose_diff = 0.05 # set to 5cm + pos_diff = obs["state"]["tcp_pose"][:2] - self.curr_reset_pose[:2] + position_cost = 10.0 * np.sum( + np.where( + np.abs(pos_diff) > max_pose_diff, + np.abs(pos_diff - np.sign(pos_diff) * max_pose_diff), + 0.0, + ) + ) + + cost_info = dict( + action_cost=action_cost, + step_cost=step_cost, + suction_reward=suction_reward, + suction_cost=suction_cost, + orientation_cost=orientation_cost, + position_cost=position_cost, + action_diff_cost=action_diff_cost, + total_cost=-( + - action_cost + - step_cost + + suction_reward + - suction_cost + - orientation_cost + - position_cost + - action_diff_cost + ), + ) + for key, info in cost_info.items(): + self.cost_infos[key] = info + ( + 0.0 if key not in self.cost_infos else self.cost_infos[key] + ) + + if self.reached_goal_state(obs): + self.last_action[:] = 0.0 + return ( + 100.0 + - action_cost + - orientation_cost + - position_cost + - action_diff_cost + ) + else: + return ( + 0.0 + + suction_reward + - action_cost + - orientation_cost + - position_cost + - suction_cost + - step_cost + - action_diff_cost + ) + + def reached_goal_state(self, obs) -> bool: + # obs[0] == gripper pressure, obs[4] == force in Z-axis + state = obs["state"] + return ( + 0.1 < state["gripper_state"][0] < 1.0 + and state["tcp_pose"][2] > self.curr_reset_pose[2] + 0.01 + ) # +1cm + + def close(self): + super().close() diff --git a/serl_robot_infra/ur_env/envs/camera_env/config.py b/serl_robot_infra/ur_env/envs/camera_env/config.py new file mode 100644 index 00000000..80923388 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/camera_env/config.py @@ -0,0 +1,31 @@ +from ur_env.envs.ur5_env import DefaultEnvConfig +import numpy as np + + +class UR5CameraConfig(DefaultEnvConfig): + """Set the configuration for UR5Env.""" + + RESET_Q = np.array( + [ # reset poses in joint space (multiple if preferred) + [2.6331, -1.5022, 2.1151, -2.183, -1.5664, -0.4762], + [1.983, -1.2533, 1.9069, -2.2314, -1.5495, 0.4462], + ] + ) + RANDOM_RESET = True + RANDOM_XY_RANGE = (0.0,) + RANDOM_ROT_RANGE = (0.04,) + ABS_POSE_LIMIT_HIGH = np.array([0.6, 0.1, 0.25, 0.05, 0.05, 0.2]) + ABS_POSE_LIMIT_LOW = np.array([-0.7, -0.85, -0.006, -0.05, -0.05, -0.2]) + ABS_POSE_RANGE_LIMITS = np.array([0.36, 0.83]) + ACTION_SCALE = np.array([0.02, 0.1, 1.0], dtype=np.float32) + + ROBOT_IP: str = "172.22.22.2" + CONTROLLER_HZ = 100 + GRIPPER_TIMEOUT = 2000 # in milliseconds + ERROR_DELTA: float = 0.05 + FORCEMODE_DAMPING: float = 0.02 + FORCEMODE_TASK_FRAME = np.zeros(6) + FORCEMODE_SELECTION_VECTOR = np.ones(6, dtype=np.int8) + FORCEMODE_LIMITS = np.array([0.5, 0.5, 0.5, 1.0, 1.0, 1.0]) + + REALSENSE_CAMERAS = {"wrist": "218622277164", "wrist_2": "218622279756"} diff --git a/serl_robot_infra/ur_env/envs/relative_env.py b/serl_robot_infra/ur_env/envs/relative_env.py new file mode 100644 index 00000000..08bd5a72 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/relative_env.py @@ -0,0 +1,123 @@ +from scipy.spatial.transform import Rotation as R +import gym +import numpy as np +from gym import Env +from franka_env.utils.transformations import ( + construct_homogeneous_matrix, + construct_rotation_matrix, +) + + +class RelativeFrame(gym.Wrapper): + """ + This wrapper transforms the observation and action to be expressed in the end-effector frame. + Optionally, it can transform the tcp_pose into a relative frame defined as the reset pose. + + This wrapper is expected to be used on top of the base UR5 environment, which has the following + observation space: + { + "state": spaces.Dict( + { + "tcp_pose": spaces.Box(-np.inf, np.inf, shape=(7,)), # xyz + quat + "tcp_vel": spaces.Box(-np.inf, np.inf, shape=(6,)), + "tcp_force": spaces.Box(-np.inf, np.inf, shape=(3,)), + "tcp_torque": spaces.Box(-np.inf, np.inf, shape=(3,)), + "gripper_state": spaces.Box(-np.inf, np.inf, shape=(2,)), + } + ), + ...... + }, and at least 6 DoF action space with (x, y, z, rx, ry, rz, ...) + """ + + def __init__(self, env: Env, include_relative_pose=True): + super().__init__(env) + self.rotation_matrix = np.eye((3)) + self.rotation_matrix_reset = np.eye((3)) + + self.include_relative_pose = include_relative_pose + if self.include_relative_pose: + # Homogeneous transformation matrix from reset pose's relative frame to base frame + self.T_r_o_inv = np.zeros((4, 4)) + + def step(self, action: np.ndarray): + # action is assumed to be (x, y, z, rx, ry, rz, gripper) + # Transform action from end-effector frame to base frame + transformed_action = self.transform_action(action) + + obs, reward, done, truncated, info = self.env.step(transformed_action) + + # this is to convert the spacemouse intervention action + if "intervene_action" in info: + info["intervene_action"] = self.transform_action_inv( + info["intervene_action"] + ) + + # Update rotation matrix + self.rotation_matrix = construct_rotation_matrix(obs["state"]["tcp_pose"]) + + # Transform observation to spatial frame + transformed_obs = self.transform_observation(obs) + return transformed_obs, reward, done, truncated, info + + def reset(self, **kwargs): + obs, info = self.env.reset(**kwargs) + + self.rotation_matrix = construct_rotation_matrix(obs["state"]["tcp_pose"]) + self.rotation_matrix_reset = self.rotation_matrix.copy() + if self.include_relative_pose: + # Update transformation matrix from the reset pose's relative frame to base frame + self.T_r_o_inv = np.linalg.inv( + construct_homogeneous_matrix(obs["state"]["tcp_pose"]) + ) + + # Transform observation to spatial frame + return self.transform_observation(obs), info + + def transform_observation(self, obs): + """ + Transform observations from spatial(base) frame into body(end-effector) frame + using the rotation and homogeneous matrix + """ + obs["state"]["tcp_vel"][:3] = ( + self.rotation_matrix_reset.transpose() @ obs["state"]["tcp_vel"][:3] + ) + obs["state"]["tcp_vel"][3:6] = ( + self.rotation_matrix_reset.transpose() @ obs["state"]["tcp_vel"][3:6] + ) + obs["state"]["tcp_force"] = ( + self.rotation_matrix.transpose() @ obs["state"]["tcp_force"] + ) + obs["state"]["tcp_torque"] = ( + self.rotation_matrix.transpose() @ obs["state"]["tcp_torque"] + ) + + if self.include_relative_pose: + T_b_o = construct_homogeneous_matrix(obs["state"]["tcp_pose"]) + T_b_r = self.T_r_o_inv @ T_b_o + + # Reconstruct transformed tcp_pose vector + p_b_r = T_b_r[:3, 3] + theta_b_r = R.from_matrix(T_b_r[:3, :3]).as_quat() + obs["state"]["tcp_pose"] = np.concatenate((p_b_r, theta_b_r)) + + return obs + + def transform_action(self, action: np.ndarray): + """ + Transform action from body(end-effector) frame into spatial(base) frame + using the rotation matrix + """ + action = np.array(action) # in case action is a jax read-only array + action[:3] = self.rotation_matrix_reset @ action[:3] + action[3:6] = self.rotation_matrix_reset @ action[3:6] + return action + + def transform_action_inv(self, action: np.ndarray): + """ + Transform action from spatial(base) frame into body(end-effector) frame + using the rotation matrix. + """ + action = np.array(action) + action[:3] = self.rotation_matrix_reset.transpose() @ action[:3] + action[3:6] = self.rotation_matrix_reset.transpose() @ action[3:6] + return action diff --git a/serl_robot_infra/ur_env/envs/ur5_env.py b/serl_robot_infra/ur_env/envs/ur5_env.py new file mode 100644 index 00000000..37746109 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/ur5_env.py @@ -0,0 +1,762 @@ +"""Gym Interface for UR5""" + +import time +import threading +import copy +import numpy as np +import gym +import cv2 +import queue +import warnings +import requests +import json +from typing import Dict, Tuple +from datetime import datetime +from collections import OrderedDict +from scipy.spatial.transform import Rotation as R +import open3d as o3d + +from ur_env.camera.video_capture import VideoCapture +from ur_env.camera.rs_capture import RSCapture + +from ur_env.camera.utils import PointCloudFusion, CalibrationTread + +from robot_controllers.ur5_controller import UrImpedanceController + + +class ImageDisplayer(threading.Thread): + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + self.daemon = True # make this a daemon thread + + def run(self): + while True: + img_array = self.queue.get() # retrieve an image from the queue + if img_array is None: # None is our signal to exit + break + + frame = np.concatenate( + [v for k, v in img_array.items() if "full" not in k], axis=0 + ) + cv2.namedWindow("RealSense Cameras", cv2.WINDOW_NORMAL) + cv2.resizeWindow("RealSense Cameras", 300, 700) + cv2.imshow("RealSense Cameras", frame) + cv2.waitKey(1) + + +class PointCloudDisplayer: + def __init__(self): + self.window = o3d.visualization.Visualizer() + self.window.create_window(height=400, width=400, visible=True) + + self.pc = o3d.geometry.PointCloud() + self.window.get_render_option().load_from_json( + "/home/nico/.config/JetBrains/PyCharm2024.1/scratches/render_options.json" + ) + + self.param = o3d.io.read_pinhole_camera_parameters( + "/home/nico/.config/JetBrains/PyCharm2024.1/scratches/camera_parameters.json" + ) + self.ctr = self.window.get_view_control() + self.coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( + size=0.01, origin=[0, 0, 0] + ) + + def display(self, points): + self.pc.clear() + # MASSIVE! speed up if float64 is used, see: https://github.com/isl-org/Open3D/issues/1045 + self.pc.points = o3d.utility.Vector3dVector(points.astype(np.float64) / 1000.0) + self.window.clear_geometries() + self.window.add_geometry(self.pc) + # self.window.add_geometry(self.coord_frame) + self.ctr.convert_from_pinhole_camera_parameters(self.param, True) + + self.window.poll_events() + # self.window.update_renderer() + + def close(self): + self.window.destroy_window() + + +############################################################################## + + +class DefaultEnvConfig: + """Default configuration for UR5Env. Fill in the values below.""" + + RESET_Q = np.zeros((6,)) + RANDOM_RESET = (False,) + RANDOM_XY_RANGE = (0.0,) + RANDOM_ROT_RANGE = (0.0,) + ABS_POSE_LIMIT_HIGH = np.zeros((6,)) + ABS_POSE_LIMIT_LOW = np.zeros((6,)) + ABS_POSE_RANGE_LIMITS = np.zeros((2,)) + ACTION_SCALE = np.zeros((3,), dtype=np.float32) + + ROBOT_IP: str = "localhost" + CONTROLLER_HZ: int = 0 + GRIPPER_TIMEOUT: int = 0 # in milliseconds + ERROR_DELTA: float = 0.0 + FORCEMODE_DAMPING: float = 0.0 + FORCEMODE_TASK_FRAME = np.zeros( + 6, + ) + FORCEMODE_SELECTION_VECTOR = np.ones( + 6, + ) + FORCEMODE_LIMITS = np.zeros( + 6, + ) + + REALSENSE_CAMERAS: Dict = { + "shoulder": "", + "wrist": "", + } + + +############################################################################## + + +class UR5Env(gym.Env): + def __init__( + self, + hz: int = 10, + fake_env=False, + config=DefaultEnvConfig, + max_episode_length: int = 100, + save_video: bool = False, + camera_mode: str = "rgb", # one of (rgb, grey, depth, both(rgb depth), pointcloud, none) + ): + self.max_episode_length = max_episode_length + self.curr_path_length = 0 + self.action_scale = config.ACTION_SCALE + + self.config = config + + self.resetQ = config.RESET_Q + self.curr_reset_pose = np.zeros((7,), dtype=np.float32) + + self.curr_pos = np.zeros((7,), dtype=np.float32) + self.curr_vel = np.zeros((6,), dtype=np.float32) + self.curr_Q = np.zeros((6,), dtype=np.float32) + self.curr_Qd = np.zeros((6,), dtype=np.float32) + self.curr_force = np.zeros((3,), dtype=np.float32) + self.curr_torque = np.zeros((3,), dtype=np.float32) + + self.gripper_state = np.zeros((2,), dtype=np.float32) + self.random_reset = config.RANDOM_RESET + self.random_xy_range = config.RANDOM_XY_RANGE + self.random_rot_range = config.RANDOM_ROT_RANGE + self.hz = hz + np.random.seed(0) # fix seed for fixed (random) initial rotations + + camera_mode = None if camera_mode.lower() == "none" else camera_mode + if camera_mode is not None and save_video: + print("Saving videos!") + self.save_video = save_video + self.recording_frames = [] + self.camera_mode = camera_mode + + self.cost_infos = {} + + self.xyz_bounding_box = gym.spaces.Box( + config.ABS_POSE_LIMIT_LOW[:3], + config.ABS_POSE_LIMIT_HIGH[:3], + dtype=np.float64, + ) + self.xy_range = gym.spaces.Box( + config.ABS_POSE_RANGE_LIMITS[0], + config.ABS_POSE_RANGE_LIMITS[1], + dtype=np.float64, + ) + self.mrp_bounding_box = gym.spaces.Box( + config.ABS_POSE_LIMIT_LOW[3:], + config.ABS_POSE_LIMIT_HIGH[3:], + dtype=np.float64, + ) + # Action/Observation Space + self.action_space = gym.spaces.Box( + np.ones((7,), dtype=np.float32) * -1, + np.ones((7,), dtype=np.float32), + ) + self.last_action = np.zeros(self.action_space.shape) + + image_space_definition = {} + if camera_mode in ["rgb", "grey", "both"]: + channel = 1 if camera_mode == "grey" else 3 + if "wrist" in config.REALSENSE_CAMERAS.keys(): + image_space_definition["wrist"] = gym.spaces.Box( + 0, 255, shape=(128, 128, channel), dtype=np.uint8 + ) + if "wrist_2" in config.REALSENSE_CAMERAS.keys(): + image_space_definition["wrist_2"] = gym.spaces.Box( + 0, 255, shape=(128, 128, channel), dtype=np.uint8 + ) + + if camera_mode in ["depth", "both"]: + if "wrist" in config.REALSENSE_CAMERAS.keys(): + image_space_definition["wrist_depth"] = gym.spaces.Box( + 0, 255, shape=(128, 128, 1), dtype=np.uint8 + ) + if "wrist_2" in config.REALSENSE_CAMERAS.keys(): + image_space_definition["wrist_2_depth"] = gym.spaces.Box( + 0, 255, shape=(128, 128, 1), dtype=np.uint8 + ) + + if camera_mode in ["pointcloud"]: + image_space_definition["wrist_pointcloud"] = gym.spaces.Box( + 0, 255, shape=(50, 50, 40), dtype=np.uint8 + ) + if camera_mode is not None and camera_mode not in [ + "rgb", + "both", + "depth", + "pointcloud", + "grey", + ]: + raise NotImplementedError(f"camera mode {camera_mode} not implemented") + + state_space = gym.spaces.Dict( + { + "tcp_pose": gym.spaces.Box(-np.inf, np.inf, shape=(7,)), # xyz + quat + "tcp_vel": gym.spaces.Box(-np.inf, np.inf, shape=(6,)), + "gripper_state": gym.spaces.Box(-1.0, 1.0, shape=(2,)), + "tcp_force": gym.spaces.Box(-np.inf, np.inf, shape=(3,)), + "tcp_torque": gym.spaces.Box(-np.inf, np.inf, shape=(3,)), + "action": gym.spaces.Box(-1.0, 1.0, shape=self.action_space.shape), + } + ) + + obs_space_definition = {"state": state_space} + if self.camera_mode in ["rgb", "both", "depth", "pointcloud", "grey"]: + obs_space_definition["images"] = gym.spaces.Dict(image_space_definition) + + self.observation_space = gym.spaces.Dict(obs_space_definition) + + self.cycle_count = 0 + self.controller = None + self.cap = None + + if fake_env: + print("[UR5Env] is fake!") + return + + self.controller = UrImpedanceController( + robot_ip=config.ROBOT_IP, + frequency=config.CONTROLLER_HZ, + kp=15000, + kd=3300, + config=config, + verbose=False, + plot=False, + ) + self.controller.start() # start Thread + + if self.camera_mode is not None: + self.init_cameras(config.REALSENSE_CAMERAS) + self.img_queue = queue.Queue() + if self.camera_mode in ["pointcloud"]: + self.displayer = ( + PointCloudDisplayer() + ) # o3d displayer cannot be threaded :/ + else: + self.displayer = ImageDisplayer(self.img_queue) + self.displayer.start() + print("[CAM] Cameras are ready!") + + while not self.controller.is_ready(): # wait for controller + time.sleep(0.1) + print("[RIC] Controller has started and is ready!") + + if self.camera_mode in ["pointcloud"]: + voxel_grid_shape = np.array( + self.observation_space["images"]["wrist_pointcloud"].shape + ) + # voxel_grid_shape[-1] *= 8 # do not use compacting for now + # voxel_grid_shape *= 2 + print(f"pointcloud resolution set to: {voxel_grid_shape}") + self.pointcloud_fusion = PointCloudFusion( + angle=30.5, + x_distance=0.185, + y_distance=-0.01, + voxel_grid_shape=voxel_grid_shape, + ) + + # load pre calibrated, else calibrate + if not self.pointcloud_fusion.load_finetuned(): + # TODO make calibration more robust! + self.calibration_thread = CalibrationTread( + pc_fusion=self.pointcloud_fusion, verbose=True + ) + self.calibration_thread.start() + + self.calibrate_pointcloud_fusion(visualize=True) + + def clip_safety_box(self, next_pos: np.ndarray) -> np.ndarray: + """Clip the pose to be within the safety box.""" + next_pos[:3] = np.clip( + next_pos[:3], self.xyz_bounding_box.low, self.xyz_bounding_box.high + ) + orientation_diff = ( + R.from_quat(next_pos[3:]) * R.from_quat(self.curr_reset_pose[3:]).inv() + ).as_mrp() + orientation_diff = np.clip( + orientation_diff, self.mrp_bounding_box.low, self.mrp_bounding_box.high + ) + next_pos[3:] = ( + R.from_mrp(orientation_diff) * R.from_quat(self.curr_reset_pose[3:]) + ).as_quat() + + return next_pos + + def get_cost_infos(self, done): + if not done: + return {} + cost_infos = self.cost_infos.copy() + self.cost_infos = {} + return cost_infos + + def step(self, action: np.ndarray) -> tuple: + """standard gym step function.""" + start_time = time.time() + action = np.clip(action, self.action_space.low, self.action_space.high) + + # position + next_pos = self.curr_pos.copy() + next_pos[:3] = next_pos[:3] + action[:3] * self.action_scale[0] + + next_pos[3:] = ( + R.from_mrp(action[3:6] * self.action_scale[1] / 4.0) + * R.from_quat(next_pos[3:]) + ).as_quat() # c * r --> applies c after r + + gripper_action = action[6] * self.action_scale[2] + + safe_pos = self.clip_safety_box(next_pos) + self._send_pos_command(safe_pos) + self._send_gripper_command(gripper_action) + + self.curr_path_length += 1 + + obs = self._get_obs(action) + + reward = self.compute_reward(obs, action) + truncated = self._is_truncated() + reward = reward if not truncated else reward - 10.0 # truncation penalty + done = ( + self.curr_path_length >= self.max_episode_length + or self.reached_goal_state(obs) + or truncated + ) + + dt = time.time() - start_time + to_sleep = max(0, (1.0 / self.hz) - dt) + if to_sleep == 0: + warnings.warn( + f"environment could not be within {self.hz} Hz, took {dt:.4f}s!" + ) + time.sleep(to_sleep) + + return obs, reward, done, truncated, self.get_cost_infos(done) + + def compute_reward(self, obs, action) -> float: + return 0.0 # overwrite for each task + + def reached_goal_state(self, obs) -> bool: + return False # overwrite for each task + + def go_to_rest(self): + """ + The concrete steps to perform reset should be + implemented each subclass for the specific task. + Should override this method if custom reset procedure is needed. + """ + + # Perform Carteasian reset + reset_Q = np.zeros((6)) + if self.resetQ.shape == (1, 6): + reset_Q[:] = self.resetQ.copy() + elif self.resetQ.shape[1] == 6 and self.resetQ.shape[0] > 1: + reset_Q[:] = self.resetQ[0, :].copy() # make random guess + self.resetQ[:] = np.roll(self.resetQ, -1, axis=0) # roll one (not random) + else: + raise ValueError(f"invalid resetQ dimension: {self.resetQ.shape}") + + self._send_reset_command(reset_Q) + + while not self.controller.is_reset(): + time.sleep(0.1) # wait for the reset operation + + self._update_currpos() + reset_pose = self.controller.get_target_pos() + + if self.random_reset: # randomize reset position in xy plane + reset_shift = np.random.uniform( + np.negative(self.random_xy_range), self.random_xy_range, (2,) + ) + reset_pose[:2] += reset_shift + + if self.random_rot_range[0] > 0.0: + random_rot = np.random.triangular( + np.negative(self.random_rot_range), + 0.0, + self.random_rot_range, + size=(3,), + ) + else: + random_rot = np.zeros((3,)) + reset_pose[3:][:] = ( + R.from_quat(reset_pose[3:]) * R.from_mrp(random_rot) + ).as_quat() + + self.curr_reset_pose[:] = reset_pose + + self.controller.set_target_pos( + reset_pose + ) # random movement after resetting + time.sleep(0.1) + while self.controller.is_moving(): + time.sleep(0.1) + else: + self.curr_reset_pose[:] = reset_pose + + def go_to_detected_box(self): + """ " + function for the demo + """ + if self.gripper_state[0] > 0.01: + reset_Q = self.curr_Q.copy() + reset_Q[:4] = [0.0, -np.pi / 2.0, np.pi / 2.0, -np.pi / 2.0] + self._send_reset_command(reset_Q) + while not self.controller.is_reset(): + time.sleep(0.1) # wait for the reset operation + + reset_Q[:4] = [np.pi / 2, -np.pi / 2.0, np.pi / 2.0, -np.pi / 2.0] + self._send_reset_command(reset_Q) + while not self.controller.is_reset(): + time.sleep(0.1) # wait for the reset operation + + # release the box + self._send_gripper_command(np.array(-1)) + time.sleep(0.1) + + # go back on top + reset_Q = [0.0, -np.pi / 2.0, np.pi / 2.0, -np.pi / 2.0, -np.pi / 2.0, 0.0] + self._send_reset_command(reset_Q) + while not self.controller.is_reset(): + time.sleep(0.1) # wait for the reset operation + time.sleep(0.5) + + def get_request(i=10): + if i == 0: + raise Exception("err") + try: + r = requests.get("http://192.168.1.204:5000/api/data") + r.raise_for_status() + boxes = r.json() + if len(boxes) == 0: + time.sleep(0.1) + return get_request(i) + else: + return boxes + + except (json.decoder.JSONDecodeError, requests.exceptions.HTTPError): + return get_request(i=i - 1) + + boxes = get_request() + + highest = list(boxes.keys())[ + np.argmax([b["world2box"]["pos"][1] for b in boxes.values()]) + ] + box = boxes[highest]["world2box"] + print( + f"pose: {[round(b, 2) for b in box['pos']]} {[round(b, 2) for b in box['rot']]}" + ) + + t = R.from_euler("xyz", [-np.pi / 2.0, np.pi, 0.0]) + pos = t.apply( + np.array(box["pos"]) + + np.array([0.0, 0.1 + boxes[highest]["size"][1] / 2.0, 0.0]) + ) + rot = ( + R.from_euler("xyz", t.apply(box["rot"])) + * R.from_euler("xyz", [np.pi, 0.0, 0.0]) + ).as_rotvec() + + init_pose = np.concatenate((pos, rot)) + + print(f"moving to {init_pose}") + self._send_taskspace_command(init_pose) + while not self.controller.is_reset(): + time.sleep(0.1) # wait for the reset operation + + self._update_currpos() + self.curr_reset_pose[:] = self.curr_pos + + def reset(self, **kwargs): + self.cycle_count += 1 + if self.save_video: + self.save_video_recording() + + self.go_to_rest() + self.curr_path_length = 0 + + obs = self._get_obs(np.zeros_like(self.last_action)) + return obs + + def save_video_recording(self): + try: + if len(self.recording_frames): + video_writer = cv2.VideoWriter( + f'./videos/{datetime.now().strftime("%m-%d_%H-%M")}.mp4', + cv2.VideoWriter_fourcc(*"mp4v"), + 10, + self.recording_frames[0].shape[:2][::-1], + ) + for frame in self.recording_frames: + video_writer.write(frame) + video_writer.release() + self.recording_frames.clear() + except Exception as e: + print(f"Failed to save video: {e}") + + def init_cameras(self, name_serial_dict=None): + """Init both cameras.""" + if self.cap is not None: # close cameras if they are already open + self.close_cameras() + + self.cap = OrderedDict() + for cam_name, cam_serial in name_serial_dict.items(): + print(f"cam serial: {cam_serial}") + rgb = self.camera_mode in ["rgb", "both", "grey"] + depth = self.camera_mode in ["depth", "both"] + pointcloud = self.camera_mode in ["pointcloud"] + cap = VideoCapture( + RSCapture( + name=cam_name, + serial_number=cam_serial, + rgb=rgb, + depth=depth, + pointcloud=pointcloud, + ) + ) + self.cap[cam_name] = cap + + def crop_image(self, name, image) -> np.ndarray: + """Crop realsense images to be a square.""" + if name == "wrist": + return image[:, 124:604, :] + elif name == "wrist_2": + return image[:, 124:604, :] + else: + raise ValueError(f"Camera {name} not recognized in cropping") + + def get_image(self) -> Dict[str, np.ndarray]: + """Get images from the realsense cameras.""" + images = {} + display_images = {} + if self.camera_mode == "pointcloud": + self.pointcloud_fusion.clear() + for key, cap in self.cap.items(): + try: + image = cap.read() + if self.camera_mode in ["rgb", "both", "grey"]: + rgb = image[..., :3].astype(np.uint8) + cropped_rgb = self.crop_image(key, rgb) + resized = cv2.resize( + cropped_rgb, + self.observation_space["images"][key].shape[:2][::-1], + ) + # convert to grayscale here + if self.camera_mode == "grey": + grey = np.array([0.2989, 0.5870, 0.1140]) + resized = np.dot(resized, grey)[..., None] + resized = resized.astype(np.uint8) + display_images[key] = np.repeat(resized, 3, axis=-1) + else: + display_images[key] = resized + + images[key] = resized[..., ::-1] + display_images[key + "_full"] = cropped_rgb + + if self.camera_mode in ["depth", "both"]: + depth_key = key + "_depth" + depth = image[..., -1:] + cropped_depth = self.crop_image(key, depth) + + resized = cv2.resize( + cropped_depth, + np.array(self.observation_space["images"][depth_key].shape[:2]) + * 3, + # (128 * 3, 128 * 3) image + )[..., None] + + resized = resized.reshape((128, 3, 128, 3, 1)).max( + (1, 3) + ) # max pool with 3x3 + + images[depth_key] = resized + display_images[depth_key] = cv2.applyColorMap( + resized, cv2.COLORMAP_JET + ) + display_images[depth_key + "_full"] = cv2.applyColorMap( + cropped_depth, cv2.COLORMAP_JET + ) + + if self.camera_mode in ["pointcloud"]: + pointcloud = image + self.pointcloud_fusion.append(pointcloud) + + except queue.Empty: + input( + f"{key} camera frozen. Check connect, then press enter to relaunch..." + ) + self.init_cameras(self.config.REALSENSE_CAMERAS) + return self.get_image() + + if self.camera_mode in ["pointcloud"]: + ( + voxel_grid, + voxel_indices, + ) = self.pointcloud_fusion.get_pointcloud_representation(voxelize=True) + + # downsample on 2x2x2 grid with sum of points (8 as max) + # vs = self.observation_space["images"]["wrist_pointcloud"].shape + # voxel_grid = np.sum(np.reshape(voxel_grid, (vs[0], 2, vs[1], 2, vs[2], 2)), axis=(1, 3, 5)) + images["wrist_pointcloud"] = voxel_grid.astype(np.uint8) + + self.displayer.display(voxel_indices) + + # self.recording_frames.append( + # np.concatenate([image for key, image in display_images.items() if "full" in key], axis=0) + # ) + self.img_queue.put(display_images) + + return images + + def calibrate_pointcloud_fusion(self, save=True, visualize=False, num_samples=20): + self.reset() + import open3d as o3d + + assert self.camera_mode in ["pointcloud"] + print("calibrating pointcloud fusion...") + # calibrate pc fusion here + + obs, reward, done, truncated, _ = self.step(np.zeros((7,))) + pc = o3d.geometry.PointCloud() + fused = self.pointcloud_fusion.fuse_pointclouds(voxelize=False, cropped=False) + pc.points = o3d.utility.Vector3dVector(fused) + o3d.visualization.draw_geometries([pc]) + + # get samples + for i in range(num_samples): + # action = [np.sin(i * np.pi / 10.), np.cos(i * np.pi / 10.), 0., -.3 * np.sin(i * np.pi / 10.), + # -.3 * np.cos(i * np.pi / 10.), 0., 0.] + action = [ + -1.0 if i % 4 < 2 else 1, + -1.0 if i % 4 in [1, 2] else 1, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + ] + + print(action) + obs, reward, done, truncated, _ = self.step(np.array(action)) + time.sleep(0.1) + + self.calibration_thread.append_backlog( + *self.pointcloud_fusion.get_original_pcds() + ) + + # calibrate() + self.controller.stop() + time.sleep(1) + self.calibration_thread.calibrate() + + if save: + self.pointcloud_fusion.save_finetuned() + + if visualize: + pc = o3d.geometry.PointCloud() + for i in range(num_samples): + pc.clear() + pcs = self.calibration_thread.pc_backlog[i] + self.pointcloud_fusion.clear() + self.pointcloud_fusion.append(pcs[0]) + self.pointcloud_fusion.append(pcs[1]) + fused = self.pointcloud_fusion.fuse_pointclouds( + voxelize=False, cropped=False + ) + pc.points = o3d.utility.Vector3dVector(fused) + o3d.visualization.draw_geometries([pc]) + + self.calibration_thread.join() + exit(f"restart the program to use the calibrated values") + + def close_cameras(self): + """Close both wrist cameras.""" + try: + for cap in self.cap.values(): + cap.close() + except Exception as e: + print(f"Failed to close cameras: {e}") + + def _send_pos_command(self, target_pos: np.ndarray): + """Internal function to send force command to the robot.""" + self.controller.set_target_pos(target_pos=target_pos) + + def _send_gripper_command(self, gripper_pos: np.ndarray): + self.controller.set_gripper_pos(gripper_pos) + + def _send_reset_command(self, reset_Q: np.ndarray): + self.controller.set_reset_Q(reset_Q) + + def _send_taskspace_command(self, target_pos): + self.controller.set_reset_pose(target_pos) + + def _update_currpos(self): + """ + Internal function to get the latest state of the robot and its gripper. + """ + state = self.controller.get_state() + + self.curr_pos[:] = state["pos"] + self.curr_vel[:] = state["vel"] + self.curr_force[:] = state["force"] + self.curr_torque[:] = state["torque"] + self.curr_Q[:] = state["Q"] + self.curr_Qd[:] = state["Qd"] + self.gripper_state[:] = state["gripper"] + + def _is_truncated(self): + return self.controller.is_truncated() + + def _get_obs(self, action) -> dict: + # get image before state observation, so they match better in time + + images = None + if self.camera_mode is not None: + images = self.get_image() + + self._update_currpos() + state_observation = { + "tcp_pose": self.curr_pos, + "tcp_vel": self.curr_vel, + "gripper_state": self.gripper_state, + "tcp_force": self.curr_force, + "tcp_torque": self.curr_torque, + "action": action, + } + + if images is not None: + return copy.deepcopy(dict(images=images, state=state_observation)) + else: + return copy.deepcopy(dict(state=state_observation)) + + def close(self): + if self.controller: + self.controller.stop() + super().close() diff --git a/serl_robot_infra/ur_env/envs/wrappers.py b/serl_robot_infra/ur_env/envs/wrappers.py new file mode 100644 index 00000000..8a8f8038 --- /dev/null +++ b/serl_robot_infra/ur_env/envs/wrappers.py @@ -0,0 +1,210 @@ +import gym +import numpy as np +from agentlace import action + +from ur_env.spacemouse.spacemouse_expert import SpaceMouseExpert +import time +from scipy.spatial.transform import Rotation as R + +from ur_env.utils.rotations import quat_2_euler, quat_2_mrp + +ROT90 = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) +ROT_GENERAL = np.array([np.eye(3), ROT90, ROT90 @ ROT90, ROT90.transpose()]) + + +class SpacemouseIntervention(gym.ActionWrapper): + def __init__(self, env, gripper_action_span=3): + super().__init__(env) + + self.gripper_enabled = True + + self.expert = SpaceMouseExpert() + self.last_intervene = 0 + self.left = np.array([False] * gripper_action_span, dtype=np.bool_) + self.right = self.left.copy() + + self.invert_axes = [-1, -1, 1, -1, -1, 1] + self.deadspace = 0.15 + + def action(self, action: np.ndarray) -> np.ndarray: + """ + Input: + - action: policy action + Output: + - action: spacemouse action if nonezero; else, policy action + """ + expert_a = self.get_deadspace_action() + + if ( + np.linalg.norm(expert_a) > 0.001 or self.left.any() or self.right.any() + ): # also read buttons with no movement + self.last_intervene = time.time() + + if self.gripper_enabled: + gripper_action = ( + np.zeros((1,)) + int(self.left.any()) - int(self.right.any()) + ) + expert_a = np.concatenate((expert_a, gripper_action), axis=0) + + if time.time() - self.last_intervene < 0.5: + expert_a = self.adapt_spacemouse_output(expert_a) + return expert_a + + return action + + def get_deadspace_action(self) -> np.ndarray: + expert_a, buttons = self.expert.get_action() + + positive = np.clip( + (expert_a - self.deadspace) / (1.0 - self.deadspace), a_min=0.0, a_max=1.0 + ) + negative = np.clip( + (expert_a + self.deadspace) / (1.0 - self.deadspace), a_min=-1.0, a_max=0.0 + ) + expert_a = positive + negative + + self.left, self.right = np.roll(self.left, -1), np.roll( + self.right, -1 + ) # shift them one to the left + self.left[-1], self.right[-1] = tuple(buttons) + + return np.array(expert_a, dtype=np.float32) + + def adapt_spacemouse_output(self, action: np.ndarray) -> np.ndarray: + """ + Input: + - expert_a: spacemouse raw output + Output: + - expert_a: spacemouse output adapted to force space (action) + """ + + position = self.unwrapped.curr_pos # get position from ur_env + z_angle = np.arctan2(position[1], position[0]) # get first joint angle + + z_rot = R.from_rotvec(np.array([0, 0, z_angle])) + action[:6] *= self.invert_axes # if some want to be inverted + action[:3] = z_rot.apply(action[:3]) # z rotation invariant translation + action[3:6] = z_rot.apply(action[3:6]) # z rotation invariant rotation + + return action + + def step(self, action): + new_action = self.action(action) + # print(f"new action: {new_action}") + obs, rew, done, truncated, info = self.env.step(new_action) + info["intervene_action"] = new_action + info["left"] = self.left.any() + info["right"] = self.right.any() + return obs, rew, done, truncated, info + + +class Quat2EulerWrapper( + gym.ObservationWrapper +): # not used anymore (stay away from euler angles!) + """ + Convert the quaternion representation of the tcp pose to euler angles + """ + + def __init__(self, env: gym.Env): + super().__init__(env) + # from xyz + quat to xyz + euler + self.observation_space["state"]["tcp_pose"] = gym.spaces.Box( + -np.inf, np.inf, shape=(6,) + ) + + def observation(self, observation): + # convert tcp pose from quat to euler + tcp_pose = observation["state"]["tcp_pose"] + observation["state"]["tcp_pose"] = np.concatenate( + (tcp_pose[:3], quat_2_euler(tcp_pose[3:])) + ) + return observation + + +class Quat2MrpWrapper(gym.ObservationWrapper): + """ + Convert the quaternion representation of the tcp pose to euler angles + """ + + def __init__(self, env: gym.Env): + super().__init__(env) + # from xyz + quat to xyz + euler + self.observation_space["state"]["tcp_pose"] = gym.spaces.Box( + -np.inf, np.inf, shape=(6,) + ) + + def observation(self, observation): + # convert tcp pose from quat to euler + tcp_pose = observation["state"]["tcp_pose"] + observation["state"]["tcp_pose"] = np.concatenate( + (tcp_pose[:3], quat_2_mrp(tcp_pose[3:])) + ) + return observation + + +def rotate_state(state: np.ndarray, num_rot: int): + assert len(state.shape) == 1 and state.shape[0] % 3 == 0 + state = state.reshape((-1, 3)).transpose() + rotated = np.dot(ROT_GENERAL[num_rot % 4], state).transpose() + return rotated.reshape((-1)) + + +class ObservationRotationWrapper(gym.Wrapper): + """ + Convert every observation into the first and 5th octant (first quadrant in Z top view) of the Relative Frame + """ + + def __init__(self, env: gym.Env): + super().__init__(env) + print("Observation Rotation Wrapper enabled!") + self.num_rot_quadrant = -1 + + def reset(self, **kwargs): + obs, info = self.env.reset() + obs = self.rotate_observation(obs, random=True) # rotate initial state random + return obs, info + + def step(self, action: np.ndarray): + action = self.rotate_action(action=action) + obs, reward, done, truncated, info = self.env.step(action) + rotated_obs = self.rotate_observation(obs) + return rotated_obs, reward, done, truncated, info + + def rotate_observation(self, observation, random=False): + if not random: + x, y = observation["state"]["tcp_pose"][:2] + self.num_rot_quadrant = int(x < 0.0) * 2 + int( + x * y < 0.0 + ) # save quadrant info + else: + self.num_rot_quadrant = ( + int(time.time_ns()) % 4 + ) # do not mess with seeded np.random + + for state in observation["state"].keys(): + if state == "gripper_state": + continue + elif state == "action": + observation["state"][state][:6] = rotate_state( + observation["state"][state][:6], self.num_rot_quadrant + ) + else: + observation["state"][state][:] = rotate_state( + observation["state"][state], self.num_rot_quadrant + ) # rotate + + if "images" in observation: + for image_keys in observation["images"].keys(): + observation["images"][image_keys][:] = np.rot90( + observation["images"][image_keys], + axes=(0, 1), + k=self.num_rot_quadrant, + ) + return observation + + def rotate_action(self, action): + rotated_action = action.copy() + rotated_action[:6] = rotate_state( + action[:6], 4 - self.num_rot_quadrant + ) # rotate + return rotated_action diff --git a/serl_robot_infra/ur_env/requirements.txt b/serl_robot_infra/ur_env/requirements.txt new file mode 100644 index 00000000..018cd664 --- /dev/null +++ b/serl_robot_infra/ur_env/requirements.txt @@ -0,0 +1,5 @@ +matplotlib +pyspacemouse +ur-rtde +open3d +clu diff --git a/serl_robot_infra/ur_env/spacemouse/__init__.py b/serl_robot_infra/ur_env/spacemouse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/serl_robot_infra/ur_env/spacemouse/spacemouse_expert.py b/serl_robot_infra/ur_env/spacemouse/spacemouse_expert.py new file mode 100644 index 00000000..f7225bee --- /dev/null +++ b/serl_robot_infra/ur_env/spacemouse/spacemouse_expert.py @@ -0,0 +1,36 @@ +import threading +import pyspacemouse +import numpy as np +from typing import Tuple + + +class SpaceMouseExpert: + """ + This class provides an interface to the SpaceMouse. + It continuously reads the SpaceMouse state and provide + a "get_action" method to get the latest action and button state. + """ + + def __init__(self): + pyspacemouse.open() + + self.state_lock = threading.Lock() + self.latest_data = {"action": np.zeros(6), "buttons": [0, 0]} + # Start a thread to continuously read the SpaceMouse state + self.thread = threading.Thread(target=self._read_spacemouse) + self.thread.daemon = True + self.thread.start() + + def _read_spacemouse(self): + while True: + state = pyspacemouse.read() + with self.state_lock: + self.latest_data["action"] = np.array( + [-state.y, state.x, state.z, -state.roll, -state.pitch, -state.yaw] + ) # spacemouse axis matched with robot base frame + self.latest_data["buttons"] = state.buttons + + def get_action(self) -> Tuple[np.ndarray, list]: + """Returns the latest action and button state of the SpaceMouse.""" + with self.state_lock: + return self.latest_data["action"], self.latest_data["buttons"] diff --git a/serl_robot_infra/ur_env/spacemouse/spacemouse_test.py b/serl_robot_infra/ur_env/spacemouse/spacemouse_test.py new file mode 100644 index 00000000..a381947e --- /dev/null +++ b/serl_robot_infra/ur_env/spacemouse/spacemouse_test.py @@ -0,0 +1,29 @@ +""" Test the spacemouse output. """ +import time +import numpy as np +from ur_env.spacemouse.spacemouse_expert import SpaceMouseExpert + + +def test_spacemouse(): + """Test the SpaceMouseExpert class. + + This interactive test prints the action and buttons of the spacemouse at a rate of 10Hz. + The user is expected to move the spacemouse and press its buttons while the test is running. + It keeps running until the user stops it. + + """ + spacemouse = SpaceMouseExpert() + with np.printoptions(precision=3, suppress=True): + while True: + action, buttons = spacemouse.get_action() + print(f"Spacemouse action: {action}, buttons: {buttons}") + time.sleep(0.1) + + +def main(): + """Call spacemouse test.""" + test_spacemouse() + + +if __name__ == "__main__": + main() diff --git a/serl_robot_infra/ur_env/utils/__init__.py b/serl_robot_infra/ur_env/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/serl_robot_infra/ur_env/utils/rotations.py b/serl_robot_infra/ur_env/utils/rotations.py new file mode 100644 index 00000000..2cb951bb --- /dev/null +++ b/serl_robot_infra/ur_env/utils/rotations.py @@ -0,0 +1,34 @@ +import numpy as np +from scipy.spatial.transform import Rotation as R + +""" +UR5 represents the orientation in axis angle representation +""" + + +def rotvec_2_quat(rotvec): + return R.from_rotvec(rotvec).as_quat() + + +def quat_2_rotvec(quat): + return R.from_quat(quat).as_rotvec() + + +def quat_2_euler(quat): + return R.from_quat(quat).as_euler("xyz") + + +def quat_2_mrp(quat): + return R.from_quat(quat).as_mrp() + + +def euler_2_quat(euler): + return R.from_euler(euler).as_quat() + + +def pose2quat(rotvec_pose) -> np.ndarray: + return np.concatenate((rotvec_pose[:3], rotvec_2_quat(rotvec_pose[3:]))) + + +def pose2rotvec(quat_pose) -> np.ndarray: + return np.concatenate((quat_pose[:3], quat_2_rotvec(quat_pose[3:]))) diff --git a/serl_robot_infra/ur_env/utils/vacuum_gripper.py b/serl_robot_infra/ur_env/utils/vacuum_gripper.py new file mode 100644 index 00000000..99fee705 --- /dev/null +++ b/serl_robot_infra/ur_env/utils/vacuum_gripper.py @@ -0,0 +1,275 @@ +""" +MIT License + +Copyright (c) 2019 Anders Prier Lindvig - SDU Robotics +Copyright (c) 2020 Fabian Freihube - DavinciKitchen GmbH + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Module to control Robotiq's gripper 2F-85 and Hand-E. +Originally from here: https://sdurobotics.gitlab.io/ur_rtde/_static/gripper_2f85.py +Adjusted for use with asyncio +""" + +import asyncio +from enum import Enum +from typing import Union, Tuple, OrderedDict + +# TODO: add blocking to release, gripping + + +class VacuumGripper: + """ + Communicates with the gripper directly, via socket with string commands, leveraging string names for variables. + """ + + # WRITE VARIABLES (CAN ALSO READ) + ACT = ( + "ACT" # act : activate (1 while activated, can be reset to clear fault status) + ) + GTO = ( + "GTO" # gto : go to (will perform go to with the actions set in pos, for, spe) + ) + ATR = "ATR" # atr : auto-release (emergency slow move) + FOR = "FOR" # for : vacuum minimum relative pressure (0-255) + SPE = "SPE" # spe : grip timeout/release delay + POS = "POS" # pos : vacuum max pressure (0-255) + MOD = "MOD" # mod : mode - automatic vs advanced mode + + # READ VARIABLES + STA = "STA" # status (0 = is reset, 1 = activating, 3 = active) + PRE = "PRE" # position request (echo of last commanded position) + OBJ = "OBJ" # object detection (0 = unknown, 1 = minimum pressure value reached, 2 = maximum pressure reached, 3 = no obj detected) + FLT = "FLT" # fault (0=ok, see manual for errors if not zero) + + ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work + + class GripperStatus(Enum): + """Gripper status reported by the gripper. The integer values have to match what the gripper sends.""" + + RESET = 0 + ACTIVATING = 1 + # UNUSED = 2 # This value is currently not used by the gripper firmware + ACTIVE = 3 + + class ObjectStatus(Enum): + """Object status reported by the gripper. The integer values have to match what the gripper sends.""" + + MOVING = 0 + DETECTED_MIN = 1 + DETECTED_MAX = 2 + NO_OBJ_DETECTED = 3 + + def __init__(self, hostname: str, port: int = 63352) -> None: + """Constructor. + + :param hostname: Hostname or ip of the robot arm. + :param port: Port. + + """ + self.socket_reader = None + self.socket_writer = None + self.command_lock = asyncio.Lock() + + self.hostname = hostname + self.port = port + + async def connect(self) -> None: + """Connects to a gripper on the provided address""" + # print(self.hostname, self.port) + self.socket_reader, self.socket_writer = await asyncio.open_connection( + self.hostname, self.port + ) + + async def disconnect(self) -> None: + """Closes the connection with the gripper.""" + self.socket_writer.close() + await self.socket_writer.wait_closed() + + async def _set_vars(self, var_dict: "OrderedDict[str, Union[int, float]]") -> bool: + """Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response. + + :param var_dict: Dictionary of variables to set (variable_name, value). + :return: True on successful reception of ack, false if no ack was received, indicating the set may not + have been effective. + """ + # construct unique command + cmd = "SET" + for variable, value in var_dict.items(): + cmd += f" {variable} {str(value)}" + cmd += "\n" # new line is required for the command to finish + # atomic commands send/rcv + async with self.command_lock: + self.socket_writer.write(cmd.encode(self.ENCODING)) + await self.socket_writer.drain() + response = await self.socket_reader.read(1024) + return self._is_ack(response) + + async def _set_var(self, variable: str, value: Union[int, float]) -> bool: + """Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response. + + :param variable: Variable to set. + :param value: Value to set for the variable. + :return: True on successful reception of ack, false if no ack was received, indicating the set may not + have been effective. + """ + return await self._set_vars(OrderedDict([(variable, value)])) + + async def _get_var(self, variable: str) -> int: + """Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the + response is received or the socket times out. + + :param variable: Name of the variable to retrieve. + :return: Value of the variable as integer. + """ + # atomic commands send/rcv + async with self.command_lock: + cmd = f"GET {variable}\n" + self.socket_writer.write(cmd.encode(self.ENCODING)) + await self.socket_writer.drain() + data = await self.socket_reader.read(1024) + + # expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value + # note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here + var_name, value_str = data.decode(self.ENCODING).split() + if var_name != variable: + raise ValueError( + f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'" + ) + value = int(value_str) + return value + + @staticmethod + def _is_ack(data: str) -> bool: + return data == b"ack" + + async def activate(self) -> None: + """Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags. + + :param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion. + """ + # stop the vacuum generator + await self._set_var(self.GTO, 0) + # await self._set_var(self.GTO, 1) + + # to clear fault status + await self._set_var(self.ACT, 0) + await self._set_var(self.ACT, 1) + + # wait for activation to go through + while not await self.is_active(): + await asyncio.sleep(0.01) + + async def is_active(self) -> bool: + """Returns whether the gripper is active.""" + status = await self._get_var(self.STA) + return VacuumGripper.GripperStatus(status) == VacuumGripper.GripperStatus.ACTIVE + + async def get_current_pressure(self) -> int: + """Returns the current pressure as returned by the physical hardware, max pressure if not gripping.""" + return await self._get_var(self.POS) + + async def get_object_status(self) -> ObjectStatus: + a = await self._get_var(self.OBJ) + return VacuumGripper.ObjectStatus(a) + + async def get_fault_status(self) -> int: + value = await self._get_var(self.FLT) + return value + + async def automatic_grip(self) -> bool: + """Sends commands to grip using automatic mode. + In automatic mode, the pressure byte is used to send a grip/release request + + :return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with + the actual position that was requested, after being adjusted to the min/max calibrated range. + """ + + # activate sets GTO to 0 and makes sure that the gripper is activated + await self.activate() + + # in automatic mode, any pressure (POS) value < 100 will lead to the grip command + var_dict = OrderedDict([(self.POS, 50), (self.MOD, 0)]) + + # first set the values, then set GTO + await self._set_vars(var_dict) + await self._set_var(self.GTO, 1) + + async def advanced_grip(self, min_pressure, max_pressure, timeout) -> bool: + """Sends commands to grip in advanced mode. + min pressure is [0, 99] + max pressure is [10, 78] + timeout is in ms [0, 255] + """ + + # activate sets GTO to 0 and makes sure that the gripper is activated + await self.activate() + + def clip_val(min_val, val, max_val): + return max(min_val, min(val, max_val)) + + clip_min_pressure = 100 - clip_val(0, min_pressure, 99) + clip_max_pressure = 100 - clip_val(10, max_pressure, 78) + clip_timeout = clip_val(0, timeout, 255) + + # val = await self.get_fault_status() + # print(val) + + # moves to the given position with the given speed and force + var_dict = OrderedDict( + [ + (self.MOD, 1), + (self.POS, clip_max_pressure), + (self.FOR, clip_min_pressure), + (self.SPE, clip_timeout), + ] + ) + + await self._set_vars(var_dict) + await self._set_var(self.GTO, 1) + + async def continuous_grip(self, timeout) -> bool: + """Sends commands to grip in advanced mode. + min pressure is [0, 99] + max pressure is [10, 78] + timeout is in ms [0, 255] + """ + + def clip_val(min_val, val, max_val): + return max(min_val, min(val, max_val)) + + clip_timeout = clip_val(0, timeout, 255) + + # moves to the given position with the given speed and force + var_dict = OrderedDict( + [(self.MOD, 1), (self.POS, 0), (self.SPE, clip_timeout), (self.GTO, 1)] + ) + + return await self._set_vars(var_dict) + + async def advanced_release(self, min_pressure, max_pressure, timeout) -> bool: + """Sends commands to do advanced release. This allows to do a more controlled release than the automatic one.""" + var_dict = OrderedDict([(self.POS, 255), (self.GTO, 1)]) + + return await self._set_vars(var_dict) + + async def automatic_release(self) -> bool: + """Sends commands to do automatic release.""" + var_dict = OrderedDict([(self.ACT, 1), (self.ATR, 1)]) + return await self._set_vars(var_dict)