diff --git a/.gitignore b/.gitignore index 9d6232dd..672039d1 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,7 @@ MUJOCO_LOG.TXT _METADATA checkpoint wandb/ + +# VS Code settings +*.code-workspace + diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..6b76b4fa --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index f6c9bccd..1e399f51 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ We fixed a major issue in the intervention action frame. See release [v0.1.1](ht - For GPU: ```bash - pip install --upgrade "jax[cuda12_pip]==0.4.35" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install --upgrade "jax[cuda12]==0.6.2" ``` - For TPU @@ -69,6 +69,27 @@ We fixed a major issue in the intervention action frame. See release [v0.1.1](ht pip install -r requirements.txt ``` + + +4. **Install the franka_sim** + ```bash + cd franka_sim + pip install -e . + pip install -r requirements.txt + ``` + +5. **Install the serl_robot_infra** + ```bash + cd serl_robot_infra + pip install -e . + ``` + +6. **Install the demos** + ```bash + cd demos + pip install -e . + ``` + ## Overview and Code Structure SERL provides a set of common libraries for users to train RL policies for robotic manipulation tasks. The main structure of running the RL experiments involves having an actor node and a learner node, both of which interact with the robot gym environment. Both nodes run asynchronously, with data being sent from the actor to the learner node via the network using [agentlace](https://github.com/youliangtan/agentlace). The learner will periodically synchronize the policy with the actor. This design provides flexibility for parallel training and inference. diff --git a/demos/demos/__init__.py b/demos/demos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py new file mode 100644 index 00000000..04a5e974 --- /dev/null +++ b/demos/demos/demoHandling.py @@ -0,0 +1,157 @@ +import os +from pathlib import Path +import numpy as np +from agentlace.data.data_store import QueuedDataStore + +class DemoHandling: + """ + Koads an .npz file containing demonstration data into a data object. + This class is designed to work with Gymnasium-style demonstration data + and is intended to be used with a QueuedDataStore or similar data store. + + The .npz file should contain the following arrays: + - 'obs' : shape (N, T+1, *obs_shape*), list of observations + - 'acs' : shape (N, T, *act_shape*), list of actions + - 'rewards' : shape (N, T), list of rewards + - 'terminateds' : shape (N, T), list of terminated flags + - 'truncateds' : shape (N, T), list of truncated flags + - 'info' : shape (N, T), list of info dicts + - 'dones' : shape (N, T), list of done flags (if available) + + Parameters + ---------- + demo_dir : str + Directory where demo .npz files live by default. + file_name : str + Name of the demo file to load. If not provided, a default will be used. + """ + def __init__( + self, + demo_dir: str = '/data/data/serl/demos', + file_name: str = 'data_franka_reach_random_20.npz' + ): + + self.debug = False # Set to True for debugging purposes + self.demo_dir = demo_dir + self.transition_ctr = 0 # Global counter for transitions across all episodes + + # Load the demo data from the .npz file + + # Check if the demo directory exists + if not os.path.exists(self.demo_dir): + raise FileNotFoundError(f"Demo directory '{self.demo_dir}' does not exist.") + + # Construct the full path to the demo file + self.demo_npz_path = os.path.join(self.demo_dir, file_name) + if not os.path.isfile(self.demo_npz_path): + raise FileNotFoundError(f"Demo file '{self.demo_npz_path}' does not exist.") + + # Load the .npz file + self.data = np.load(self.demo_npz_path, allow_pickle=True) + + def get_num_transitions(self): + """ + Returns the total number of transitions counted in the demo data. + """ + return int(self.data["transition_ctr"]) if "transition_ctr" in self.data else 0 + + def get_num_demos(self): + """ + Returns the total number of demonstrations in the demo data. + """ + return int(self.data["num_demos"]) if "num_demos" in self.data else 0 + + def insert_data_to_buffer(self,data_store: QueuedDataStore): + """ + Load a raw Gymnasium-style .npz of expert episodes into data_store. + The .npz file must contain arrays named 'obs', 'acs', 'rewards', + 'terminateds', 'truncateds', 'info', and optionally 'dones'. + Each episode is processed, and transitions are inserted into the data_store. + Inserted transitions in data store will remain in the data_store as pointers. + + ***Note*** + Need to insert obs and acs in the same way as async_sac_state via jax + + Parameters + ---------- + data_store : QueuedDataStore + + Returns + ------- + None + """ + + obs_buffer = self.data['obs'] # shape (N, T+1, ...) + act_buffer = self.data['acs'] # shape (N, T, ...) + rew_buffer = self.data['rewards'] # shape (N, T) + term_buffer = self.data['terminateds'] # shape (N, T) + trunc_buffer = self.data['truncateds'] # shape (N, T) + info_buffer = self.data['info'] # shape (N, T) + done_buffer = self.data['dones'] # shape (N, T) #.get('dones', term_buffer | trunc_buffer) + + num_demos = self.get_num_demos() + if num_demos == 0: + raise ValueError("No demonstrations found in the provided .npz file.") + + num_transitions = self.get_num_transitions() + if num_transitions == 0: + raise ValueError("No transitions found in the provided .npz file.") + + + # Extract the number of episodes and transitions + for ep in range(num_demos): + ep_obs = obs_buffer[ep] + ep_acts = act_buffer[ep] + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] + + T = len(ep_acts) + for t in range(T): + obs_t = np.asarray(ep_obs[t], dtype=np.float32) + next_obs_t = np.asarray(ep_obs[t+1], dtype=np.float32) + a_t = np.asarray(ep_acts[t], dtype=np.float32) + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + #info_t = ep_info[t] + # masks will be created right before insert below + + if self.debug: + np.set_printoptions(precision=3, suppress=True) + + print(f"Demo {ep:2}, Step {t:3} \n " + f"Obs: [{obs_t[0]:.2f} {obs_t[1]:.2f} {obs_t[2]:.2f}] \n " + f"Action: [{a_t[0]:.2f} {a_t[1]:.2f} {a_t[2]:.2f}] \n " + f"Reward: {r_t:.2f} \n " + f"Done: {done_t}") + + # Insert using SERLs data_store/ReplayBuffer insert mechanism directly. + data_store.insert( + dict( + observations =obs_t, + actions =a_t, + next_observations=next_obs_t, + rewards =r_t, + masks =1.0 - done_t, + dones =done_t + ) + ) + + print(f"Loaded a total of {num_transitions} from {num_demos} episodes from '{self.demo_npz_path}' ") + + +# if __name__ == "__main__": +# # Instantiate a DemoHandling object +# handler = DemoHandling(demo_dir='/data/data/serl/demos', +# file_name='data_franka_reach_random_20.npz') + +# # Idenitfy the total number of transitions in the datastore +# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') + +# # Simulate SERL's datastore creation w/ capacity 2000 +# ds = QueuedDataStore(2000) + +# # Insert the demo data into the datastore +# handler.insert_data_to_buffer(ds) diff --git a/demos/demos/franka_pick_n_place_drq_demo_script.py b/demos/demos/franka_pick_n_place_drq_demo_script.py new file mode 100644 index 00000000..3e1a8458 --- /dev/null +++ b/demos/demos/franka_pick_n_place_drq_demo_script.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +import os +import time +from datetime import datetime + +import numpy as np + +# Logging +from absl import app, flags, logging +from oxe_envlogger.envlogger import AutoOXEEnvLogger + +# DRL +import gym +import mujoco +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper + +# Needed to create the franka environment +import franka_sim + +# Teleoperation imports +import sys, select, termios, tty + +# RLDS/TFDS +import json, glob, inspect +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets import folder_dataset +#------------------------------------------------------------------------------------------- +# Flags +#------------------------------------------------------------------------------------------- +flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") +flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging +#flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", + "Directory to save the output data. This is where the RLDS logs will be saved.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") +flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") + +FLAGS = flags.FLAGS + +#------------------------------------------------------------------------------------------- +## Telop Config Variables +#------------------------------------------------------------------------------------------- +ACTION_MAX = 10 # Maximum action value for clipping actions + +# Bind, xyz, gripper vals to keys +moveBindings = { + 'i':(1,0,0,0), + ',':(-1,0,0,0), + 'j':(0,1,0,0), + 'l':(0,-1,0,0), + 'u':(0,0,0,1), + 'o':(0,0,0,-1), + 'm':(0,0,1,0), + '.':(0,0,-1,0), + 'g':(0,0,0,0.1), # open gripper + 'h':(0,0,0,-0.1), # close gripper + + } + +# Extend bindings to include camera controls +camBindings = { + 'a': ("azimuth", -5), # rotate left + 'd': ("azimuth", 5), # rotate right + 'w': ("elevation", 2), # tilt up + 's': ("elevation", -2), # tilt down + 'q': ("distance", -0.1),# zoom in + 'e': ("distance", 0.1), # zoom out +} + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def update_camera(viewer,key): + """ + Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. + Controls: + 'a' : rotate left + 'd' : rotate right + 'w' : tilt up + 's' : tilt down + 'q' : zoom in + 'e' : zoom out + """ + if hasattr(viewer, 'cam'): + # Get current camera parameters + attr, delta = camBindings[key] + val = getattr(viewer.cam, attr) + + setattr(viewer.cam, attr, val + delta) + +def close_logger_and_env(env): + """ + Best-effort shutdown: + 1) close embedded envloggers/writers if present + 2) close dm_env (if you have a DeepMind-style env) + 3) close Gym/Gymnasium env + 4) close the Mujoco viewer + Also walks through common wrapper attributes (env, _env, unwrapped). + """ + import logging + + visited = set() + + def _safe_close(obj, label=""): + if obj is None or id(obj) in visited: + return + visited.add(id(obj)) + + # 1) Close common logger/writer attributes first (flush TFRecord) + for name in ("_envlogger", "envlogger", "logger", "_logger", "writer"): + try: + logger_obj = getattr(obj, name, None) + if logger_obj is not None and hasattr(logger_obj, "close"): + logger_obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.{name}.close() raised {e!r}") + + # 2) Close dm_env if present + try: + dm = getattr(obj, "dm_env", None) + if dm is not None and hasattr(dm, "close"): + dm.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.dm_env.close() raised {e!r}") + + # 3) Close Gym/Gymnasium env + try: + if hasattr(obj, "close"): + obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.close() raised {e!r}") + + # 4) Close Mujoco viewer if accessible + try: + # Some stacks keep viewer at env._viewer.viewer; others just env._viewer + viewer = None + vwrap = getattr(obj, "_viewer", None) + if vwrap is not None: + viewer = getattr(vwrap, "viewer", vwrap) + if viewer is not None and hasattr(viewer, "close"): + viewer.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label} viewer close raised {e!r}") + + # Recurse into common wrapper links + for child_name in ("env", "_env", "unwrapped", "environment", "base_env"): + child = getattr(obj, child_name, None) + if child is not None and child is not obj: + _safe_close(child, f"{label}.{child_name}" if label else child_name) + + _safe_close(env, "env") + +def finalize_tfds_metadata_beamless(builder_dir: str): + """ + Beam-free finalize: count TFRecord examples per shard and write + numShards/shardLengths into dataset_info.json so TFDS will load. + """ + import os, json, glob + import tensorflow as tf + + info_path = os.path.join(builder_dir, "dataset_info.json") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + + with open(info_path) as f: + info = json.load(f) + + ds_name = info["name"] # e.g. "PandaReachSparseCube-v0" + file_fmt = info.get("fileFormat", "tfrecord") + tmpl_str = info["splits"][0]["filepathTemplate"] + + # Prefer strict pattern "-.-" + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{ds_name}-*.{file_fmt}-*"))) + if not shard_paths: + # Fallback to any tfrecord-like file + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_fmt}*"))) + if not shard_paths: + raise FileNotFoundError( + f"No {file_fmt} shards found in {builder_dir}. " + f"Expected like '{ds_name}-train.{file_fmt}-00000' per template '{tmpl_str}'." + ) + + # Count episodes (1 Example = 1 episode with envlogger/RLDS) + shard_lengths = [sum(1 for _ in tf.data.TFRecordDataset(p)) for p in shard_paths] + + # Write lengths for each split using this template + for s in info["splits"]: + if s.get("filepathTemplate") == tmpl_str: + s["numShards"] = len(shard_paths) + s["shardLengths"] = shard_lengths + # Re-write dataset_info.json with updated shard info + with open(info_path, "w") as f: + json.dump(info, f, indent=2) + + # Sanity log + import tensorflow_datasets as tfds + b = tfds.builder_from_directory(builder_dir) + print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + +def ensure_dir_exists(): + """ + For oxe_envlogger + RLDS compatibility, data must be written in the following format: + /data/data/serl/demos/franka_reach_drq_demo_script/ + └── session_20250821_222412/ + └── PandaReachSparseCube-v0/ + └── 0.1.0/ + dataset_info.json + features.json + PandaReachSparseCube-v0-train.tfrecord-00000-of-00001 + + We can have a base path with customized sessions inside. + Inside each session we have: env-version-files + + + Returns + ------- + out_path : str + The path to the output directory. + """ + # Customize the path + root = FLAGS.output_dir + session = datetime.now().strftime("session_%Y%m%d_%H%M%S") + session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") + #session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") + + # Dataset details + dataset_name = FLAGS.env + version = "0.1.0" # PArt of RLDS format. needed. + + # Create output filename with configuration details + dataset_dir = os.path.join(session_root, dataset_name, version) + os.makedirs(dataset_dir, exist_ok=True) + logging.info(f"TFDS builder dir: {dataset_dir}") + + return dataset_dir + +def getKey(settings): + """ + Waits briefly for a keypress and returns the pressed key. + + Parameters + ---------- + settings : list + Original terminal settings so they can be restored after reading. + + Returns + ------- + key : str + The key pressed by the user, or '' (empty string) if no key was pressed. + """ + # Put terminal into raw mode so keypress is captured instantly + tty.setraw(sys.stdin.fileno()) + + # Wait for human to input action, we will not provide a timeout so it is blocking. + rlist, _, _ = select.select([sys.stdin], [], []) # select.select(rlist, wlist, xlist[, timeout]) + + if rlist: + # Read exactly one character if a key is pressed + key = sys.stdin.read(1) + else: + key = '' + + # Restore terminal to original settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + return key + +def get_kb_demo_action(env,speed=0.075): + """ + Reads keyboard input and maps it to a 3D action vector for robot control or camera action. + TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. + Otherwise, none actions are still considered steps in the loop. + + The function uses non-blocking keyboard input to allow interactive + teleoperation. Keys are mapped to directions in Cartesian space: + - 'i' : +x (forward) + - ',' : -x (backward) + - 'j' : +y (left) + - 'l' : -y (right) + - 'm' : +z (up) + - '.' : -z (down) + - 'k' : stop (zero vector) + + Parameters + ---------- + speed : float, optional + Step size for each key press (default 0.2). + + Returns + ------- + np.ndarray + Action vector of shape (3,), where each entry corresponds to + [x, y, z] translation command. Example: [0.2, 0.0, 0.0]. + """ + # Save current terminal settings so we can restore later + settings = termios.tcgetattr(sys.stdin) + + # Initialize action as a zero vector (no movement) + action = np.zeros(4, dtype=float) + + try: + # Capture the pressed key + key = getKey(settings) + + # Check keys for camera first, if so, update camera and get another key for action + if key in camBindings: + if hasattr(env.unwrapped, "_viewer"): + update_camera(env.unwrapped._viewer.viewer,key) + key = getKey(settings) # get another key for action + + elif key in moveBindings: + # Lookup (x, y, z) direction and scale by speed + dx, dy, dz, g= moveBindings[key] + action = np.array([dx, dy, dz, g], dtype=float) * speed + + elif key == 'k': + # 'k' means stop → zero vector + action = np.zeros(4, dtype=float) + + elif key == '\x03': # CTRL-C + raise KeyboardInterrupt + + finally: + # Restore terminal even if something goes wrong + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + + # Clip action values to prevent excessive commands + action = np.clip(action, -ACTION_MAX, ACTION_MAX) + + return action + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 2.0 # Camera distance + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +############################################################################## +def main(unused_argv): + logging.info(f'Creating gym environment...') + + # Render mode configuration based on debug flag + if FLAGS.debug: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' + else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + # Create the environment with the specified render mode and wrappers + env = gym.make(FLAGS.env, render_mode=_render_mode) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env == "PandaReachSparseCube-v0" or FLAGS.env == "PandaPickCubeVision-v0": + env = SERLObsWrapper( + env, + target_hw=(128, 128), + img_dtype=np.uint8, # or np.float32 + normalize=False, # True if using float32 in [0,1] + ) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + logging.info(f'Done creating {FLAGS.env} environment.') + + # Set camera to front view if viewer is available + if hasattr(env.unwrapped, '_viewer'): + viewer = set_front_cam_view(env) + if viewer: + logging.info('Camera view set to front-facing perspective.') + else: + logging.warning('Failed to set camera view. Viewer not available.') + + # Wrap with oxe_envlogger to record demos + dataset_dir = None + session_root = None + if FLAGS.enable_envlogger: + + dataset_dir = ensure_dir_exists() + + # Will save as many episodes as possible into files of 200MB each by default. + env = AutoOXEEnvLogger( + env=env, + dataset_name=FLAGS.env, + directory=dataset_dir, + #split_name="train", # "train", "test", or "validation" + ) + logging.info('Recording %r demos...', FLAGS.num_demos) + + #--- LOOP DEMOS/EPISODES --- + # Loop through the number of demos specified by the user to record demonstrations + try: + for i in range(FLAGS.num_demos): + + # Log custom metadata during new episode: language embeddings randomly. + if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) + + # Start a new episode + env.reset() + terminated = False + truncated = False + + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached + while not (terminated or truncated): + + # Get action from the demo function + action = get_kb_demo_action(env) + + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": np.float32(time.time())}) + + return_step = env.step(action) + + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False + + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 + + logging.info('Done recording %r demos.', FLAGS.num_demos) + + finally: + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: + + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") + + # closes logger(s) + dm_env + gym + viewer + # env.unwrapped.unwrapped.unwrapped.close() # close mujoco viewer + # env.env.env.env.close() + close_logger_and_env(env) + + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) + + # Note: async_drq_sim will read from the dataset_dir you printed above. + +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/demos/demos/franka_pick_place_demo_script.py b/demos/demos/franka_pick_place_demo_script.py new file mode 100755 index 00000000..e5abbc11 --- /dev/null +++ b/demos/demos/franka_pick_place_demo_script.py @@ -0,0 +1,407 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with stable-baselines3. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences +""" +import os +import numpy as np +import gymnasium as gym +from gymnasium.wrappers import TimeLimit +from time import sleep +from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the pick-and-place task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ file for use with stable-baselines3. + """ + # Initialize Fetch pick-and-place environment + env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") + env = TimeLimit(env, max_episode_steps=50) + + # Adjust physical settings + # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + num_demos = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() + print("Reset!") + + # Generate demonstration episodes + while len(actions) < attempted_demos: + obs,_ = env.reset() # Reset environment for new episode + print(f"We will run a total of: {attempted_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = pick_and_place_demo(env, obs) + + # Print success message + if res: + num_demos += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + + ## Write data to demos folder + # 1. Get the absolute path of this script + #script_path = os.path.abspath(__file__) + + # 2. Extract its directory + #script_dir = os.path.dirname(script_path) + script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. + + # 3. Create output filename with configuration details + fileName = "data_" + robot + fileName += "_" + initStateSpace + fileName += "_" + str(attempted_demos) + fileName += ".npz" + + # 3. Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones) + + print(f"Data saved to {fileName}.") + +def pick_and_place_demo(env, lastObs): + """ + Executes a scripted pick-and-place sequence using a hierarchical approach. + + Implements 4-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + 2. Grasp: Move to object and close gripper + 3. Transport: Move grasped object to goal position + 4. Maintain: Hold position until episode ends + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Args: + env: Gymnasium environment instance + lastObs: Last observation containing goal and object state information + - desired_goal: Target position for object placement + - observations: + ee_position[0:3], + ee_velocity[3:6], + fingers_width[6], + object_position[7:10], + object_rotation[10:13], + object_velp[13:16], + object_velr[16:19], + """ + + ## Init goal, current_pos, and object position from last observation + goal = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + object_pos = np.zeros(3, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + fgr_pos = np.zeros(1, dtype=np.float32) + + # Initialize episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Proportional control gain for action scaling -- empirically tuned + Kp = 8.0 + + # pre_pick_offset + pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) + + # Error thresholds + error_threshold = 0.011 # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + # Extract desired position from desired_goal dict + goal = lastObs["desired_goal"][0:3] + + # Current robot end-effector position from observation dict + current_pos = lastObs["observation"][0:3] + + # Current object position from observation dict: + object_pos = lastObs["observation"][7:10] + + # Relative position between end-effector and object + object_rel_pos = object_pos - current_pos + + ## Phase 1: Approach Object (Above) + # Create target position 3cm above the object. Use copy() method. + error = object_rel_pos.copy() + error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. + + timeStep = 0 # Track total timesteps in episode + episodeObs.append(lastObs) + + # Phase 1: Move gripper to position above object + # Terminate when distance to above-object position < 5mm + print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: + env.render() # Visual feedback + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Proportional control with gain of 6 + # action = Kp * error + action[:3] = error * Kp + + # Open gripper for approach + action[ len(action)-1 ] = 0.05 + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeAcs.append(action) + episodeInfo.append(info) + episodeRews.append(reward) + episodeObs.append(new_obs) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + # Phase 2: Descend Grasp Object + # Move gripper directly to object and close gripper + # Terminate when relative distance to object < 5mm + print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") + error = object_pos - current_pos # remove offset + while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm + env.render() + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Direct proportional control to object position + action[:3] = error * Kp + + # Close gripper to grasp object + action[len(action)-1] = -finger_delta_fast * 2 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + #sleep(0.5) # Optional: Slow down for better visualization + + # Phase 3: Transport to Goal + # Move grasped object to desired goal position + # Terminate when distance between object and goal < 1cm + print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") + + # Weld activation + if weld_flag: + activate_weld(env, constraint_name="grasp_weld") + + # Set error between goal and hand assuming the object is grasped + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: + env.render() + + action = np.array([0., 0., 0., 0.]) + + # Proportional control toward goal position + action[:3] = gh_error[:3] * Kp + + # Maintain grip on object + #action[len(action)-1] = 0 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + + # Print debug information + print( + f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"goal_pos: {np.array2string(goal, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(gh_error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + sleep(0.5) # Optional: Slow down for better visualization + + ## Check for success and store episode data + gh_norm = np.linalg.norm(gh_error) + ho_nomr = np.linalg.norm(ho_error) + if gh_norm < error_threshold and ho_nomr < error_threshold: + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + actions.append(episodeAcs) + observations.append(episodeObs) + infos.append(episodeInfo) + rewards.append(episodeRews) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episodeTerminated) + truncateds.append(episodeTruncated) + dones.append(episodeDones) + + # Deactivate weld constraint after successful pick + if weld_flag: + deactivate_weld(env, constraint_name="grasp_weld") + + # Close mujoco viewer + env.close() + + # Break out of the loop to start a new episode + return True + + # If we reach here, the episode was not successful + if weld_flag: + print("Failed to transport object to goal position. Deactivating weld.") + deactivate_weld(env, constraint_name="grasp_weld") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/demos/franka_reach_demo_script.py b/demos/demos/franka_reach_demo_script.py new file mode 100755 index 00000000..f97848a4 --- /dev/null +++ b/demos/demos/franka_reach_demo_script.py @@ -0,0 +1,453 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with your desired algorithm. + +Note that different error thresholds lead to different rewards and in turn done values. +Adjust carefully. Depending on the the controller and clipping scaling (ACTION_MAX) settings, you may get very different behaviors. +The current program clips actions and leads to small increments that allows to more precise movements and close the error. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences + +TODO: Convert to a class-based structure for better modularity and reusability. +""" +import os +import numpy as np +import gym +from time import sleep, perf_counter +from datetime import datetime + +import franka_sim +import franka_sim.envs.panda_reach_gym_env as panda_reach_env + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode +transition_ctr = 0 # Global counter for transitions across all episodes + +#------------------------------------------------------------------------------------------- +## Key Config Variables +#------------------------------------------------------------------------------------------- +# Proportional and derivative control gain for action scaling -- empirically tuned +Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 +Kv = 10.0 + +ACTION_MAX = 10 # Maximum action value for clipping actions +ERROR_THRESHOLD = 0.008 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. + +# Number of demonstration episodes to generate +NUM_DEMOS = 20 + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' +task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' + +# Debug mode for rendering and visualization +DEBUG = False + +if DEBUG: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' +else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + +# Indices for franka_sim reach environment observations +if robot == 'franka' and task == 'reach': + + opi = np.array([0, 3]) # Indices for object position in observation + gpi = np.array([3]) # Indices for gripper position in observation + rpi = np.array([4, 7]) # Indices for robot position in observation + rvi = np.array([7, 10]) # Indices for robot velocity in observation + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + +#------------------------------------------------------------------------------------------- +# Franka sim environments do not have weld constraints like the franka_mujoco environments. +# def activate_weld(env, constraint_name="grasp_weld"): +# """ +# Activate a weld constraint during pick portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to activate +# :return: True if the weld was successfully activated, False if the constraint was not found +# """ + +# try: +# # Activate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 1 +# print("Activated weld") +# return True + +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +# def deactivate_weld(env, constraint_name="grasp_weld"): +# """ +# Deactivate a weld constraint during place portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to deactivate +# :return: True if the weld was successfully deactivated, False if the constraint was not +# found +# """ + +# try: +# # Deactivate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 0 +# print("Deactivated weld") +# return True + +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 3.0 # Camera distance + viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): + """ + Store transition data in the episode dictionary and update global counter. + """ + global transition_ctr + transition_ctr += 1 + + episode_dict["observations"].append(new_obs) + episode_dict["rewards"].append(rewards) + episode_dict["actions"].append(action) + episode_dict["infos"].append(info) + episode_dict["terminateds"].append(terminated) + episode_dict["truncateds"].append(truncated) + episode_dict["dones"].append(done) + +def store_episode_data(episode_data): + """ + Store complete episode data in global lists only if we succeeded (avoid bad demos). + """ + actions.append(episode_data["actions"]) + observations.append(episode_data["observations"]) + infos.append(episode_data["infos"]) + rewards.append(episode_data["rewards"]) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episode_data["terminateds"]) + truncateds.append(episode_data["truncateds"]) + dones.append(episode_data["dones"]) + +def update_state_info(episode_data, time_step, dt, error, reward): + """ + Update and return the current state information. Always get the latest entry with [-1] + + Args: + new_obs (dict): New observation dictionary containing the current state. + time_step (int): Current time step in the episode. + dt (float): Current time step in the episode. + error (np.ndarray): Current error vector between object and end-effector positions. + reward (float): Current reward value for the action taken. + + Returns: + object_pos (np.ndarray): Current position of the object in the environment. + gripper_pos (float): Current position of the gripper. + current_pos (np.ndarray): Current position of the end-effector. + current_vel (np.ndarray): Current velocity of the end-effector. + """ + object_pos = episode_data["observations"][-1][ opi[0]:opi[1] ] # Block position + gripper_pos = episode_data["observations"][-1][ gpi[0] ] # Gripper position + current_pos = episode_data["observations"][-1][ rpi[0]:rpi[1] ] # Panda/tcp position + current_vel = episode_data["observations"][-1][ rvi[0]:rvi[1] ] # Panda/t + + + # Print debug information + print( + f"Step: {time_step}, ErrNorm: {np.linalg.norm(error):.4f}, " + f"bot_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(gripper_pos, precision=2)}, " + f"err: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(episode_data['actions'][-1], precision=3)}, " + f"dt: {dt:.4f}, reward: {reward:.3f}" + ) + + + return object_pos, gripper_pos, current_pos, current_vel + +def compute_error(object_pos, current_pos, prev_error, dt): + """Compute the error and its derivative between the object position and the current end-effector position. + Args: + object_pos (np.ndarray): The position of the object in the environment. + current_pos (np.ndarray): The current position of the end-effector. + prev_error (np.ndarray): The previous error value for derivative calculation. + dt (float): Time step for derivative calculation. + + Returns: + error (np.ndarray): The current error vector between the object and end-effector positions. + derror (np.ndarray): The derivative of the error vector. + """ + error = object_pos - current_pos # Calculate the error vector + derror = (error - prev_error) / dt # Calculate the derivative of the error vector + + prev_error = error.copy() # Update previous error for next iteration + return error, derror + +def demo(env, lastObs): + """ + Executes a scripted reach sequence using a hierarchical approach. + + Implements 1-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Gripper: + - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. + + Args: + env: Gymnasium environment instance + lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel + - observations: + object_pos[0:3], + gripper_pos[3] + panda/tcp_pos[4:7], + panda/tcp_vel[7:10] + + Returns: + """ + + ## Init goal, current_pos, and object position from last observation + object_pos = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + gripper_pos = np.zeros(1, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + + + # Initialize (single) episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Dictionary to store episode data + episode_data = { + "observations": episodeObs, + "actions": episodeAcs, + "rewards": episodeRews, + "infos": episodeInfo, + "terminateds": episodeTerminated, + "truncateds": episodeTruncated, + "dones": episodeDones + } + + # close gripper + fgr_pos = 0 + + # Error thresholds + error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + object_pos = lastObs[opi[0]:opi[1]] # block pos + current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos + + # Relative position between end-effector and object + dt = env.unwrapped.model.opt.timestep # Mujoco time step + prev_error = np.zeros_like(object_pos) + error, derror = compute_error(object_pos, current_pos, prev_error, dt) + + time_step = 0 # Track total time_steps in episode + episodeObs.append(lastObs) # Store initial observation + + # Initialize previous time for dt calculation + prev_time = perf_counter() # Start time for dt calculation + + # Phase 1: Reach + # Terminate when distance to above-object position < error_threshold + print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: + env.render() # Visual feedback + + # Record current time and compute dt + curr_time = perf_counter() + dt = curr_time - prev_time + prev_time = curr_time + + # Initialize action vector [x, y, z] + action = np.array([0., 0., 0.]) + + # Proportional control with gain of 6 + action[:3] = error * Kp + derror * Kv + prev_error = error.copy() # Update previous error for next iteration + + # Clip action to prevent excessive movements + action = np.clip(action/ACTION_MAX, -0.1, 0.1) # + + # Keep gripper closed -- no need. only 3 dimensions of control + #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + + # Store episode data + store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) + + # Update and print state information + object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) + + # Update error for next iteration + error, derror = compute_error(object_pos, cur_pos, prev_error, dt) + + # Update time step + time_step += 1 + + # Sleep + #if DEBUG: + sleep(0.25) # Activated when DEBUG is True for better visualization. + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + store_episode_data(episode_data) + + # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. + # if weld_flag: + # deactivate_weld(env, constraint_name="grasp_weld") + + # Break out of the loop to start a new episode + return True + + # # If we reach here, the episode was not successful + # if weld_flag: + # print("Failed to transport object to goal position. Deactivating weld.") + # deactivate_weld(env, constraint_name="grasp_weld") + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ. + + Arguments that can be configured with flags: + - env + - render + - demo_ctr + + """ + # Initialize the Panda environment. + env = gym.make("PandaReachCube-v0", render_mode=_render_mode) + env = gym.wrappers.FlattenObservation(env) + + # Adjust physical settings + # env.model.opt.time_step = 0.001 # Smaller time_step for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + num_demos = NUM_DEMOS # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + demo_ctr = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() # For reach environment expect 10 observations: r_pos, r_vel, finger, object_pos. + + # Adjust camera view for better visualization + viewer = set_front_cam_view(env) + + print("Reset!") + + # Generate demonstration episodes + while len(actions) < num_demos: + obs,_ = env.reset() # Reset environment for new episode + + print(f"We will run a total of: {num_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = demo(env, obs) + + # Print success message + if res: + demo_ctr += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + + # Close the environment after all episodes are done + env.close() + + ## Write data to demos folder. Assumes mounted /data folder and internal data folder. + script_dir = '/data/data/serl/demos' + + # Create output filename with configuration details + fileName = "data_" + robot + "_" + task + fileName += "_" + initStateSpace + fileName += "_" + str(num_demos) + + # Add timestamp to filename for uniqueness + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + fileName += "_" + timestamp + fileName += ".npz" + + # Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Ensure the directory exists + os.makedirs(script_dir, exist_ok=True) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict and values as np.arrays of type objects. This allows you to handle different lengths. + np.savez_compressed( + out_path, + acs=np.array(actions, dtype=object), + obs=np.array(observations, dtype=object), + rewards=np.array(rewards, dtype=object), + info=np.array(infos, dtype=object), + terminateds=np.array(terminateds, dtype=object), + truncateds=np.array(truncateds, dtype=object), + dones=np.array(dones, dtype=object), + transition_ctr=transition_ctr, + num_demos=num_demos + ) + + print(f"Data saved to {fileName}.") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + +if __name__ == "__main__": + main() diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py new file mode 100644 index 00000000..b793cda3 --- /dev/null +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +import os +import time +from datetime import datetime + +import numpy as np + +# Logging +from absl import app, flags, logging +from oxe_envlogger.envlogger import AutoOXEEnvLogger + +# DRL +import gym +import mujoco +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper + +# Needed to create the franka environment +import franka_sim + +# Teleoperation imports +import sys, select, termios, tty + +# RLDS/TFDS +import json, glob, inspect +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets import folder_dataset +#------------------------------------------------------------------------------------------- +# Flags +#------------------------------------------------------------------------------------------- +flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") +flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging +#flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", + "Directory to save the output data. This is where the RLDS logs will be saved.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") +flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") + +FLAGS = flags.FLAGS + +#------------------------------------------------------------------------------------------- +## Telop Config Variables +#------------------------------------------------------------------------------------------- +ACTION_MAX = 10 # Maximum action value for clipping actions + +# Bind, xyz, gripper vals to keys +moveBindings = { + 'i':(1,0,0,0), + ',':(-1,0,0,0), + 'j':(0,1,0,0), + 'l':(0,-1,0,0), + 'u':(0,0,0,1), + 'o':(0,0,0,-1), + 'm':(0,0,1,0), + '.':(0,0,-1,0), + } + +# Extend bindings to include camera controls +camBindings = { + 'a': ("azimuth", -5), # rotate left + 'd': ("azimuth", 5), # rotate right + 'w': ("elevation", 2), # tilt up + 's': ("elevation", -2), # tilt down + 'q': ("distance", -0.1),# zoom in + 'e': ("distance", 0.1), # zoom out +} + +def update_camera(viewer,key): + """ + Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. + Controls: + 'a' : rotate left + 'd' : rotate right + 'w' : tilt up + 's' : tilt down + 'q' : zoom in + 'e' : zoom out + """ + if hasattr(viewer, 'cam'): + # Get current camera parameters + attr, delta = camBindings[key] + val = getattr(viewer.cam, attr) + + setattr(viewer.cam, attr, val + delta) + +def close_logger_and_env(env): + """ + Best-effort shutdown: + 1) close embedded envloggers/writers if present + 2) close dm_env (if you have a DeepMind-style env) + 3) close Gym/Gymnasium env + 4) close the Mujoco viewer + Also walks through common wrapper attributes (env, _env, unwrapped). + """ + import logging + + visited = set() + + def _safe_close(obj, label=""): + if obj is None or id(obj) in visited: + return + visited.add(id(obj)) + + # 1) Close common logger/writer attributes first (flush TFRecord) + for name in ("_envlogger", "envlogger", "logger", "_logger", "writer"): + try: + logger_obj = getattr(obj, name, None) + if logger_obj is not None and hasattr(logger_obj, "close"): + logger_obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.{name}.close() raised {e!r}") + + # 2) Close dm_env if present + try: + dm = getattr(obj, "dm_env", None) + if dm is not None and hasattr(dm, "close"): + dm.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.dm_env.close() raised {e!r}") + + # 3) Close Gym/Gymnasium env + try: + if hasattr(obj, "close"): + obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.close() raised {e!r}") + + # 4) Close Mujoco viewer if accessible + try: + # Some stacks keep viewer at env._viewer.viewer; others just env._viewer + viewer = None + vwrap = getattr(obj, "_viewer", None) + if vwrap is not None: + viewer = getattr(vwrap, "viewer", vwrap) + if viewer is not None and hasattr(viewer, "close"): + viewer.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label} viewer close raised {e!r}") + + # Recurse into common wrapper links + for child_name in ("env", "_env", "unwrapped", "environment", "base_env"): + child = getattr(obj, child_name, None) + if child is not None and child is not obj: + _safe_close(child, f"{label}.{child_name}" if label else child_name) + + _safe_close(env, "env") + +def finalize_tfds_metadata_beamless(builder_dir: str): + """ + Beam-free finalize: count TFRecord examples per shard and write + numShards/shardLengths into dataset_info.json so TFDS will load. + """ + import os, json, glob + import tensorflow as tf + + info_path = os.path.join(builder_dir, "dataset_info.json") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + + with open(info_path) as f: + info = json.load(f) + + ds_name = info["name"] # e.g. "PandaReachSparseCube-v0" + file_fmt = info.get("fileFormat", "tfrecord") + tmpl_str = info["splits"][0]["filepathTemplate"] + + # Prefer strict pattern "-.-" + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{ds_name}-*.{file_fmt}-*"))) + if not shard_paths: + # Fallback to any tfrecord-like file + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_fmt}*"))) + if not shard_paths: + raise FileNotFoundError( + f"No {file_fmt} shards found in {builder_dir}. " + f"Expected like '{ds_name}-train.{file_fmt}-00000' per template '{tmpl_str}'." + ) + + # Count episodes (1 Example = 1 episode with envlogger/RLDS) + shard_lengths = [sum(1 for _ in tf.data.TFRecordDataset(p)) for p in shard_paths] + + # Write lengths for each split using this template + for s in info["splits"]: + if s.get("filepathTemplate") == tmpl_str: + s["numShards"] = len(shard_paths) + s["shardLengths"] = shard_lengths + # Re-write dataset_info.json with updated shard info + with open(info_path, "w") as f: + json.dump(info, f, indent=2) + + # Sanity log + import tensorflow_datasets as tfds + b = tfds.builder_from_directory(builder_dir) + print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + +def ensure_dir_exists(): + """ + For oxe_envlogger + RLDS compatibility, data must be written in the following format: + /data/data/serl/demos/franka_reach_drq_demo_script/ + └── session_20250821_222412/ + └── PandaReachSparseCube-v0/ + └── 0.1.0/ + dataset_info.json + features.json + PandaReachSparseCube-v0-train.tfrecord-00000-of-00001 + + We can have a base path with customized sessions inside. + Inside each session we have: env-version-files + + + Returns + ------- + out_path : str + The path to the output directory. + """ + # Customize the path + root = FLAGS.output_dir + session = datetime.now().strftime("session_%Y%m%d_%H%M%S") + session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") + #session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") + + # Dataset details + dataset_name = FLAGS.env + version = "0.1.0" # PArt of RLDS format. needed. + + # Create output filename with configuration details + dataset_dir = os.path.join(session_root, dataset_name, version) + os.makedirs(dataset_dir, exist_ok=True) + logging.info(f"TFDS builder dir: {dataset_dir}") + + return dataset_dir + +def getKey(settings): + """ + Waits briefly for a keypress and returns the pressed key. + + Parameters + ---------- + settings : list + Original terminal settings so they can be restored after reading. + + Returns + ------- + key : str + The key pressed by the user, or '' (empty string) if no key was pressed. + """ + # Put terminal into raw mode so keypress is captured instantly + tty.setraw(sys.stdin.fileno()) + + # Wait for human to input action, we will not provide a timeout so it is blocking. + rlist, _, _ = select.select([sys.stdin], [], []) # select.select(rlist, wlist, xlist[, timeout]) + + if rlist: + # Read exactly one character if a key is pressed + key = sys.stdin.read(1) + else: + key = '' + + # Restore terminal to original settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + return key + +def get_kb_demo_action(env,speed=0.075): + """ + Reads keyboard input and maps it to a 3D action vector for robot control or camera action. + TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. + Otherwise, none actions are still considered steps in the loop. + + The function uses non-blocking keyboard input to allow interactive + teleoperation. Keys are mapped to directions in Cartesian space: + - 'i' : +x (forward) + - ',' : -x (backward) + - 'j' : +y (left) + - 'l' : -y (right) + - 'm' : +z (up) + - '.' : -z (down) + - 'k' : stop (zero vector) + + Parameters + ---------- + speed : float, optional + Step size for each key press (default 0.2). + + Returns + ------- + np.ndarray + Action vector of shape (3,), where each entry corresponds to + [x, y, z] translation command. Example: [0.2, 0.0, 0.0]. + """ + # Save current terminal settings so we can restore later + settings = termios.tcgetattr(sys.stdin) + + # Initialize action as a zero vector (no movement) + action = np.zeros(4, dtype=float) + + try: + # Capture the pressed key + key = getKey(settings) + + # Check keys for camera first, if so, update camera and get another key for action + if key in camBindings: + if hasattr(env.unwrapped, "_viewer"): + update_camera(env.unwrapped._viewer.viewer,key) + key = getKey(settings) # get another key for action + + elif key in moveBindings: + # Lookup (x, y, z) direction and scale by speed + dx, dy, dz, g= moveBindings[key] + action = np.array([dx, dy, dz, g], dtype=float) * speed + + elif key == 'k': + # 'k' means stop → zero vector + action = np.zeros(4, dtype=float) + + elif key == '\x03': # CTRL-C + raise KeyboardInterrupt + + finally: + # Restore terminal even if something goes wrong + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + + # Clip action values to prevent excessive commands + action = np.clip(action, -ACTION_MAX, ACTION_MAX) + + return action + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 2.0 # Camera distance + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +############################################################################## +def main(unused_argv): + logging.info(f'Creating gym environment...') + + # Render mode configuration based on debug flag + if FLAGS.debug: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' + else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + # Create the environment with the specified render mode and wrappers + env = gym.make(FLAGS.env, render_mode=_render_mode) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env == "PandaReachSparseCube-v0": + env = SERLObsWrapper( + env, + target_hw=(128, 128), + img_dtype=np.uint8, # or np.float32 + normalize=False, # True if using float32 in [0,1] + ) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + logging.info(f'Done creating {FLAGS.env} environment.') + + # Set camera to front view if viewer is available + if hasattr(env.unwrapped, '_viewer'): + viewer = set_front_cam_view(env) + if viewer: + logging.info('Camera view set to front-facing perspective.') + else: + logging.warning('Failed to set camera view. Viewer not available.') + + # Wrap with oxe_envlogger to record demos + dataset_dir = None + session_root = None + if FLAGS.enable_envlogger: + + dataset_dir = ensure_dir_exists() + + # Will save as many episodes as possible into files of 200MB each by default. + env = AutoOXEEnvLogger( + env=env, + dataset_name=FLAGS.env, + directory=dataset_dir, + #split_name="train", # "train", "test", or "validation" + ) + logging.info('Recording %r demos...', FLAGS.num_demos) + + #--- LOOP DEMOS/EPISODES --- + # Loop through the number of demos specified by the user to record demonstrations + try: + for i in range(FLAGS.num_demos): + + # Log custom metadata during new episode: language embeddings randomly. + if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) + + # Start a new episode + env.reset() + terminated = False + truncated = False + + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached + while not (terminated or truncated): + + # Get action from the demo function + action = get_kb_demo_action(env) + + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": np.float32(time.time())}) + + return_step = env.step(action) + + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False + + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 + + logging.info('Done recording %r demos.', FLAGS.num_demos) + + finally: + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: + + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") + + # closes logger(s) + dm_env + gym + viewer + # env.unwrapped.unwrapped.unwrapped.close() # close mujoco viewer + # env.env.env.env.close() + close_logger_and_env(env) + + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) + + # Note: async_drq_sim will read from the dataset_dir you printed above. + +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/demos/demos/load_demo_test.py b/demos/demos/load_demo_test.py new file mode 100644 index 00000000..8fd33aad --- /dev/null +++ b/demos/demos/load_demo_test.py @@ -0,0 +1,127 @@ +# Updated and advanced train.py that includes logging, vectorized environments, and periodic recorded evaluations +import os +from pathlib import Path +import numpy as np + +def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_done: bool = True): + """ + Load a raw Gymnasium-style .npz of expert episodes into model.replay_buffer. + + demo_npz_path must contain at least these arrays: + - 'episodeObs' : shape (T+1, *obs_shape*), list of observations + - 'episodeAcs' : shape (T, *act_shape*), list of actions + - 'episodeRews' : shape (T,), list of rewards + - 'episodeTerminated' : shape (T,), list of terminated flags + - 'episodeTruncated' : shape (T,), list of truncated flags + - 'episodeInfo' : shape (T,), list of info dicts + + Parameters + ---------- + data_store : DataStore + The data store to which the demo transitions will be added. + demo_npz_path : str + Path to the .npz file you saved from your demo collector. + combine_done : bool, default=True + If True, `done = terminated or truncated`. If False, `done = terminated` only. + """ + + # Load all demo data. Structure: var_name[num_demo][time_step][key if dict] = value + data = np.load(demo_npz_path, allow_pickle=True) + + obs_buffer = data['obs'] # length T+1 + act_buffer = data['acs'] # length T + rew_buffer = data['rewards'] # length T + term_buffer = data['terminateds'] # length T + trunc_buffer = data['truncateds'] # length T + info_buffer = data['info'] # length T + done_buffer = data['dones'] # length T, if available + + # Extract number of demonstrations + num_demos = obs_buffer.shape[0] + + # Extract rollout data for a single episode + for ep in range(num_demos): + ep_obs = obs_buffer[ep] # this is a length‐(T+1) array of dicts + ep_acts = act_buffer[ep] # length‐T array of actions + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] # length‐T array of dicts + + # Length of episode: + T = len(ep_acts) + + # Extract single transitions from the episode data + for t in range(T): + # raw single‐step data: + obs_t = ep_obs[t] # dict[str, np.ndarray] (obs_dim,) + next_obs_t = ep_obs[t+1] + a_t = ep_acts[t] # np.ndarray (action_dim,) + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + + # Rehydrate info dict and inject the timeout flag + raw_info = ep_info[t] # dict[str,Any] + if isinstance(raw_info, str): + import ast + info_t = ast.literal_eval(raw_info) + else: + info_t = raw_info.copy() + # Append truncated information to info_t + info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) + + # Enter transition into data_store or QueuedDataStore + data_store.insert( + dict( + observations=obs_t, + actions=a_t, + next_observations=next_obs_t, + rewards=r_t, + masks=1.0 - done_t, + dones=done_t + ) + ) + + print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." + f"(combine_done={combine_done}).") + +def get_demo_path(relative_path: str) -> str: + """ + Given a path relative to this script file, return + the absolute, normalized path as a string. + + Example: + # If your demos live at ../../../demos/data.npz + demo_file = get_demo_path("../../../demos/data_franka_random_10.npz") + """ + # 1) Resolve this script’s directory + script_dir = Path(__file__).resolve().parent + + # 2) Join with the user-supplied relative path and normalize + full_path = (script_dir / relative_path).resolve() + + return str(full_path) + + +def main(): + """ + Load demos + """ + + # Get abs path to demo file + script_dir = '/data/data/serl/demos' + default_file = 'data_franka_reach_random_20.npz' + + prompt = f"Please input the name of the file to load [{default_file}]: " + + file_name = input(prompt) or default_file + demo_file = os.path.join(script_dir, file_name) + + # Load the demo file into data_store as in async_sac_state.py + from agentlace.data.data_store import QueuedDataStore + data_store = QueuedDataStore(2000) + load_demos_to_her_buffer_gymnasium(data_store, demo_file, combine_done=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/oxe_envlogger b/demos/oxe_envlogger new file mode 160000 index 00000000..de3c48bc --- /dev/null +++ b/demos/oxe_envlogger @@ -0,0 +1 @@ +Subproject commit de3c48bcf094ebba350ce0ba183efca4478a501a diff --git a/demos/requirements.txt b/demos/requirements.txt new file mode 100644 index 00000000..febfe029 --- /dev/null +++ b/demos/requirements.txt @@ -0,0 +1,4 @@ +tensorflow-metadata==1.17.2 +apache-beam==2.67.0 +protobuf>=4.21.6,<4.22 +git+https://github.com/rail-berkeley/oxe_envlogger.git@main?? diff --git a/demos/setup.py b/demos/setup.py new file mode 100644 index 00000000..d6259a7e --- /dev/null +++ b/demos/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="demos", + version="0.1", + packages=find_packages(), +) diff --git a/docs/real_franka.md b/docs/real_franka.md old mode 100644 new mode 100755 diff --git a/docs/sim_quick_start.md b/docs/sim_quick_start.md old mode 100644 new mode 100755 diff --git a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py index 78a8eee4..7669559f 100644 --- a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py +++ b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py @@ -485,7 +485,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) agents[v] = agent else: @@ -500,7 +500,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_cable_route_drq/async_drq_randomized.py b/examples/async_cable_route_drq/async_drq_randomized.py index 90770fc4..3c15a441 100644 --- a/examples/async_cable_route_drq/async_drq_randomized.py +++ b/examples/async_cable_route_drq/async_drq_randomized.py @@ -370,7 +370,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_drq_sim/.vscode/launch.json b/examples/async_drq_sim/.vscode/launch.json new file mode 100644 index 00000000..4e3fbf20 --- /dev/null +++ b/examples/async_drq_sim/.vscode/launch.json @@ -0,0 +1,169 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_sim.py", + // Command-line arguments matching run_learner.sh + "args": [ + "--render", + "--env", "PandaReachSparseCube-v0", // Environment to use + "--agent", "drq", // Agent type (drq or sac) + "--exp_name", "PandaReachCubeVision-v0_001", // Experiment name for wandb logging + "--max_traj_length", "100", // Max episode length/ Max episode length + "--seed", "42", // Random seed for reproducibility + // "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "4", // Critic-to-actor update ratio + "--max_steps", "100_000 ", // Maximum training steps + "--replay_buffer_capacity", "200_000", // Replay buffer capacity + "--random_steps", "0", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--steps_per_update", "30", // Number of env steps per update + "--learner", // REQUIRED: Indicates this is a learner instance + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + + // Fractal + // "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + // "--branch_method", "constant", + // "--split_method", "constant", + // "--starting_branch_count", "3", // Start with 27 branches + // "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + "--preload_rlds_path", "/media/bison/hdd/data/serl/demos/franka_reach_drq_demo_script/2_demos_session_20250916_100650/PandaReachSparseCube-v0/0.1.0", // Preload RLDS dataset for faster loading + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + //"--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_sim.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--render", + "--env", "PandaReachSparseCube-v0", // Environment to use + "--agent", "drq", // Agent type (drq or sac) + "--exp_name", "PandaPickCubeVision-v0_sparse_001", // Experiment name for wandb logging + "--max_traj_length", "100", // Max episode length/ Max episode length + "--seed", "42", // Random seed for reproducibility + // "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--max_steps", "50_000", // Maximum training steps + "--replay_buffer_capacity", "200_000", // Replay buffer capacity + "--random_steps", "300", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--steps_per_update", "30", // Number of env steps per update + "--actor", // REQUIRED: Indicates this is a learner instance + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + + // Fractal + // "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + // "--branch_method", "constant", + // "--split_method", "constant", + // "--starting_branch_count", "3", // Start with 27 branches + // "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + "--preload_rlds_path", "/media/bison/hdd/data/serl/demos/franka_reach_drq_demo_script/2_demos_session_20250916_100650/PandaReachSparseCube-v0/0.1.0", // Preload RLDS dataset for faster loading + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + //"--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index a84bc580..37af0fb5 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -38,9 +38,10 @@ FLAGS = flags.FLAGS -flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_string("run_name", None, "Name of run for wandb logging") flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") @@ -69,6 +70,21 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "memory_efficient_replay_buffer", "Which replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_float("alpha",None,"alpha value") +flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], + "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + + " Hourglass: Contract from max to min then expand to max") +flags.DEFINE_integer("min_branch_count", None, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") + flags.DEFINE_boolean( "debug", False, "Debug mode." ) # debug mode will disable wandb logging @@ -108,7 +124,7 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env == "PandaReachSparseCube-v0": eval_env = SERLObsWrapper(eval_env) eval_env = ChunkingWrapper(eval_env, obs_horizon=1, act_exec_horizon=None) eval_env = RecordEpisodeStatistics(eval_env) @@ -191,9 +207,12 @@ def learner( """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project=FLAGS.exp_name, + name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, + # wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, + # offline=FLAGS.wandb_offline, ) # To track the step in the training loop @@ -325,11 +344,17 @@ def main(_): else: env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env in {"PandaPickCube-v0", "PandaReachCube-v0", "PandaPickSparseCube-v0", "PandaReachSparseCube-v0", "PandaPickCubeVision-v0", "PandaReachCubeVision-v0", "PandaPickSparseCubeVision-v0", "PandaReachSparseCubeVision-v0"}: + x_obs_idx=np.array([0,4]) + y_obs_idx=np.array([1,5]) + else: + raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") + + if FLAGS.env == "PandaReachSparseCube-v0": env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + else: + env = gym.wrappers.FlattenObservation(env) image_keys = [key for key in env.observation_space.keys() if key != "state"] @@ -345,7 +370,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: @@ -354,7 +379,21 @@ def main(_): env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="memory_efficient_replay_buffer", + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, image_keys=image_keys, ) @@ -370,14 +409,39 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]: # NOTE: Create your own custom data transform function here if you # are loading this via with --preload_rlds_path with tf rlds data # This default does nothing + # See: https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=X1KXM8IGecRO + # https://www.tensorflow.org/guide/data + # https://github.com/google-research/rlds/blob/main/docs/transformations.md + # Batch: rlds.transformations.batch (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=TGT3YfzFOrBm) + # Reverb: rlds.transformations.pattern_map (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_dataset_patterns.ipynb ) + # Nested data set manipulation: rlds.transformations.episode_length/.sum_dataset/.final_step/.map_nested_steps + # Concatenation: rlds.transformations.concatenate / .concat_if_terminal (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_examples.ipynb#scrollTo=pWNhxwJzOUJv) + # Stats: rlds.transformations.mean_and_std (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=Z0TITfo_4oZr) + # Truncation: rlds.transformations.truncate_after_condition + # Alignment: rlds.transformations.shift_keys + # Zero Init: rlds.transformations.zeros_from_spec return data demo_buffer = make_replay_buffer( env, capacity=FLAGS.replay_buffer_capacity, - type="memory_efficient_replay_buffer", - image_keys=image_keys, + rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, + image_keys=image_keys, preload_data_transform=preload_data_transform, ) diff --git a/examples/async_drq_sim/automated_tests.sh b/examples/async_drq_sim/automated_tests.sh new file mode 100644 index 00000000..07a58f48 --- /dev/null +++ b/examples/async_drq_sim/automated_tests.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +SEEDS=$1 +# WANDB_OUTPUT_DIR=~/wandb_logs +TEST="async_sac_state_sim.py" +CONDA_ENV="serl" +ENV="PandaReachSparseCube-v0" +MAX_STEPS=1000000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +BATCH_SIZE=128 +EXP_NAME="FIRST-TESTS-$ENV" +REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" +PRELOAD_RLDS="/data/data/serl/demos/franka_reach_drq_demo_script/10_demos_session_202500914_213515/PandaReachSparseCube-v0/0.1.0" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --batch_size $BATCH_SIZE --preload_rlds_path $PRELOAD_RLDS --encoder_type resnet-pretrained" +ARGS="" + +function run_test { + + for seed in $(seq 1 1 $SEEDS) + do + # OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + # PORTS=( $OPEN_PORTS ) + # PORT_NUMBER=${PORTS[0]} + # BROADCAST_PORT=${PORTS[1]} + + # ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_actor.sh --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash run_learner.sh --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for learner to finish + while ! tmux capture-pane -t serl_session:0.2 -p | grep "logout" > /dev/null; + do + sleep 100 + done + echo "Finished!" + done +} + +# BASELINE TESTING +for replay_buffer_capacity in 1000000 +do + ARGS="--run_name baseline --replay_buffer_type memory_efficient_replay_buffer --replay_buffer_capacity $replay_buffer_capacity" + run_test +done + +# CONSTANT TESTING +for starting_branch_count in 1 27 +do + for workspace_width in 0.5 + do + for replay_buffer_capacity in 1000000 + do + ARGS="--run_name constant-$starting_branch_count^1 --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done + done +done + +tmux kill-window -t serl_session:$SEED diff --git a/examples/async_drq_sim/automated_tests_helper.sh b/examples/async_drq_sim/automated_tests_helper.sh new file mode 100644 index 00000000..3c02ecb6 --- /dev/null +++ b/examples/async_drq_sim/automated_tests_helper.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export MUJOCO_GL=egl + +python async_drq_sim.py "$@" \ No newline at end of file diff --git a/examples/async_drq_sim/run_actor.sh b/examples/async_drq_sim/run_actor.sh index 52fcfc41..ece2d80e 100644 --- a/examples/async_drq_sim/run_actor.sh +++ b/examples/async_drq_sim/run_actor.sh @@ -1,10 +1,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ -python async_drq_sim.py "$@" \ - --actor \ - --render \ - --exp_name=serl_dev_drq_sim_test_resnet \ - --seed 0 \ - --random_steps 1000 \ - --encoder_type resnet-pretrained \ - --debug +export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export MUJOCO_GL=egl && \ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + +python async_drq_sim.py --actor "$@" diff --git a/examples/async_drq_sim/run_learner.sh b/examples/async_drq_sim/run_learner.sh index 39445448..48deba7e 100644 --- a/examples/async_drq_sim/run_learner.sh +++ b/examples/async_drq_sim/run_learner.sh @@ -1,11 +1,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ -python async_drq_sim.py "$@" \ - --learner \ - --exp_name=serl_dev_drq_sim_test_resnet \ - --seed 0 \ - --training_starts 1000 \ - --critic_actor_ratio 4 \ - --encoder_type resnet-pretrained \ - # --demo_path franka_lift_cube_image_20_trajs.pkl \ - --debug # wandb is disabled when debug +export XLA_PYTHON_CLIENT_MEM_FRACTION=.4 && \ +export MUJOCO_GL=egl && \ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + +python async_drq_sim.py --learner "$@" diff --git a/examples/async_drq_sim/tmux_launch_tests.sh b/examples/async_drq_sim/tmux_launch_tests.sh new file mode 100644 index 00000000..aa7a6a57 --- /dev/null +++ b/examples/async_drq_sim/tmux_launch_tests.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +SEEDS=1 +# Create a new tmux session +tmux new-session -d -s serl_session +tmux setw -g remain-on-exit on + +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh $SEEDS" C-m + + +# Attach to the tmux session +tmux attach-session -t serl_session + +# kill the tmux session by running the following command +# tmux kill-session -t serl_session diff --git a/examples/async_drq_sim/tmux_rlpd_launch.sh b/examples/async_drq_sim/tmux_rlpd_launch.sh old mode 100644 new mode 100755 diff --git a/examples/async_pcb_insert_drq/async_drq_randomized.py b/examples/async_pcb_insert_drq/async_drq_randomized.py index 8248379e..5a1770ff 100644 --- a/examples/async_pcb_insert_drq/async_drq_randomized.py +++ b/examples/async_pcb_insert_drq/async_drq_randomized.py @@ -440,7 +440,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_peg_insert_drq/async_drq_randomized.py b/examples/async_peg_insert_drq/async_drq_randomized.py index 4fd76f08..2d089c96 100644 --- a/examples/async_peg_insert_drq/async_drq_randomized.py +++ b/examples/async_peg_insert_drq/async_drq_randomized.py @@ -45,6 +45,7 @@ flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") flags.DEFINE_integer("critic_actor_ratio", 4, "critic to actor update ratio.") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") @@ -347,7 +348,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_peg_insert_drq/docs/TestSpaceMouse.py b/examples/async_peg_insert_drq/docs/TestSpaceMouse.py new file mode 100644 index 00000000..0e8b86ce --- /dev/null +++ b/examples/async_peg_insert_drq/docs/TestSpaceMouse.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +SpaceMouse Test Script (Fixed Version) + +This script tests if a 3Dconnexion SpaceMouse is properly connected and accessible. +It's designed to handle different versions of pyspacemouse and potential API inconsistencies. + +Usage: + python FixedTestSpaceMouse.py + +Press Ctrl+C to exit the program. +""" + +import time +import sys +import traceback +import threading + +# Try importing the pyspacemouse library +try: + import pyspacemouse + print("Successfully imported pyspacemouse library") + + # Print the library version if available + try: + version = getattr(pyspacemouse, "__version__", "unknown") + print(f"pyspacemouse version: {version}") + except: + print("Could not determine pyspacemouse version") +except ImportError as e: + print(f"Failed to import pyspacemouse: {e}") + print("Try installing it with: pip install pyspacemouse") + sys.exit(1) +except Exception as e: + print(f"Unexpected error importing pyspacemouse: {e}") + traceback.print_exc() + sys.exit(1) + +def print_device_info(device): + """Print detailed information about the SpaceMouse device.""" + print("\n===== DEVICE INFORMATION =====") + + # Check if device is a string or an object + if isinstance(device, str): + print(f"Device information: {device}") + return + + # Try to access device attributes safely + try: + attrs = [ + ("Manufacturer", "manufacturer_string", "Unknown"), + ("Product", "product_string", "Unknown"), + ("Vendor ID", "vendor_id", "Unknown"), + ("Product ID", "product_id", "Unknown"), + ("Serial Number", "serial_number", "Unknown"), + ("Release Number", "release_number", "Unknown"), + ("Interface Number", "interface_number", "Unknown") + ] + + for label, attr, default in attrs: + value = getattr(device, attr, default) + if attr in ["vendor_id", "product_id"] and value != "Unknown": + print(f"{label}: 0x{value:04x}") + else: + print(f"{label}: {value}") + except Exception as e: + print(f"Error accessing device attributes: {e}") + print(f"Raw device data: {device}") + + print("===============================\n") + +def test_device_connection(): + """Test if the SpaceMouse device is connected and accessible.""" + print("Attempting to detect SpaceMouse devices...") + + try: + # List all available devices + try: + devices = pyspacemouse.list_devices() + print(f"list_devices() returned: {devices}") + + if not devices: + print("No SpaceMouse devices found!") + return None + + print(f"Found {len(devices)} devices.") + + # Safely print device information + for i, device in enumerate(devices): + print(f"Device {i+1}:") + if isinstance(device, str): + print(f" {device}") + else: + try: + vendor_id = getattr(device, "vendor_id", "Unknown") + product_id = getattr(device, "product_id", "Unknown") + product_string = getattr(device, "product_string", "Unknown Device") + + if vendor_id != "Unknown" and product_id != "Unknown": + print(f" {product_string} (Vendor ID: 0x{vendor_id:04x}, Product ID: 0x{product_id:04x})") + else: + print(f" {product_string}") + except Exception as e: + print(f" Error printing device {i+1} info: {e}") + print(f" Raw device data: {device}") + + # Use the first device found + return devices[0] + + except AttributeError: + # Alternative approach if list_devices doesn't work as expected + print("list_devices() method not working as expected, trying open() directly...") + if pyspacemouse.open(): + print("Successfully opened a SpaceMouse device directly") + return "SpaceMouse Device" + else: + print("Failed to open any SpaceMouse device directly") + return None + + except Exception as e: + print(f"Error detecting devices: {e}") + traceback.print_exc() + return None + +def open_device(device): + """Try to open the SpaceMouse device.""" + print("Attempting to open the SpaceMouse...") + try: + # Check if device is already open + if hasattr(pyspacemouse, "is_open") and pyspacemouse.is_open(): + print("Device is already open") + return True + + # Try to open the device + if isinstance(device, str): + # If device is a string, just try to open any device + if pyspacemouse.open(callback=None): + print("Successfully opened a SpaceMouse device") + return True + else: + # Try to open the specific device + try: + if pyspacemouse.open(callback=None, device=device): + print("Successfully opened the SpaceMouse device") + return True + except TypeError: + # If device parameter isn't supported, try without it + if pyspacemouse.open(callback=None): + print("Successfully opened a SpaceMouse device") + return True + + print("Failed to open the SpaceMouse device") + print("Check if another program is already using it") + return False + + except Exception as e: + print(f"Error opening device: {e}") + traceback.print_exc() + return False + +def safe_read(): + """Safely read from the device, handling potential API differences.""" + try: + state = pyspacemouse.read() + return state + except Exception as e: + print(f"Error reading from device: {e}") + return None + +def monitor_button_state(stop_event): + """Monitor and print button state changes in a separate thread.""" + last_buttons = None + + while not stop_event.is_set(): + try: + state = safe_read() + if state and hasattr(state, "buttons") and state.buttons is not None: + current_buttons = state.buttons + if last_buttons != current_buttons: + buttons_pressed = [i for i, pressed in enumerate(current_buttons) if pressed] + if buttons_pressed: + print(f"Buttons pressed: {buttons_pressed}") + last_buttons = current_buttons.copy() if hasattr(current_buttons, "copy") else current_buttons + time.sleep(0.05) # Small sleep to prevent 100% CPU usage + except Exception as e: + print(f"Error in button monitoring thread: {e}") + time.sleep(1) # Wait a bit longer on error + +def main(): + """Main function to test the SpaceMouse.""" + print("SpaceMouse Test Script (Fixed Version)") + print("-------------------------------------") + + # First, check if we can detect any devices + device = test_device_connection() + if not device: + print("\nNo SpaceMouse detected. Please ensure:") + print("1. The device is connected to your computer") + print("2. You have installed libhidapi (sudo apt-get install libhidapi-dev libhidapi-hidraw0)") + print("3. You have proper permissions (sudo usermod -a -G plugdev $USER)") + print("4. You've created proper udev rules if needed") + sys.exit(1) + + # Print detailed device information + print_device_info(device) + + # Try to open the device + if not open_device(device): + sys.exit(1) + + # Test basic functionality + print("\nTesting basic device functionality...") + test_state = safe_read() + if test_state: + print("Successfully read initial state from device:") + try: + attrs = ["x", "y", "z", "roll", "pitch", "yaw", "buttons"] + for attr in attrs: + if hasattr(test_state, attr): + value = getattr(test_state, attr) + print(f" {attr}: {value}") + else: + print(f" {attr}: Not available") + except Exception as e: + print(f"Error reading attributes: {e}") + print(f"Raw state: {test_state}") + else: + print("Could not read initial state from device!") + print("The device might be connected but not functioning correctly.") + sys.exit(1) + + # Create a thread to monitor button presses + stop_event = threading.Event() + button_thread = threading.Thread(target=monitor_button_state, args=(stop_event,)) + button_thread.daemon = True + button_thread.start() + + # Monitor device movement + print("\nMove your SpaceMouse to see the values") + print("Press Ctrl+C to exit") + print("\nReading SpaceMouse state...") + + try: + last_print_time = time.time() + while True: + # Read the current state + state = safe_read() + + if state: + current_time = time.time() + + # Check if any movement data is available + x = getattr(state, "x", 0) or 0 + y = getattr(state, "y", 0) or 0 + z = getattr(state, "z", 0) or 0 + roll = getattr(state, "roll", 0) or 0 + pitch = getattr(state, "pitch", 0) or 0 + yaw = getattr(state, "yaw", 0) or 0 + + if any(abs(val) > 0.01 for val in [x, y, z, roll, pitch, yaw]): + # Print at most 10 times per second + if current_time - last_print_time >= 0.1: + print(f"\rPosition: X:{x:6.2f} Y:{y:6.2f} Z:{z:6.2f} | " + f"Rotation: Roll:{roll:6.2f} Pitch:{pitch:6.2f} Yaw:{yaw:6.2f}", + end="", flush=True) + last_print_time = current_time + + time.sleep(0.01) # Small sleep to prevent 100% CPU usage + + except KeyboardInterrupt: + print("\n\nExiting...") + except Exception as e: + print(f"\n\nError reading from device: {e}") + traceback.print_exc() + finally: + # Clean up + stop_event.set() + button_thread.join(timeout=1.0) + try: + pyspacemouse.close() + print("\nClosed SpaceMouse connection") + except Exception as e: + print(f"\nError closing device: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/async_peg_insert_drq/docs/listGym_env.py b/examples/async_peg_insert_drq/docs/listGym_env.py new file mode 100644 index 00000000..c6179294 --- /dev/null +++ b/examples/async_peg_insert_drq/docs/listGym_env.py @@ -0,0 +1,5 @@ +import gymnasium as gym + +# Correct method to list environments +for env in gym.envs.registry.keys(): + print(env) diff --git a/examples/async_peg_insert_drq/record_demo.py b/examples/async_peg_insert_drq/record_demo.py index 5a0fd02b..1840a5d3 100644 --- a/examples/async_peg_insert_drq/record_demo.py +++ b/examples/async_peg_insert_drq/record_demo.py @@ -32,7 +32,7 @@ transitions = [] success_count = 0 - success_needed = 20 + success_needed = 30 total_count = 0 pbar = tqdm(total=success_needed) diff --git a/examples/async_peg_insert_drq/run_actor.sh b/examples/async_peg_insert_drq/run_actor.sh index a251e756..d11b8429 100644 --- a/examples/async_peg_insert_drq/run_actor.sh +++ b/examples/async_peg_insert_drq/run_actor.sh @@ -1,12 +1,34 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="FrankaPegInsert-Vision-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ +export CHECKPOINT_EVAL="/home/student/code/cleiver/serl/examples/async_peg_insert_drq/checkpoints/checkpoints-07-14-2025-23-15-59" && \ + + +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + python async_drq_randomized.py "$@" \ --actor \ --render \ - --env FrankaPegInsert-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet \ + --env $ENV_NAME \ + --exp_name=serl-peg-insert \ + --max_steps 25000 \ --seed 0 \ --random_steps 0 \ --training_starts 200 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \ + --demo_path peg_insert_30_demos_2025-07-14_22-57-59.pkl \ + --checkpoint_period 1000 \ + --checkpoint_path "$CHECKPOINT_DIR" \ + # --eval_checkpoint_step=5000 \ + # --eval_n_trajs=5 \ + #--debug # wandb is disabled when debug diff --git a/examples/async_peg_insert_drq/run_learner.sh b/examples/async_peg_insert_drq/run_learner.sh index c2823a19..63435ada 100644 --- a/examples/async_peg_insert_drq/run_learner.sh +++ b/examples/async_peg_insert_drq/run_learner.sh @@ -1,16 +1,32 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.6 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="FrankaPegInsert-Vision-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + python async_drq_randomized.py "$@" \ --learner \ - --env FrankaPegInsert-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet_097 \ + --env $ENV_NAME \ + --exp_name=serl-peg-insert \ --seed 0 \ + --max_steps 25000 \ --random_steps 1000 \ --training_starts 200 \ --critic_actor_ratio 4 \ - --batch_size 256 \ + --batch_size 128 \ --eval_period 2000 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \ + --demo_path peg_insert_30_demos_2025-07-14_22-57-59.pkl\ --checkpoint_period 1000 \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_peg_insert_drq/5x5_20degs_20demos_rand_peg_insert_097 + --checkpoint_path "$CHECKPOINT_DIR" \ + #--debug # wandb is disabled when debug \ No newline at end of file diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json new file mode 100644 index 00000000..a1768ccb --- /dev/null +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -0,0 +1,159 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + + + + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_sac_state_sim.py", + // Command-line arguments matching run_learner.sh + "args": [ + + "--learner", // REQUIRED: Indicates this is a learner instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name", "PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1_000", // Number of random steps at beginning + "--max_steps", "50_000 ", // Maximum training steps + "--training_starts", "1_000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--split_method", "constant", + "--starting_branch_count", "3", // Start with 27 branches + "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + "--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_sac_state_sim.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--actor", // REQUIRED: Indicates this is an actor instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name", "PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1_000", // Number of random steps at beginning + "--max_steps", "500_000 ", // Maximum training steps + "--training_starts", "1_000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--split_method", "constant", + "--starting_branch_count", "3", // Start with 27 branches + "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + "--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 90a1acf8..ad55f04a 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -27,16 +27,23 @@ import franka_sim +from demos.demoHandling import DemoHandling + FLAGS = flags.FLAGS flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") flags.DEFINE_string("agent", "sac", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_string("run_name", None, "Name of run for wandb logging") flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") flags.DEFINE_integer("batch_size", 256, "Batch size.") flags.DEFINE_integer("critic_actor_ratio", 8, "critic to actor update ratio.") +flags.DEFINE_integer("port_number", 5488, "Port for server") +flags.DEFINE_integer("broadcast_port", 5489, "Port for server") +flags.DEFINE_boolean("wandb_offline", False, "Save locally to be synced with 'wandb sync ") +flags.DEFINE_string("wandb_output_dir", None, "Where to save local wandb files") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") @@ -49,7 +56,7 @@ flags.DEFINE_integer("eval_period", 2000, "Evaluation period.") flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.") -# flag to indicate if this is a leaner or a actor +# flag to indicate if this is a learner or a actor flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") flags.DEFINE_boolean("render", False, "Render the environment.") @@ -57,14 +64,34 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") -flags.DEFINE_boolean( - "debug", False, "Debug mode." -) # debug mode will disable wandb logging - +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_float("alpha",None,"alpha value") +flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], + "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + + " Hourglass: Contract from max to min then expand to max") +flags.DEFINE_integer("min_branch_count", None, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") + +# Debug +flags.DEFINE_boolean("debug", False, "Debug mode.") # debug mode will disable wandb logging + +# Logging flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +# Load demonstation data +flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") +flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") +flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") + def print_green(x): return print("\033[92m {}\033[00m".format(x)) @@ -72,14 +99,14 @@ def print_green(x): ############################################################################## -def actor(agent: SACAgent, data_store, env, sampling_rng): +def actor(agent: SACAgent, data_store, env, sampling_rng, demos_handler=None): """ This is the actor loop, which runs when "--actor" is set to True. """ client = TrainerClient( "actor_env", FLAGS.ip, - make_trainer_config(), + make_trainer_config(port_number=FLAGS.port_number, broadcast_port=FLAGS.broadcast_port), data_store, wait_for_server=True, ) @@ -92,8 +119,8 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - eval_env = gym.wrappers.FlattenObservation(eval_env) + #if FLAGS.env == "PandaPickCube-v0": + eval_env = gym.wrappers.FlattenObservation(eval_env) ## Note!! eval_env = RecordEpisodeStatistics(eval_env) obs, _ = env.reset() @@ -102,6 +129,16 @@ def update_params(params): # training loop timer = Timer() running_return = 0.0 + + # Load demos: handler.run will insert all transition demo data into the data store. + if FLAGS.load_demos: + with timer.context("sample and step into env with loaded demos"): + + # Insert complete demonstration into the data store + print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") + demos_handler.insert_data_to_buffer(data_store) + FLAGS.random_steps = 0 # Set random steps to 0 since we have demo data + # For subsequent steps, sample actions from the agent for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): timer.tick("total") @@ -117,30 +154,30 @@ def update_params(params): ) actions = np.asarray(jax.device_get(actions)) - # Step environment - with timer.context("step_env"): + # Step environment + with timer.context("step_env"): - next_obs, reward, done, truncated, info = env.step(actions) - next_obs = np.asarray(next_obs, dtype=np.float32) - reward = np.asarray(reward, dtype=np.float32) + next_obs, reward, done, truncated, info = env.step(actions) + next_obs = np.asarray(next_obs, dtype=np.float32) + reward = np.asarray(reward, dtype=np.float32) - running_return += reward + running_return += reward - data_store.insert( - dict( - observations=obs, - actions=actions, - next_observations=next_obs, - rewards=reward, - masks=1.0 - done, - dones=done or truncated, + data_store.insert( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done or truncated, + ) ) - ) - obs = next_obs - if done or truncated: - running_return = 0.0 - obs, _ = env.reset() + obs = next_obs + if done or truncated: + running_return = 0.0 + obs, _ = env.reset() if FLAGS.render: env.render() @@ -167,21 +204,23 @@ def update_params(params): ############################################################################## - + def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): """ The learner loop, which runs when "--learner" is set to True. """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project=FLAGS.exp_name, + name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, + # wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, + # offline=FLAGS.wandb_offline, ) # To track the step in the training loop update_steps = 0 - def stats_callback(type: str, payload: dict) -> dict: """Callback for when server receives stats request.""" assert type == "send-stats", f"Invalid request type: {type}" @@ -190,7 +229,7 @@ def stats_callback(type: str, payload: dict) -> dict: return {} # not expecting a response # Create server - server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server = TrainerServer(make_trainer_config(port_number=FLAGS.port_number, broadcast_port=FLAGS.broadcast_port), request_callback=stats_callback) server.register_data_store("actor_env", replay_buffer) server.start(threaded=True) @@ -228,7 +267,7 @@ def stats_callback(type: str, payload: dict) -> dict: batch = next(replay_iterator) with timer.context("train"): - agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.utd_ratio) + agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.critic_actor_ratio) agent = jax.block_until_ready(agent) # publish the updated network @@ -265,9 +304,14 @@ def main(_): env = gym.make(FLAGS.env, render_mode="human") else: env = gym.make(FLAGS.env) - - if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env in {"PandaPickCube-v0", "PandaReachCube-v0", "PandaPickSparseCube-v0", "PandaReachSparseCube-v0"}: + x_obs_idx=np.array([0,4]) + y_obs_idx=np.array([1,5]) + else: + raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") + + env = gym.wrappers.FlattenObservation(env) rng, sampling_rng = jax.random.split(rng) agent: SACAgent = make_sac_agent( @@ -279,17 +323,56 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: SACAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) + # Demo Data + if FLAGS.load_demos: + print_green("Setting demo parameters") + # Create a handler for the demo data + demos_handler = DemoHandling( + demo_dir=FLAGS.demo_dir, + file_name=FLAGS.file_name, + ) + + # 1. Modify actor data_store size + # Extract number of demo transitions + demo_transitions = demos_handler.get_num_transitions() + + if demo_transitions > 2000: + qds_size = demo_transitions + 1000 # Increment the queue size on the actor + else: + qds_size = 2000 # the original queue size on the actor + + # 2. Modify training starts (since we have good data) + FLAGS.training_starts = 1 + + else: + demos_handler = None + qds_size = 2000 # the original queue size on the actor + + if FLAGS.learner: sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) replay_buffer = make_replay_buffer( env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="replay_buffer", + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, ) replay_iterator = replay_buffer.get_iterator( sample_args={ @@ -308,11 +391,20 @@ def main(_): elif FLAGS.actor: sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) - data_store = QueuedDataStore(2000) # the queue size on the actor + + if FLAGS.load_demos: + print_green("loading demo data") + + # Create a data store for the actor + data_store = QueuedDataStore(qds_size) # the queue size on the actor + else: + print_green("no demo data, using empty data store") + # Create a data store for the actor + data_store = QueuedDataStore(2000) # the queue size on the actor # actor loop print_green("starting actor loop") - actor(agent, data_store, env, sampling_rng) + actor(agent, data_store, env, sampling_rng, demos_handler) else: raise NotImplementedError("Must be either a learner or an actor") diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh new file mode 100644 index 00000000..30766bef --- /dev/null +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +SEEDS=$1 +WANDB_OUTPUT_DIR=~/wandb_logs +TEST="async_sac_state_sim.py" +CONDA_ENV="serl" +ENV="PandaReachCube-v0" +MAX_STEPS=25000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +CRITIC_ACTOR_RATIO=8 +EXP_NAME="GENERAL-RETESTING-$ENV" +REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" + +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS" +ARGS="" + +function run_test { + + for seed in $(seq 1 1 $SEEDS) + do + # OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + # PORTS=( $OPEN_PORTS ) + # PORT_NUMBER=${PORTS[0]} + # BROADCAST_PORT=${PORTS[1]} + + # ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for learner to finish + while ! tmux capture-pane -t serl_session:0.2 -p | grep "logout" > /dev/null; + do + sleep 1 + done + echo "Finished!" + done +} + +# BASELINE TESTING +for CRITIC_ACTOR_RATIO in 8 +do + for batch_size in 256 2048 + do + for replay_buffer_capacity in 1000 1000000 + do + ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" + run_test + done + done +done + +# CONSTANT TESTING +for CRITIC_ACTOR_RATIO in 8 +do + for starting_branch_count in 2 8 64 + do + for batch_size in 256 + do + for workspace_width in 10 1 .1 + do + for replay_buffer_capacity in $((1000 * $starting_branch_count * $starting_branch_count)) $((1000000 * $starting_branch_count * $starting_branch_count)) + do + ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done + done + done + done +done + +# # FRACTAL TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000000 +# do +# for workspace_width in 0.5 +# do +# for alpha in 0.9 +# do +# for branching_factor in 3 9 +# do +# for max_depth in 2 4 +# do +# # Fractal Expansion +# ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" +# run_test + +# # Fractal Contraction +# ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" +# run_test +# done +# done +# done +# done +# done +# done + +# # DISASSOCIATIVE TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000000 +# do +# for workspace_width in 0.5 +# do +# for alpha in 0.9 +# do +# for min_branch_count in 1 3 9 +# do +# for max_branch_count in 3 9 27 +# do +# # Disassociative (Hourglass) +# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" +# run_test + +# # Disassociative (Octahedron) +# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" + +# done +# done +# done +# done +# done +# done + +tmux kill-window -t serl_session:$SEED diff --git a/examples/async_sac_state_sim/automated_tests_helper.sh b/examples/async_sac_state_sim/automated_tests_helper.sh new file mode 100644 index 00000000..90702a2f --- /dev/null +++ b/examples/async_sac_state_sim/automated_tests_helper.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ + +python async_sac_state_sim.py "$@" \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 57677916..35df30e6 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,10 +1,44 @@ +#!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ -python async_sac_state_sim.py "$@" \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + +python async_sac_state_sim.py \ --actor \ - --render \ - --env PandaPickCube-v0 \ - --exp_name=serl_dev_sim_test \ - --seed 0 \ + --env PandaReachCube-v0 \ + --exp_name this_is_a_fake_test_experiment \ + --run_name this_is_a_custom_run_name \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 50_000 \ + --training_starts 1000 \ --random_steps 1000 \ - --debug + --critic_actor_ratio 8 \ + --batch_size 256 \ + --replay_buffer_capacity 1_000_000 \ + --save_model True \ + --branch_method constant \ + --split_method constant \ + --starting_branch_count 3 \ + --workspace_width 0.5 \ + --alpha 1 \ + # --debug # wandb is disabled when debug + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ + # --max_traj_length 100 \ + # --max_depth 4 \ + # --branching_factor 3 \ + # --checkpoint_period 10000 \ + # --checkpoint_path "$CHECKPOINT_DIR" \ + #--render \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 10a203c1..7415620f 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,11 +1,43 @@ +#!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ -python async_sac_state_sim.py "$@" \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + +python async_sac_state_sim.py "$@"\ --learner \ - --env PandaPickCube-v0 \ - --exp_name=serl_dev_sim_test \ - --seed 0 \ + --env PandaReachCube-v0 \ + --exp_name this_is_a_fake_test_experiment \ + --run_name this_is_a_custom_run_name \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 50_000 \ --training_starts 1000 \ + --random_steps 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --debug # wandb is disabled when debug + --replay_buffer_capacity 1_000_000 \ + --save_model True \ + --branch_method constant \ + --split_method constant \ + --starting_branch_count 3 \ + --workspace_width 0.5 \ + --alpha 1 \ + # --debug # wandb is disabled when debug + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ + # --max_traj_length 100 \ + # --max_depth 4 \ + # --branching_factor 3 \ + # --checkpoint_period 10000 \ + # --checkpoint_path "$CHECKPOINT_DIR" \ \ No newline at end of file diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 78ff94a8..dff31d54 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -1,9 +1,9 @@ #!/bin/bash -EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} +# EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} CONDA_ENV=${CONDA_ENV:-"serl"} -cd $EXAMPLE_DIR +# cd $EXAMPLE_DIR echo "Running from $(pwd)" # Create a new tmux session @@ -13,10 +13,10 @@ tmux new-session -d -s serl_session tmux split-window -v # Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m +tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh '$@'" C-m # Navigate to the activate the conda environment in the second pane -tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m +tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh '$@'" C-m # Attach to the tmux session tmux attach-session -t serl_session diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh new file mode 100644 index 00000000..80d47f46 --- /dev/null +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +SEEDS=5 +# Create a new tmux session +tmux new-session -d -s serl_session +tmux setw -g remain-on-exit on + +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh $SEEDS" C-m + + +# Attach to the tmux session +tmux attach-session -t serl_session + +# kill the tmux session by running the following command +# tmux kill-session -t serl_session diff --git a/franka_sim/README.md b/franka_sim/README.md index 486f0e05..37fef730 100644 --- a/franka_sim/README.md +++ b/franka_sim/README.md @@ -8,7 +8,7 @@ It includes a state-based and a vision-based Franka lift cube task environment. - run `pip install -r requirements.txt` to install sim dependencies. # Explore the Environments -- Run `python franka_sim/test/test_gym_env_human.py` to launch a display window and visualize the task. +- Run `python3 franka_sim/test/test_gym_env_human.py` to launch a display window and visualize the task. # Credits: - This simulation is initially built by [Kevin Zakka](https://kzakka.com/). diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 967e9e5d..ff6ecc2b 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -5,6 +5,7 @@ "GymRenderingSpec", ] +# Register environments from franka_sim.envs where specific classes are found in the PandaXXX.py scripts from gym.envs.registration import register register( @@ -15,6 +16,17 @@ register( id="PandaPickCubeVision-v0", entry_point="franka_sim.envs:PandaPickCubeGymEnv", - max_episode_steps=100, + max_episode_steps=200, kwargs={"image_obs": True}, ) +register( + id="PandaReachCube-v0", + entry_point="franka_sim.envs:PandaReachCubeGymEnv", + max_episode_steps=100, +) +register( + id="PandaReachSparseCube-v0", + entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", + max_episode_steps=200, +) + diff --git a/franka_sim/franka_sim/envs/__init__.py b/franka_sim/franka_sim/envs/__init__.py index 50c68828..8315e983 100644 --- a/franka_sim/franka_sim/envs/__init__.py +++ b/franka_sim/franka_sim/envs/__init__.py @@ -1,5 +1,9 @@ from franka_sim.envs.panda_pick_gym_env import PandaPickCubeGymEnv +from franka_sim.envs.panda_reach_gym_env import PandaReachCubeGymEnv +from franka_sim.envs.panda_reach_sparse_gym_env import PandaReachSparseCubeGymEnv __all__ = [ "PandaPickCubeGymEnv", + "PandaReachCubeGymEnv", + "PandaReachSparseCubeGymEnv" ] diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index f0a8d25b..c3d3324e 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,6 +144,9 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, + width=960, + height=960, + camera_id=0 ) self._viewer.render(self.render_mode) @@ -226,7 +229,7 @@ def render(self): rendered_frames = [] for cam_id in self.camera_id: rendered_frames.append( - self._viewer.render(render_mode="rgb_array", camera_id=cam_id) + self._viewer.render(render_mode="rgb_array") ) return rendered_frames @@ -291,7 +294,7 @@ def _compute_reward(self) -> float: if __name__ == "__main__": env = PandaPickCubeGymEnv(render_mode="human") env.reset() - for i in range(100): + for i in range(1000): env.step(np.random.uniform(-1, 1, 4)) env.render() env.close() diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py new file mode 100644 index 00000000..b1980d7c --- /dev/null +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -0,0 +1,296 @@ +from pathlib import Path +from typing import Any, Literal, Tuple, Dict + +import gym +import mujoco +import numpy as np +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None + +from franka_sim.controllers import opspace +from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv +from gym.envs.registration import register + +_HERE = Path(__file__).parent +_XML_PATH = _HERE / "xmls" / "arena.xml" +_PANDA_HOME = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)) +_CARTESIAN_BOUNDS = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]) +_SAMPLING_BOUNDS = np.asarray([[0.25, -0.25], [0.55, 0.25]]) + + +class PandaReachCubeGymEnv(MujocoGymEnv): + metadata = {"render_modes": ["rgb_array", "human"]} + + def __init__( + self, + action_scale: np.ndarray = np.asarray([0.1, 1]), + seed: int = 0, + control_dt: float = 0.02, + physics_dt: float = 0.002, + time_limit: float = 10.0, + render_spec: GymRenderingSpec = GymRenderingSpec(), + render_mode: Literal["rgb_array", "human"] = None, + image_obs: bool = False, + demo: str = "None", + ): + self._action_scale = action_scale + + super().__init__( + xml_path=_XML_PATH, + seed=seed, + control_dt=control_dt, + physics_dt=physics_dt, + time_limit=time_limit, + render_spec=render_spec, + ) + + self.metadata = { + "render_modes": [ + "human", + "rgb_array", + ], + "render_fps": int(np.round(1.0 / self.control_dt)), + } + + self.render_mode = render_mode + self.camera_id = (0, 1) + self.image_obs = image_obs + + # Caching. + self._panda_dof_ids = np.asarray( + [self._model.joint(f"joint{i}").id for i in range(1, 8)] + ) + self._panda_ctrl_ids = np.asarray( + [self._model.actuator(f"actuator{i}").id for i in range(1, 8)] + ) + self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id + self._pinch_site_id = self._model.site("pinch").id + self._block_z = self._model.geom("block").size[2] + + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + "block_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + } + ), + } + ) + + if self.image_obs: + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + } + ), + "images": gym.spaces.Dict( + { + "front": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + "wrist": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + } + ), + } + ) + + self.action_space = gym.spaces.Box( + low=np.asarray([-1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0]), + dtype=np.float32, + ) + + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO + from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + + self._viewer = MujocoRenderer( + self.model, + self.data, + width=960, + height=960, + camera_id=0 + ) + if self.render_mode: + self._viewer.render(self.render_mode) + + self.demo = demo + + def reset( + self, seed=None, **kwargs + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" + mujoco.mj_resetData(self._model, self._data) + + # Reset arm to home position. + self._data.qpos[self._panda_dof_ids] = _PANDA_HOME + mujoco.mj_forward(self._model, self._data) + + # Reset mocap body to home position. + tcp_pos = self._data.sensor("2f85/pinch_pos").data + self._data.mocap_pos[0] = tcp_pos + + # Sample a new block position. + block_xy = np.random.uniform(*_SAMPLING_BOUNDS) + self._data.jnt("block").qpos[:3] = (*block_xy, self._block_z) + mujoco.mj_forward(self._model, self._data) + + # Cache the initial block height. + # self._z_init = self._data.sensor("block_pos").data[2] + # self._z_success = self._z_init + 0.2 + + obs = self._compute_observation() + return obs, {} + + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ + x, y, z = action + + # Set the mocap position. + pos = self._data.mocap_pos[0].copy() + dpos = np.asarray([x, y, z]) * self._action_scale[0] + npos = np.clip(pos + dpos, *_CARTESIAN_BOUNDS) + self._data.mocap_pos[0] = npos + + # Set gripper grasp. + self._data.ctrl[self._gripper_ctrl_id] = 0 # Fully open position + + for _ in range(self._n_substeps): + tau = opspace( + model=self._model, + data=self._data, + site_id=self._pinch_site_id, + dof_ids=self._panda_dof_ids, + pos=self._data.mocap_pos[0], + ori=self._data.mocap_quat[0], + joint=_PANDA_HOME, + gravity_comp=True, + ) + self._data.ctrl[self._panda_ctrl_ids] = tau + mujoco.mj_step(self._model, self._data) + + obs = self._compute_observation() + rew = self._compute_reward() + + # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. + # IF ACCIDENTALLY MERGED, IT WILL REDUCE PERFORMANCE OF THE AGENT. + if self.demo == "franka_reach_demo": + if rew >= 0.85: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 + terminated = True + else: + # Check if the time limit is exceeded. + if self._time_limit is not None: + terminated = self.time_limit_exceeded() + else: + terminated = False + + return obs, rew, terminated, False, {} + + def render(self): + rendered_frames = [] + for cam_id in self.camera_id: + rendered_frames.append( + self._viewer.render(render_mode="rgb_array") + ) + return rendered_frames + + # Helper methods. + + def _compute_observation(self) -> dict: + obs = {} + obs["state"] = {} + + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + + if self.image_obs: + obs["images"] = {} + obs["images"]["front"], obs["images"]["wrist"] = self.render() + else: + block_pos = self._data.sensor("block_pos").data.astype(np.float32) + obs["state"]["block_pos"] = block_pos + + if self.render_mode == "human": + self._viewer.render(self.render_mode) + + return obs + + def _compute_reward(self) -> float: + # Get positions + block_pos = self._data.sensor("block_pos").data + tcp_pos = self._data.sensor("2f85/pinch_pos").data + + # Calculate distance + dist = np.linalg.norm(block_pos - tcp_pos) + + # Distance-based reward. Note at norm of 0.015, reward will be 0.5 + r_close = np.exp(-20 * dist) + r_close = np.clip(r_close, 0.0, 1.0) + + return r_close + + +if __name__ == "__main__": + # Create wrapped environment + env = PandaReachCubeGymEnv(render_mode="human") + env.reset() + for i in range(5000): + env.step(np.random.uniform(-1, 1, 3)) + env.render() + env.close() diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py new file mode 100644 index 00000000..840be6da --- /dev/null +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -0,0 +1,304 @@ +from pathlib import Path +from typing import Any, Literal, Tuple, Dict + +import gym +import mujoco +import numpy as np +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None + +from franka_sim.controllers import opspace +from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv +from gym.envs.registration import register + +_HERE = Path(__file__).parent +_XML_PATH = _HERE / "xmls" / "arena.xml" +_PANDA_HOME = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)) +_CARTESIAN_BOUNDS = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]) +_SAMPLING_BOUNDS = np.asarray([[0.25, -0.25], [0.55, 0.25]]) + + +class PandaReachSparseCubeGymEnv(MujocoGymEnv): + metadata = {"render_modes": ["rgb_array", "human"]} + + def __init__( + self, + action_scale: np.ndarray = np.asarray([0.1, 1]), + seed: int = 0, + control_dt: float = 0.02, + physics_dt: float = 0.002, + time_limit: float = 10.0, + render_spec: GymRenderingSpec = GymRenderingSpec(), + render_mode: Literal["rgb_array", "human"] = "rgb_array", + image_obs: bool = True, + demo: str = "None", + ): + self._action_scale = action_scale + + super().__init__( + xml_path=_XML_PATH, + seed=seed, + control_dt=control_dt, + physics_dt=physics_dt, + time_limit=time_limit, + render_spec=render_spec, + ) + + self.metadata = { + "render_modes": [ + "human", + "rgb_array", + ], + "render_fps": int(np.round(1.0 / self.control_dt)), + } + + self.render_mode = render_mode + self.camera_id = (0, 1) + self.image_obs = image_obs + + # Caching. + self._panda_dof_ids = np.asarray( + [self._model.joint(f"joint{i}").id for i in range(1, 8)] + ) + self._panda_ctrl_ids = np.asarray( + [self._model.actuator(f"actuator{i}").id for i in range(1, 8)] + ) + self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id + self._pinch_site_id = self._model.site("pinch").id + self._block_z = self._model.geom("block").size[2] + + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + "block_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + } + ), + } + ) + + if self.image_obs: + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + } + ), + "images": gym.spaces.Dict( + { + "front": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + "wrist": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + } + ), + } + ) + + self.action_space = gym.spaces.Box( + low=np.asarray([-1.0, -1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0, 1.0]), + dtype=np.float32, + ) + + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO + from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + + self._viewer = MujocoRenderer( + self.model, + self.data, + width=128, + height=128, + camera_id=0 + ) + self._viewer.render(self.render_mode) + + def reset( + self, seed=None, **kwargs + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" + mujoco.mj_resetData(self._model, self._data) + + # Reset arm to home position. + self._data.qpos[self._panda_dof_ids] = _PANDA_HOME + mujoco.mj_forward(self._model, self._data) + + # Reset mocap body to home position. + tcp_pos = self._data.sensor("2f85/pinch_pos").data + self._data.mocap_pos[0] = tcp_pos + + # Sample a new block position. + block_xy = np.random.uniform(*_SAMPLING_BOUNDS) + self._data.jnt("block").qpos[:3] = (*block_xy, self._block_z) + mujoco.mj_forward(self._model, self._data) + + # Cache the initial block height. + # self._z_init = self._data.sensor("block_pos").data[2] + # self._z_success = self._z_init + 0.2 + + obs = self._compute_observation() + return obs, {} + + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ + x, y, z, _ = action + + # Set the mocap position. + pos = self._data.mocap_pos[0].copy() + dpos = np.asarray([x, y, z]) * self._action_scale[0] + npos = np.clip(pos + dpos, *_CARTESIAN_BOUNDS) + self._data.mocap_pos[0] = npos + + # Set gripper grasp. + self._data.ctrl[self._gripper_ctrl_id] = 0 # Keep gripper closed. + # g = self._data.ctrl[self._gripper_ctrl_id] / 255 + # dg = grasp * self._action_scale[1] + # ng = np.clip(g + dg, 0.0, 1.0) + # self._data.ctrl[self._gripper_ctrl_id] = ng * 255 + + for _ in range(self._n_substeps): + tau = opspace( + model=self._model, + data=self._data, + site_id=self._pinch_site_id, + dof_ids=self._panda_dof_ids, + pos=self._data.mocap_pos[0], + ori=self._data.mocap_quat[0], + joint=_PANDA_HOME, + gravity_comp=True, + ) + self._data.ctrl[self._panda_ctrl_ids] = tau + mujoco.mj_step(self._model, self._data) + + # Compute observation. + obs = self._compute_observation() + + # For sparse reward we return 1 if the task is achieved, else 0. + rew = self._compute_reward() + + # # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. + # # IF ACCIDENTALLY MERGED, IT WILL REDUCE PERFORMANCE OF THE AGENT. + # if self.demo == "franka_reach_demo": + # if rew >= 0.85: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 + # terminated = True + # else: + # Check if the time limit is exceeded. + + # Episode is terminated if the reward is 1.0 (i.e. the task is achieved). + terminated = (rew == 1.0) + + if self._time_limit is not None: + terminated = terminated or self.time_limit_exceeded() + + return obs, rew, terminated, False, {} + + def render(self): + rendered_frames = [] + for cam_id in self.camera_id: + rendered_frames.append( + self._viewer.render(render_mode="rgb_array") + ) + return rendered_frames + + # Helper methods. + + def _compute_observation(self) -> dict: + obs = {} + obs["state"] = {} + + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + + if self.image_obs: + obs["images"] = {} + obs["images"]["front"], obs["images"]["wrist"] = self.render() + else: + block_pos = self._data.sensor("block_pos").data.astype(np.float32) + obs["state"]["block_pos"] = block_pos + + if self.render_mode == "human": + self._viewer.render(self.render_mode) + + return obs + + def _compute_reward(self) -> float: + # Get positions + block_pos = self._data.sensor("block_pos").data + tcp_pos = self._data.sensor("2f85/pinch_pos").data + + # Calculate distance + dist = np.linalg.norm(block_pos - tcp_pos) + + # Distance-based reward. Note at norm of 0.015, reward will be 0.5 + if dist < 0.015: + reward = 1.0 + else: + reward = 0.0 + + return reward + + +if __name__ == "__main__": + # Create wrapped environment + env = PandaReachSparseCubeGymEnv(render_mode="human") + env.reset() + for i in range(5000): + env.step(np.random.uniform(-1, 1, 4)) + env.render() + env.close() diff --git a/franka_sim/franka_sim/envs/xmls/arena.xml b/franka_sim/franka_sim/envs/xmls/arena.xml index e8b69cd9..5c766831 100644 --- a/franka_sim/franka_sim/envs/xmls/arena.xml +++ b/franka_sim/franka_sim/envs/xmls/arena.xml @@ -19,7 +19,7 @@ - + diff --git a/franka_sim/franka_sim/test/test_gym_env_human.py b/franka_sim/franka_sim/test/test_gym_env_human.py index 592bb24e..971a4516 100644 --- a/franka_sim/franka_sim/test/test_gym_env_human.py +++ b/franka_sim/franka_sim/test/test_gym_env_human.py @@ -6,7 +6,7 @@ from franka_sim import envs -env = envs.PandaPickCubeGymEnv(action_scale=(0.1, 1)) +env = envs.PandaPickCubeGymEnv(action_scale=(0.1, 1)) # or render_mode="human") action_spec = env.action_space diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index 86f92e95..9b92279f 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -1,6 +1,7 @@ +opencv-python<4.12.0.88 gym >= 0.26 numpy>=1.24.3 -flax>=0.8.0 +flax>=0.8.0, < 0.10.6 distrax>=0.1.2 ml_collections >= 0.1.0 tqdm >= 4.60.0 @@ -8,7 +9,7 @@ chex>=0.1.85 optax>=0.1.5 orbax-checkpoint>=0.5.10 absl-py >= 0.12.0 -scipy==1.11.4 +scipy>=1.11.4 wandb >= 0.12.14 tensorflow>=2.16.0 tensorflow_probability>=0.24.0 diff --git a/serl_launcher/serl_launcher/agents/continuous/sac.py b/serl_launcher/serl_launcher/agents/continuous/sac.py index ca933db4..dd723924 100644 --- a/serl_launcher/serl_launcher/agents/continuous/sac.py +++ b/serl_launcher/serl_launcher/agents/continuous/sac.py @@ -575,11 +575,11 @@ def scan_body(carry: Tuple[SACAgent], data: Tuple[Batch]): def make_minibatch(data: jnp.ndarray): return jnp.reshape(data, (utd_ratio, minibatch_size) + data.shape[1:]) - minibatches = jax.tree_map(make_minibatch, batch) + minibatches = jax.tree.map(make_minibatch, batch) (agent,), critic_infos = jax.lax.scan(scan_body, (self,), (minibatches,)) - critic_infos = jax.tree_map(lambda x: jnp.mean(x, axis=0), critic_infos) + critic_infos = jax.tree.map(lambda x: jnp.mean(x, axis=0), critic_infos) del critic_infos["actor"] del critic_infos["temperature"] diff --git a/serl_launcher/serl_launcher/common/common.py b/serl_launcher/serl_launcher/common/common.py index a3a53e74..1e80fa06 100644 --- a/serl_launcher/serl_launcher/common/common.py +++ b/serl_launcher/serl_launcher/common/common.py @@ -22,7 +22,7 @@ def shard_batch(batch, sharding): batch: A pytree of arrays. sharding: A jax Sharding object with shape (num_devices,). """ - return jax.tree_map( + return jax.tree.map( lambda x: jax.device_put( x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) ), @@ -115,7 +115,7 @@ class JaxRLTrainState(struct.PyTreeNode): @staticmethod def _tx_tree_map(*args, **kwargs): - return jax.tree_map( + return jax.tree.map( *args, is_leaf=lambda x: isinstance(x, optax.GradientTransformation), **kwargs, @@ -128,7 +128,7 @@ def target_update(self, tau: float) -> "JaxRLTrainState": new_target_params = tau * params + (1 - tau) * target_params """ - new_target_params = jax.tree_map( + new_target_params = jax.tree.map( lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params ) return self.replace(target_params=new_target_params) @@ -158,7 +158,7 @@ def apply_gradients(self, *, grads: Any) -> "JaxRLTrainState": ) # apply all the updates additively - updates_acc = jax.tree_map( + updates_acc = jax.tree.map( lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat ) new_params = optax.apply_updates(self.params, updates_acc) @@ -200,7 +200,7 @@ def apply_loss_fns( rngs = jax.tree_util.tree_unflatten(treedef, rngs) # compute gradients - grads_and_aux = jax.tree_map( + grads_and_aux = jax.tree.map( lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), loss_fns, rngs, @@ -214,8 +214,8 @@ def apply_loss_fns( grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) if has_aux: - grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) - aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) + grads = jax.tree.map(lambda _, x: x[0], loss_fns, grads_and_aux) + aux = jax.tree.map(lambda _, x: x[1], loss_fns, grads_and_aux) return self.apply_gradients(grads=grads), aux else: return self.apply_gradients(grads=grads_and_aux) diff --git a/serl_launcher/serl_launcher/common/wandb.py b/serl_launcher/serl_launcher/common/wandb.py index 2ef341aa..4b27cb9e 100644 --- a/serl_launcher/serl_launcher/common/wandb.py +++ b/serl_launcher/serl_launcher/common/wandb.py @@ -6,6 +6,7 @@ import absl.flags as flags import ml_collections import wandb +import uuid def _recursive_flatten_dict(d: dict): @@ -41,12 +42,13 @@ def __init__( variant, wandb_output_dir=None, debug=False, + offline=False, ): self.config = wandb_config if self.config.unique_identifier == "": self.config.unique_identifier = datetime.datetime.now().strftime( - "%Y%m%d_%H%M%S" - ) + "%m-%d-%Y" + ) + str(uuid.uuid1()) self.config.experiment_id = ( self.experiment_id @@ -65,11 +67,15 @@ def __init__( if debug: mode = "disabled" else: - mode = "online" + if offline: + mode = "offline" + else: + mode = "online" self.run = wandb.init( config=self._variant, project=self.config.project, + name=self.config.name, entity=self.config.entity, group=self.config.group, tags=self.config.tag, diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index 20e65813..a6f074ee 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -7,6 +7,9 @@ from serl_launcher.data.memory_efficient_replay_buffer import ( MemoryEfficientReplayBuffer, ) +from serl_launcher.data.fractal_symmetry_replay_buffer import ( + FractalSymmetryReplayBuffer +) from agentlace.data.data_store import DataStoreBase @@ -112,8 +115,6 @@ def insert(self, data): RLDSStepType.TRUNCATION, }: self.step_type = RLDSStepType.RESTART - elif self.step_type == RLDSStepType.TRUNCATION: - self.step_type = RLDSStepType.RESTART elif not data["masks"]: # 0 is done, 1 is not done self.step_type = RLDSStepType.TERMINATION elif data["dones"]: @@ -143,6 +144,69 @@ def latest_data_id(self): def get_latest_data(self, from_id: int): raise NotImplementedError # TODO +class FractalSymmetryReplayBufferDataStore(FractalSymmetryReplayBuffer, DataStoreBase): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + workspace_width: int, + x_obs_idx, + y_obs_idx, + branch_method: str, + split_method: str, + rlds_logger: Optional[RLDSLogger] = None, + image_keys: Iterable[str] = None, + **kwargs: dict, + ): + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, workspace_width, x_obs_idx, y_obs_idx, branch_method, split_method, image_keys, **kwargs) + DataStoreBase.__init__(self, capacity) + self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger + + # ensure thread safety + def insert(self, data): + with self._lock: + super(FractalSymmetryReplayBufferDataStore, self).insert(data) + + # TODO: Data logging currently does NOT WORK as shown if we want to log our transformed transitions + # add data to the rlds logger + if self._logger: + if self.step_type in { + RLDSStepType.TERMINATION, + RLDSStepType.TRUNCATION, + }: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["dones"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: check if this is correct + reward=data["rewards"], + step_type=self.step_type, + ) + + # ensure thread safety + def sample(self, *args, **kwargs): + with self._lock: + return super(FractalSymmetryReplayBufferDataStore, self).sample(*args, **kwargs) + + # NOTE: method for DataStoreBase + def latest_data_id(self): + return self._insert_index + + # NOTE: method for DataStoreBase + def get_latest_data(self, from_id: int): + raise NotImplementedError # TODO def populate_data_store( data_store: DataStoreBase, @@ -162,7 +226,6 @@ def populate_data_store( print(f"Loaded {len(data_store)} transitions.") return data_store - def populate_data_store_with_z_axis_only( data_store: DataStoreBase, demos_path: str, diff --git a/serl_launcher/serl_launcher/data/dataset.py b/serl_launcher/serl_launcher/data/dataset.py index 0760fb02..2daba136 100644 --- a/serl_launcher/serl_launcher/data/dataset.py +++ b/serl_launcher/serl_launcher/data/dataset.py @@ -1,45 +1,99 @@ -from functools import partial +from functools import partial # unused. from typing import Dict, Iterable, Optional, Tuple, Union import jax import jax.numpy as jnp import numpy as np + +# frozen_dict is an immutable & nested dictionary-like structure to manage parameters and states in NNs. Key in Flax, model params passed explicitly vs stored in mutable objects. from flax.core import frozen_dict from gym.utils import seeding +# Nested data structures where the leaves are NumPy arrays, and internal nodes are dictionaries with string keys. DataType = Union[np.ndarray, Dict[str, "DataType"]] DatasetDict = Dict[str, DataType] +# Utility functions to check lengths and subselect data from a dataset dictionary. def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: + """ + Check the lengths of items in a dataset dictionary. + + Upon initializing a Dataset, _check_lengths is invoked to assert that all data arrays are of equal length. This is critical because: + The Dataset assumes a uniform length across features to support indexing, sampling, and batching. + Inconsistent lengths would lead to silent errors or runtime failures during sampling, model training, or evaluation. + + If all items are of the same length, return that length. + If items are of different lengths, raise an assertion error. + If the dataset is empty, return 0. + If the dataset is not a dictionary, raise a TypeError. + Args: + dataset_dict (DatasetDict): The dataset dictionary to check. + dataset_len (Optional[int]): The length to compare against, if provided. + Returns: + int: The length of the dataset if all items are of the same length. + Raises: + TypeError: If the dataset is not a dictionary or contains unsupported types. + """ + for v in dataset_dict.values(): if isinstance(v, dict): dataset_len = dataset_len or _check_lengths(v, dataset_len) + elif isinstance(v, np.ndarray): item_len = len(v) dataset_len = dataset_len or item_len assert dataset_len == item_len, "Inconsistent item lengths in the dataset." + else: raise TypeError("Unsupported type.") return dataset_len def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: + """ + Subselect enables flexible, consistent indexing into complex datasets to either split or filter data. + It is especially important when working with nested dictionary-based datasets — a common structure in modern machine learning. + Our dataset will be deeply structure and we want indexing to be applied consistently at every depth. + Used by Dataset.split() and Dataset.filter() methods to extract subsets of data based on indices. + + Args: + dataset_dict (DatasetDict): The dataset dictionary to subselect from. + index (np.ndarray): The indices to select from the dataset. + Returns: + DatasetDict: A new dataset dictionary with items selected based on the index. + Raises: + TypeError: If the dataset contains unsupported types. + """ + + new_dataset_dict = {} for k, v in dataset_dict.items(): if isinstance(v, dict): new_v = _subselect(v, index) + elif isinstance(v, np.ndarray): new_v = v[index] + else: raise TypeError("Unsupported type.") + new_dataset_dict[k] = new_v return new_dataset_dict -def _sample( - dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray -) -> DatasetDict: +def _sample(dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray) -> DatasetDict: + """ + This function is used to extract a subset of data from the dataset dictionary, which can be either a NumPy array or a nested dictionary structure. + Args: + dataset_dict (Union[np.ndarray, DatasetDict]): The dataset dictionary or array to sample from. + indx (np.ndarray): The indices to sample from the dataset. + Returns: + DatasetDict: A new dataset dictionary with items sampled based on the indices. + Raises: + TypeError: If the dataset is not a NumPy array or a dictionary. + """ + if isinstance(dataset_dict, np.ndarray): return dataset_dict[indx] elif isinstance(dataset_dict, dict): @@ -52,7 +106,11 @@ def _sample( class Dataset(object): - def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): + + def __init__(self, + dataset_dict: DatasetDict, + seed: Optional[int] = None ): + self.dataset_dict = dataset_dict self.dataset_len = _check_lengths(dataset_dict) @@ -63,6 +121,7 @@ def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): if seed is not None: self.seed(seed) + # @property decorator is used here to expose np_random as a read-only attribute @property def np_random(self) -> np.random.RandomState: if self._np_random is None: @@ -70,20 +129,43 @@ def np_random(self) -> np.random.RandomState: return self._np_random def seed(self, seed: Optional[int] = None) -> list: + """ + Set the random seed for reproducibility. Ensures a valid RNG is encapsulated behind an attribute-like interface. + Users do not need to call self.seed() or self.get_np_random() explicitly. Just self.np_random above. + + Args: + seed (Optional[int]): The seed to set. If None, a random seed will be generated. + Returns: + list: A list containing the seed used. + """ self._np_random, self._seed = seeding.np_random(seed) return [self._seed] def __len__(self) -> int: return self.dataset_len - def sample( - self, - batch_size: int, - keys: Optional[Iterable[str]] = None, - indx: Optional[np.ndarray] = None, - ) -> frozen_dict.FrozenDict: + def sample(self, + batch_size: int, + keys: Optional[Iterable[str]] = None, + indx: Optional[np.ndarray] = None, + ) -> frozen_dict.FrozenDict: + """ + Sample a random batch of data from the dataset. + + This method allows for flexible sampling of data, either by specifying keys or using random indices. + This is useful for training models, where you might want to sample a batch of data points from a larger dataset. + Args: + batch_size (int): The number of samples to return. + keys (Optional[Iterable[str]]): The keys to sample from the dataset. If None, all keys will be sampled. + indx (Optional[np.ndarray]): Specific indices to sample from. If None, random indices will be generated. + Returns: + frozen_dict.FrozenDict: A frozen dictionary containing the sampled data. + """ + if indx is None: if hasattr(self.np_random, "integers"): + + # Generate batch_size num of rand ints, each sampled uniformly at random from the range [0, len(self) - 1]. indx = self.np_random.integers(len(self), size=batch_size) else: indx = self.np_random.randint(len(self), size=batch_size) @@ -101,7 +183,17 @@ def sample( return frozen_dict.freeze(batch) - def sample_jax(self, batch_size: int, keys: Optional[Iterable[str]] = None): + def sample_jax(self, + batch_size: int, + keys: Optional[Iterable[str]] = None): + """ + Sample a batch of data from the dataset using JAX. This method is optimized for performance and can be used in JAX-based training loops. + Args: + batch_size (int): The number of samples to return. + keys (Optional[Iterable[str]]): The keys to sample from the dataset. If None, all keys will be sampled. + Returns: + Tuple[int, frozen_dict.FrozenDict]: A tuple containing the maximum index sampled and a frozen dictionary with the sampled data. + """ if not hasattr(self, "rng"): self.rng = jax.random.PRNGKey(self._seed or 42) @@ -118,7 +210,7 @@ def _sample_jax(rng, src, max_indx: int): return ( rng, indx.max(), - jax.tree_map(lambda d: jnp.take(d, indx, axis=0), src), + jax.tree.map(lambda d: jnp.take(d, indx, axis=0), src), ) self._sample_jax = _sample_jax @@ -129,12 +221,25 @@ def _sample_jax(rng, src, max_indx: int): return indx_max, sample def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: + """ + Split the dataset into two parts based on the given ratio. The first part will contain a fraction of the dataset specified by the ratio, and the second part will contain the rest. + This method is useful for creating training and testing datasets, where you want to split the data into two parts for model evaluation. + Args: + ratio (float): The fraction of the dataset to include in the first part. Must be between 0 and 1. + Returns: + Tuple[Dataset, Dataset]: A tuple containing two Dataset objects, the first part and the second part of the split dataset. + Raises: + AssertionError: If the ratio is not between 0 and 1. + """ + assert 0 < ratio and ratio < 1 - train_index = np.index_exp[: int(self.dataset_len * ratio)] - test_index = np.index_exp[int(self.dataset_len * ratio) :] + train_index = np.index_exp[: int(self.dataset_len * ratio)] # First part of the dataset. + test_index = np.index_exp[int(self.dataset_len * ratio) :] # Second part of the dataset. + # Shuffle the indices to ensure random sampling. index = np.arange(len(self), dtype=np.int32) self.np_random.shuffle(index) + train_index = index[: int(self.dataset_len * ratio)] test_index = index[int(self.dataset_len * ratio) :] @@ -143,53 +248,102 @@ def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: return Dataset(train_dataset_dict), Dataset(test_dataset_dict) def _trajectory_boundaries_and_returns(self) -> Tuple[list, list, list]: + """ + This method computes the boundaries of episodes in the dataset and calculates the returns for each episode. + It identifies the start and end indices of each episode based on the 'dones' array in the dataset. + The returns for each episode are calculated by summing the rewards within each episode. + This is useful for reinforcement learning tasks where episodes are defined by sequences of states, actions, and rewards. + Returns: + Tuple[list, list, list]: A tuple containing three lists: + - episode_starts: The starting indices of each episode. + - episode_ends: The ending indices of each episode. + - episode_returns: The total returns for each episode. + """ + + # Initialize lists (note plural) to store episode boundaries and returns. episode_starts = [0] episode_ends = [] + # Initialize variables to track the current episode return and a list to store returns. episode_return = 0 episode_returns = [] + # Iterate through the dataset to find episode boundaries and calculate returns. + # The dataset_dict is expected to have 'rewards' and 'dones' keys. for i in range(len(self)): episode_return += self.dataset_dict["rewards"][i] + # If the current index indicates the end of an episode, store the return and update boundaries. if self.dataset_dict["dones"][i]: episode_returns.append(episode_return) - episode_ends.append(i + 1) + episode_ends.append(i + 1) # Store the end index of the episode including the current index. + + # If this is not the last episode, set the start of the next episode. if i + 1 < len(self): episode_starts.append(i + 1) episode_return = 0.0 return episode_starts, episode_ends, episode_returns - def filter( - self, take_top: Optional[float] = None, threshold: Optional[float] = None + def filter(self, + take_top: Optional[float] = None, + threshold: Optional[float] = None ): + """ + Filter the dataset based on episode returns. This method allows you to keep only the episodes that meet a certain return threshold or are among the top returns. + This is useful for focusing on high-performing episodes in reinforcement learning tasks. + Args: + take_top (Optional[float]): If specified, keep only the top N percent of episodes based on their returns. + threshold (Optional[float]): If specified, keep only the episodes with returns greater than or equal to this value. + Raises: + AssertionError: If both take_top and threshold are specified, or if neither is specified. + """ assert (take_top is None and threshold is not None) or ( take_top is not None and threshold is None ) + # Create a tupe of lists of episode boundaries and returns. ( episode_starts, episode_ends, episode_returns, ) = self._trajectory_boundaries_and_returns() + # If no threshold is specified, calculate it based on the top N percent of returns. if take_top is not None: + # np.percentile gives the value below which XX% of the data lies threshold = np.percentile(episode_returns, 100 - take_top) + # create a boolean index array to filter episodes based on the threshold. bool_indx = np.full((len(self),), False, dtype=bool) for i in range(len(episode_returns)): if episode_returns[i] >= threshold: bool_indx[episode_starts[i] : episode_ends[i]] = True + # Return a new dataset dictionary containing only the episodes that meet the threshold. self.dataset_dict = _subselect(self.dataset_dict, bool_indx) + # Update the dataset length after filtering. self.dataset_len = _check_lengths(self.dataset_dict) def normalize_returns(self, scaling: float = 1000): + """ + Normalize the returns in the dataset to a specified scaling factor. This is useful for stabilizing training in reinforcement learning tasks. + Normally done per batch of episodes before training a model to update the policy. + + Args: + scaling (float): The scaling factor to normalize the returns. Default is 1000. + Raises: + AssertionError: If the dataset does not contain 'rewards' or 'dones' keys. + """ + + # Extract episode returns (_, _, episode_returns) = self._trajectory_boundaries_and_returns() - self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min( - episode_returns - ) + + # Normalize rewards in the dataset from 0-1 by dividing by the max range of returns. + self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min(episode_returns) + + # Scale rewards. Note that large scaling factors can lead to numerical instability, small scaleing factors can lead to poor learning. + # Scaling allows you to control the learning dynamics. self.dataset_dict["rewards"] *= scaling diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py new file mode 100644 index 00000000..e3ccf68c --- /dev/null +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -0,0 +1,440 @@ +import copy +from typing import Iterable, Optional + +import gym +import numpy as np +from serl_launcher.data.dataset import DatasetDict, _sample +from serl_launcher.data.replay_buffer import ReplayBuffer +from flax.core import frozen_dict + +class FractalSymmetryReplayBuffer(ReplayBuffer): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + workspace_width: int, + x_obs_idx : np.ndarray, + y_obs_idx : np.ndarray, + branch_method: str, + split_method: str, + img_keys: list, + kwargs: dict, + ): + + # Initialize values + self.debug_time = True + self.current_branch_count = 1 + self.update_max_traj_length = False + self.workspace_width = workspace_width + self.img_keys = img_keys + self._img_insert_index_ = 0 + + # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations + self.x_obs_idx = x_obs_idx + self.y_obs_idx = y_obs_idx + + # Set initial fractal config values + self.timestep = 0 + self.current_depth = 0 + + self.split_method = split_method + self.branch_method = branch_method + + self._handle_methods_(kwargs) + + # Warn about unused kwargs + for k in kwargs.keys(): + print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") + + # Account for images + self._num_stack = None + next_observation_space = None + if self.img_keys: + self.img_buffer = {} + next_observation_space_dict = copy.deepcopy(observation_space.spaces) + for k in img_keys: + img_obs_space = observation_space.spaces[k] + if self._num_stack is None: + self._num_stack = img_obs_space.shape[0] + img_buffer_size = ((self.expected_branches + capacity - 1) // self.expected_branches) * (self._num_stack + 1) + buffer_shape = list(img_obs_space.shape[1:]) + buffer_shape.insert(0, img_buffer_size) + self.img_buffer[k] = np.empty(buffer_shape, img_obs_space.dtype) + + observation_space.spaces[k] = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(), dtype=np.int32) + next_observation_space_dict.pop(k) + next_observation_space = gym.spaces.Dict(next_observation_space_dict) + + + # Init replay buffer class + super().__init__( + observation_space=observation_space, + next_observation_space=next_observation_space, + action_space=action_space, + capacity=capacity, + ) + + self.generate_transform_deltas() + + def _handle_method_arg_(self, value, method_type, method, kwargs): + if hasattr(self, value): + return + assert value in kwargs.keys(), f"\033[31mERROR: \033[0m{value} must be defined for {method_type} \"{method}\"" + setattr(FractalSymmetryReplayBuffer, value, kwargs[value]) + del kwargs[value] + + def _handle_methods_(self, kwargs): + + # Initialize branch_method + match self.branch_method: + case "fractal": + self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) + + self.branch = self.fractal_branch + if not self.split_method: + self.split_method = "time" + self.expected_branches = (self.branching_factor ** self.max_depth) ** 2 + + case "contraction": + self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) + + self.branch = self.fractal_contraction + if not self.split_method: + self.split_method = "time" + self.expected_branches = (self.branching_factor ** self.max_depth) ** 2 + + case "linear": + raise NotImplementedError("linear branch method is not yet implemented") + # self.branch = self.linear_branch + + + case "disassociated": + self._handle_method_arg_("min_branch_count", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("max_branch_count", "branch_method", self.branch_method, kwargs) + + if self.min_branch_count > self.max_branch_count: + raise ValueError(f"min_branch_count ({self.min_branch_count}) is larger than max_branch_count ({self.max_branch_count})") + + match kwargs["disassociated_type"]: + case "hourglass": + self.starting_branch_count = self.max_branch_count + case "octahedron": + self.starting_branch_count = self.min_branch_count + case _: + raise ValueError(f"incorrect value passed to disassociated_type") + + self.disassociated_type = kwargs["disassociated_type"] + del kwargs["disassociated_type"] + self.branch = self.disassociated_branch + if not self.split_method: + self.split_method = "time" + self.expected_branches = self.max_branch_count ** 2 + + case "constant": + self._handle_method_arg_("starting_branch_count", "branch_method", self.branch_method, kwargs) + + self.branch = self.constant_branch + if not self.split_method: + self.split_method = "never" + self.expected_branches = self.starting_branch_count ** 2 + + case _: + raise ValueError("incorrect value passed to branch_method") + + match self.split_method: + case "time": + self._handle_method_arg_("max_depth", "split_method", self.split_method, kwargs) + self._handle_method_arg_("max_traj_length", "split_method", self.split_method, kwargs) + self._handle_method_arg_("alpha", "split_method", self.split_method, kwargs) + + self.update_max_traj_length = True + self.split = self.time_split + + case "constant": + self.split = self.constant_split + + case "never": + self.split = self.never_split + + case _: + raise ValueError("incorrect value passed to split_method") + + if hasattr(self, "starting_branch_count"): + self.current_branch_count = self.starting_branch_count + + def generate_transform_deltas(self): + + obs_state = self.dataset_dict["observations"] + if self.img_keys: + obs_state = self.dataset_dict["observations"]["state"] + + obs_size = obs_state.shape[-1] + total_branches = self.current_branch_count ** 2 + + self.transform_deltas = np.zeros(shape=(total_branches, obs_size), dtype=np.float32) + + idx = np.arange(total_branches) + x_deltas, y_deltas = np.divmod(idx, self.current_branch_count) + + x_deltas = (2 * x_deltas + 1) * self.workspace_width / (2 * self.current_branch_count) + y_deltas = (2 * y_deltas + 1) * self.workspace_width / (2 * self.current_branch_count) + x_deltas = np.repeat(x_deltas, self.x_obs_idx.size) + y_deltas = np.repeat(y_deltas, self.y_obs_idx.size) + x_deltas = np.reshape(x_deltas, (total_branches, self.x_obs_idx.size)) + y_deltas = np.reshape(y_deltas, (total_branches, self.y_obs_idx.size)) + + self.transform_deltas[..., self.x_obs_idx] = x_deltas + self.transform_deltas[..., self.y_obs_idx] = y_deltas + + if self._num_stack: + self.transform_deltas = np.expand_dims(self.transform_deltas, axis=1) + self.transform_deltas = np.repeat(self.transform_deltas, self._num_stack, axis=1) + + def fractal_branch(self): + ''' + Computes the number of branches for the current depth using an exponential growth rule. + + This method implements a "fractal branching" strategy, where the number of branches + increases exponentially with depth. At each depth `d`, the number of branches is calculated as: + + num_branches = branching_factor ** current_depth + + where: + - branching_factor: The base number of branches at each split. + - current_depth: The current depth in the fractal tree (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' + # return a new number of branches = branching_factor ^ depth + return self.branching_factor ** self.current_depth + + def fractal_contraction(self): + ''' + Computes the number of branches for the current depth using a contraction rule. + + This method implements a "fractal contraction" branching strategy, where the number + of branches decreases exponentially with depth. At each depth `d`, the number of branches + is calculated as: + + num_branches = start_num / (branching_factor ** (d - 1)) + + where: + - start_num: The initial number of branches at depth 1. + - branching_factor: The factor by which the number of branches contracts at each depth. + - d: The current depth (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' + + return self.branching_factor ** (self.max_depth - self.current_depth + 1) + + def constant_branch(self): + ''' + Used to create pure translations with no further branching. + self.current_branch_count used to set the total number of transformations. + ''' + # return current number of branches + return self.current_branch_count + + def disassociated_branch(self): + ''' + Used to create branches for disassociated fractal methods. + self.min_branch_count specifies the mininum branch count desired during the fractal rollout + self.max_branch_count specifies the maximum branch count desired during the fractal rollout + self.disassociated_type specifies whether to expand and then contract or to contract and then expand + self.steps_per_depth specifies the number of timesteps to take before splitting + (calculated indirectly via self.max_traj_length / self.num_depth_sectors) + self.num_depth_sectors specifies the number of sectors the rollout should be divided into for even splitting + ''' + if self.disassociated_type == "hourglass": + return int((self.max_branch_count - self.min_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.min_branch_count) + elif self.disassociated_type == "octahedron": + return int((self.min_branch_count - self.max_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.max_branch_count) + + def linear_branch(self): + # return a new number of branches = branches_count + n + return self.current_branch_count + self.branching_factor + + def time_split(self, data_dict: DatasetDict): + if self.timestep % (self.max_traj_length//self.max_depth) or self.current_depth >= self.max_depth: + return False + self.current_depth += 1 + return True + + def constant_split(self, data_dict: DatasetDict): + self.current_depth += 1 + return True + + def never_split(self, data_dict: DatasetDict): + return False + + def insert_images(self, observation: dict): + for k in self.img_keys: + if self._num_stack: + self.img_buffer[k][self._img_insert_index_] = observation[k][0, ...] + else: + self.img_buffer[k][self._img_insert_index_] = observation[k] + self._img_insert_index_ += 1 + + def insert(self, data: DatasetDict): + + data_dict = copy.deepcopy(data) + + if self.img_keys: + obs = data_dict["observations"]["state"] + n_obs = data_dict["next_observations"]["state"] + else: + obs = data_dict["observations"] + n_obs = data_dict["next_observations"] + + actions = data_dict["actions"] + rewards = data_dict["rewards"] + masks = data_dict["masks"] + dones = data_dict["dones"] + + # Update number of branches if needed + if self.split(data_dict): + temp = self.current_branch_count + self.current_branch_count = self.branch() + # Update transform_deltas if needed + if temp != self.current_branch_count: + self.generate_transform_deltas() + + # Initialize to extreme x and y + base_diff = -self.workspace_width/2 + obs[..., self.x_obs_idx] += base_diff + obs[..., self.y_obs_idx] += base_diff + n_obs[..., self.x_obs_idx] += base_diff + n_obs[..., self.y_obs_idx] += base_diff + + # Transform transitions + num_transforms = self.current_branch_count ** 2 + + obs_shape = np.ones(len(obs.shape) + 1, dtype=int) + obs_shape[0] = num_transforms + obs = np.tile(obs, obs_shape) + n_obs = np.tile(n_obs, obs_shape) + actions = np.tile(actions, (num_transforms, 1)) + rewards = np.tile(rewards, num_transforms) + masks = np.tile(masks, num_transforms) + dones = np.tile(dones, num_transforms) + + obs += self.transform_deltas + n_obs += self.transform_deltas + + # Insert images + if self.img_keys: + if self.timestep == 0: + for i in range(self._num_stack): + self.insert_images(data_dict["observations"]) + self.insert_images(data_dict["next_observations"]) + + for k in self.img_keys: + data_dict["observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) + data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + data_dict["next_observations"].pop(k) + + # Pack back into dictionary and insert + if self.img_keys: + data_dict["observations"]["state"] = obs + data_dict["next_observations"]["state"] = n_obs + else: + data_dict["observations"] = obs + data_dict["next_observations"] = n_obs + + data_dict["actions"] = actions + data_dict["rewards"] = rewards + data_dict["masks"] = masks + data_dict["dones"] = dones + + super().insert(data_dict, batch_size=num_transforms) + + # Reset current_depth, timestep, and max_traj_length + self.timestep += 1 + if data_dict["dones"][0]: + self.current_depth = 0 + if self.update_max_traj_length: + self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + self.timestep = 0 + + def sample( + self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None, pack_obs_and_next_obs: bool = False, + ) -> frozen_dict.FrozenDict: + """Samples from the replay buffer. + + Args: + batch_size: Minibatch size. + keys: Keys to sample. + indx: Take indices instead of sampling. + pack_obs_and_next_obs: whether to pack img and next_img into one image. + It's useful when they have overlapping frames. + + Returns: + A frozen dictionary. + """ + # If no images, sample normally + if not self.img_keys: + return super().sample(batch_size, keys, indx) + + # Generate random indexes for sampling + if indx is None: + if hasattr(self.np_random, "integers"): + indx = self.np_random.integers(len(self), size=batch_size) + else: + indx = self.np_random.randint(len(self), size=batch_size) + + for i in range(batch_size): + while indx[i] >= self._size: + if hasattr(self.np_random, "integers"): + indx[i] = self.np_random.integers(len(self)) + else: + indx[i] = self.np_random.randint(len(self)) + else: + raise NotImplementedError() + + # Sample w/o images + if keys is None: + keys = self.dataset_dict.keys() + else: + assert "observations" in keys + + keys = list(keys) + keys.remove("observations") + + batch = super().sample(batch_size, keys, indx) + batch = batch.unfreeze() + + obs_keys = self.dataset_dict["observations"].keys() + obs_keys = list(obs_keys) + for k in self.img_keys: + obs_keys.remove(k) + + batch["observations"] = {} + for k in obs_keys: + batch["observations"][k] = _sample( + self.dataset_dict["observations"][k], indx + ) + + # Sample images + for k in self.img_keys: + obs_imgs = self.img_buffer[k] + obs_imgs = np.lib.stride_tricks.sliding_window_view( + obs_imgs, self._num_stack + 1, axis=0 + ) + obs_imgs = obs_imgs[self.dataset_dict["observations"][k][indx] - self._num_stack] + # transpose from (B, H, W, C, T) to (B, T, H, W, C) to follow jaxrl_m convention + obs_imgs = obs_imgs.transpose((0, 4, 1, 2, 3)) + + if pack_obs_and_next_obs: + batch["observations"][k] = obs_imgs + else: + batch["observations"][k] = obs_imgs[:, :-1, ...] + if "next_observations" in keys: + batch["next_observations"][k] = obs_imgs[:, 1:, ...] + + return frozen_dict.freeze(batch) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py new file mode 100644 index 00000000..f6b9e256 --- /dev/null +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -0,0 +1,306 @@ +import gym.wrappers +import numpy as np +import gym +from serl_launcher.utils.launcher import make_replay_buffer +from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper +from absl import app, flags +import franka_sim +# import pandas as pd + +FLAGS = flags.FLAGS + +flags.DEFINE_integer("capacity", 10, "Replay buffer capacity.") +flags.DEFINE_string("branch_method", "constant", "Method for determining the number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", "never", "Method for determining whether to change the number of transforms per dimension (x,y)") +flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") +flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only +flags.DEFINE_integer("max_steps",100,"Maximum steps") +flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("alpha",1,"alpha value") +# Density Workspace width +flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') + +def main(_): + + x_obs_idx = np.array([0, 4]) + y_obs_idx = np.array([1, 5]) + + # Initialize replay buffer + env = gym.make("PandaPickCubeVision-v0") + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=3, act_exec_horizon=None) + + # env = gym.make("PandaReachCube-v0") + # env = gym.wrappers.FlattenObservation(env) + + image_keys = [key for key in env.observation_space.keys() if key != "state"] + + their_buffer = make_replay_buffer( + env, + type="memory_efficient_replay_buffer", + capacity=FLAGS.capacity, + image_keys=image_keys, + + ) + + replay_buffer = make_replay_buffer( + env, + type="fractal_symmetry_replay_buffer", + capacity=FLAGS.capacity, + split_method=FLAGS.split_method, + branch_method=FLAGS.branch_method, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx= y_obs_idx, + image_keys=image_keys, + # max_depth=FLAGS.max_depth, + max_traj_length = 100, + # branching_factor=FLAGS.branching_factor, + # alpha = FLAGS.alpha, + starting_branch_count = FLAGS.starting_branch_count, + ) + + observation, info = env.reset() + action = env.action_space.sample() + next_observation, reward, terminated, truncated, info = env.step(action) + + # observation = np.zeros_like(observation) + for k in observation.keys(): + observation[k] = np.zeros_like(observation[k]) + next_observation[k] = np.ones_like(next_observation[k]) + + action = np.ones_like(action) + # next_observation = np.ones_like(next_observation) + reward = 1 + + data_dict = dict( + observations=observation, + next_observations=next_observation, + actions=action, + rewards=reward, + masks=not truncated and not terminated, + dones=truncated or terminated, + ) + + del env, observation, next_observation, action, reward, truncated, terminated, info, y_obs_idx, x_obs_idx, _ + + for i in range(6): + + replay_buffer.insert(data_dict) + their_buffer.insert(data_dict) + assert(replay_buffer.dataset_dict["observations"]["state"][i % 10].all() == their_buffer.dataset_dict["observations"]["state"][(i + 3) % 10].all()) + assert(replay_buffer.dataset_dict["next_observations"]["state"][i % 10].all() == their_buffer.dataset_dict["next_observations"]["state"][(i + 3) % 10].all()) + assert(replay_buffer.img_buffer["front"][replay_buffer.dataset_dict["observations"]["front"][i % 10]].all() == their_buffer.dataset_dict["observations"]["front"][(i + 3) % 10].all()) + + data_dict["observations"]["state"] += 1 + data_dict["next_observations"]["state"] += 1 + for k in image_keys: + data_dict["observations"][k] += 1 + data_dict["next_observations"][k] += 1 + + replay_buffer.sample(batch_size=3, indx=np.array([2,3,4])) + their_buffer.sample(batch_size=3, indx=np.array([5,6,7])) + + + + # branch() tests + + #------------------------------------------------------------------- + # Fractal Associative Expansions + #------------------------------------------------------------------- + replay_buffer.branching_factor = 3 + + replay_buffer.current_depth = 1 + result = replay_buffer.fractal_branch() + expected = 3 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 2 + result = replay_buffer.fractal_branch() + expected = 9 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 3 + result = replay_buffer.fractal_branch() + expected = 27 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 4 + result = replay_buffer.fractal_branch() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 0 + + del result, expected + + print("\033[32mTEST PASSED \033[0m fractal_branch() tests passed") + + #------------------------------------------------------------------- + # Fractal Associative Contractions + #------------------------------------------------------------------- + replay_buffer.branching_factor = 3 + + replay_buffer.current_depth = 1 + result = replay_buffer.fractal_contraction() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 2 + result = replay_buffer.fractal_contraction() + expected = 27 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 3 + result = replay_buffer.fractal_contraction() + expected = 9 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 4 + result = replay_buffer.fractal_contraction() + expected = 3 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 0 + + del result, expected + + print("\033[32mTEST PASSED \033[0m fractal_contraction() tests passed") + + #------------------------------------------------------------------- + # split() tests + #------------------------------------------------------------------- + + ## time + replay_buffer.max_steps = 100 + replay_buffer.max_depth = 4 + + replay_buffer.timestep = 0 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 25 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 50 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 75 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + + replay_buffer.timestep = 100 + result = replay_buffer.time_split(data_dict) + expected = False + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + del result, expected + + print("\033[32mTEST PASSED \033[0m time_split() test passed") + + # insert() tests + # insert() tests + initial_size = len(replay_buffer.dataset_dict['observations'][0]) * replay_buffer._insert_index % len(replay_buffer.dataset_dict['observations']) + + replay_buffer.insert(data_dict) + final_size = len(replay_buffer.dataset_dict['observations'][0]) * replay_buffer._insert_index % len(replay_buffer.dataset_dict['observations']) + + result = final_size > initial_size + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected buffer size to increase from {initial_size} to {final_size})" + del result, expected, initial_size, final_size + + + print("\033[32mTEST PASSED \033[0m insert() tests passed") + + #------------------------------------------------------------------- + # Fractal Expansions with workspace_width_modification + #------------------------------------------------------------------- + print('\nWorkspace width tests....') + + replay_buffer.branching_factor = 3 + replay_buffer.current_depth = 1 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + + replay_buffer.current_depth = 2 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + + replay_buffer.current_depth = 3 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + replay_buffer.current_depth = 0 + + del result, expected + + print("\n\033[32mTEST PASSED \033[0m workspace_width_method() test passed") + + + print("\nfinished!\n") + + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/replay_buffer.py b/serl_launcher/serl_launcher/data/replay_buffer.py index f7d798a2..0a1ab414 100644 --- a/serl_launcher/serl_launcher/data/replay_buffer.py +++ b/serl_launcher/serl_launcher/data/replay_buffer.py @@ -22,17 +22,24 @@ def _init_replay_dict( def _insert_recursively( - dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int + dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int, capacity: int, batch_size: int = None, ): if isinstance(dataset_dict, np.ndarray): - dataset_dict[insert_index] = data_dict + if batch_size: + if insert_index + batch_size > capacity: + dataset_dict[insert_index:capacity] = data_dict[0:(capacity - insert_index)] + dataset_dict[0:(insert_index + batch_size - capacity)] = data_dict[(capacity - insert_index):batch_size] + else: + dataset_dict[insert_index:(insert_index + batch_size)] = data_dict + else: + dataset_dict[insert_index] = data_dict elif isinstance(dataset_dict, dict): assert dataset_dict.keys() == data_dict.keys(), ( dataset_dict.keys(), data_dict.keys(), ) for k in dataset_dict.keys(): - _insert_recursively(dataset_dict[k], data_dict[k], insert_index) + _insert_recursively(dataset_dict[k], data_dict[k], insert_index, capacity, batch_size) else: raise TypeError() @@ -68,15 +75,20 @@ def __init__( def __len__(self) -> int: return self._size - def insert(self, data_dict: DatasetDict): - _insert_recursively(self.dataset_dict, data_dict, self._insert_index) + def insert(self, data_dict: DatasetDict, batch_size : int = None): + _insert_recursively(self.dataset_dict, data_dict, self._insert_index, self._capacity, batch_size) - self._insert_index = (self._insert_index + 1) % self._capacity - self._size = min(self._size + 1, self._capacity) + if batch_size: + self._insert_index = (self._insert_index + batch_size) % self._capacity + self._size = min(self._size + batch_size, self._capacity) + else: + self._insert_index = (self._insert_index + 1) % self._capacity + self._size = min(self._size + 1, self._capacity) def get_iterator(self, queue_size: int = 2, sample_args: dict = {}, device=None): # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device - # queue_size = 2 should be ok for one GPU. + + # queue_size = 2 should be ok for one GPU. See more at https://chatgpt.com/share/687af063-d6b0-8004-92b6-0e88b9c5f1e8 queue = collections.deque() def enqueue(n): diff --git a/serl_launcher/serl_launcher/networks/reward_classifier.py b/serl_launcher/serl_launcher/networks/reward_classifier.py index b05e196c..6b60f17c 100644 --- a/serl_launcher/serl_launcher/networks/reward_classifier.py +++ b/serl_launcher/serl_launcher/networks/reward_classifier.py @@ -67,7 +67,7 @@ def create_classifier( with open(pretrained_encoder_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 782221eb..d9cdf5f4 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -18,6 +18,7 @@ from serl_launcher.data.data_store import ( MemoryEfficientReplayBufferDataStore, ReplayBufferDataStore, + FractalSymmetryReplayBufferDataStore, ) ############################################################################## @@ -179,13 +180,17 @@ def make_trainer_config(port_number: int = 5488, broadcast_port: int = 5489): def make_wandb_logger( project: str = "agentlace", + name: str = "placeholder_run_name", description: str = "serl_launcher", + wandb_output_dir: str = None, debug: bool = False, + offline: bool = False, ): wandb_config = WandBLogger.get_default_config() wandb_config.update( { "project": project, + "name": name, "exp_descriptor": description, "tag": description, } @@ -193,7 +198,9 @@ def make_wandb_logger( wandb_logger = WandBLogger( wandb_config=wandb_config, variant={}, + wandb_output_dir=wandb_output_dir, debug=debug, + offline=offline ) return wandb_logger @@ -206,6 +213,13 @@ def make_replay_buffer( image_keys: list = [], # used only type=="memory_efficient_replay_buffer" preload_rlds_path: Optional[str] = None, preload_data_transform: Optional[callable] = None, + branch_method: str = None, # used only type=="fractal_symmetry_replay_buffer" + split_method : str = None, # used only type=="fractal_symmetry_replay_buffer" + workspace_width : float = None, # used only type=="fractal_symmetry_replay_buffer" + workspace_width_method : str = None, # used only type=="fractal_symmetry_replay_buffer" + x_obs_idx = None, + y_obs_idx = None, + **kwargs: dict # used only type=="fractal_symmetry_replay_buffer" ): """ This is the high-level helper function to @@ -215,7 +229,7 @@ def make_replay_buffer( - env: gym or gymasium environment - capacity: capacity of the replay buffer - rlds_logger_path: path to save RLDS logs - - type: support only for "replay_buffer" and "memory_efficient_replay_buffer" + - type: support only for "replay_buffer", "memory_efficient_replay_buffer", and "fractal_symmetry_replay_buffer" - image_keys: list of image keys, used only "memory_efficient_replay_buffer" - preload_rlds_path: path to preloaded RLDS trajectories - preload_data_transform: data transformation function for preloaded RLDS data @@ -254,9 +268,31 @@ def make_replay_buffer( rlds_logger=rlds_logger, image_keys=image_keys, ) + elif type == "fractal_symmetry_replay_buffer": + replay_buffer = FractalSymmetryReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + branch_method=branch_method, + split_method=split_method, + workspace_width=workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + rlds_logger=rlds_logger, + image_keys=image_keys, + kwargs=kwargs, + ) + else: raise ValueError(f"Unsupported replay_buffer_type: {type}") + # Load RLDS or oxe_envlogger recroded data with tfds.builder_from_directory. + # Choose number of episodes by passing split="train[:N%]" or split="test[:N%]" + # or ds = tfds.builder_from_directory(builder_dir).as_dataset(split="train").take(5) + # See more details: https://www.tensorflow.org/datasets/splits + # + # It's also possible to filter specirfic episodes, i.e. by time: + # ds = ds.filter(lambda ep: tf.strings.regex_full_match(ep['some/session_id'], "20250821_222412")) if preload_rlds_path: print(f" - Preloaded {preload_rlds_path} to replay buffer") dataset = tfds.builder_from_directory(preload_rlds_path).as_dataset(split="all") diff --git a/serl_launcher/serl_launcher/utils/train_utils.py b/serl_launcher/serl_launcher/utils/train_utils.py index 31037317..dd8816a7 100644 --- a/serl_launcher/serl_launcher/utils/train_utils.py +++ b/serl_launcher/serl_launcher/utils/train_utils.py @@ -108,7 +108,7 @@ def load_resnet10_params(agent, image_keys=("image",), public=True): with open(file_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) diff --git a/serl_launcher/serl_launcher/vision/data_augmentations.py b/serl_launcher/serl_launcher/vision/data_augmentations.py index 2c2440fa..f455bbf1 100644 --- a/serl_launcher/serl_launcher/vision/data_augmentations.py +++ b/serl_launcher/serl_launcher/vision/data_augmentations.py @@ -169,7 +169,7 @@ def hsv_to_rgb(h, s, v): def adjust_brightness(rgb_tuple, delta): - return jax.tree_map(lambda x: x + delta, rgb_tuple) + return jax.tree.map(lambda x: x + delta, rgb_tuple) def adjust_contrast(image, factor): @@ -177,7 +177,7 @@ def _adjust_contrast_channel(channel): mean = jnp.mean(channel, axis=(-2, -1), keepdims=True) return factor * (channel - mean) + mean - return jax.tree_map(_adjust_contrast_channel, image) + return jax.tree.map(_adjust_contrast_channel, image) def adjust_saturation(h, s, v, factor): @@ -256,7 +256,7 @@ def identity_fn(x, unused_rng, unused_param): def cond_fn(args, i): def clip(args): - return jax.tree_map(lambda arg: jnp.clip(arg, 0.0, 1.0), args) + return jax.tree.map(lambda arg: jnp.clip(arg, 0.0, 1.0), args) out = jax.lax.cond( should_apply & should_apply_color & (i == idx), @@ -275,7 +275,7 @@ def clip(args): random_hue_cond = _make_cond(_random_hue, idx=3) def _color_jitter(x): - rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1))) + rgb_tuple = tuple(jax.tree.map(jnp.squeeze, jnp.split(x, 3, axis=-1))) if shuffle: order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32)) else: diff --git a/serl_launcher/serl_launcher/wrappers/chunking.py b/serl_launcher/serl_launcher/wrappers/chunking.py index 175c9a5f..404a31c5 100644 --- a/serl_launcher/serl_launcher/wrappers/chunking.py +++ b/serl_launcher/serl_launcher/wrappers/chunking.py @@ -9,7 +9,7 @@ def stack_obs(obs): dict_list = {k: [dic[k] for dic in obs] for k in obs[0]} - return jax.tree_map( + return jax.tree.map( lambda x: np.stack(x), dict_list, is_leaf=lambda x: isinstance(x, list) ) diff --git a/serl_launcher/serl_launcher/wrappers/remap.py b/serl_launcher/serl_launcher/wrappers/remap.py index 7acb2d93..1d724dc1 100644 --- a/serl_launcher/serl_launcher/wrappers/remap.py +++ b/serl_launcher/serl_launcher/wrappers/remap.py @@ -31,4 +31,4 @@ def __init__(self, env: gym.Env, new_structure: Any): raise TypeError(f"Unsupported type {type(new_structure)}") def observation(self, observation): - return jax.tree_map(lambda x: observation[x], self.new_structure) + return jax.tree.map(lambda x: observation[x], self.new_structure) diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index 41c169f9..0c789762 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -2,24 +2,184 @@ from gym.spaces import flatten_space, flatten +# Optional resizers +import numpy as np +try: + import cv2 + _HAS_CV2 = True +except Exception: + _HAS_CV2 = False + +try: + from PIL import Image + _HAS_PIL = True +except Exception: + _HAS_PIL = False + +def _resize_hwc(img: np.ndarray, hw: tuple[int, int]) -> np.ndarray: + """ + Resize HxWxC image to (H', W', C) without changing dtype/range. + Supports cv2, PIL, or pure NumPy resizing. + If img is float32, it will be scaled to [0,1] if not already in that range. + If img is uint8, it will be resized as-is. + + Args: + img (np.ndarray): Input image in HxWxC format. + hw (tuple[int, int]): Target height and width (H', W'). + Returns: + np.ndarray: Resized image in H'xW'xC format. + """ + H, W = hw + if _HAS_CV2: + # cv2 wants (W, H) + return cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + + if _HAS_PIL: + pil = Image.fromarray(img if img.dtype == np.uint8 else np.clip(img, 0, 255).astype(np.uint8)) + pil = pil.resize((W, H), resample=Image.Resampling.BILINEAR) + out = np.asarray(pil) + if img.dtype != np.uint8: + # If original was float, map back to float [0,1] + out = out.astype(np.float32) / 255.0 + return out + + # Pure NumPy (simple nearest neighbor) + y_idx = (np.linspace(0, img.shape[0] - 1, H)).astype(np.int32) + x_idx = (np.linspace(0, img.shape[1] - 1, W)).astype(np.int32) + return img[y_idx][:, x_idx] + class SERLObsWrapper(gym.ObservationWrapper): """ - This observation wrapper treat the observation space as a dictionary - of a flattened state space and the images. + Observation wrapper for SERL environments. + Flattens the 'state' space and resizes images to a target height and width. + Supports both uint8 and float32 images, with optional normalization. + The observation space is a Dict with 'state' and resized image spaces. + + Args: + env (gym.Env): The environment to wrap. + target_hw (tuple[int, int]): Target height and width for resized images. + img_dtype (np.dtype): Data type for images, either np.uint8 or np.float32. + normalize (bool): If True, scales float32 images to [0,1]. + image_parent_key (str): Key in the observation dict where images are stored. + + Defaults to "images". + Returns: + gym.spaces.Dict: The new observation space with flattened state and resized images. """ - def __init__(self, env): + def __init__( + self, + env, + target_hw=(128, 128), # (H, W) for resized images + img_dtype=np.uint8, # np.uint8 for [0..255], or np.float32 for [0..1] + normalize=False, # if True and img_dtype=float32, scale to [0,1] + image_parent_key="images", # where images live in the original obs dict + ): super().__init__(env) - self.observation_space = gym.spaces.Dict( - { - "state": flatten_space(self.env.observation_space["state"]), - **(self.env.observation_space["images"]), - } - ) + assert isinstance(self.env.observation_space, gym.spaces.Dict), \ + "Expected Dict observation_space with keys {'state', 'images'}" + + # ---- Build new observation_space ---- + base_space = self.env.observation_space + assert "state" in base_space.spaces, "Missing 'state' in observation_space" + assert image_parent_key in base_space.spaces, f"Missing '{image_parent_key}' in observation_space" + img_space_dict = base_space.spaces[image_parent_key] + assert isinstance(img_space_dict, gym.spaces.Dict), \ + f"'{image_parent_key}' must be a Dict of image spaces" + + # Flattened state space + state_space = flatten_space(base_space.spaces["state"]) + + + # Image spaces (resized) + H, W = target_hw + image_spaces = {} + for k, sp in img_space_dict.spaces.items(): + # Assume HWC input; preserve channel count + if hasattr(sp, "shape") and sp.shape is not None: + if len(sp.shape) != 3: + raise ValueError(f"Image space '{k}' must be HxWxC; got shape {sp.shape}") + C = sp.shape[-1] + else: + raise ValueError(f"Image space '{k}' missing shape") + + if img_dtype == np.uint8: + low, high = 0, 255 + elif img_dtype == np.float32: + low, high = 0.0, 1.0 if normalize else float(getattr(sp, "high", 1.0)) + else: + raise ValueError("img_dtype must be np.uint8 or np.float32") + + image_spaces[k] = gym.spaces.Box( + low=low, + high=high, + shape=(H, W, C), + dtype=img_dtype, + ) + + # Final Dict space: {'state': ..., 'front': Box(...), 'wrist': Box(...), ...} + self.observation_space = gym.spaces.Dict({ + "state": state_space, + **image_spaces + }) + # self.observation_space = gym.spaces.Dict( + # { + # "state": flatten_space(self.env.observation_space["state"]), + # **(self.env.observation_space["images"]), + # } + # ) + + # Store config + self._target_hw = target_hw + self._img_dtype = img_dtype + self._normalize = normalize + self._image_parent_key = image_parent_key + + # def observation(self, obs): + # obs = { + # "state": flatten(self.env.observation_space["state"], obs["state"]), + # **(obs["images"]), + # } + # return obs def observation(self, obs): - obs = { - "state": flatten(self.env.observation_space["state"], obs["state"]), - **(obs["images"]), - } - return obs + # Flatten state using original (pre-flatten) state space definition + flat_state = flatten(self.env.observation_space.spaces["state"], obs["state"]) + + # Pull original images dict + imgs = obs[self._image_parent_key] + + # Resize & cast each image to match observation_space spec + out = {"state": flat_state} + for k, sp in self.observation_space.spaces.items(): + if k == "state": + continue + img = imgs[k] + # Ensure HWC + if img.ndim != 3: + raise ValueError(f"Image '{k}' must be HxWxC; got shape {img.shape}") + + # If float32 images in [0,1] but we want uint8, scale up before resize for best quality + want_uint8 = (self._img_dtype == np.uint8) + if want_uint8: + if img.dtype != np.uint8: + # Assume 0..1 range; if 0..255 float, clip and cast + img = np.clip(img, 0.0, 1.0) if img.max() <= 1.0 else np.clip(img/255.0, 0.0, 1.0) + img = (img * 255.0 + 0.5).astype(np.uint8) + resized = _resize_hwc(img, self._target_hw).astype(np.uint8) + else: + # float32 output + if img.dtype == np.uint8: + if self._normalize: + img = img.astype(np.float32) / 255.0 + else: + img = img.astype(np.float32) # keep 0..255 range if you really want that + else: + img = img.astype(np.float32) + if self._normalize and img.max() > 1.0: + img = img / 255.0 + resized = _resize_hwc(img, self._target_hw).astype(np.float32) + + out[k] = resized + + return out \ No newline at end of file diff --git a/serl_robot_infra/franka_env/envs/peg_env/config.py b/serl_robot_infra/franka_env/envs/peg_env/config.py index d2bcae9b..c027b5a9 100644 --- a/serl_robot_infra/franka_env/envs/peg_env/config.py +++ b/serl_robot_infra/franka_env/envs/peg_env/config.py @@ -7,24 +7,24 @@ class PegEnvConfig(DefaultEnvConfig): SERVER_URL: str = "http://127.0.0.1:5000/" REALSENSE_CAMERAS = { - "wrist_1": "130322274175", - "wrist_2": "127122270572", + "wrist_1": "218622274083", + "wrist_2": "218622271526", } TARGET_POSE = np.array( - [ - 0.5906439143742067, - 0.07771711953459341, - 0.0937835826958042, - 3.1099675, - 0.0146619, - -0.0078615, + [0.6179721091801964, + -0.08069386463219706, + 0.07628962570607248, + -3.1161359527499712, + 0.04124456532930543, + 1.5939026317635385, ] + ) RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0]) REWARD_THRESHOLD: np.ndarray = np.array([0.01, 0.01, 0.01, 0.2, 0.2, 0.2]) APPLY_GRIPPER_PENALTY = False ACTION_SCALE = np.array([0.02, 0.1, 1]) - RANDOM_RESET = True + RANDOM_RESET = False #Turn to true after basic task is finished RANDOM_XY_RANGE = 0.05 RANDOM_RZ_RANGE = np.pi / 6 ABS_POSE_LIMIT_LOW = np.array( diff --git a/serl_robot_infra/robot_servers/franka_server.py b/serl_robot_infra/robot_servers/franka_server.py index 0c582257..5d45378e 100644 --- a/serl_robot_infra/robot_servers/franka_server.py +++ b/serl_robot_infra/robot_servers/franka_server.py @@ -17,7 +17,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "robot_ip", "172.16.0.2", "IP address of the franka robot's controller box" + "robot_ip", "10.200.110.10", "IP address of the franka robot's controller box" ) flags.DEFINE_string( "gripper_ip", "192.168.1.114", "IP address of the robotiq gripper if being used"