diff --git a/examples/box_picking_drq/drq_policy.py b/examples/box_picking_drq/drq_policy.py new file mode 100644 index 00000000..7b9a7219 --- /dev/null +++ b/examples/box_picking_drq/drq_policy.py @@ -0,0 +1,660 @@ +#!/usr/bin/env python3 +import copy +import time +from functools import partial +import jax +import jax.numpy as jnp +import numpy as np +import pynput +import threading +import tqdm +from absl import app, flags +from flax.training import checkpoints +from datetime import datetime + +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics + +from serl_launcher.agents.continuous.drq import DrQAgent +from serl_launcher.common.evaluation import evaluate +from serl_launcher.utils.timer_utils import Timer +from serl_launcher.wrappers.chunking import ChunkingWrapper +from serl_launcher.utils.sampling_utils import TemporalActionEnsemble +from serl_launcher.utils.train_utils import ( + print_agent_params, + parameter_overview, +) + +from agentlace.trainer import TrainerServer, TrainerClient +from agentlace.data.data_store import QueuedDataStore + +from serl_launcher.utils.launcher import ( + make_voxel_drq_agent, + make_trainer_config, + make_wandb_logger, +) +from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore +from serl_launcher.wrappers.serl_obs_wrappers import ( + SERLObsWrapper, + ScaleObservationWrapper, +) +from serl_launcher.wrappers.observation_statistics_wrapper import ( + ObservationStatisticsWrapper, +) +from ur_env.envs.relative_env import RelativeFrame +from ur_env.envs.wrappers import ( + SpacemouseIntervention, + Quat2MrpWrapper, + ObservationRotationWrapper, +) + +import ur_env + +# used to debug nan errors (also in jit-ed functions) +# jax.config.update("jax_debug_nans", True) + +devices = jax.local_devices() +num_devices = len(devices) +sharding = jax.sharding.PositionalSharding(devices) + +FLAGS = flags.FLAGS + +flags.DEFINE_string("env", "box_picking_camera_env", "Name of environment.") +flags.DEFINE_string("agent", "drq", "Name of agent.") +flags.DEFINE_string( + "exp_name", "box picking drq", "Name of the experiment for wandb logging." +) +flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") +flags.DEFINE_string( + "camera_mode", "rgb", "Camera mode, one of (rgb, depth, both, pointcloud)" +) + +flags.DEFINE_integer("seed", 1, "Random seed.") +flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") +flags.DEFINE_integer("utd_ratio", 4, "UTD ratio.") + +flags.DEFINE_string( + "state_mask", + "all", + "if all the states should be considered, see serl_launcher/common/encoding for more info", +) +flags.DEFINE_string("encoder_type", "voxnet-pretrained", "Encoder type.") +flags.DEFINE_integer( + "encoder_bottleneck_dim", 128, "bottleneck dimension of the encoder" +) +flags.DEFINE_multi_string( + "encoder_kwargs", None, "Encoder kwargs in the form ['dict key', 'dict value']" +) +flags.DEFINE_bool( + "enable_obs_rotation_wrapper", + False, + "Whether to enable observation rotation wrapper (train in one quaternion)", +) +flags.DEFINE_bool( + "enable_temporal_ensemble_sampling", + False, + "Whether to enable sampling the action from a temporal ensemble: action = 0.5*a0 + 0.3*a-1 + 0.2*a-2 + 0.1*a-3", +) + +flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") +flags.DEFINE_integer( + "replay_buffer_capacity", 10000, "Replay buffer capacity." +) # quite low to forget demo trajectories + +flags.DEFINE_integer("random_steps", 0, "Sample random actions for this many steps.") +flags.DEFINE_integer("training_starts", 0, "Training starts after this step.") +flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.") + +flags.DEFINE_integer("log_period", 10, "Logging period.") +flags.DEFINE_integer("eval_period", 1000, "Evaluation period in seconds") +flags.DEFINE_integer("eval_n_trajs", 10, "Number of trajectories for evaluation.") + +# flag to indicate if this is a leaner or a actor +flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") +flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") +flags.DEFINE_string("ip", "localhost", "IP address of the learner.") +flags.DEFINE_string("demo_path", None, "Path to the demo data.") +flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") +flags.DEFINE_string( + "checkpoint_path", + "/home/nico/real-world-rl/serl/examples/box_picking_drq/checkpoints", + "Path to save checkpoints.", +) + +flags.DEFINE_integer( + "eval_checkpoint_step", 0, "evaluate the policy from ckpt at this step" +) +flags.DEFINE_string( + "log_rlds_path", + "/home/nico/real-world-rl/serl/examples/box_picking_drq/rlds", + "Path to save RLDS logs.", +) +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") + +flags.DEFINE_boolean( + "debug", False, "Debug mode." +) # debug mode will disable wandb logging + + +def print_green(x): + return print("\033[92m {}\033[00m".format(x)) + + +PAUSE_EVENT_FLAG = threading.Event() +PAUSE_EVENT_FLAG.clear() # clear() to continue the actor/learner loop, set() to pause + + +def pause_callback(key): + """Callback for when a key is pressed""" + global PAUSE_EVENT_FLAG + try: + # chosen a rarely used key to avoid conflicts. this listener is always on, even when the program is not in focus + if not PAUSE_EVENT_FLAG.is_set() and key == pynput.keyboard.Key.pause: + print("Requested pause training") + # set the PAUSE FLAG to pause the actor/learner loop + PAUSE_EVENT_FLAG.set() + except AttributeError: + # print(f'{key} pressed') + pass + + +listener = pynput.keyboard.Listener( + on_press=pause_callback +) # to enable keyboard based pause +listener.start() + + +############################################################################## + + +def actor(agent: DrQAgent, data_store, env, sampling_rng): + """ + This is the actor loop, which runs when "--actor" is set to True. + """ + global PAUSE_EVENT_FLAG + + if FLAGS.eval_checkpoint_step: + wandb_logger = make_wandb_logger( + project="paper_evaluation_unseen" + if "eval" in FLAGS.env + else "paper_evaluation", + description=FLAGS.exp_name or FLAGS.env, + debug=FLAGS.debug, + ) + success_counter = 0 + + ckpt = checkpoints.restore_checkpoint( + FLAGS.checkpoint_path, + agent.state, + step=FLAGS.eval_checkpoint_step, + ) + agent = agent.replace(state=ckpt) + action_ensemble = TemporalActionEnsemble( + activated=FLAGS.enable_temporal_ensemble_sampling + ) + + trajectories = [] + traj_infos = [] + for episode in range(FLAGS.eval_n_trajs): + trajectory = [] + obs, _ = env.reset() + done = False + action_ensemble.reset() + start_time = time.time() + + while not done: + actions = agent.sample_actions( + observations=jax.device_put(obs), + argmax=True, + ) + actions = np.asarray(jax.device_get(actions)) + + ensembled_action = action_ensemble.sample( + actions + ) # will return actions if not activated + next_obs, reward, done, truncated, info = env.step(ensembled_action) + transition = dict( + observations=obs[ + "state" + ].copy(), # do not save voxel grid or images + actions=ensembled_action, + next_observations=next_obs["state"].copy(), + rewards=reward, + masks=1.0 - done, + dones=done, + ) + trajectory.append(transition) + obs = next_obs + + if done or truncated: + success_counter += reward > 50.0 + dt = time.time() - start_time + running_reward = np.sum( + np.asarray([t["rewards"] for t in trajectory]) + ) + running_reward = max(running_reward, -100.0) # -100 min value + + print(f"{success_counter}/{episode + 1} ", end=" ") + print(f"time: {dt:.3f}s running_rew: {running_reward:.2f}") + + trajectories.append( + {"traj": trajectory, "time": dt, "success": (reward > 50.0)} + ) + infos = { + "running_reward": running_reward, + "time": dt, + "success_rate": float(reward > 50.0), + "action_cost": np.linalg.norm( + np.asarray([t["actions"] for t in trajectory]), + axis=1, + ord=2, + ).mean(), + } + traj_infos.append(infos) + wandb_logger.log(infos, step=episode) + + # if pause event is requested, pause the actor + if PAUSE_EVENT_FLAG.is_set(): + print("Actor eval loop interrupted") + response = input("Do you want to continue (c), or exit (e)? ") + if response == "c": + # update PAUSE FLAG to continue training + PAUSE_EVENT_FLAG.clear() + print("Continuing") + else: + print("Stopping actor eval") + break + + traj_infos = { + k: [d[k] for d in traj_infos] for k in traj_infos[0] + } # list of dicts to dict of lists + mean_infos = {"mean_" + key: np.mean(val) for key, val in traj_infos.items()} + wandb_logger.log(mean_infos) + for key, value in mean_infos.items(): + print(f"{key}: {value:.3f}") + + filename = f"trajectories {'temp_ens' if action_ensemble.is_activated() else ''} {datetime.now().strftime('%m-%d %H%M')}.pkl" + with open(filename, "wb") as f: + import pickle + + pickle.dump(trajectories, f) + return # after done eval, return and exit + + client = TrainerClient( + "actor_env", + FLAGS.ip, + make_trainer_config(), + data_store, + wait_for_server=True, + ) + + # Function to update the agent with new params + def update_params(params): + nonlocal agent + agent = agent.replace(state=agent.state.replace(params=params)) + + client.recv_network_callback(update_params) + + obs, _ = env.reset() + + # training loop + timer = Timer() + running_return = 0.0 + + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): + timer.tick("total") + + with timer.context("sample_actions"): + if step < FLAGS.random_steps: + actions = env.action_space.sample() + else: + sampling_rng, key = jax.random.split(sampling_rng) + actions = agent.sample_actions( + observations=jax.device_put(obs), + seed=key, + deterministic=False, + ) + actions = np.asarray(jax.device_get(actions)) + + # Step environment + with timer.context("step_env"): + next_obs, reward, done, truncated, info = env.step(actions) + + # override the action with the intervention action + if "intervene_action" in info: + actions = info.pop("intervene_action") + + reward = np.asarray(reward, dtype=np.float32) + info = np.asarray(info) + running_return = running_return * 0.99 + reward + transition = dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done, + ) + data_store.insert(transition) + + obs = next_obs + if done or truncated: + stats = {"train": info} # send stats to the learner to log + client.request("send-stats", stats) + print(f"running return: {running_return}") + running_return = 0.0 + obs, _ = env.reset() + + if step % FLAGS.steps_per_update == 0: + client.update() + + timer.tock("total") + + if FLAGS.eval_period and step % FLAGS.eval_period == 0 and step: + with timer.context("eval"): + evaluate_info = evaluate( + policy_fn=partial(agent.sample_actions, argmax=True), + env=env, + num_episodes=FLAGS.eval_n_trajs, + ) + stats = {"eval": evaluate_info} + client.request("send-stats", stats) + + if step % FLAGS.log_period == 0: + stats = {"timer": timer.get_average_times()} + client.request("send-stats", stats) + + if PAUSE_EVENT_FLAG.is_set(): + print_green("Actor loop interrupted") + response = input( + "Do you want to continue (c), save replay buffer and exit (s) or simply exit (e)? " + ) + if response == "c": + print("Continuing") + PAUSE_EVENT_FLAG.clear() + else: + if response == "s": + print("Saving replay buffer") + data_store.save( + "replay_buffer_actor.npz" + ) # not yet supported for QueuedDataStore + else: + print("Replay buffer not saved") + print("Stopping actor client") + client.stop() + break + + +############################################################################## + + +def learner(rng, agent: DrQAgent, replay_buffer, wandb_logger=None): + """ + The learner loop, which runs when "--learner" is set to True. + """ + # To track the step in the training loop + update_steps = 0 + global PAUSE_EVENT_FLAG + + def stats_callback(type: str, payload: dict) -> dict: + """Callback for when server receives stats request.""" + assert type == "send-stats", f"Invalid request type: {type}" + if wandb_logger is not None: + wandb_logger.log(payload, step=update_steps) + return {} # not expecting a response + + # Create server + server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server.register_data_store("actor_env", replay_buffer) + server.start(threaded=True) + + # Loop to wait until replay_buffer is filled + pbar = tqdm.tqdm( + total=FLAGS.training_starts, + initial=len(replay_buffer), + desc="Filling up replay buffer", + position=0, + leave=True, + ) + while len(replay_buffer) < FLAGS.training_starts: + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + time.sleep(1) + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + pbar.close() + + # send the initial network to the actor + server.publish_network(agent.state.params) + print_green("sent initial network to actor") + + replay_iterator = replay_buffer.get_iterator( + sample_args={ + "batch_size": FLAGS.batch_size, + "pack_obs_and_next_obs": True, + }, + device=sharding.replicate(), + ) + + # wait till the replay buffer is filled with enough data + timer = Timer() + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"): + timer.tick("learner_total") + + # run n-1 critic updates and 1 critic + actor update. + # This makes training on GPU faster by reducing the large batch transfer time from CPU to GPU + for critic_step in range(FLAGS.utd_ratio - 1): + with timer.context("sample_replay_buffer"): + batch = next(replay_iterator) + + with timer.context("train_critics"): + agent, critics_info = agent.update_critics( + batch, + ) + + with timer.context("train"): + batch = next(replay_iterator) + agent, update_info = agent.update_high_utd(batch, utd_ratio=1) + + timer.tock("learner_total") + + # publish the updated network + if step > 0 and step % (FLAGS.steps_per_update) == 0: + agent = jax.block_until_ready(agent) + server.publish_network(agent.state.params) + + if update_steps % FLAGS.log_period == 0 and wandb_logger: + wandb_logger.log(update_info, step=update_steps) + wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps) + wandb_logger.log({"replay_buffer_size": len(replay_buffer)}) + + update_steps += 1 + + if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0: + assert FLAGS.checkpoint_path is not None + checkpoints.save_checkpoint( + FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100 + ) + + if PAUSE_EVENT_FLAG.is_set(): + print("Learner loop interrupted") + response = input( + "Do you want to continue (c), save training state and exit (s) or simply exit (e)? " + ) + if "c" in response: + print("Continuing") + PAUSE_EVENT_FLAG.clear() + else: + if response == "s": + print("Saving learner state") + agent_ckpt = checkpoints.save_checkpoint( + FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100 + ) + replay_buffer.save( + "replay_buffer_learner.npz" + ) # not yet supported for QueuedDataStore + # TODO: save other parts of training state + else: + print("Training state not saved") + print("Stopping learner client") + break + + server.stop() + parameter_overview(agent) # print end state + + +############################################################################## + + +def main(_): + assert FLAGS.batch_size % num_devices == 0 + if FLAGS.checkpoint_path.split("/")[-1] == "checkpoints": + FLAGS.checkpoint_path = ( + FLAGS.checkpoint_path + + " " + + FLAGS.exp_name + + " " + + datetime.now().strftime("%m%d-%H:%M") + ) + + # seed + rng = jax.random.PRNGKey(FLAGS.seed) + + # create env and load dataset + env = gym.make( + FLAGS.env, + camera_mode=FLAGS.camera_mode, + fake_env=FLAGS.learner, + max_episode_length=FLAGS.max_traj_length, + ) + # if FLAGS.actor: + # env = SpacemouseIntervention(env) + env = RelativeFrame(env) + env = Quat2MrpWrapper(env) + env = ScaleObservationWrapper( + env + ) # scale obs space (after quat2mrp, but before serlobs) + env = ObservationStatisticsWrapper(env) + if FLAGS.enable_obs_rotation_wrapper: + env = ObservationRotationWrapper(env) + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + env = RecordEpisodeStatistics(env) + + image_keys = [key for key in env.observation_space.keys() if key != "state"] + print(f"image keys: {image_keys}") + + rng, sampling_rng = jax.random.split(rng) + + # assert FLAGS.encoder_kwargs is None or len(FLAGS.encoder_kwargs) % 2 == 0 + encoder_kwargs = { + "bottleneck_dim": FLAGS.encoder_bottleneck_dim, + **( + dict(zip(*[iter(FLAGS.encoder_kwargs)] * 2)) if FLAGS.encoder_kwargs else {} + ), + } + encoder_kwargs = { + k: (int(v) if str(v).isdigit() else v) for k, v in encoder_kwargs.items() + } + + agent: DrQAgent = make_voxel_drq_agent( + seed=FLAGS.seed, + sample_obs=env.observation_space.sample(), + sample_action=env.action_space.sample(), + image_keys=image_keys, + encoder_type=FLAGS.encoder_type, + state_mask=FLAGS.state_mask, + encoder_kwargs=encoder_kwargs, + ) + + # replicate agent across devices + # need the jnp.array to avoid a bug where device_put doesn't recognize primitives + agent: DrQAgent = jax.device_put( + jax.tree_map(jnp.array, agent), sharding.replicate() + ) + + # print useful info + print_agent_params(agent, image_keys) + parameter_overview(agent) + + if FLAGS.enable_obs_rotation_augmentation: + print("Batch Observation Rotation enabled!") + assert ( + not FLAGS.enable_obs_rotation_augmentation + or not FLAGS.enable_obs_rotation_wrapper + ) # both is pointless + + def create_replay_buffer_and_wandb_logger(): + replay_buffer = MemoryEfficientReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=FLAGS.replay_buffer_capacity, + image_keys=image_keys, + ) + # set up wandb and logging + wandb_logger = make_wandb_logger( + project="paper_experiments", + description=FLAGS.exp_name or FLAGS.env, + debug=FLAGS.debug, + ) + return replay_buffer, wandb_logger + + if FLAGS.learner: + sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) + replay_buffer, wandb_logger = create_replay_buffer_and_wandb_logger() + + import pickle as pkl + + with open(FLAGS.demo_path, "rb") as f: + trajs = pkl.load(f) + + # check which observations can be ignored for this run + to_pop = [] + for obs_name in [i for i in trajs[0]["observations"].keys()]: + if obs_name not in env.observation_space.spaces: + to_pop.append(obs_name) + print(f"ignored {to_pop} observation in the demo trajectories") + + for traj in trajs: + for obs_name in to_pop: + traj["observations"].pop(obs_name) + traj["next_observations"].pop(obs_name) + + replay_buffer.insert(traj) + print(f"replay buffer size: {len(replay_buffer)}") + + # learner loop + print_green("starting learner loop") + try: + learner( + sampling_rng, + agent, + replay_buffer=replay_buffer, + wandb_logger=wandb_logger, + ) + except KeyboardInterrupt: + print_green("leraner loop interrupted") + finally: + # Wrap up the learner loop + env.close() + print("Learner loop finished") + + elif FLAGS.actor: + sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) + data_store = QueuedDataStore(50000) # the queue size on the actor + + # actor loop + print_green("starting actor loop") + try: + actor(agent, data_store, env, sampling_rng) + print_green("actor loop finished") + except (KeyboardInterrupt, RuntimeError) as e: + print_green("actor loop interrupted: " + str(e)) + finally: + env.close() + + else: + raise NotImplementedError("Must be either a learner or an actor") + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/box_picking_drq/record_demo.py b/examples/box_picking_drq/record_demo.py new file mode 100644 index 00000000..e9d8afac --- /dev/null +++ b/examples/box_picking_drq/record_demo.py @@ -0,0 +1,131 @@ +import gym +from tqdm import tqdm +import numpy as np +import copy +import pickle as pkl +import datetime +import os +import threading +from pynput import keyboard + +from ur_env.envs.relative_env import RelativeFrame +from ur_env.envs.wrappers import ( + SpacemouseIntervention, + Quat2MrpWrapper, + ObservationRotationWrapper, +) + +from serl_launcher.wrappers.serl_obs_wrappers import ( + SERLObsWrapper, + ScaleObservationWrapper, +) +from serl_launcher.wrappers.chunking import ChunkingWrapper + +import ur_env + +exit_program = threading.Event() + + +def on_space(key, info_dict): + if key == keyboard.Key.space: + for key, item in info_dict.items(): + print(f"{key}: {item}", end=" ") + print() + + +def on_esc(key): + if key == keyboard.Key.esc: + exit_program.set() + + +if __name__ == "__main__": + env = gym.make( + "box_picking_camera_env", + camera_mode="pointcloud", + max_episode_length=100, + ) + env = SpacemouseIntervention(env) + env = RelativeFrame(env) + env = Quat2MrpWrapper(env) + env = ScaleObservationWrapper(env) + # env = ObservationRotationWrapper(env) # if it should be enabled + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + obs, _ = env.reset() + + transitions = [] + success_count = 0 + success_needed = 20 + total_count = 0 + pbar = tqdm(total=success_needed) + + info_dict = { + "state": env.unwrapped.curr_pos, + "gripper_state": env.unwrapped.gripper_state, + "force": env.unwrapped.curr_force, + "reset_pose": env.unwrapped.curr_reset_pose, + } + listener_1 = keyboard.Listener( + daemon=True, on_press=lambda event: on_space(event, info_dict=info_dict) + ) + listener_1.start() + + listener_2 = keyboard.Listener(on_press=on_esc, daemon=True) + listener_2.start() + + uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + file_name = f"box_picking_{success_needed}_demos_{uuid}.pkl" + file_dir = os.path.dirname(os.path.realpath(__file__)) # same dir as this script + file_path = os.path.join(file_dir, file_name) + + if not os.access(file_dir, os.W_OK): + raise PermissionError(f"No permission to write to {file_dir}") + + try: + running_reward = 0.0 + while success_count < success_needed: + if exit_program.is_set(): + raise KeyboardInterrupt # stop program, but clean up before + + next_obs, rew, done, truncated, info = env.step(action=np.zeros((7,))) + actions = info["intervene_action"] + + transition = copy.deepcopy( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=rew, + masks=1.0 - done, + dones=done, + ) + ) + transitions.append(transition) + + obs = next_obs + running_reward += rew + + if done or truncated: + success_count += int(rew > 0.99) + total_count += 1 + print( + f"{rew}\tGot {success_count} successes of {total_count} trials. {success_needed} successes needed." + ) + pbar.update(int(rew > 0.99)) + obs, _ = env.reset() + print("Reward total:", running_reward) + running_reward = 0.0 + + with open(file_path, "wb") as f: + pkl.dump(transitions, f) + print(f"saved {success_needed} demos to {file_path}") + + except KeyboardInterrupt as e: + print(f"\nProgram was interrupted, cleaning up... ", e.__str__()) + + finally: + pbar.close() + env.close() + listener_1.stop() + listener_2.stop() diff --git a/examples/box_picking_drq/run_actor.sh b/examples/box_picking_drq/run_actor.sh new file mode 100755 index 00000000..be65622f --- /dev/null +++ b/examples/box_picking_drq/run_actor.sh @@ -0,0 +1,20 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +python drq_policy.py "$@" \ + --actor \ + --env box_picking_camera_env \ + --max_traj_length 100 \ + --exp_name=box_picking \ + --camera_mode pointcloud \ + --seed 1 \ + --max_steps 20000 \ + --random_steps 0 \ + --training_starts 500 \ + --utd_ratio 8 \ + --batch_size 128 \ + --eval_period 1000 \ + --encoder_type voxnet-pretrained \ + --state_mask all \ + --encoder_bottleneck_dim 128 \ +# --enable_obs_rotation_wrapper \ +# --debug diff --git a/examples/box_picking_drq/run_evaluation.sh b/examples/box_picking_drq/run_evaluation.sh new file mode 100644 index 00000000..929e9618 --- /dev/null +++ b/examples/box_picking_drq/run_evaluation.sh @@ -0,0 +1,18 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +python drq_policy.py "$@" \ + --actor \ + --env box_picking_camera_env \ + --exp_name=drq_evaluation \ + --camera_mode pointcloud \ + --batch_size 128 \ + --max_traj_length 100 \ + --checkpoint_path "checkpoint folder path here"\ + --eval_checkpoint_step 10000 \ + --eval_n_trajs 20 \ + \ + --encoder_type voxnet-pretrained \ + --state_mask all \ + --encoder_bottleneck_dim 128 \ +# --enable_obs_rotation_wrapper \ +# --debug diff --git a/examples/box_picking_drq/run_learner.sh b/examples/box_picking_drq/run_learner.sh new file mode 100755 index 00000000..21f4d65f --- /dev/null +++ b/examples/box_picking_drq/run_learner.sh @@ -0,0 +1,22 @@ +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +python drq_policy.py "$@" \ + --learner \ + --env box_picking_camera_env \ + --exp_name=ox_picking \ + --camera_mode pointcloud \ + --max_traj_length 100 \ + --seed 1 \ + --max_steps 25000 \ + --random_steps 0 \ + --training_starts 500 \ + --utd_ratio 8 \ + --batch_size 128 \ + --eval_period 20000 \ + --checkpoint_period 1000 \ + --encoder_type voxnet-pretrained \ + --state_mask all \ + --encoder_bottleneck_dim 128 \ + --demo_path "demo path here *.pkl" \ +# --enable_obs_rotation_wrapper \ +# --debug diff --git a/serl_launcher/serl_launcher/agents/continuous/drq.py b/serl_launcher/serl_launcher/agents/continuous/drq.py index 9b2b690b..518f17ed 100644 --- a/serl_launcher/serl_launcher/agents/continuous/drq.py +++ b/serl_launcher/serl_launcher/agents/continuous/drq.py @@ -10,14 +10,22 @@ from serl_launcher.agents.continuous.sac import SACAgent from serl_launcher.common.common import JaxRLTrainState, ModuleDict, nonpytree_field -from serl_launcher.common.encoding import EncodingWrapper +from serl_launcher.common.encoding import ( + EncodingWrapper, + MaskedEncodingWrapper, + create_state_mask, +) from serl_launcher.common.optimizers import make_optimizer from serl_launcher.common.typing import Batch, Data, Params, PRNGKey from serl_launcher.networks.actor_critic_nets import Critic, Policy, ensemblize from serl_launcher.networks.lagrange import GeqLagrangeMultiplier from serl_launcher.networks.mlp import MLP +from serl_launcher.vision.voxel_grid_encoders import VoxNet from serl_launcher.utils.train_utils import _unpack, concat_batches -from serl_launcher.vision.data_augmentations import batched_random_crop +from serl_launcher.vision.data_augmentations import ( + batched_random_crop, + batched_random_shift_voxel, +) class DrQAgent(SACAgent): @@ -86,7 +94,7 @@ def create( # Config assert not entropy_per_dim, "Not implemented" if target_entropy is None: - target_entropy = -actions.shape[-1] / 2 + target_entropy = -actions.shape[-1] return cls( state=state, @@ -241,15 +249,203 @@ def create_drq( return agent + @classmethod + def create_voxel_drq( + cls, + rng: PRNGKey, + observations: Data, + actions: jnp.ndarray, + # Model architecture + encoder_type: str = "resnet", + use_proprio: bool = False, + state_mask: str = "all", + critic_network_kwargs: dict = { + "hidden_dims": [256, 256], + }, + policy_network_kwargs: dict = { + "hidden_dims": [256, 256], + }, + policy_kwargs: dict = { + "tanh_squash_distribution": True, + "std_parameterization": "uniform", + }, + encoder_kwargs: dict = {}, + critic_ensemble_size: int = 2, + critic_subsample_size: Optional[int] = None, + temperature_init: float = 1.0, + image_keys: Iterable[str] = ("image",), + **kwargs, + ): + """ + Create a new voxel-based agent. + """ + + policy_network_kwargs["activate_final"] = True + critic_network_kwargs["activate_final"] = True + + if encoder_type == "small": + from serl_launcher.vision.small_encoders import SmallEncoder + + small_encoder = SmallEncoder( + features=(64, 64, 32, 32), + kernel_sizes=(3, 3, 3, 3), + strides=(2, 2, 1, 1), + padding="VALID", + pool_method="spatial_learned_embeddings", + bottleneck_dim=128, + spatial_block_size=8, + name=f"small_encoder", + ) + encoders = {image_key: small_encoder for image_key in image_keys} + elif encoder_type == "resnet": + from serl_launcher.vision.resnet_v1 import resnetv1_configs + + encoders = { + image_key: resnetv1_configs["resnetv1-10"]( + name=f"encoder_{image_key}", **encoder_kwargs + ) + for image_key in image_keys + } + elif encoder_type == "resnet-pretrained": + from serl_launcher.vision.resnet_v1 import ( + PreTrainedResNetEncoder, + resnetv1_configs, + ) + + pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"]( + pre_pooling=True, + name="pretrained_encoder", + ) + + encoders = { + image_key: PreTrainedResNetEncoder( + rng=rng, + pretrained_encoder=pretrained_encoder, + name=f"encoder_{image_key}", + **encoder_kwargs, + ) + for image_key in image_keys + } + elif encoder_type == "resnet-pretrained-18": + # pretrained ResNet18 from pytorch + from serl_launcher.vision.resnet_v1_18 import resnetv1_18_configs + from serl_launcher.vision.resnet_v1 import PreTrainedResNetEncoder + + pretrained_encoder = resnetv1_18_configs["resnetv1-18-frozen"]( + name="pretrained_encoder", + ) + + encoders = { + image_key: PreTrainedResNetEncoder( + rng=rng, + pretrained_encoder=pretrained_encoder, + name=f"encoder_{image_key}", + **encoder_kwargs, + ) + for image_key in image_keys + } + elif encoder_type == "voxnet" or encoder_type == "voxnet-pretrained": + encoders = { + image_key: VoxNet( + bottleneck_dim=encoder_kwargs["bottleneck_dim"], + use_conv_bias=True, + final_activation=nn.tanh, + pretrained=encoder_type == "voxnet-pretrained", + ) + for image_key in image_keys + } + elif encoder_type.lower() == "none": + encoders = None + else: + raise NotImplementedError(f"Unknown encoder type: {encoder_type}") + + state_mask_arr = create_state_mask(state_mask) + print(f"state_mask: {state_mask} {state_mask_arr.astype(jnp.int32)}") + encoder_def = MaskedEncodingWrapper( + encoder=encoders, + use_proprio=use_proprio, + enable_stacking=True, + image_keys=image_keys, + state_mask=state_mask_arr, + ) + + encoders = { + "critic": encoder_def, + "actor": encoder_def, + } + + # Define networks + critic_backbone = partial(MLP, **critic_network_kwargs) + critic_backbone = ensemblize(critic_backbone, critic_ensemble_size)( + name="critic_ensemble" + ) + critic_def = partial( + Critic, encoder=encoders["critic"], network=critic_backbone + )(name="critic") + + policy_def = Policy( + encoder=encoders["actor"], + network=MLP(**policy_network_kwargs), + action_dim=actions.shape[-1], + **policy_kwargs, + name="actor", + ) + + temperature_def = GeqLagrangeMultiplier( + init_value=temperature_init, + constraint_shape=(), + constraint_type="geq", + name="temperature", + ) + + agent = cls.create( + rng, + observations, + actions, + actor_def=policy_def, + critic_def=critic_def, + temperature_def=temperature_def, + critic_ensemble_size=critic_ensemble_size, + critic_subsample_size=critic_subsample_size, + image_keys=image_keys, + **kwargs, + ) + + if encoder_type == "resnet-pretrained": # load pretrained weights for ResNet-10 + from serl_launcher.utils.train_utils import load_resnet10_params + + agent = load_resnet10_params(agent, image_keys) + + if encoder_type == "voxnet-pretrained": + from serl_launcher.utils.train_utils import load_pretrained_VoxNet_params + + agent = load_pretrained_VoxNet_params(agent, image_keys) + + return agent + def data_augmentation_fn(self, rng, observations): + # TODO make it configurable: see https://github.com/rail-berkeley/serl/pull/67 + for pixel_key in self.config["image_keys"]: - observations = observations.copy( - add_or_replace={ - pixel_key: batched_random_crop( - observations[pixel_key], rng, padding=4, num_batch_dims=2 - ) - } - ) + # pointcloud augmentation + if "pointcloud" in pixel_key: + observations = observations.copy( + add_or_replace={ + pixel_key: batched_random_shift_voxel( + observations[pixel_key], rng, padding=3, num_batch_dims=2 + ) + } + ) + + # image augmentation + else: + observations = observations.copy( + add_or_replace={ + pixel_key: batched_random_crop( + observations[pixel_key], rng, padding=4, num_batch_dims=2 + ) + } + ) return observations @partial(jax.jit, static_argnames=("utd_ratio", "pmap_axis")) diff --git a/serl_launcher/serl_launcher/agents/continuous/sac.py b/serl_launcher/serl_launcher/agents/continuous/sac.py index ca933db4..e04733bb 100644 --- a/serl_launcher/serl_launcher/agents/continuous/sac.py +++ b/serl_launcher/serl_launcher/agents/continuous/sac.py @@ -167,9 +167,16 @@ def critic_loss_fn(self, batch, params: Params, rng: PRNGKey): ) chex.assert_shape(target_q, (batch_size,)) - if self.config["backup_entropy"]: + if self.config["backup_entropy"]: # not the same as in original jaxrl_m SAC implementation: https://github.com/dibyaghosh/jaxrl_m/blob/main/examples/mujoco/sac.py temperature = self.forward_temperature() - target_q = target_q - temperature * next_actions_log_probs + # target_q = target_q - temperature * next_actions_log_probs # serl original + target_q = ( + target_q + - self.config["discount"] + * batch["masks"] + * next_actions_log_probs + * temperature + ) # as in jaxrl_m predicted_qs = self.forward_critic( batch["observations"], batch["actions"], rng=rng, grad_params=params @@ -385,7 +392,7 @@ def create( # Config assert not entropy_per_dim, "Not implemented" if target_entropy is None: - target_entropy = -actions.shape[-1] / 2 + target_entropy = -actions.shape[-1] return cls( state=state, diff --git a/serl_launcher/serl_launcher/common/encoding.py b/serl_launcher/serl_launcher/common/encoding.py index 823782d5..c54915b7 100644 --- a/serl_launcher/serl_launcher/common/encoding.py +++ b/serl_launcher/serl_launcher/common/encoding.py @@ -7,6 +7,27 @@ from einops import rearrange, repeat +def create_state_mask(mask_str: str) -> jnp.ndarray: + all = jnp.ones((27,), dtype=jnp.bool) + none = jnp.zeros_like(all) + no_action = all.at[:7].set(False) + gripper = none.at[0 + 7 : 2 + 7].set(True) + no_ForceTorque = all.at[7 + 2 : 7 + 5].set(False).at[7 + 11 : 7 + 14].set(False) + action_only = none.at[:7].set(True) + masks = dict( + all=all, + none=jnp.zeros_like(all), + gripper=gripper, + position_gripper=gripper.at[5:11].set(True), + no_ForceTorque=no_ForceTorque, + no_ForceTorqueAction=jnp.bitwise_and(no_ForceTorque, no_action), + gripper_Zinfo=gripper.at[7 + 7].set(True), + action_only=action_only, + ) + assert mask_str in masks + return masks[mask_str] + + class EncodingWrapper(nn.Module): """ Encodes observations into a single flat encoding, adding additional @@ -72,6 +93,82 @@ def __call__( return encoded +class MaskedEncodingWrapper(nn.Module): + """ + Encodes observations into a single flat encoding, adding additional + functionality for adding proprioception and stopping the gradient. + + Args: + encoder: The encoder network. + use_proprio: Whether to concatenate proprioception (after encoding). + state_mask: Which proprioceptive states to propagate, and which to ignore + """ + + encoder: nn.Module + use_proprio: bool + state_mask: jnp.ndarray + enable_stacking: bool = False + image_keys: Iterable[str] = ("image",) + + @nn.compact + def __call__( + self, + observations: Dict[str, jnp.ndarray], + train=False, + stop_gradient=False, + is_encoded=False, + ) -> jnp.ndarray: + # encode images with encoder + if self.encoder is None: + # project state to embeddings as well + state = observations["state"] + if self.enable_stacking: + # Combine stacking and channels into a single dimension + if len(state.shape) == 2: + state = rearrange(state, "T C -> (T C)") + if len(state.shape) == 3: + state = rearrange(state, "B T C -> B (T C)") + # do not use proprio latent dim + return state + + encoded = [] + for image_key in self.image_keys: + image = observations[image_key] + if not is_encoded: + if self.enable_stacking: + # Combine stacking and channels into a single dimension + if len(image.shape) == 4: + image = rearrange(image, "T H W C -> H W (T C)") + if len(image.shape) == 5: + image = rearrange(image, "B T H W C -> B H W (T C)") + + image = self.encoder[image_key](image, train=train, encode=not is_encoded) + + if stop_gradient: + image = jax.lax.stop_gradient(image) + + encoded.append(image) + + encoded = jnp.concatenate(encoded, axis=-1) + + if self.use_proprio: + # project state to embeddings as well + state = observations["state"] + state = state[..., self.state_mask] # only propagate non-zero mask entries + if state.shape[-1] != 0: + if self.enable_stacking: + # Combine stacking and channels into a single dimension + if len(state.shape) == 2: + state = rearrange(state, "T C -> (T C)") + encoded = encoded.reshape(-1) + if len(state.shape) == 3: + state = rearrange(state, "B T C -> B (T C)") + # do not use proprio latent di + encoded = jnp.concatenate([encoded, state], axis=-1) + + return encoded + + class GCEncodingWrapper(nn.Module): """ Encodes observations and goals into a single flat encoding. Handles all the diff --git a/serl_launcher/serl_launcher/networks/actor_critic_nets.py b/serl_launcher/serl_launcher/networks/actor_critic_nets.py index 189eef82..3746ae82 100644 --- a/serl_launcher/serl_launcher/networks/actor_critic_nets.py +++ b/serl_launcher/serl_launcher/networks/actor_critic_nets.py @@ -62,7 +62,8 @@ def __call__( obs_enc = self.encoder(observations) inputs = jnp.concatenate([obs_enc, actions], -1) - outputs = self.network(inputs, train=train) + outputs = self.network(inputs, train) + # train=train throws: "RuntimeWarning: kwargs are not supported in vmap, so "train" is(are) ignored" if self.init_final is not None: value = nn.Dense( 1, @@ -157,7 +158,7 @@ def ensemblize(cls, num_qs, out_axes=0): return nn.vmap( cls, variable_axes={"params": 0}, - split_rngs={"params": True}, + split_rngs={"params": True, "dropout": True}, in_axes=None, out_axes=out_axes, axis_size=num_qs, diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 782221eb..c37a2620 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -116,6 +116,66 @@ def make_drq_agent( return agent +def make_voxel_drq_agent( + seed, + sample_obs, + sample_action, + image_keys=("image",), + encoder_type="voxnet", + state_mask="all", + encoder_kwargs=None, +): + if encoder_kwargs is None: + encoder_kwargs = dict(bottleneck_dim=128) + + agent = DrQAgent.create_voxel_drq( + jax.random.PRNGKey(seed), + sample_obs, + sample_action, + encoder_type=encoder_type, + use_proprio=True, + state_mask=state_mask, + image_keys=image_keys, + policy_kwargs=dict( + tanh_squash_distribution=True, + std_parameterization="exp", + std_min=1e-5, + std_max=5, + ), + critic_network_kwargs=dict( + activations=nn.tanh, + use_layer_norm=True, + hidden_dims=[256, 256], + dropout_rate=0.1, + ), + policy_network_kwargs=dict( + activations=nn.tanh, + use_layer_norm=True, + hidden_dims=[256, 256], + dropout_rate=0.1, + ), + temperature_init=1e-2, + discount=0.99, # 0.99 + backup_entropy=True, + critic_ensemble_size=10, + critic_subsample_size=2, + encoder_kwargs=encoder_kwargs, + # dict( + # # pooling_method="spatial_learned_embeddings", + # bottleneck_dim=128, + # # num_spatial_blocks=8, + # # num_kp=64, + # ), + actor_optimizer_kwargs={ + "learning_rate": 3e-3, + }, + critic_optimizer_kwargs={ + "learning_rate": 3e-3, + }, + ) + return agent + + def make_vice_agent( seed, sample_obs, diff --git a/serl_launcher/serl_launcher/utils/sampling_utils.py b/serl_launcher/serl_launcher/utils/sampling_utils.py new file mode 100644 index 00000000..9827b2de --- /dev/null +++ b/serl_launcher/serl_launcher/utils/sampling_utils.py @@ -0,0 +1,30 @@ +import numpy as np + + +class TemporalActionEnsemble: + def __init__(self, activated=True, action_shape=(7,), ensemble=None): + if ensemble is None: + ensemble = [0.5, 0.3, 0.2, 0.1] + self.activated = activated + self.ensemble = np.asarray(ensemble) + self.buffer = np.zeros((len(ensemble), action_shape[0])) + + if activated: + print(f"Temporal Action Ensemble enabled: {self.ensemble}") + + def reset(self): + self.buffer[...] = 0.0 + + def sample(self, curr_action: np.ndarray): + if not self.activated: + return curr_action + + curr_action = curr_action.reshape(-1) + assert curr_action.shape[0] == self.buffer.shape[1] + + self.buffer = np.roll(self.buffer, axis=0, shift=1) + self.buffer[0, :] = curr_action + return np.dot(self.ensemble, self.buffer) + + def is_activated(self): + return self.activated diff --git a/serl_launcher/serl_launcher/utils/train_utils.py b/serl_launcher/serl_launcher/utils/train_utils.py index 31037317..200e1798 100644 --- a/serl_launcher/serl_launcher/utils/train_utils.py +++ b/serl_launcher/serl_launcher/utils/train_utils.py @@ -128,3 +128,110 @@ def load_resnet10_params(agent, image_keys=("image",), public=True): agent = agent.replace(state=agent.state.replace(params=new_params)) return agent + + +def load_pretrained_VoxNet_params(agent, image_keys=("pointcloud",)): + ckpt = jnp.load("/home/nico/Downloads/c-11.npz") + + new_params = agent.state.params + + for image_key in image_keys: + new_encoder_params = new_params["modules_actor"]["encoder"][ + f"encoder_{image_key}" + ] + to_replace = { + "conv_5x5x5": "voxnet/conv1/conv3d/", + "conv_3x3x3": "voxnet/conv2/conv3d/", + "conv_2x2x2": "voxnet/conv3/conv3d/", + } + replaced = [] + for key, weights in to_replace.items(): + if key in new_encoder_params: + shape = new_encoder_params[key]["kernel"].shape + new_encoder_params[key]["kernel"] = ( + new_encoder_params[key]["kernel"] + .at[:] + .set(ckpt[weights + "kernel:0"][..., : shape[-1]]) + ) + new_encoder_params[key]["bias"] = ( + new_encoder_params[key]["bias"] + .at[:] + .set(ckpt[weights + "bias:0"][: shape[-1]]) + ) + replaced.append(f"{key}:{shape}") + + print(f"replaced {replaced} in {image_key}") + + # replace LayerNorm params with pretrained BN ones + new_encoder_params["LayerNorm_0"]["bias"] = ( + new_encoder_params["LayerNorm_0"]["bias"] + .at[:] + .set(ckpt["voxnet/conv1/batch_normalization/beta:0"]) + ) + new_encoder_params["LayerNorm_0"]["scale"] = ( + new_encoder_params["LayerNorm_0"]["scale"] + .at[:] + .set(ckpt["voxnet/conv1/batch_normalization/gamma:0"]) + ) + + new_encoder_params["LayerNorm_1"]["bias"] = ( + new_encoder_params["LayerNorm_0"]["bias"] + .at[:] + .set(ckpt["voxnet/conv2/batch_normalization/beta:0"]) + ) + new_encoder_params["LayerNorm_1"]["scale"] = ( + new_encoder_params["LayerNorm_0"]["scale"] + .at[:] + .set(ckpt["voxnet/conv2/batch_normalization/gamma:0"]) + ) + + agent = agent.replace(state=agent.state.replace(params=new_params)) + return agent + + +def print_agent_params(agent, image_keys=("image",)): + """ + helper function to print the parameter count of the actor and critic networks + """ + + def get_size(params): + return sum(x.size for x in jax.tree.leaves(params)) + + total_param_count = get_size(agent.state.params) + actor, critic = ( + agent.state.params["modules_actor"], + agent.state.params["modules_critic"], + ) + + # calculate encoder params + try: + pretrained_encoder_count = get_size( + actor["encoder"][f"encoder_{image_keys[0]}"]["pretrained_encoder"] + ) + except Exception as e: + pretrained_encoder_count = 0 + + try: + encoder_count = get_size(actor["encoder"]) + except Exception as e: + encoder_count = 0 + + actor_count = get_size(actor) + critic_count = get_size(critic) + + print(f"\ntotal params: {total_param_count / 1e6:.3f}M") + print( + f"encoder params: {(encoder_count - pretrained_encoder_count) / 1e6:.3f}M pretrained encoder params: {pretrained_encoder_count / 1e6:.3f}M" + ) + print( + f"actor params: {(actor_count - encoder_count) / 1e6:.3f}M critic_params: {critic_count / 1e6:.3f}M" + ) + print( + f"total parameters to train: {(total_param_count - pretrained_encoder_count) / 1e6:.3f}M\n" + ) + + +def parameter_overview(agent): + from clu import parameter_overview + + print(parameter_overview.get_parameter_overview(agent.state.params)) diff --git a/serl_launcher/serl_launcher/vision/data_augmentations.py b/serl_launcher/serl_launcher/vision/data_augmentations.py index 2c2440fa..3f449cd5 100644 --- a/serl_launcher/serl_launcher/vision/data_augmentations.py +++ b/serl_launcher/serl_launcher/vision/data_augmentations.py @@ -4,6 +4,7 @@ import jax.numpy as jnp +@partial(jax.jit, static_argnames="padding") def random_crop(img, rng, *, padding): crop_from = jax.random.randint(rng, (2,), 0, 2 * padding + 1) crop_from = jnp.concatenate([crop_from, jnp.zeros((1,), dtype=jnp.int32)]) @@ -36,6 +37,36 @@ def batched_random_crop(img, rng, *, padding, num_batch_dims: int = 1): return img +def random_shift_3d(img, rng, *, padding): + crop_from = jax.random.randint(rng, (3,), 0, 2 * padding + 1) + padded_img = jnp.pad( + img, + ( + (padding, padding), + (padding, padding), + (padding, padding), + ), + mode="constant" + ) + return jax.lax.dynamic_slice(padded_img, crop_from, img.shape) + + +@partial(jax.jit, static_argnames=("padding", "num_batch_dims")) +def batched_random_shift_voxel(img, rng, *, padding, num_batch_dims: int = 1): + original_shape = img.shape + img = jnp.reshape(img, (-1, *img.shape[num_batch_dims:])) + # shape (B, B2, X, Y, Z) + + rngs = jax.random.split(rng, img.shape[0]) + img = jax.vmap( + lambda i, r: random_shift_3d(i, r, padding=padding), in_axes=(0, 0), out_axes=0 + )(img, rngs) + + # Restore batch dims + img = jnp.reshape(img, original_shape) + return img + + def _maybe_apply(apply_fn, inputs, rng, apply_prob): should_apply = jax.random.uniform(rng, shape=()) <= apply_prob return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x) diff --git a/serl_launcher/serl_launcher/vision/resnet_v1.py b/serl_launcher/serl_launcher/vision/resnet_v1.py index e18769b3..e89a6f42 100644 --- a/serl_launcher/serl_launcher/vision/resnet_v1.py +++ b/serl_launcher/serl_launcher/vision/resnet_v1.py @@ -8,6 +8,7 @@ import numpy as np from serl_launcher.vision.film_conditioning_layer import FilmConditioning +from serl_launcher.common.typing import PRNGKey ModuleDef = Any @@ -198,7 +199,6 @@ class ResNetEncoder(nn.Module): norm: str = "group" add_spatial_coordinates: bool = False pooling_method: str = "avg" - use_spatial_softmax: bool = False softmax_temperature: float = 1.0 use_multiplicative_cond: bool = False num_spatial_blocks: int = 8 @@ -322,10 +322,11 @@ def __call__( class PreTrainedResNetEncoder(nn.Module): + rng: PRNGKey = None pooling_method: str = "avg" - use_spatial_softmax: bool = False softmax_temperature: float = 1.0 num_spatial_blocks: int = 8 + num_kp: Optional[int] = None # for Spatial Softmax bottleneck_dim: Optional[int] = None pretrained_encoder: nn.module = None @@ -348,8 +349,22 @@ def __call__( channel=channel, num_features=self.num_spatial_blocks, )(x) - x = nn.Dropout(0.1, deterministic=not train)(x) + x = nn.Dropout(0.1, deterministic=not train)(x, rng=self.rng) elif self.pooling_method == "spatial_softmax": + if self.num_kp is not None: + """ + implemented as in https://github.com/huggingface/lerobot/blob/ff8f6aa6cde2957f08547eb081aac12ca4669b6a/lerobot/common/policies/diffusion/modeling_diffusion.py#L316 + In this case it would result in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. + """ + x = nn.Conv( + features=self.num_kp, + kernel_size=1, + use_bias=False, + dtype=jnp.float32, + kernel_init=nn.initializers.kaiming_normal(), + name="spatial_softmax_conv", + )(x) height, width, channel = x.shape[-3:] pos_x, pos_y = jnp.meshgrid( jnp.linspace(-1.0, 1.0, height), jnp.linspace(-1.0, 1.0, width) diff --git a/serl_launcher/serl_launcher/vision/resnet_v1_18.py b/serl_launcher/serl_launcher/vision/resnet_v1_18.py new file mode 100644 index 00000000..cc8bbe0d --- /dev/null +++ b/serl_launcher/serl_launcher/vision/resnet_v1_18.py @@ -0,0 +1,492 @@ +import jax.lax +import jax.numpy as jnp +import flax.linen as nn +import functools +from typing import Any, Callable, Iterable, Optional, Tuple, Union +import h5py +import warnings + +from flax.linen.module import compact, merge_param +from jax.nn import initializers +from jax import lax + +from serl_launcher.vision.resnet_v1 import SpatialLearnedEmbeddings, SpatialSoftmax + +PRNGKey = Any +Array = Any +Shape = Tuple[int] +Dtype = Any + + +# ---------------------------------------------------------------# +# Normalization +# ---------------------------------------------------------------# +def batch_norm(x, train, epsilon=1e-05, momentum=0.99, params=None, dtype="float32"): + # we do not use running average in the implementation (set to False) + if params is None: + x = BatchNorm( + epsilon=epsilon, + momentum=momentum, + use_running_average=False, # was not train + dtype=dtype, + )(x) + else: + x = BatchNorm( + epsilon=epsilon, + momentum=momentum, + bias_init=lambda *_: jnp.array(params["bias"]), + scale_init=lambda *_: jnp.array(params["scale"]), + mean_init=lambda *_: jnp.array(params["mean"]), + var_init=lambda *_: jnp.array(params["var"]), + use_running_average=False, # was not train + dtype=dtype, + )(x) + return x + + +def _absolute_dims(rank, dims): + return tuple([rank + dim if dim < 0 else dim for dim in dims]) + + +class BatchNorm(nn.Module): + """BatchNorm Module. + + Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py + + Attributes: + use_running_average: if True, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + axis: the feature or non-batch axis of the input. + momentum: decay rate for the exponential moving average of the batch statistics. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + momentum: float = 0.99 + epsilon: float = 1e-5 + dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32) + var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32) + axis_name: Optional[str] = None + axis_index_groups: Any = None + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when parameters are mutable) the running average + of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + Returns: + Normalized inputs (the same shape as inputs). + """ + use_running_average = merge_param( + "use_running_average", self.use_running_average, use_running_average + ) + x = jnp.asarray(x, jnp.float32) + axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) + axis = _absolute_dims(x.ndim, axis) + feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) + reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) + reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) + + # see NOTE above on initialization behavior + initializing = self.is_mutable_collection("params") + + if use_running_average: + ra_mean = self.variable( + "batch_stats", "mean", self.mean_init, reduced_feature_shape + ) + ra_var = self.variable( + "batch_stats", "var", self.var_init, reduced_feature_shape + ) + mean, var = ra_mean.value, ra_var.value + else: + mean = jnp.mean(x, axis=reduction_axis, keepdims=False) + mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) + if self.axis_name is not None and not initializing: + concatenated_mean = jnp.concatenate([mean, mean2]) + mean, mean2 = jnp.split( + lax.pmean( + concatenated_mean, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups, + ), + 2, + ) + var = mean2 - lax.square(mean) + + y = x - mean.reshape(feature_shape) + mul = lax.rsqrt(var + self.epsilon) + if self.use_scale: + scale = self.param("scale", self.scale_init, reduced_feature_shape).reshape( + feature_shape + ) + mul = mul * scale + y = y * mul + if self.use_bias: + bias = self.param("bias", self.bias_init, reduced_feature_shape).reshape( + feature_shape + ) + y = y + bias + return jnp.asarray(y, self.dtype) + + +LAYERS = {"resnet18": [2, 2, 2, 2]} + + +class BasicBlock(nn.Module): + """ + Basic Block. + + Attributes: + features (int): Number of output channels. + kernel_size (Tuple): Kernel size. + downsample (bool): If True, downsample spatial resolution. + stride (bool): If True, use strides (2, 2). Not used in this module. + The attribute is only here for compatibility with Bottleneck. + param_dict (h5py.Group): Parameter dict with pretrained parameters. + kernel_init (functools.partial): Kernel initializer. + bias_init (functools.partial): Bias initializer. + block_name (str): Name of block. + dtype (str): Data type. + """ + + features: int + kernel_size: Union[int, Iterable[int]] = (3, 3) + downsample: bool = False + stride: bool = True + param_dict: h5py.Group = None + kernel_init: functools.partial = nn.initializers.lecun_normal() + bias_init: functools.partial = nn.initializers.zeros + block_name: str = None + dtype: str = "float32" + + @nn.compact + def __call__(self, x, act, train=True): + """ + Run Basic Block. + + Args: + x (tensor): Input tensor of shape [N, H, W, C]. + act (dict): Dictionary containing activations. + train (bool): Training mode. + + Returns: + (tensor): Output shape of shape [N, H', W', features]. + """ + residual = x + + x = nn.Conv( + features=self.features, + kernel_size=self.kernel_size, + strides=(2, 2) if self.downsample else (1, 1), + padding=((1, 1), (1, 1)), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["conv1"]["weight"]), + use_bias=False, + dtype=self.dtype, + )(x) + + x = batch_norm( + x, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None if self.param_dict is None else self.param_dict["bn1"], + dtype=self.dtype, + ) + x = nn.relu(x) + + x = nn.Conv( + features=self.features, + kernel_size=self.kernel_size, + strides=(1, 1), + padding=((1, 1), (1, 1)), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["conv2"]["weight"]), + use_bias=False, + dtype=self.dtype, + )(x) + + x = batch_norm( + x, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None if self.param_dict is None else self.param_dict["bn2"], + dtype=self.dtype, + ) + + if self.downsample: + residual = nn.Conv( + features=self.features, + kernel_size=(1, 1), + strides=(2, 2), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array( + self.param_dict["downsample"]["conv"]["weight"] + ), + use_bias=False, + dtype=self.dtype, + )(residual) + + residual = batch_norm( + residual, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None + if self.param_dict is None + else self.param_dict["downsample"]["bn"], + dtype=self.dtype, + ) + + x += residual + x = nn.relu(x) + act[self.block_name] = x + return x + + +class ResNet(nn.Module): + """ + ResNet. + + Attributes: + output (str): + Output of the module. Available options are: + - 'softmax': Output is a softmax tensor of shape [N, 1000] + - 'log_softmax': Output is a softmax tensor of shape [N, 1000] + - 'logits': Output is a tensor of shape [N, 1000] + - 'activations': Output is a dictionary containing the ResNet activations + pretrained (str): + Indicates if and what type of weights to load. Options are: + - 'imagenet': Loads the network parameters trained on ImageNet + - None: Parameters of the module are initialized randomly + normalize (bool): + If True, the input will be normalized with the ImageNet statistics. + architecture (str): + Which ResNet model to use: + - 'resnet18' + num_classes (int): + Number of classes. + block (nn.Module): + Type of residual block: + - BasicBlock + kernel_init (function): + A function that takes in a shape and returns a tensor. + bias_init (function): + A function that takes in a shape and returns a tensor. + ckpt_dir (str): + The directory to which the pretrained weights are downloaded. + Only relevant if a pretrained model is used. + If this argument is None, the weights will be saved to a temp directory. + dtype (str): Data type. + """ + + output: str = "softmax" + pretrained: str = "imagenet" + normalize: bool = True + architecture: str = "resnet18" + num_classes: int = 1000 + block: nn.Module = BasicBlock + kernel_init: functools.partial = nn.initializers.lecun_normal() + bias_init: functools.partial = nn.initializers.zeros + ckpt_dir: str = None + dtype: str = "float32" + pre_pooling: bool = True # skip pooling + + def setup(self): + # self.param_dict = None + if self.pretrained == "imagenet": + # ckpt_file = utils.download(self.ckpt_dir, URLS[self.architecture]) + self.param_dict = h5py.File(self.ckpt_dir, "r") + # print(f"loaded pretrained weights from {self.ckpt_dir}") + + @nn.compact + def __call__(self, observations, train=False): + """ + Args: + x (tensor): Input tensor of shape [N, H, W, 3]. Images must be in range [0, 1]. + train (bool): Training mode. + + Returns: + (tensor): Out + if pre_pooling is True: features of shape (B, 7, 7, 512) + """ + # assert observations.shape[-3:] == (224, 224, 3) + + if self.normalize: + mean = jnp.array([0.485, 0.456, 0.406]).reshape(1, 1, 1, -1) + std = jnp.array([0.229, 0.224, 0.225]).reshape(1, 1, 1, -1) + x = (observations.astype(jnp.float32) / 255.0 - mean) / std + + if self.pretrained == "imagenet": + if self.num_classes != 1000: + warnings.warn( + f"The user specified parameter 'num_classes' was set to {self.num_classes} " + "but will be overwritten with 1000 to match the specified pretrained checkpoint 'imagenet', if ", + UserWarning, + ) + num_classes = 1000 + else: + num_classes = self.num_classes + + act = {} + + x = nn.Conv( + features=64, + kernel_size=(7, 7), + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["conv1"]["weight"]), + strides=(2, 2), + padding=((3, 3), (3, 3)), + use_bias=False, + dtype=self.dtype, + )(x) + act["conv1"] = x + + x = batch_norm( + x, + train=train, + epsilon=1e-05, + momentum=0.1, + params=None if self.param_dict is None else self.param_dict["bn1"], + dtype=self.dtype, + ) + x = nn.relu(x) + x = nn.max_pool( + x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)) + ) + + # Layer 1 + down = self.block.__name__ == "Bottleneck" + for i in range(LAYERS[self.architecture][0]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer1"][f"block{i}"] + ) + x = self.block( + features=64, + kernel_size=(3, 3), + downsample=i == 0 and down, + stride=i != 0, + param_dict=params, + block_name=f"block1_{i}", + dtype=self.dtype, + )(x, act, train) + + # Layer 2 + for i in range(LAYERS[self.architecture][1]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer2"][f"block{i}"] + ) + x = self.block( + features=128, + kernel_size=(3, 3), + downsample=i == 0, + param_dict=params, + block_name=f"block2_{i}", + dtype=self.dtype, + )(x, act, train) + + # Layer 3 + for i in range(LAYERS[self.architecture][2]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer3"][f"block{i}"] + ) + x = self.block( + features=256, + kernel_size=(3, 3), + downsample=i == 0, + param_dict=params, + block_name=f"block3_{i}", + dtype=self.dtype, + )(x, act, train) + + # Layer 4 + for i in range(LAYERS[self.architecture][3]): + params = ( + None + if self.param_dict is None + else self.param_dict["layer4"][f"block{i}"] + ) + x = self.block( + features=512, + kernel_size=(3, 3), + downsample=i == 0, + param_dict=params, + block_name=f"block4_{i}", + dtype=self.dtype, + )(x, act, train) + + # if we want the pre_pooling output, return here + if self.pre_pooling: + return jax.lax.stop_gradient(x) # shape (b, 7, 7, 512) + + # Classifier + x = jnp.mean(x, axis=(1, 2)) + x = nn.Dense( + features=num_classes, + kernel_init=self.kernel_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["fc"]["weight"]), + bias_init=self.bias_init + if self.param_dict is None + else lambda *_: jnp.array(self.param_dict["fc"]["bias"]), + dtype=self.dtype, + )(x) + act["fc"] = x + + if self.output == "softmax": + return nn.softmax(x) + if self.output == "log_softmax": + return nn.log_softmax(x) + if self.output == "activations": + return act + return x + + +resnetv1_18_configs = { + "resnetv1-18-frozen": functools.partial( + ResNet, + architecture="resnet18", + ckpt_dir="/examples/box_picking_drq/resnet18_weights.h5", # download from #TODO + pre_pooling=True, + ) +} diff --git a/serl_launcher/serl_launcher/vision/small_encoders.py b/serl_launcher/serl_launcher/vision/small_encoders.py index 630c7b9a..d0e37969 100644 --- a/serl_launcher/serl_launcher/vision/small_encoders.py +++ b/serl_launcher/serl_launcher/vision/small_encoders.py @@ -14,9 +14,12 @@ class SmallEncoder(nn.Module): pool_method: str = "spatial_learned_embeddings" bottleneck_dim: Optional[int] = None spatial_block_size: Optional[int] = 8 + num_kp: Optional[int] = 32 @nn.compact - def __call__(self, observations: jnp.ndarray, train=False) -> jnp.ndarray: + def __call__( + self, observations: jnp.ndarray, train=False, encode=True + ) -> jnp.ndarray: assert len(self.features) == len(self.strides) x = observations.astype(jnp.float32) / 255.0 @@ -44,6 +47,13 @@ def __call__(self, observations: jnp.ndarray, train=False) -> jnp.ndarray: raise ValueError( "spatial_block_size must be set when using spatial_learned_embeddings" ) + x = nn.Conv( # 512 to num_kp features (less complexity) + features=self.num_kp, + kernel_size=1, + use_bias=False, + dtype=jnp.float32, + kernel_init=nn.initializers.kaiming_normal(), + )(x) x = SpatialLearnedEmbeddings(*(x.shape[-3:]), self.spatial_block_size)(x) x = nn.Dropout(0.1, deterministic=not train)(x) diff --git a/serl_launcher/serl_launcher/vision/voxel_grid_encoders.py b/serl_launcher/serl_launcher/vision/voxel_grid_encoders.py new file mode 100644 index 00000000..45fabc33 --- /dev/null +++ b/serl_launcher/serl_launcher/vision/voxel_grid_encoders.py @@ -0,0 +1,148 @@ +from functools import partial +from typing import Any, Callable, Optional, Sequence, Tuple + +import flax.linen as nn +import jax.numpy as jnp +import jax.lax as lax + +import jax + + +class SpatialSoftArgmax3D(nn.Module): + """ + 3D Implementation of Spatial Soft Argmax + why arg-max and not max: see https://github.com/tensorflow/tensorflow/issues/6271#issuecomment-266893850 + """ + + x_len: int + y_len: int + z_len: int + channel: int + temperature: float = 1.0 + + def setup(self): + pos_x, pos_y, pos_z = jnp.meshgrid( + jnp.linspace(-1.0, 1.0, self.x_len), + jnp.linspace(-1.0, 1.0, self.y_len), + jnp.linspace(-1.0, 1.0, self.z_len), + indexing="ij", + ) + self.pos_x = pos_x.reshape(-1) # shape (x*y*z) + self.pos_y = pos_y.reshape(-1) + self.pos_z = pos_z.reshape(-1) + + @nn.compact + def __call__(self, features): + # add batch dim if missing + no_batch_dim = len(features.shape) < 5 + if no_batch_dim: + features = features[None] + + assert len(features.shape) == 5 + batch_size, num_featuremaps = features.shape[0], features.shape[-1] + features = features.transpose(0, 4, 1, 2, 3).reshape( + batch_size, num_featuremaps, self.x_len * self.y_len * self.z_len + ) + + softmax_attention = nn.softmax(features / self.temperature, axis=-1) + expected_x = jnp.sum(self.pos_x * softmax_attention, axis=-1) + expected_y = jnp.sum(self.pos_y * softmax_attention, axis=-1) + expected_z = jnp.sum(self.pos_z * softmax_attention, axis=-1) + expected_xyz = jnp.concatenate([expected_x, expected_y, expected_z], axis=-1) + + expected_xy = jnp.reshape(expected_xyz, (batch_size, 3, num_featuremaps)) + + if no_batch_dim: + expected_xy = expected_xy[0] + return expected_xy + + +class VoxNet(nn.Module): + """ + VoxNet-like implementation: https://github.com/AutoDeep/VoxNet/blob/master/src/nets/voxNet.py + """ + + use_conv_bias: bool = False + bottleneck_dim: Optional[int] = None + final_activation: Callable[[jnp.ndarray], jnp.ndarray] | str = nn.tanh + pretrained: bool = False + scale_factor: float = 1.0 + + @nn.compact + def __call__( + self, + observations: jnp.ndarray, + encode: bool = True, + train: bool = True, + ): + # observations has shape (B, X, Y, Z) + no_batch_dim = len(observations.shape) < 4 + if no_batch_dim: + observations = observations[None] + + observations = ( + observations.astype(jnp.float32)[..., None] / self.scale_factor + ) # add conv channel + + conv3d = partial( + nn.Conv, + kernel_init=nn.initializers.xavier_normal(), + use_bias=self.use_conv_bias, + padding="valid", + bias_init=nn.zeros_init(), + ) + l_relu = partial(nn.leaky_relu, negative_slope=0.1) + max_pool = partial(nn.max_pool, window_shape=(2, 2, 2), strides=(2, 2, 2)) + + if self.pretrained: + feature_dimensions = (64, 64, 32) + else: + feature_dimensions = (32, 16, 8) + + x = observations + x = conv3d( + features=feature_dimensions[0], + kernel_size=(5, 5, 5), + strides=(2, 2, 2), + name="conv_5x5x5", + )(x) + x = nn.LayerNorm()(x) + x = l_relu(x) + + x = conv3d( + features=feature_dimensions[1], + kernel_size=(3, 3, 3), + strides=(1, 1, 1), + name="conv_3x3x3", + )(x) + x = max_pool(x) + + if self.pretrained: + x = jax.lax.stop_gradient( + x + ) # unfortunately also cuts gradients of the LayerNorm above + + x = nn.LayerNorm()(x) + x = l_relu(x) + + x = conv3d( + features=feature_dimensions[ + 2 + ], # if pretrained, uses [..] out of 128 pretrained params as initial weights + kernel_size=(2, 2, 2), + strides=(2, 2, 2), + name="conv_2x2x2", + )(x) + x = nn.LayerNorm()(x) + x = l_relu(x) + + # x = SpatialSoftArgmax3D(10, 10, 8, 64)(x) # not used for now + + # reshape and dense (preserve batch dim) + x = jnp.reshape(x, (1 if no_batch_dim else x.shape[0], -1)) + if self.bottleneck_dim is not None: + x = nn.Dense(self.bottleneck_dim)(x) + x = nn.LayerNorm()(x) + x = self.final_activation(x) + + return x[0] if no_batch_dim else x diff --git a/serl_launcher/serl_launcher/wrappers/observation_statistics_wrapper.py b/serl_launcher/serl_launcher/wrappers/observation_statistics_wrapper.py new file mode 100644 index 00000000..90c4836a --- /dev/null +++ b/serl_launcher/serl_launcher/wrappers/observation_statistics_wrapper.py @@ -0,0 +1,93 @@ +import numpy as np +from collections import deque +import gym + + +class ObservationStatisticsWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): + """ + This wrapper will keep track of the observation statistics. + + At the end of an episode, the statistics of the episode will be added to ``info`` + using the key ``obsStat``. + """ + + def __init__(self, env: gym.Env, deque_size: int = 100): + gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size) + gym.Wrapper.__init__(self, env) + + self.buffer = {} + + # make buffer + for name, space in self.env.observation_space["state"].items(): + self.buffer[name] = np.zeros( + shape=(self.max_episode_length, space.shape[0]) + ) + + # may not be used + self.num_envs = getattr(env, "num_envs", 1) + self.episode_count = 0 + self.episode_start_times: np.ndarray = None + self.episode_returns = None + self.episode_lengths = None + self.return_queue = deque(maxlen=deque_size) + self.length_queue = deque(maxlen=deque_size) + self.is_vector_env = getattr(env, "is_vector_env", False) + + def step(self, action): + """Steps through the environment, recording the episode statistics.""" + ( + observations, + rewards, + terminations, + truncations, + infos, + ) = self.env.step(action) + assert isinstance( + infos, dict + ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." + + for name, obs in observations["state"].items(): + self.buffer[name][self.curr_path_length - 1, :] = obs + + dones = np.logical_or(terminations, truncations) + num_dones = np.sum(dones) + if num_dones: + calc_buffs = {} + calc_buffs.update( + { + name + "_mean": np.mean(obs[: self.curr_path_length], axis=0) + for name, obs in self.buffer.items() + } + ) + calc_buffs.update( + { + name + "_std": np.std(obs[: self.curr_path_length], axis=0) + for name, obs in self.buffer.items() + } + ) + buff = {} + for name, value in calc_buffs.items(): + for i in range(value.shape[0]): + buff[ + name + f"_{['x', 'y', 'z', 'rx', 'ry', 'rz', 'grip'][i]}" + ] = value[i] + infos["obsStat"] = buff + # print(buff) + + return ( + observations, + rewards, + terminations, + truncations, + infos, + ) + + def reset(self, **kwargs): + """Resets the environment using kwargs and resets the episode returns and lengths.""" + obs, info = super().reset(**kwargs) + + # reset buffer to zero + for name, value in self.buffer.items(): + value[...] = 0 + + return obs, info diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index 41c169f9..4bfbb721 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -23,3 +23,41 @@ def observation(self, obs): **(obs["images"]), } return obs + + +class ScaleObservationWrapper(gym.ObservationWrapper): + """ + This observation wrapper scales the observations with the provided hyperparams + (to somewhat normalize the observations space) + """ + + def __init__( + self, + env, + translation_scale=100.0, + rotation_scale=10.0, + force_scale=1.0, + torque_scale=10.0, + ): + super().__init__(env) + self.translation_scale = translation_scale + self.rotation_scale = rotation_scale + self.force_scale = force_scale + self.torque_scale = torque_scale + + def scale_wrapper_get_scales(self): + return dict( + translation_scale=self.translation_scale, + rotation_scale=self.rotation_scale, + force_scale=self.force_scale, + torque_scale=self.torque_scale, + ) + + def observation(self, obs): + obs["state"]["tcp_pose"][:3] *= self.translation_scale + obs["state"]["tcp_pose"][3:] *= self.rotation_scale + obs["state"]["tcp_vel"][:3] *= self.translation_scale + obs["state"]["tcp_vel"][3:] *= self.rotation_scale + obs["state"]["tcp_force"] *= self.force_scale + obs["state"]["tcp_torque"] *= self.torque_scale + return obs diff --git a/serl_robot_infra/franka_env/utils/transformations.py b/serl_robot_infra/franka_env/utils/transformations.py index 52a55527..3687237b 100644 --- a/serl_robot_infra/franka_env/utils/transformations.py +++ b/serl_robot_infra/franka_env/utils/transformations.py @@ -23,6 +23,14 @@ def construct_adjoint_matrix(tcp_pose): return adjoint_matrix +def construct_rotation_matrix(tcp_pose): + """ + Construct the adjoint matrix for a spatial velocity vector + :args: tcp_pose: (x, y, z, qx, qy, qz, qw) + """ + return R.from_quat(tcp_pose[3:]).as_matrix() + + def construct_homogeneous_matrix(tcp_pose): """ Construct the homogeneous transformation matrix from given pose. diff --git a/serl_robot_infra/robot_controllers/__init__.py b/serl_robot_infra/robot_controllers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/serl_robot_infra/robot_controllers/ur5_controller.py b/serl_robot_infra/robot_controllers/ur5_controller.py new file mode 100644 index 00000000..b0f04a7d --- /dev/null +++ b/serl_robot_infra/robot_controllers/ur5_controller.py @@ -0,0 +1,448 @@ +import datetime +import time +import threading +import asyncio +import numpy as np +from scipy.spatial.transform import Rotation as R +from rtde_control import RTDEControlInterface +from rtde_receive import RTDEReceiveInterface + +from ur_env.utils.vacuum_gripper import VacuumGripper +from ur_env.utils.rotations import rotvec_2_quat, quat_2_rotvec, pose2rotvec, pose2quat + +np.set_printoptions(precision=4, suppress=True) + + +def pos_difference(quat_pose_1: np.ndarray, quat_pose_2: np.ndarray): + assert quat_pose_1.shape == (7,) + assert quat_pose_2.shape == (7,) + p_diff = np.sum(np.abs(quat_pose_1[:3] - quat_pose_2[:3])) + + r_diff = ( + R.from_quat(quat_pose_1[3:]) * R.from_quat(quat_pose_2[3:]).inv() + ).magnitude() + return p_diff + r_diff + + +class UrImpedanceController(threading.Thread): + def __init__( + self, + robot_ip, + frequency=100, + kp=10000, + kd=2200, + config=None, + verbose=False, + *args, + **kwargs, + ): + super(UrImpedanceController, self).__init__(*args, **kwargs) + self._stop = threading.Event() + self._reset = threading.Event() + self._is_ready = threading.Event() + self._is_truncated = threading.Event() + self.lock = threading.Lock() + + self.robot_ip = robot_ip + self.frequency = frequency + self.kp = kp + self.kd = kd + self.gripper_timeout = { + "timeout": config.GRIPPER_TIMEOUT, + "last_grip": time.monotonic() - 1e6, + } + self.verbose = verbose + + self.target_pos = np.zeros( + (7,), dtype=np.float32 + ) # new as quat to avoid +- problems with axis angle repr. + self.target_grip = np.zeros((1,), dtype=np.float32) + self.curr_pos = np.zeros((7,), dtype=np.float32) + self.curr_vel = np.zeros((6,), dtype=np.float32) + self.gripper_state = np.zeros((2,), dtype=np.float32) + self.curr_Q = np.zeros((6,), dtype=np.float32) + self.curr_Qd = np.zeros((6,), dtype=np.float32) + self.curr_force_lowpass = np.zeros((6,), dtype=np.float32) # force of tool tip + self.curr_force = np.zeros((6,), dtype=np.float32) + + self.reset_Q = np.array( + [np.pi / 2.0, -np.pi / 2.0, np.pi / 2.0, -np.pi / 2.0, -np.pi / 2.0, 0.0], + dtype=np.float32, + ) # reset state in Joint Space + self.reset_Pose = np.zeros_like(self.reset_Q) + self.reset_height = np.array([0.1], dtype=np.float32) # TODO make customizable + + self.delta = config.ERROR_DELTA + self.fm_damping = config.FORCEMODE_DAMPING + self.fm_task_frame = config.FORCEMODE_TASK_FRAME + self.fm_selection_vector = config.FORCEMODE_SELECTION_VECTOR + self.fm_limits = config.FORCEMODE_LIMITS + + self.ur_control: RTDEControlInterface = None + self.ur_receive: RTDEReceiveInterface = None + self.robotiq_gripper: VacuumGripper = None + + # only temporary to test + self.hist_data = [[], []] + self.horizon = [0, 500] + self.err = 0 + self.noerr = 0 + + # log to file (reset every new run) + with open("/tmp/console2.txt", "w") as f: + f.write("reset\n") + self.second_console = open("/tmp/console2.txt", "a") + + def start(self): + super().start() + if self.verbose: + print(f"[RIC] Controller process spawned at {self.native_id}") + + def print(self, msg, both=False): + self.second_console.write(f"{datetime.datetime.now()} --> {msg}\n") + if both: + print(msg) + + async def start_ur_interfaces(self, gripper=True): + self.ur_control = RTDEControlInterface(self.robot_ip) + self.ur_receive = RTDEReceiveInterface(self.robot_ip) + if gripper: + self.robotiq_gripper = VacuumGripper(self.robot_ip) + await self.robotiq_gripper.connect() + await self.robotiq_gripper.activate() + if self.verbose: + gr_string = "(with gripper) " if gripper else "" + print(f"[RIC] Controller connected to robot {gr_string}at: {self.robot_ip}") + + async def restart_ur_interface(self): + self._is_truncated.set() + self.print("[RIC] forcemode failed, is now truncated!") + + # disconnect and reconnect, otherwise the controller won't take any commands + self.ur_control.disconnect() + try: + print(f"[RTDE] trying to reconnect") + self.ur_control.reconnect() + except RuntimeError: + self.ur_receive.disconnect() + for _ in range(10): + try: + self.ur_control.disconnect() + self.ur_receive.disconnect() + await self.start_ur_interfaces(gripper=False) + return + except Exception as e: + print(e) + time.sleep(0.2) + + def stop(self): + self._stop.set() + + def stopped(self): + return self._stop.is_set() + + def is_moving(self): + return np.linalg.norm(self.get_state()["vel"], 2) > 0.01 + + def set_target_pos(self, target_pos: np.ndarray): + if target_pos.shape == (7,): + target_orientation = target_pos[3:] + elif target_pos.shape == (6,): + target_orientation = rotvec_2_quat(target_pos[3:]) + else: + raise ValueError(f"[RIC] target pos has shape {target_pos.shape}") + + with self.lock: + self.target_pos[:3] = target_pos[:3] + self.target_pos[3:] = target_orientation + + self.print(f"target: {self.target_pos}") + + def set_reset_Q(self, reset_Q: np.ndarray): + with self.lock: + self.reset_Q[:] = reset_Q + self._reset.set() + + def set_reset_pose(self, reset_pose: np.ndarray): + with self.lock: + self.reset_Pose[:] = reset_pose + self._reset.set() + + def set_gripper_pos(self, target_grip: np.ndarray): + with self.lock: + self.target_grip[:] = target_grip + + def get_target_pos(self, copy=True): + with self.lock: + if copy: + return self.target_pos.copy() + else: + return self.target_pos + + async def _update_robot_state(self): + pos = self.ur_receive.getActualTCPPose() + vel = self.ur_receive.getActualTCPSpeed() + Q = self.ur_receive.getActualQ() + Qd = self.ur_receive.getActualQd() + force = self.ur_receive.getActualTCPForce() + pressure = await self.robotiq_gripper.get_current_pressure() + obj_status = await self.robotiq_gripper.get_object_status() + + # 3-> no object detected, 0-> sucking empty, [1, 2] obj detected + grip_status = [-1.0, 1.0, 1.0, 0.0][obj_status.value] + + pressure = ( + pressure if pressure < 99 else 0 + ) # 100 no obj, 99 sucking empty, so they are ignored + # grip status, 0->neutral, -1->bad (sucking but no obj), 1-> good (sucking and obj) + grip_status = 1.0 if pressure > 0 else grip_status + pressure /= 98.0 # pressure between [0, 1] + with self.lock: + self.curr_pos[:] = pose2quat(pos) + self.curr_vel[:] = vel + self.curr_Q[:] = Q + self.curr_Qd[:] = Qd + self.curr_force[:] = np.array(force) + # use moving average (5), since the force fluctuates heavily + self.curr_force_lowpass[:] = ( + 0.1 * np.array(force) + 0.9 * self.curr_force_lowpass[:] + ) + self.gripper_state[:] = [pressure, grip_status] + + def get_state(self): + with self.lock: + state = { + "pos": self.curr_pos, + "vel": self.curr_vel, + "Q": self.curr_Q, + "Qd": self.curr_Qd, + "force": self.curr_force_lowpass[:3], + "torque": self.curr_force_lowpass[3:], + "gripper": self.gripper_state, + } + return state + + def is_ready(self): + return self._is_ready.is_set() + + def is_reset(self): + return not self._reset.is_set() + + def _calculate_force(self): + target_pos = self.get_target_pos(copy=True) + with self.lock: + curr_pos = self.curr_pos + curr_vel = self.curr_vel + + # calc position for + kp, kd = self.kp, self.kd + diff_p = np.clip( + target_pos[:3] - curr_pos[:3], a_min=-self.delta, a_max=self.delta + ) + vel_delta = 2 * self.delta * self.frequency + diff_d = np.clip(-curr_vel[:3], a_min=-vel_delta, a_max=vel_delta) + force_pos = kp * diff_p + kd * diff_d + + # calc torque + rot_diff = R.from_quat(target_pos[3:]) * R.from_quat(curr_pos[3:]).inv() + vel_rot_diff = R.from_rotvec(curr_vel[3:]).inv() + torque = ( + rot_diff.as_rotvec() * 100 + vel_rot_diff.as_rotvec() * 22 + ) # TODO make customizable + + # check for big downward tcp force and adapt accordingly + if self.curr_force[2] > 3.5 and force_pos[2] < 0.0: + force_pos[2] = ( + max((1.5 - self.curr_force_lowpass[2]), 0.0) * force_pos[2] + + min(self.curr_force_lowpass[2] - 0.5, 1.0) * 20.0 + ) + + return np.concatenate((force_pos, torque)) + + async def send_gripper_command(self, force_release=False): + if force_release: + await self.robotiq_gripper.automatic_release() + self.target_grip[0] = 0.0 + return + + timeout_exceeded = ( + time.monotonic() - self.gripper_timeout["last_grip"] + ) * 1000 > self.gripper_timeout["timeout"] + # target grip above threshold and timeout exceeded and not gripping something already + if ( + self.target_grip[0] > 0.5 + and timeout_exceeded + and self.gripper_state[1] < 0.5 + ): + await self.robotiq_gripper.automatic_grip() + self.target_grip[0] = 0.0 + self.gripper_timeout["last_grip"] = time.monotonic() + # print("grip") + + # release if below neg threshold and gripper activated (grip_status not zero) + elif self.target_grip[0] < -0.5 and abs(self.gripper_state[1]) > 0.5: + await self.robotiq_gripper.automatic_release() + self.target_grip[0] = 0.0 + # print("release") + + def _truncate_check(self): + downward_force = self.curr_force_lowpass[2] > 20.0 + if downward_force: # TODO add better criteria + self._is_truncated.set() + else: + self._is_truncated.clear() + + def is_truncated(self): + return self._is_truncated.is_set() + + def run(self): + try: + asyncio.run( + self.run_async() + ) # gripper has to be awaited, both init and commands + finally: + self.stop() + + async def _go_to_reset_pose(self): + self.ur_control.forceModeStop() + + # first disable vaccum gripper + if self.robotiq_gripper: + await self.send_gripper_command(force_release=True) + time.sleep(0.01) + + # then move up (so no boxes are moved) + success = True + while self.curr_pos[2] < self.reset_height: + if ( + self.curr_Q[2] < 0.5 + ): # if the shoulder joint is near 180deg --> do not move into singularity + success = success and self.ur_control.speedJ( + [0.0, -1.0, 1.0, 0.0, 0.0, 0.0], acceleration=0.8 + ) + else: + success = success and self.ur_control.speedL( + [0.0, 0.0, 0.25, 0.0, 0.0, 0.0], acceleration=0.8 + ) + await self._update_robot_state() + time.sleep(0.01) + self.ur_control.speedStop(a=1.0) + + if self.reset_Pose.std() > 0.001: + success = success and self.ur_control.moveL( + self.reset_Pose, speed=0.5, acceleration=0.3 + ) + self.print( + f"[RIC] moving to {self.reset_Pose} with moveL (task space)", + both=self.verbose, + ) + self.reset_Pose[:] = 0.0 + else: + # then move to desired Jointspace position + success = success and self.ur_control.moveJ( + self.reset_Q, speed=1.0, acceleration=0.8 + ) + self.print( + f"[RIC] moving to {self.reset_Q} with moveJ (joint space)", + both=self.verbose, + ) + + time.sleep(0.1) # wait for 100ms + await self._update_robot_state() + with self.lock: + self.target_pos = self.curr_pos.copy() + + self.ur_control.forceModeSetDamping(self.fm_damping) # less damping = Faster + self.ur_control.zeroFtSensor() + + if not success: # restart if not successful + await self.restart_ur_interface() + else: + self._reset.clear() + + async def run_async(self): + await self.start_ur_interfaces(gripper=True) + + self.ur_control.forceModeSetDamping(self.fm_damping) # less damping = Faster + + try: + dt = 1.0 / self.frequency + self.ur_control.zeroFtSensor() + await self._update_robot_state() + self.target_pos = self.curr_pos.copy() + print(f"[RIC] target position set to curr pos: {self.target_pos}") + + self._is_ready.set() + + while not self.stopped(): + if self._reset.is_set(): + await self._update_robot_state() + await self._go_to_reset_pose() + + t_now = time.monotonic() + + # update robot state and check for truncation + await self._update_robot_state() + self._truncate_check() + + # calculate force + force = self._calculate_force() + # print(self.target_pos, self.curr_pos, force) + self.print( + f" p:{self.curr_pos} f:{self.curr_force_lowpass} gr:{self.gripper_state}" + ) # log to file + + # send command to robot + t_start = self.ur_control.initPeriod() + fm_successful = self.ur_control.forceMode( + self.fm_task_frame, + self.fm_selection_vector, + force, + 2, + self.fm_limits, + ) + if not fm_successful: # truncate if the robot ends up in a singularity + await self.restart_ur_interface() + await self._go_to_reset_pose() + + if self.robotiq_gripper: + await self.send_gripper_command() + + self.ur_control.waitPeriod(t_start) + + a = dt - (time.monotonic() - t_now) + time.sleep(max(0.0, a)) + self.err, self.noerr = self.err + int(a < 0.0), self.noerr + int( + a >= 0.0 + ) # some logging + if a < -0.04: # log if delay more than 50ms + self.print( + f"Controller Thread stopped for {(time.monotonic() - t_now)*1e3:.1f} ms" + ) + + finally: + if self.verbose: + print( + f"[RTDEPositionalController] >dt: {self.err}