From 475f793b13d994004d2a677fdb6b14cf58017437 Mon Sep 17 00:00:00 2001 From: CleiverRuiz-LU Date: Fri, 4 Apr 2025 15:56:25 -0500 Subject: [PATCH 001/141] init luanch file for debugging PickandPlace SAC --- .../async_sac_state_sim/.vscode/launch.json | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 examples/async_sac_state_sim/.vscode/launch.json diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json new file mode 100644 index 00000000..4010fde4 --- /dev/null +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -0,0 +1,110 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_sac_state_sim.py", + // Command-line arguments matching run_learner.sh + "args": [ + "--learner", // REQUIRED: Indicates this is a learner instance + "--env", "PandaPickCube-v0", // Environment to use + "--exp_name=serl_dev_sim_test", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--max_steps", "1000000", // Maximum training steps + "--training_starts", "1000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--save_model", "True", // Whether to save model checkpoints + "--checkpoint_period", "10000", // How often to save checkpoints + "--checkpoint_path", "./checkpoints", // Where to save checkpoints + "--debug" // Debug mode (disables wandb) + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".05" // Limit JAX memory usage to 5% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_sac_state_sim.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--actor", // REQUIRED: Indicates this is an actor instance + "--render", // Enable rendering the environment + "--env", "PandaPickCube-v0", // Environment to use + "--exp_name=serl_dev_sim_test", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1000", // Number of random steps at beginning + "--debug" // Debug mode (disables wandb) + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".05" // Limit JAX memory usage to 5% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file From 44321d84a08a77040c911a8578b062846f418832 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 10 Apr 2025 16:52:02 -0500 Subject: [PATCH 002/141] Peg Insertion Task demos with corrected flags --- examples/async_peg_insert_drq/README.md | 11 + .../async_drq_randomized.py | 1 + .../docs/TestSpaceMouse.py | 286 ++++++++++++++++++ .../async_peg_insert_drq/docs/listGym_env.py | 5 + examples/async_peg_insert_drq/record_demo.py | 1 + examples/async_peg_insert_drq/run_actor.sh | 2 +- examples/async_peg_insert_drq/run_learner.sh | 4 +- 7 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 examples/async_peg_insert_drq/README.md create mode 100644 examples/async_peg_insert_drq/docs/TestSpaceMouse.py create mode 100644 examples/async_peg_insert_drq/docs/listGym_env.py diff --git a/examples/async_peg_insert_drq/README.md b/examples/async_peg_insert_drq/README.md new file mode 100644 index 00000000..068ec121 --- /dev/null +++ b/examples/async_peg_insert_drq/README.md @@ -0,0 +1,11 @@ + + + + + + + + + +for record demo +make sure to connect to RealSense camera diff --git a/examples/async_peg_insert_drq/async_drq_randomized.py b/examples/async_peg_insert_drq/async_drq_randomized.py index 4fd76f08..7b8ac779 100644 --- a/examples/async_peg_insert_drq/async_drq_randomized.py +++ b/examples/async_peg_insert_drq/async_drq_randomized.py @@ -45,6 +45,7 @@ flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") flags.DEFINE_integer("critic_actor_ratio", 4, "critic to actor update ratio.") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") diff --git a/examples/async_peg_insert_drq/docs/TestSpaceMouse.py b/examples/async_peg_insert_drq/docs/TestSpaceMouse.py new file mode 100644 index 00000000..0e8b86ce --- /dev/null +++ b/examples/async_peg_insert_drq/docs/TestSpaceMouse.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +SpaceMouse Test Script (Fixed Version) + +This script tests if a 3Dconnexion SpaceMouse is properly connected and accessible. +It's designed to handle different versions of pyspacemouse and potential API inconsistencies. + +Usage: + python FixedTestSpaceMouse.py + +Press Ctrl+C to exit the program. +""" + +import time +import sys +import traceback +import threading + +# Try importing the pyspacemouse library +try: + import pyspacemouse + print("Successfully imported pyspacemouse library") + + # Print the library version if available + try: + version = getattr(pyspacemouse, "__version__", "unknown") + print(f"pyspacemouse version: {version}") + except: + print("Could not determine pyspacemouse version") +except ImportError as e: + print(f"Failed to import pyspacemouse: {e}") + print("Try installing it with: pip install pyspacemouse") + sys.exit(1) +except Exception as e: + print(f"Unexpected error importing pyspacemouse: {e}") + traceback.print_exc() + sys.exit(1) + +def print_device_info(device): + """Print detailed information about the SpaceMouse device.""" + print("\n===== DEVICE INFORMATION =====") + + # Check if device is a string or an object + if isinstance(device, str): + print(f"Device information: {device}") + return + + # Try to access device attributes safely + try: + attrs = [ + ("Manufacturer", "manufacturer_string", "Unknown"), + ("Product", "product_string", "Unknown"), + ("Vendor ID", "vendor_id", "Unknown"), + ("Product ID", "product_id", "Unknown"), + ("Serial Number", "serial_number", "Unknown"), + ("Release Number", "release_number", "Unknown"), + ("Interface Number", "interface_number", "Unknown") + ] + + for label, attr, default in attrs: + value = getattr(device, attr, default) + if attr in ["vendor_id", "product_id"] and value != "Unknown": + print(f"{label}: 0x{value:04x}") + else: + print(f"{label}: {value}") + except Exception as e: + print(f"Error accessing device attributes: {e}") + print(f"Raw device data: {device}") + + print("===============================\n") + +def test_device_connection(): + """Test if the SpaceMouse device is connected and accessible.""" + print("Attempting to detect SpaceMouse devices...") + + try: + # List all available devices + try: + devices = pyspacemouse.list_devices() + print(f"list_devices() returned: {devices}") + + if not devices: + print("No SpaceMouse devices found!") + return None + + print(f"Found {len(devices)} devices.") + + # Safely print device information + for i, device in enumerate(devices): + print(f"Device {i+1}:") + if isinstance(device, str): + print(f" {device}") + else: + try: + vendor_id = getattr(device, "vendor_id", "Unknown") + product_id = getattr(device, "product_id", "Unknown") + product_string = getattr(device, "product_string", "Unknown Device") + + if vendor_id != "Unknown" and product_id != "Unknown": + print(f" {product_string} (Vendor ID: 0x{vendor_id:04x}, Product ID: 0x{product_id:04x})") + else: + print(f" {product_string}") + except Exception as e: + print(f" Error printing device {i+1} info: {e}") + print(f" Raw device data: {device}") + + # Use the first device found + return devices[0] + + except AttributeError: + # Alternative approach if list_devices doesn't work as expected + print("list_devices() method not working as expected, trying open() directly...") + if pyspacemouse.open(): + print("Successfully opened a SpaceMouse device directly") + return "SpaceMouse Device" + else: + print("Failed to open any SpaceMouse device directly") + return None + + except Exception as e: + print(f"Error detecting devices: {e}") + traceback.print_exc() + return None + +def open_device(device): + """Try to open the SpaceMouse device.""" + print("Attempting to open the SpaceMouse...") + try: + # Check if device is already open + if hasattr(pyspacemouse, "is_open") and pyspacemouse.is_open(): + print("Device is already open") + return True + + # Try to open the device + if isinstance(device, str): + # If device is a string, just try to open any device + if pyspacemouse.open(callback=None): + print("Successfully opened a SpaceMouse device") + return True + else: + # Try to open the specific device + try: + if pyspacemouse.open(callback=None, device=device): + print("Successfully opened the SpaceMouse device") + return True + except TypeError: + # If device parameter isn't supported, try without it + if pyspacemouse.open(callback=None): + print("Successfully opened a SpaceMouse device") + return True + + print("Failed to open the SpaceMouse device") + print("Check if another program is already using it") + return False + + except Exception as e: + print(f"Error opening device: {e}") + traceback.print_exc() + return False + +def safe_read(): + """Safely read from the device, handling potential API differences.""" + try: + state = pyspacemouse.read() + return state + except Exception as e: + print(f"Error reading from device: {e}") + return None + +def monitor_button_state(stop_event): + """Monitor and print button state changes in a separate thread.""" + last_buttons = None + + while not stop_event.is_set(): + try: + state = safe_read() + if state and hasattr(state, "buttons") and state.buttons is not None: + current_buttons = state.buttons + if last_buttons != current_buttons: + buttons_pressed = [i for i, pressed in enumerate(current_buttons) if pressed] + if buttons_pressed: + print(f"Buttons pressed: {buttons_pressed}") + last_buttons = current_buttons.copy() if hasattr(current_buttons, "copy") else current_buttons + time.sleep(0.05) # Small sleep to prevent 100% CPU usage + except Exception as e: + print(f"Error in button monitoring thread: {e}") + time.sleep(1) # Wait a bit longer on error + +def main(): + """Main function to test the SpaceMouse.""" + print("SpaceMouse Test Script (Fixed Version)") + print("-------------------------------------") + + # First, check if we can detect any devices + device = test_device_connection() + if not device: + print("\nNo SpaceMouse detected. Please ensure:") + print("1. The device is connected to your computer") + print("2. You have installed libhidapi (sudo apt-get install libhidapi-dev libhidapi-hidraw0)") + print("3. You have proper permissions (sudo usermod -a -G plugdev $USER)") + print("4. You've created proper udev rules if needed") + sys.exit(1) + + # Print detailed device information + print_device_info(device) + + # Try to open the device + if not open_device(device): + sys.exit(1) + + # Test basic functionality + print("\nTesting basic device functionality...") + test_state = safe_read() + if test_state: + print("Successfully read initial state from device:") + try: + attrs = ["x", "y", "z", "roll", "pitch", "yaw", "buttons"] + for attr in attrs: + if hasattr(test_state, attr): + value = getattr(test_state, attr) + print(f" {attr}: {value}") + else: + print(f" {attr}: Not available") + except Exception as e: + print(f"Error reading attributes: {e}") + print(f"Raw state: {test_state}") + else: + print("Could not read initial state from device!") + print("The device might be connected but not functioning correctly.") + sys.exit(1) + + # Create a thread to monitor button presses + stop_event = threading.Event() + button_thread = threading.Thread(target=monitor_button_state, args=(stop_event,)) + button_thread.daemon = True + button_thread.start() + + # Monitor device movement + print("\nMove your SpaceMouse to see the values") + print("Press Ctrl+C to exit") + print("\nReading SpaceMouse state...") + + try: + last_print_time = time.time() + while True: + # Read the current state + state = safe_read() + + if state: + current_time = time.time() + + # Check if any movement data is available + x = getattr(state, "x", 0) or 0 + y = getattr(state, "y", 0) or 0 + z = getattr(state, "z", 0) or 0 + roll = getattr(state, "roll", 0) or 0 + pitch = getattr(state, "pitch", 0) or 0 + yaw = getattr(state, "yaw", 0) or 0 + + if any(abs(val) > 0.01 for val in [x, y, z, roll, pitch, yaw]): + # Print at most 10 times per second + if current_time - last_print_time >= 0.1: + print(f"\rPosition: X:{x:6.2f} Y:{y:6.2f} Z:{z:6.2f} | " + f"Rotation: Roll:{roll:6.2f} Pitch:{pitch:6.2f} Yaw:{yaw:6.2f}", + end="", flush=True) + last_print_time = current_time + + time.sleep(0.01) # Small sleep to prevent 100% CPU usage + + except KeyboardInterrupt: + print("\n\nExiting...") + except Exception as e: + print(f"\n\nError reading from device: {e}") + traceback.print_exc() + finally: + # Clean up + stop_event.set() + button_thread.join(timeout=1.0) + try: + pyspacemouse.close() + print("\nClosed SpaceMouse connection") + except Exception as e: + print(f"\nError closing device: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/async_peg_insert_drq/docs/listGym_env.py b/examples/async_peg_insert_drq/docs/listGym_env.py new file mode 100644 index 00000000..c6179294 --- /dev/null +++ b/examples/async_peg_insert_drq/docs/listGym_env.py @@ -0,0 +1,5 @@ +import gymnasium as gym + +# Correct method to list environments +for env in gym.envs.registry.keys(): + print(env) diff --git a/examples/async_peg_insert_drq/record_demo.py b/examples/async_peg_insert_drq/record_demo.py index 5a0fd02b..7faef6e3 100644 --- a/examples/async_peg_insert_drq/record_demo.py +++ b/examples/async_peg_insert_drq/record_demo.py @@ -29,6 +29,7 @@ env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) obs, _ = env.reset() + #obs = env.reset() #Changed to this to for potential older Gym version. transitions = [] success_count = 0 diff --git a/examples/async_peg_insert_drq/run_actor.sh b/examples/async_peg_insert_drq/run_actor.sh index a251e756..b21f954c 100644 --- a/examples/async_peg_insert_drq/run_actor.sh +++ b/examples/async_peg_insert_drq/run_actor.sh @@ -9,4 +9,4 @@ python async_drq_randomized.py "$@" \ --random_steps 0 \ --training_starts 200 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \ + --demo_path peg_insert_20_demos_2025-04-10_15-36-26.pkl \ diff --git a/examples/async_peg_insert_drq/run_learner.sh b/examples/async_peg_insert_drq/run_learner.sh index c2823a19..2a6db997 100644 --- a/examples/async_peg_insert_drq/run_learner.sh +++ b/examples/async_peg_insert_drq/run_learner.sh @@ -11,6 +11,6 @@ python async_drq_randomized.py "$@" \ --batch_size 256 \ --eval_period 2000 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \ + --demo_path peg_insert_20_demos_2025-04-10_15-36-26.pkl \ --checkpoint_period 1000 \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_peg_insert_drq/5x5_20degs_20demos_rand_peg_insert_097 + --checkpoint_path /home/student/robot/robot_ws/src/serl/examples/async_peg_insert_drq/checkpoints \ No newline at end of file From dcff2b2c59b5052e418aefcc5947ce6982b4a3a3 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 10 Apr 2025 16:53:53 -0500 Subject: [PATCH 003/141] insertion task config and franka_server files corrected --- .../franka_env/envs/peg_env/config.py | 15 ++++----------- serl_robot_infra/robot_servers/franka_server.py | 2 +- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/serl_robot_infra/franka_env/envs/peg_env/config.py b/serl_robot_infra/franka_env/envs/peg_env/config.py index d2bcae9b..f59c133e 100644 --- a/serl_robot_infra/franka_env/envs/peg_env/config.py +++ b/serl_robot_infra/franka_env/envs/peg_env/config.py @@ -7,24 +7,17 @@ class PegEnvConfig(DefaultEnvConfig): SERVER_URL: str = "http://127.0.0.1:5000/" REALSENSE_CAMERAS = { - "wrist_1": "130322274175", - "wrist_2": "127122270572", + "wrist_1": "218622274083", + "wrist_2": "218622271526", } TARGET_POSE = np.array( - [ - 0.5906439143742067, - 0.07771711953459341, - 0.0937835826958042, - 3.1099675, - 0.0146619, - -0.0078615, - ] + [0.45744842827647436,0.0761972695608443,0.07560009685145638,3.136626282737411,-0.02836925376945154,1.596492563779626] ) RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0]) REWARD_THRESHOLD: np.ndarray = np.array([0.01, 0.01, 0.01, 0.2, 0.2, 0.2]) APPLY_GRIPPER_PENALTY = False ACTION_SCALE = np.array([0.02, 0.1, 1]) - RANDOM_RESET = True + RANDOM_RESET = False #Turn to true after basic task is finished RANDOM_XY_RANGE = 0.05 RANDOM_RZ_RANGE = np.pi / 6 ABS_POSE_LIMIT_LOW = np.array( diff --git a/serl_robot_infra/robot_servers/franka_server.py b/serl_robot_infra/robot_servers/franka_server.py index 0c582257..5d45378e 100644 --- a/serl_robot_infra/robot_servers/franka_server.py +++ b/serl_robot_infra/robot_servers/franka_server.py @@ -17,7 +17,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "robot_ip", "172.16.0.2", "IP address of the franka robot's controller box" + "robot_ip", "10.200.110.10", "IP address of the franka robot's controller box" ) flags.DEFINE_string( "gripper_ip", "192.168.1.114", "IP address of the robotiq gripper if being used" From 7e23bbd16d9bd4a887f643852784ba2a63da8d60 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 10 Apr 2025 16:59:28 -0500 Subject: [PATCH 004/141] Pick and Place bash script corrections --- examples/async_sac_state_sim/async_sac_state_sim.py | 2 +- examples/async_sac_state_sim/run_actor.sh | 10 +++++++++- examples/async_sac_state_sim/run_learner.sh | 6 +++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 90a1acf8..27d43c3c 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -228,7 +228,7 @@ def stats_callback(type: str, payload: dict) -> dict: batch = next(replay_iterator) with timer.context("train"): - agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.utd_ratio) + agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.critic_actor_ratio) agent = jax.block_until_ready(agent) # publish the updated network diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 57677916..b5e73f5b 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -7,4 +7,12 @@ python async_sac_state_sim.py "$@" \ --exp_name=serl_dev_sim_test \ --seed 0 \ --random_steps 1000 \ - --debug + --max_steps 1000000 \ + --training_starts 1000 \ + --critic_actor_ratio 8 \ + --batch_size 256 \ + --save_model True \ + --checkpoint_period 10000 \ + --checkpoint_path "/home/student/robot/robot_ws/src/serl/examples/async_sac_state_sim/checkpoints" \ + #--debug # wandb is disabled when debug + diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 10a203c1..11ea66ef 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -5,7 +5,11 @@ python async_sac_state_sim.py "$@" \ --env PandaPickCube-v0 \ --exp_name=serl_dev_sim_test \ --seed 0 \ + --max_steps 1000000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --debug # wandb is disabled when debug + --save_model True \ + --checkpoint_period 10000 \ + --checkpoint_path "/home/student/robot/robot_ws/src/serl/examples/async_sac_state_sim/checkpoints" \ + #--debug # wandb is disabled when debug From 422d207e717c6ce13bc9106327daf5cf0b6ec456 Mon Sep 17 00:00:00 2001 From: bison Date: Wed, 11 Jun 2025 16:46:55 -0500 Subject: [PATCH 005/141] Fix bug in the way Mujoco renders the panda_pick_gym_env.py and shows when calling the test/test_gym_env_human.py. Need to add width and height & cameraid to the MujocoRenderer instantiation. --- franka_sim/README.md | 2 +- franka_sim/franka_sim/envs/panda_pick_gym_env.py | 3 +++ franka_sim/franka_sim/test/test_gym_env_human.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/franka_sim/README.md b/franka_sim/README.md index 486f0e05..37fef730 100644 --- a/franka_sim/README.md +++ b/franka_sim/README.md @@ -8,7 +8,7 @@ It includes a state-based and a vision-based Franka lift cube task environment. - run `pip install -r requirements.txt` to install sim dependencies. # Explore the Environments -- Run `python franka_sim/test/test_gym_env_human.py` to launch a display window and visualize the task. +- Run `python3 franka_sim/test/test_gym_env_human.py` to launch a display window and visualize the task. # Credits: - This simulation is initially built by [Kevin Zakka](https://kzakka.com/). diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index f0a8d25b..58ecdc4b 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,6 +144,9 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, + width=960, + height=960, + camera_id=0 ) self._viewer.render(self.render_mode) diff --git a/franka_sim/franka_sim/test/test_gym_env_human.py b/franka_sim/franka_sim/test/test_gym_env_human.py index 592bb24e..971a4516 100644 --- a/franka_sim/franka_sim/test/test_gym_env_human.py +++ b/franka_sim/franka_sim/test/test_gym_env_human.py @@ -6,7 +6,7 @@ from franka_sim import envs -env = envs.PandaPickCubeGymEnv(action_scale=(0.1, 1)) +env = envs.PandaPickCubeGymEnv(action_scale=(0.1, 1)) # or render_mode="human") action_spec = env.action_space From a3bae6dcd1ff726226b6827a6fccb62519e2218f Mon Sep 17 00:00:00 2001 From: bison Date: Thu, 12 Jun 2025 15:58:42 -0500 Subject: [PATCH 006/141] Documented dataset.py. Included detailed explanatory notes and docstrings to facilitate student understanding. --- serl_launcher/serl_launcher/data/dataset.py | 194 ++++++++++++++++++-- 1 file changed, 174 insertions(+), 20 deletions(-) diff --git a/serl_launcher/serl_launcher/data/dataset.py b/serl_launcher/serl_launcher/data/dataset.py index 0760fb02..f47eae9a 100644 --- a/serl_launcher/serl_launcher/data/dataset.py +++ b/serl_launcher/serl_launcher/data/dataset.py @@ -1,45 +1,99 @@ -from functools import partial +from functools import partial # unused. from typing import Dict, Iterable, Optional, Tuple, Union import jax import jax.numpy as jnp import numpy as np + +# frozen_dict is an immutable & nested dictionary-like structure to manage parameters and states in NNs. Key in Flax, model params passed explicitly vs stored in mutable objects. from flax.core import frozen_dict from gym.utils import seeding +# Nested data structures where the leaves are NumPy arrays, and internal nodes are dictionaries with string keys. DataType = Union[np.ndarray, Dict[str, "DataType"]] DatasetDict = Dict[str, DataType] +# Utility functions to check lengths and subselect data from a dataset dictionary. def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: + """ + Check the lengths of items in a dataset dictionary. + + Upon initializing a Dataset, _check_lengths is invoked to assert that all data arrays are of equal length. This is critical because: + The Dataset assumes a uniform length across features to support indexing, sampling, and batching. + Inconsistent lengths would lead to silent errors or runtime failures during sampling, model training, or evaluation. + + If all items are of the same length, return that length. + If items are of different lengths, raise an assertion error. + If the dataset is empty, return 0. + If the dataset is not a dictionary, raise a TypeError. + Args: + dataset_dict (DatasetDict): The dataset dictionary to check. + dataset_len (Optional[int]): The length to compare against, if provided. + Returns: + int: The length of the dataset if all items are of the same length. + Raises: + TypeError: If the dataset is not a dictionary or contains unsupported types. + """ + for v in dataset_dict.values(): if isinstance(v, dict): dataset_len = dataset_len or _check_lengths(v, dataset_len) + elif isinstance(v, np.ndarray): item_len = len(v) dataset_len = dataset_len or item_len assert dataset_len == item_len, "Inconsistent item lengths in the dataset." + else: raise TypeError("Unsupported type.") return dataset_len def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: + """ + Subselect enables flexible, consistent indexing into complex datasets to either split or filter data. + It is especially important when working with nested dictionary-based datasets — a common structure in modern machine learning. + Our dataset will be deeply structure and we want indexing to be applied consistently at every depth. + Used by Dataset.split() and Dataset.filter() methods to extract subsets of data based on indices. + + Args: + dataset_dict (DatasetDict): The dataset dictionary to subselect from. + index (np.ndarray): The indices to select from the dataset. + Returns: + DatasetDict: A new dataset dictionary with items selected based on the index. + Raises: + TypeError: If the dataset contains unsupported types. + """ + + new_dataset_dict = {} for k, v in dataset_dict.items(): if isinstance(v, dict): new_v = _subselect(v, index) + elif isinstance(v, np.ndarray): new_v = v[index] + else: raise TypeError("Unsupported type.") + new_dataset_dict[k] = new_v return new_dataset_dict -def _sample( - dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray -) -> DatasetDict: +def _sample(dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray) -> DatasetDict: + """ + This function is used to extract a subset of data from the dataset dictionary, which can be either a NumPy array or a nested dictionary structure. + Args: + dataset_dict (Union[np.ndarray, DatasetDict]): The dataset dictionary or array to sample from. + indx (np.ndarray): The indices to sample from the dataset. + Returns: + DatasetDict: A new dataset dictionary with items sampled based on the indices. + Raises: + TypeError: If the dataset is not a NumPy array or a dictionary. + """ + if isinstance(dataset_dict, np.ndarray): return dataset_dict[indx] elif isinstance(dataset_dict, dict): @@ -52,7 +106,11 @@ def _sample( class Dataset(object): - def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): + + def __init__(self, + dataset_dict: DatasetDict, + seed: Optional[int] = None ): + self.dataset_dict = dataset_dict self.dataset_len = _check_lengths(dataset_dict) @@ -63,6 +121,7 @@ def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): if seed is not None: self.seed(seed) + # @property decorator is used here to expose np_random as a read-only attribute @property def np_random(self) -> np.random.RandomState: if self._np_random is None: @@ -70,20 +129,43 @@ def np_random(self) -> np.random.RandomState: return self._np_random def seed(self, seed: Optional[int] = None) -> list: + """ + Set the random seed for reproducibility. Ensures a valid RNG is encapsulated behind an attribute-like interface. + Users do not need to call self.seed() or self.get_np_random() explicitly. Just self.np_random above. + + Args: + seed (Optional[int]): The seed to set. If None, a random seed will be generated. + Returns: + list: A list containing the seed used. + """ self._np_random, self._seed = seeding.np_random(seed) return [self._seed] def __len__(self) -> int: return self.dataset_len - def sample( - self, - batch_size: int, - keys: Optional[Iterable[str]] = None, - indx: Optional[np.ndarray] = None, - ) -> frozen_dict.FrozenDict: + def sample(self, + batch_size: int, + keys: Optional[Iterable[str]] = None, + indx: Optional[np.ndarray] = None, + ) -> frozen_dict.FrozenDict: + """ + Sample a random batch of data from the dataset. + + This method allows for flexible sampling of data, either by specifying keys or using random indices. + This is useful for training models, where you might want to sample a batch of data points from a larger dataset. + Args: + batch_size (int): The number of samples to return. + keys (Optional[Iterable[str]]): The keys to sample from the dataset. If None, all keys will be sampled. + indx (Optional[np.ndarray]): Specific indices to sample from. If None, random indices will be generated. + Returns: + frozen_dict.FrozenDict: A frozen dictionary containing the sampled data. + """ + if indx is None: if hasattr(self.np_random, "integers"): + + # Generate batch_size num of rand ints, each sampled uniformly at random from the range [0, len(self) - 1]. indx = self.np_random.integers(len(self), size=batch_size) else: indx = self.np_random.randint(len(self), size=batch_size) @@ -101,7 +183,17 @@ def sample( return frozen_dict.freeze(batch) - def sample_jax(self, batch_size: int, keys: Optional[Iterable[str]] = None): + def sample_jax(self, + batch_size: int, + keys: Optional[Iterable[str]] = None): + """ + Sample a batch of data from the dataset using JAX. This method is optimized for performance and can be used in JAX-based training loops. + Args: + batch_size (int): The number of samples to return. + keys (Optional[Iterable[str]]): The keys to sample from the dataset. If None, all keys will be sampled. + Returns: + Tuple[int, frozen_dict.FrozenDict]: A tuple containing the maximum index sampled and a frozen dictionary with the sampled data. + """ if not hasattr(self, "rng"): self.rng = jax.random.PRNGKey(self._seed or 42) @@ -129,12 +221,25 @@ def _sample_jax(rng, src, max_indx: int): return indx_max, sample def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: + """ + Split the dataset into two parts based on the given ratio. The first part will contain a fraction of the dataset specified by the ratio, and the second part will contain the rest. + This method is useful for creating training and testing datasets, where you want to split the data into two parts for model evaluation. + Args: + ratio (float): The fraction of the dataset to include in the first part. Must be between 0 and 1. + Returns: + Tuple[Dataset, Dataset]: A tuple containing two Dataset objects, the first part and the second part of the split dataset. + Raises: + AssertionError: If the ratio is not between 0 and 1. + """ + assert 0 < ratio and ratio < 1 - train_index = np.index_exp[: int(self.dataset_len * ratio)] - test_index = np.index_exp[int(self.dataset_len * ratio) :] + train_index = np.index_exp[: int(self.dataset_len * ratio)] # First part of the dataset. + test_index = np.index_exp[int(self.dataset_len * ratio) :] # Second part of the dataset. + # Shuffle the indices to ensure random sampling. index = np.arange(len(self), dtype=np.int32) self.np_random.shuffle(index) + train_index = index[: int(self.dataset_len * ratio)] test_index = index[int(self.dataset_len * ratio) :] @@ -143,53 +248,102 @@ def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: return Dataset(train_dataset_dict), Dataset(test_dataset_dict) def _trajectory_boundaries_and_returns(self) -> Tuple[list, list, list]: + """ + This method computes the boundaries of episodes in the dataset and calculates the returns for each episode. + It identifies the start and end indices of each episode based on the 'dones' array in the dataset. + The returns for each episode are calculated by summing the rewards within each episode. + This is useful for reinforcement learning tasks where episodes are defined by sequences of states, actions, and rewards. + Returns: + Tuple[list, list, list]: A tuple containing three lists: + - episode_starts: The starting indices of each episode. + - episode_ends: The ending indices of each episode. + - episode_returns: The total returns for each episode. + """ + + # Initialize lists (note plural) to store episode boundaries and returns. episode_starts = [0] episode_ends = [] + # Initialize variables to track the current episode return and a list to store returns. episode_return = 0 episode_returns = [] + # Iterate through the dataset to find episode boundaries and calculate returns. + # The dataset_dict is expected to have 'rewards' and 'dones' keys. for i in range(len(self)): episode_return += self.dataset_dict["rewards"][i] + # If the current index indicates the end of an episode, store the return and update boundaries. if self.dataset_dict["dones"][i]: episode_returns.append(episode_return) - episode_ends.append(i + 1) + episode_ends.append(i + 1) # Store the end index of the episode including the current index. + + # If this is not the last episode, set the start of the next episode. if i + 1 < len(self): episode_starts.append(i + 1) episode_return = 0.0 return episode_starts, episode_ends, episode_returns - def filter( - self, take_top: Optional[float] = None, threshold: Optional[float] = None + def filter(self, + take_top: Optional[float] = None, + threshold: Optional[float] = None ): + """ + Filter the dataset based on episode returns. This method allows you to keep only the episodes that meet a certain return threshold or are among the top returns. + This is useful for focusing on high-performing episodes in reinforcement learning tasks. + Args: + take_top (Optional[float]): If specified, keep only the top N percent of episodes based on their returns. + threshold (Optional[float]): If specified, keep only the episodes with returns greater than or equal to this value. + Raises: + AssertionError: If both take_top and threshold are specified, or if neither is specified. + """ assert (take_top is None and threshold is not None) or ( take_top is not None and threshold is None ) + # Create a tupe of lists of episode boundaries and returns. ( episode_starts, episode_ends, episode_returns, ) = self._trajectory_boundaries_and_returns() + # If no threshold is specified, calculate it based on the top N percent of returns. if take_top is not None: + # np.percentile gives the value below which XX% of the data lies threshold = np.percentile(episode_returns, 100 - take_top) + # create a boolean index array to filter episodes based on the threshold. bool_indx = np.full((len(self),), False, dtype=bool) for i in range(len(episode_returns)): if episode_returns[i] >= threshold: bool_indx[episode_starts[i] : episode_ends[i]] = True + # Return a new dataset dictionary containing only the episodes that meet the threshold. self.dataset_dict = _subselect(self.dataset_dict, bool_indx) + # Update the dataset length after filtering. self.dataset_len = _check_lengths(self.dataset_dict) def normalize_returns(self, scaling: float = 1000): + """ + Normalize the returns in the dataset to a specified scaling factor. This is useful for stabilizing training in reinforcement learning tasks. + Normally done per batch of episodes before training a model to update the policy. + + Args: + scaling (float): The scaling factor to normalize the returns. Default is 1000. + Raises: + AssertionError: If the dataset does not contain 'rewards' or 'dones' keys. + """ + + # Extract episode returns (_, _, episode_returns) = self._trajectory_boundaries_and_returns() - self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min( - episode_returns - ) + + # Normalize rewards in the dataset from 0-1 by dividing by the max range of returns. + self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min(episode_returns) + + # Scale rewards. Note that large scaling factors can lead to numerical instability, small scaleing factors can lead to poor learning. + # Scaling allows you to control the learning dynamics. self.dataset_dict["rewards"] *= scaling From 99e84de64e358c1cb3bf3cc2f8ac7737d3a3e932 Mon Sep 17 00:00:00 2001 From: bison Date: Thu, 12 Jun 2025 16:02:17 -0500 Subject: [PATCH 007/141] ignore vs workspaces --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 9d6232dd..bf70074e 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,7 @@ MUJOCO_LOG.TXT _METADATA checkpoint wandb/ + +# VS Code settings +.vscode/ +*.code-workspace From 5bced50f7558b949064c97c8237986f1be3d0dd6 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 16 Jun 2025 15:14:55 -0500 Subject: [PATCH 008/141] learner param settings changes: xla_python_client_mem_fraction and batch size --- examples/async_drq_sim/run_learner.sh | 2 +- examples/async_drq_sim/tmux_rlpd_launch.sh | 0 examples/async_peg_insert_drq/record_demo.py | 1 - examples/async_peg_insert_drq/run_actor.sh | 5 +++-- examples/async_peg_insert_drq/run_learner.sh | 7 ++++--- 5 files changed, 8 insertions(+), 7 deletions(-) mode change 100644 => 100755 examples/async_drq_sim/tmux_rlpd_launch.sh diff --git a/examples/async_drq_sim/run_learner.sh b/examples/async_drq_sim/run_learner.sh index 39445448..d36331c5 100644 --- a/examples/async_drq_sim/run_learner.sh +++ b/examples/async_drq_sim/run_learner.sh @@ -7,5 +7,5 @@ python async_drq_sim.py "$@" \ --training_starts 1000 \ --critic_actor_ratio 4 \ --encoder_type resnet-pretrained \ - # --demo_path franka_lift_cube_image_20_trajs.pkl \ + --demo_path franka_lift_cube_image_20_trajs.pkl \ --debug # wandb is disabled when debug diff --git a/examples/async_drq_sim/tmux_rlpd_launch.sh b/examples/async_drq_sim/tmux_rlpd_launch.sh old mode 100644 new mode 100755 diff --git a/examples/async_peg_insert_drq/record_demo.py b/examples/async_peg_insert_drq/record_demo.py index 7faef6e3..5a0fd02b 100644 --- a/examples/async_peg_insert_drq/record_demo.py +++ b/examples/async_peg_insert_drq/record_demo.py @@ -29,7 +29,6 @@ env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) obs, _ = env.reset() - #obs = env.reset() #Changed to this to for potential older Gym version. transitions = [] success_count = 0 diff --git a/examples/async_peg_insert_drq/run_actor.sh b/examples/async_peg_insert_drq/run_actor.sh index b21f954c..a21431bf 100644 --- a/examples/async_peg_insert_drq/run_actor.sh +++ b/examples/async_peg_insert_drq/run_actor.sh @@ -1,12 +1,13 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ python async_drq_randomized.py "$@" \ --actor \ --render \ --env FrankaPegInsert-Vision-v0 \ --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet \ + --max_steps 25000 \ --seed 0 \ --random_steps 0 \ --training_starts 200 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2025-04-10_15-36-26.pkl \ + --demo_path peg_insert_20_demos_2025-04-14_15-36-51.pkl \ diff --git a/examples/async_peg_insert_drq/run_learner.sh b/examples/async_peg_insert_drq/run_learner.sh index 2a6db997..0495db45 100644 --- a/examples/async_peg_insert_drq/run_learner.sh +++ b/examples/async_peg_insert_drq/run_learner.sh @@ -1,16 +1,17 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.6 && \ python async_drq_randomized.py "$@" \ --learner \ --env FrankaPegInsert-Vision-v0 \ --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet_097 \ --seed 0 \ + --max_steps 25000 \ --random_steps 1000 \ --training_starts 200 \ --critic_actor_ratio 4 \ - --batch_size 256 \ + --batch_size 128 \ --eval_period 2000 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2025-04-10_15-36-26.pkl \ + --demo_path peg_insert_20_demos_2025-04-14_15-36-51.pkl\ --checkpoint_period 1000 \ --checkpoint_path /home/student/robot/robot_ws/src/serl/examples/async_peg_insert_drq/checkpoints \ No newline at end of file From 8465a9754d1911dfbaf01d747959e37150450752 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 16 Jun 2025 15:22:58 -0500 Subject: [PATCH 009/141] small adjustment to target pose in PegEnvConfig --- franka_sim/franka_sim/envs/panda_pick_gym_env.py | 5 ++++- serl_robot_infra/franka_env/envs/peg_env/config.py | 9 ++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index f0a8d25b..b76e0ea7 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,6 +144,9 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, + width=960, + height=960, + camera_id=0 ) self._viewer.render(self.render_mode) @@ -226,7 +229,7 @@ def render(self): rendered_frames = [] for cam_id in self.camera_id: rendered_frames.append( - self._viewer.render(render_mode="rgb_array", camera_id=cam_id) + self._viewer.render(render_mode="rgb_array") ) return rendered_frames diff --git a/serl_robot_infra/franka_env/envs/peg_env/config.py b/serl_robot_infra/franka_env/envs/peg_env/config.py index f59c133e..97f263c5 100644 --- a/serl_robot_infra/franka_env/envs/peg_env/config.py +++ b/serl_robot_infra/franka_env/envs/peg_env/config.py @@ -11,7 +11,14 @@ class PegEnvConfig(DefaultEnvConfig): "wrist_2": "218622271526", } TARGET_POSE = np.array( - [0.45744842827647436,0.0761972695608443,0.07560009685145638,3.136626282737411,-0.02836925376945154,1.596492563779626] + [0.5130075190601081, + -0.23745299669341197, + 0.07808493744892302, + 3.104175639241274, + -0.024134743891730315, + -3.1304483269500967, + ] + ) RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0]) REWARD_THRESHOLD: np.ndarray = np.array([0.01, 0.01, 0.01, 0.2, 0.2, 0.2]) From eccbe901eb8a536fe6f2346d1a1d4ecabf195e2f Mon Sep 17 00:00:00 2001 From: Cleiver Ruiz-Martinez Date: Wed, 2 Jul 2025 16:38:23 -0500 Subject: [PATCH 010/141] reminder to install custom packages from serl --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index f6c9bccd..7f8db41f 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,21 @@ We fixed a major issue in the intervention action frame. See release [v0.1.1](ht pip install -r requirements.txt ``` + + +4. **Install the franka_sim** + ```bash + cd franka_sim + pip install -e . + pip install -r requirements.txt + ``` + +5. **Install the serl_robot_infra** + ```bash + cd serl_robot_infra + pip install -e . + ``` + ## Overview and Code Structure SERL provides a set of common libraries for users to train RL policies for robotic manipulation tasks. The main structure of running the RL experiments involves having an actor node and a learner node, both of which interact with the robot gym environment. Both nodes run asynchronously, with data being sent from the actor to the learner node via the network using [agentlace](https://github.com/youliangtan/agentlace). The learner will periodically synchronize the policy with the actor. This design provides flexibility for parallel training and inference. From bd049a0dd9937edb0e3e26f625c5daeed82950b5 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 3 Jul 2025 18:02:37 -0500 Subject: [PATCH 011/141] created file for fractal symmetry replay buffer --- .../data/fractal_symmetry_replay_buffer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py new file mode 100644 index 00000000..f0be46df --- /dev/null +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -0,0 +1,13 @@ +import copy +from typing import Iterable, Optional, Tuple + +import gym +import numpy as np +from serl_launcher.data.dataset import DatasetDict, _sample +from serl_launcher.data.replay_buffer import ReplayBuffer +from flax.core import frozen_dict +from gym.spaces import Box + +class FractalSymmetryReplayBuffer(ReplayBuffer): + def __init__(): + return From b6877196c36243f4be54cc766fb894c95aba6932 Mon Sep 17 00:00:00 2001 From: Cleiver Ivan Ruiz-Martinez Date: Fri, 4 Jul 2025 17:26:28 -0500 Subject: [PATCH 012/141] phase 1 of reach env, no V&V test yet --- franka_sim/franka_sim/__init__.py | 12 + franka_sim/franka_sim/envs/__init__.py | 2 + .../franka_sim/envs/panda_pick_gym_env.py | 2 +- .../franka_sim/envs/panda_reach_gym_env.py | 281 ++++++++++++++++++ 4 files changed, 296 insertions(+), 1 deletion(-) create mode 100644 franka_sim/franka_sim/envs/panda_reach_gym_env.py diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 967e9e5d..dab351ab 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -18,3 +18,15 @@ max_episode_steps=100, kwargs={"image_obs": True}, ) +register( + id="PandaReachCube-v0", + entry_point="franka_sim.envs:PandaReachCubeGymEnv", + max_episode_steps=100, +) +# register( +# id="PandaReachCubeVision-v0", +# entry_point="franka_sim.envs:PandaReachCubeGymEnv", +# max_episode_steps=100, +# kwargs={"image_obs": True}, +# ) + diff --git a/franka_sim/franka_sim/envs/__init__.py b/franka_sim/franka_sim/envs/__init__.py index 50c68828..e3c90e39 100644 --- a/franka_sim/franka_sim/envs/__init__.py +++ b/franka_sim/franka_sim/envs/__init__.py @@ -1,5 +1,7 @@ from franka_sim.envs.panda_pick_gym_env import PandaPickCubeGymEnv +from franka_sim.envs.panda_reach_gym_env import PandaReachCubeGymEnv __all__ = [ "PandaPickCubeGymEnv", + "PandaReachCubeGymEnv", ] diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index b76e0ea7..c3d3324e 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -294,7 +294,7 @@ def _compute_reward(self) -> float: if __name__ == "__main__": env = PandaPickCubeGymEnv(render_mode="human") env.reset() - for i in range(100): + for i in range(1000): env.step(np.random.uniform(-1, 1, 4)) env.render() env.close() diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py new file mode 100644 index 00000000..8d1ec251 --- /dev/null +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -0,0 +1,281 @@ +from pathlib import Path +from typing import Any, Literal, Tuple, Dict + +import gym +import mujoco +import numpy as np +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None + +from franka_sim.controllers import opspace +from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv +from gym.envs.registration import register + +_HERE = Path(__file__).parent +_XML_PATH = _HERE / "xmls" / "arena.xml" +_PANDA_HOME = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)) +_CARTESIAN_BOUNDS = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]) +_SAMPLING_BOUNDS = np.asarray([[0.25, -0.25], [0.55, 0.25]]) + + +class PandaReachCubeGymEnv(MujocoGymEnv): + metadata = {"render_modes": ["rgb_array", "human"]} + + def __init__( + self, + action_scale: np.ndarray = np.asarray([0.1, 1]), + seed: int = 0, + control_dt: float = 0.02, + physics_dt: float = 0.002, + time_limit: float = 10.0, + render_spec: GymRenderingSpec = GymRenderingSpec(), + render_mode: Literal["rgb_array", "human"] = "rgb_array", + image_obs: bool = False, + ): + self._action_scale = action_scale + + super().__init__( + xml_path=_XML_PATH, + seed=seed, + control_dt=control_dt, + physics_dt=physics_dt, + time_limit=time_limit, + render_spec=render_spec, + ) + + self.metadata = { + "render_modes": [ + "human", + "rgb_array", + ], + "render_fps": int(np.round(1.0 / self.control_dt)), + } + + self.render_mode = render_mode + self.camera_id = (0, 1) + self.image_obs = image_obs + + # Caching. + self._panda_dof_ids = np.asarray( + [self._model.joint(f"joint{i}").id for i in range(1, 8)] + ) + self._panda_ctrl_ids = np.asarray( + [self._model.actuator(f"actuator{i}").id for i in range(1, 8)] + ) + self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id + self._pinch_site_id = self._model.site("pinch").id + self._block_z = self._model.geom("block").size[2] + + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + "block_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + } + ), + } + ) + + if self.image_obs: + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + } + ), + "images": gym.spaces.Dict( + { + "front": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + "wrist": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + } + ), + } + ) + + self.action_space = gym.spaces.Box( + low=np.asarray([-1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0]), + dtype=np.float32, + ) + + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO + from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + + self._viewer = MujocoRenderer( + self.model, + self.data, + width=960, + height=960, + camera_id=0 + ) + self._viewer.render(self.render_mode) + + def reset( + self, seed=None, **kwargs + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" + mujoco.mj_resetData(self._model, self._data) + + # Reset arm to home position. + self._data.qpos[self._panda_dof_ids] = _PANDA_HOME + mujoco.mj_forward(self._model, self._data) + + # Reset mocap body to home position. + tcp_pos = self._data.sensor("2f85/pinch_pos").data + self._data.mocap_pos[0] = tcp_pos + + # Sample a new block position. + block_xy = np.random.uniform(*_SAMPLING_BOUNDS) + self._data.jnt("block").qpos[:3] = (*block_xy, self._block_z) + mujoco.mj_forward(self._model, self._data) + + # Cache the initial block height. + # self._z_init = self._data.sensor("block_pos").data[2] + # self._z_success = self._z_init + 0.2 + + obs = self._compute_observation() + return obs, {} + + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ + x, y, z = action + + # Set the mocap position. + pos = self._data.mocap_pos[0].copy() + dpos = np.asarray([x, y, z]) * self._action_scale[0] + npos = np.clip(pos + dpos, *_CARTESIAN_BOUNDS) + self._data.mocap_pos[0] = npos + + # Set gripper grasp. + self._data.ctrl[self._gripper_ctrl_id] = 0 # Fully open position + + for _ in range(self._n_substeps): + tau = opspace( + model=self._model, + data=self._data, + site_id=self._pinch_site_id, + dof_ids=self._panda_dof_ids, + pos=self._data.mocap_pos[0], + ori=self._data.mocap_quat[0], + joint=_PANDA_HOME, + gravity_comp=True, + ) + self._data.ctrl[self._panda_ctrl_ids] = tau + mujoco.mj_step(self._model, self._data) + + obs = self._compute_observation() + rew = self._compute_reward() + terminated = self.time_limit_exceeded() + + return obs, rew, terminated, False, {} + + def render(self): + rendered_frames = [] + for cam_id in self.camera_id: + rendered_frames.append( + self._viewer.render(render_mode="rgb_array") + ) + return rendered_frames + + # Helper methods. + + def _compute_observation(self) -> dict: + obs = {} + obs["state"] = {} + + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + + if self.image_obs: + obs["images"] = {} + obs["images"]["front"], obs["images"]["wrist"] = self.render() + else: + block_pos = self._data.sensor("block_pos").data.astype(np.float32) + obs["state"]["block_pos"] = block_pos + + if self.render_mode == "human": + self._viewer.render(self.render_mode) + + return obs + + def _compute_reward(self) -> float: + # Get positions + block_pos = self._data.sensor("block_pos").data + tcp_pos = self._data.sensor("2f85/pinch_pos").data + + # Calculate distance + dist = np.linalg.norm(block_pos - tcp_pos) + + # Distance-based reward + r_close = np.exp(-20 * dist) + print(r_close) + + return r_close + + +if __name__ == "__main__": + # Create wrapped environment + env = PandaReachCubeGymEnv(render_mode="human") + env.reset() + for i in range(5000): + env.step(np.random.uniform(-1, 1, 3)) + env.render() + env.close() From bdac8733e14c5cadb4991e17d8cceaadcfeda999 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Mon, 7 Jul 2025 17:26:28 -0500 Subject: [PATCH 013/141] Organization underway, not functional or tested --- .../data/fractal_symmetry_replay_buffer.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index f0be46df..da88bcce 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -9,5 +9,50 @@ from gym.spaces import Box class FractalSymmetryReplayBuffer(ReplayBuffer): - def __init__(): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + split_method: str, + split_freq: int, # For time-based only + split_qs: int, + branch_method: str, + ): + self.split_method=split_method + self.split_qs=split_qs + self.split_freq=split_freq + self.branch_method=branch_method + # TODO + # Initialize passed-in variables as needed + # Add flags for variables passed to + # Function definition dependent on flags? + + super().__init__( + self, + observation_space=observation_space, + action_space=action_space, + capacity=capacity + ) + + # NOT TESTED + def transform(self, data_dict: DatasetDict, translation: np.array): + if not isinstance(data_dict, dict): + return + for k in data_dict.keys(): + if (self.placeholder): # if key is included in list provided + data_dict[k] += translation + else: + self.transform(data_dict[k], translation) + + def fractal_branch(self, data_dict: DatasetDict): + return + + def constant_branch(self, data_dict: DatasetDict): + return + + def linear_branch(self, data_dict: DatasetDict): return + + def insert(self, data_dict: DatasetDict): + return \ No newline at end of file From 8e6eed47819416c1dc15b12821127977f4522ff8 Mon Sep 17 00:00:00 2001 From: Cleiver Ivan Ruiz-Martinez Date: Mon, 7 Jul 2025 18:33:18 -0500 Subject: [PATCH 014/141] additional bash script for REACH --- examples/async_sac_state_sim/run_actor.sh | 8 +++++--- examples/async_sac_state_sim/run_learner.sh | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index b5e73f5b..b8ba3c8d 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,10 +1,12 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="PandaReachCube-v0" && \ python async_sac_state_sim.py "$@" \ --actor \ --render \ - --env PandaPickCube-v0 \ - --exp_name=serl_dev_sim_test \ + --env $ENV_NAME \ + --exp_name=serl-reach \ --seed 0 \ --random_steps 1000 \ --max_steps 1000000 \ @@ -13,6 +15,6 @@ python async_sac_state_sim.py "$@" \ --batch_size 256 \ --save_model True \ --checkpoint_period 10000 \ - --checkpoint_path "/home/student/robot/robot_ws/src/serl/examples/async_sac_state_sim/checkpoints" \ + --checkpoint_path "$SCRIPT_DIR/$ENV_NAME/checkpoints" \ #--debug # wandb is disabled when debug diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 11ea66ef..9aead4a2 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,9 +1,11 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="PandaReachCube-v0" && \ python async_sac_state_sim.py "$@" \ --learner \ - --env PandaPickCube-v0 \ - --exp_name=serl_dev_sim_test \ + --env $ENV_NAME \ + --exp_name=serl-reach \ --seed 0 \ --max_steps 1000000 \ --training_starts 1000 \ @@ -11,5 +13,5 @@ python async_sac_state_sim.py "$@" \ --batch_size 256 \ --save_model True \ --checkpoint_period 10000 \ - --checkpoint_path "/home/student/robot/robot_ws/src/serl/examples/async_sac_state_sim/checkpoints" \ + --checkpoint_path "$SCRIPT_DIR/$ENV_NAME/checkpoints" \ #--debug # wandb is disabled when debug From 358f26e6c9bd2bbaf913854b853621b942d3f7bb Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 8 Jul 2025 14:50:34 -0500 Subject: [PATCH 015/141] Fleshed out several functions, created TODO list for development --- .../data/fractal_symmetry_replay_buffer.py | 92 +++++++++++++++---- 1 file changed, 76 insertions(+), 16 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index da88bcce..92da8ccd 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -14,19 +14,38 @@ def __init__( observation_space: gym.Space, action_space: gym.Space, capacity: int, + split_method: str, - split_freq: int, # For time-based only - split_qs: int, branch_method: str, + workspace_width: int, + depth: int = 1, + dendrites: int = None, + timesplit_freq: int = None, ): - self.split_method=split_method - self.split_qs=split_qs - self.split_freq=split_freq - self.branch_method=branch_method + + self.split_method=split_method # Determines method for when to change transformation number + self.branch_method=branch_method # Determines method for how many transforms to make for a given transition + self.workspace_width = workspace_width # Total 1D distance for transform calculations + self.depth=depth # Total number of layers + self.dendrites=dendrites # For fractal only + self.timesplit_freq=timesplit_freq # For time-based only + + # TODO # Initialize passed-in variables as needed # Add flags for variables passed to - # Function definition dependent on flags? + # Add datastore class in datastore.py + # Create transform function + # Create at least one functional branch_method + # Create at least one functional split_method + # Create automated test.py with tests for current methods + # Create more methods with tests + # + # Future considerations: + # get_pos_dict_paths to return a list of paths to alter for more complex environments + # when dealing with x and y for transforms, consider two versions: + # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms + # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 dendrites (original plus 8) in order to conform super().__init__( self, @@ -35,24 +54,65 @@ def __init__( capacity=capacity ) - # NOT TESTED def transform(self, data_dict: DatasetDict, translation: np.array): - if not isinstance(data_dict, dict): - return - for k in data_dict.keys(): - if (self.placeholder): # if key is included in list provided - data_dict[k] += translation - else: - self.transform(data_dict[k], translation) + # TBD: + # return data_dict with positional arguments += translation + # return data_dict with positional arguments = translation + return def fractal_branch(self, data_dict: DatasetDict): + # return a new number of branches = dendrites ^ depth return def constant_branch(self, data_dict: DatasetDict): + # return current number of branches return def linear_branch(self, data_dict: DatasetDict): + # return a new number of branches = branches_count + n return + def time_split(self, data_dict: DatasetDict): + # return True when a set time has passed + return False + + def rel_pos_split(self, data_dict: DatasetDict): + # return True if there is a change of depth determined by relative position + return False + + def height_split(self, data_dict: DatasetDict): + # return True if there is a change of depth determined by height + return False + + def velocity_split(self, data_dict: DatasetDict): + # return True if there is a change of depth determined by velocity + return False + + def constant_split(self, data_dict: DatasetDict): + # return True every transition + return True + def insert(self, data_dict: DatasetDict): - return \ No newline at end of file + match self.split_method: + case "time": + return + case "rel_pos": + return + case "height": + return + case "vel": + print("whoop-dee-doo") + case _: + return + + match self.branch_method: + case "fractal": + self. + case "linear": + return + case "constant": + return + case _: + return + + From 7b3e345829fe835c653224de10cb58754eeca6d7 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 8 Jul 2025 18:13:32 -0500 Subject: [PATCH 016/141] Insert function completed, requires testing --- .../data/fractal_symmetry_replay_buffer.py | 102 ++++++++++++++---- 1 file changed, 83 insertions(+), 19 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 92da8ccd..c49be241 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -18,21 +18,29 @@ def __init__( split_method: str, branch_method: str, workspace_width: int, + + #Potentially temporary variables depth: int = 1, dendrites: int = None, timesplit_freq: int = None, + branch_count_rate_of_change: int = None, ): self.split_method=split_method # Determines method for when to change transformation number self.branch_method=branch_method # Determines method for how many transforms to make for a given transition self.workspace_width = workspace_width # Total 1D distance for transform calculations + + # Potentially temporary variables + self.current_branch_count = 1 # Current number of branches + self.current_depth = 1 # Current depth self.depth=depth # Total number of layers self.dendrites=dendrites # For fractal only self.timesplit_freq=timesplit_freq # For time-based only - + self.branch_count_rate_of_change=branch_count_rate_of_change # For linear only # TODO - # Initialize passed-in variables as needed + # Initialize passed-in variables as needed + # Temporary variables that are optional should be part of a *args or **kwargs # Add flags for variables passed to # Add datastore class in datastore.py # Create transform function @@ -54,36 +62,44 @@ def __init__( capacity=capacity ) + # REQUIRES TESTING def transform(self, data_dict: DatasetDict, translation: np.array): - # TBD: # return data_dict with positional arguments += translation - # return data_dict with positional arguments = translation - return + data_dict["observations"]["state"] += translation + data_dict["next_observations"]["state"] += translation + # REQUIRES TESTING def fractal_branch(self, data_dict: DatasetDict): # return a new number of branches = dendrites ^ depth - return + self.depth += 1 + return self.dendrites ** (self.depth - 1) + # REQUIRES TESTING def constant_branch(self, data_dict: DatasetDict): # return current number of branches - return + return self.current_branch_count + # REQUIRES TESTING def linear_branch(self, data_dict: DatasetDict): # return a new number of branches = branches_count + n - return + return self.current_branch_count + self.branch_count_rate_of_change + # UNFINISHED def time_split(self, data_dict: DatasetDict): # return True when a set time has passed return False + # UNFINISHED def rel_pos_split(self, data_dict: DatasetDict): # return True if there is a change of depth determined by relative position return False + # UNFINISHED def height_split(self, data_dict: DatasetDict): # return True if there is a change of depth determined by height return False + # UNFINISHED def velocity_split(self, data_dict: DatasetDict): # return True if there is a change of depth determined by velocity return False @@ -92,27 +108,75 @@ def constant_split(self, data_dict: DatasetDict): # return True every transition return True + # REQUIRES TESTING def insert(self, data_dict: DatasetDict): + # Choose when to change number of branches match self.split_method: case "time": - return + split = self.time_split case "rel_pos": - return + split = self.rel_pos_split case "height": - return - case "vel": - print("whoop-dee-doo") + split = self.height_split + case "velocity": + split = self.velocity_split case _: - return + raise ValueError("incorrect value passed to split_method") + # Choose how number of branches changes match self.branch_method: case "fractal": - self. + branch = self.fractal_branch case "linear": - return + branch = self.linear_branch case "constant": - return + branch = self.constant_branch case _: - return - + raise ValueError("incorrect value passed to branch_method") + + # Update number of branches if needed + if split(data_dict): + self.current_branch_count = branch(data_dict) + + #---------------------TEMPORARY SECTION TO BE IMPROVED------------------------# + # Save positions + rx = data_dict["observations"]["state"][0] + # ry = data_dict["observations"]["state"][1] + bx = data_dict["observations"]["state"][7] + # by = data_dict["observations"]["state"][8] + + rx2 = data_dict["next_observations"]["state"][0] + # ry2 = data_dict["next_observations"]["state"][1] + bx2 = data_dict["next_observations"]["state"][7] + # by2 = data_dict["next_observations"]["state"][8] + + # OR set to extreme for iterations + x = self.workspace_width / 2 + data_dict["observations"]["state"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) + data_dict["next_observations"]["state"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) + #-----------------------------------------------------------------------------# + + from_left = [] + from_center = [] + + # Initial insert of most-extreme branch + dx = self.workspace_width / (self.current_branch_count * 2) + super().insert(self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0]))) + dx = dx * 2 + + from_left.append(data_dict["observations"]["state"][0] - self.workspace_width) + from_center.append(data_dict["observations"]["state"][0] - rx) + + # Insert of rest of branches + for t in range(0, self.current_branch_count - 1): + super().insert(self.transform( + data_dict, + np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0]) + )) + from_left.append(data_dict["observations"]["state"][0] - self.workspace_width) + from_center.append(data_dict["observations"]["state"][0] - rx) + print(self.workspace_width) + print(from_left) + print(from_center) + \ No newline at end of file From 8971a9e97ff810b76e3adb410faed8d620474a67 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 9 Jul 2025 17:23:42 -0500 Subject: [PATCH 017/141] Created test file for fsrb and tested insert() and fractal_branch() --- .../async_sac_state_sim.py | 2 +- .../data/fractal_symmetry_replay_buffer.py | 77 ++++++++++--------- serl_launcher/serl_launcher/data/fsrb_test.py | 72 +++++++++++++++++ 3 files changed, 114 insertions(+), 37 deletions(-) create mode 100644 serl_launcher/serl_launcher/data/fsrb_test.py diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 27d43c3c..7d14b070 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -49,7 +49,7 @@ flags.DEFINE_integer("eval_period", 2000, "Evaluation period.") flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.") -# flag to indicate if this is a leaner or a actor +# flag to indicate if this is a learner or a actor flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") flags.DEFINE_boolean("render", False, "Render the environment.") diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index c49be241..e6cae446 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -1,12 +1,7 @@ -import copy -from typing import Iterable, Optional, Tuple - import gym import numpy as np -from serl_launcher.data.dataset import DatasetDict, _sample +from serl_launcher.data.dataset import DatasetDict from serl_launcher.data.replay_buffer import ReplayBuffer -from flax.core import frozen_dict -from gym.spaces import Box class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( @@ -20,10 +15,10 @@ def __init__( workspace_width: int, #Potentially temporary variables - depth: int = 1, - dendrites: int = None, - timesplit_freq: int = None, - branch_count_rate_of_change: int = None, + depth: int, + dendrites: int, + timesplit_freq: int, + branch_count_rate_of_change: int, ): self.split_method=split_method # Determines method for when to change transformation number @@ -31,8 +26,9 @@ def __init__( self.workspace_width = workspace_width # Total 1D distance for transform calculations # Potentially temporary variables + self.transition_num = 0 self.current_branch_count = 1 # Current number of branches - self.current_depth = 1 # Current depth + self.current_depth = 0 # Current depth self.depth=depth # Total number of layers self.dendrites=dendrites # For fractal only self.timesplit_freq=timesplit_freq # For time-based only @@ -56,23 +52,22 @@ def __init__( # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 dendrites (original plus 8) in order to conform super().__init__( - self, observation_space=observation_space, action_space=action_space, - capacity=capacity + capacity=capacity, ) # REQUIRES TESTING def transform(self, data_dict: DatasetDict, translation: np.array): # return data_dict with positional arguments += translation - data_dict["observations"]["state"] += translation - data_dict["next_observations"]["state"] += translation + data_dict["observations"] += translation + data_dict["next_observations"] += translation # REQUIRES TESTING def fractal_branch(self, data_dict: DatasetDict): # return a new number of branches = dendrites ^ depth - self.depth += 1 - return self.dendrites ** (self.depth - 1) + self.current_depth += 1 + return self.dendrites ** (self.current_depth) # REQUIRES TESTING def constant_branch(self, data_dict: DatasetDict): @@ -84,6 +79,13 @@ def linear_branch(self, data_dict: DatasetDict): # return a new number of branches = branches_count + n return self.current_branch_count + self.branch_count_rate_of_change + # FOR TESTING ONLY + def test_split(self, data_dict: DatasetDict): + split = False + if self.transition_num % 2: + split = True + return split + # UNFINISHED def time_split(self, data_dict: DatasetDict): # return True when a set time has passed @@ -120,6 +122,8 @@ def insert(self, data_dict: DatasetDict): split = self.height_split case "velocity": split = self.velocity_split + case "test": + split = self.test_split case _: raise ValueError("incorrect value passed to split_method") @@ -134,26 +138,27 @@ def insert(self, data_dict: DatasetDict): case _: raise ValueError("incorrect value passed to branch_method") + self.transition_num += 1 # Update number of branches if needed if split(data_dict): self.current_branch_count = branch(data_dict) #---------------------TEMPORARY SECTION TO BE IMPROVED------------------------# # Save positions - rx = data_dict["observations"]["state"][0] + rx = data_dict["observations"][0] # ry = data_dict["observations"]["state"][1] - bx = data_dict["observations"]["state"][7] + bx = data_dict["observations"][7] # by = data_dict["observations"]["state"][8] - rx2 = data_dict["next_observations"]["state"][0] + rx2 = data_dict["next_observations"][0] # ry2 = data_dict["next_observations"]["state"][1] - bx2 = data_dict["next_observations"]["state"][7] + bx2 = data_dict["next_observations"][7] # by2 = data_dict["next_observations"]["state"][8] # OR set to extreme for iterations x = self.workspace_width / 2 - data_dict["observations"]["state"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) - data_dict["next_observations"]["state"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) + data_dict["observations"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) + data_dict["next_observations"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) #-----------------------------------------------------------------------------# from_left = [] @@ -161,22 +166,22 @@ def insert(self, data_dict: DatasetDict): # Initial insert of most-extreme branch dx = self.workspace_width / (self.current_branch_count * 2) - super().insert(self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0]))) + self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) + super().insert(data_dict) dx = dx * 2 - from_left.append(data_dict["observations"]["state"][0] - self.workspace_width) - from_center.append(data_dict["observations"]["state"][0] - rx) + from_left.append(data_dict["observations"][0] - self.workspace_width / 2) + from_center.append(data_dict["observations"][0] - rx) # Insert of rest of branches for t in range(0, self.current_branch_count - 1): - super().insert(self.transform( - data_dict, - np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0]) - )) - from_left.append(data_dict["observations"]["state"][0] - self.workspace_width) - from_center.append(data_dict["observations"]["state"][0] - rx) + self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) + super().insert(data_dict) + from_left.append(data_dict["observations"][0] - self.workspace_width / 2) + from_center.append(data_dict["observations"][0] - rx) - print(self.workspace_width) - print(from_left) - print(from_center) - \ No newline at end of file + print("\ntransition:\t", self.transition_num,"\ndepth:\t\t", self.current_depth, sep="\t") + print("workspace width:", self.workspace_width, sep="\t") + print("transforms from '0':", np.round(from_left, 2), sep="\t") + print("transforms from center:", np.round(from_center, 2), "\n", sep="\t") + return diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py new file mode 100644 index 00000000..53820dcd --- /dev/null +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -0,0 +1,72 @@ +import gym.wrappers +import numpy as np +import gym +from serl_launcher.data.dataset import DatasetDict +from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer +from absl import app, flags +import franka_sim + +FLAGS = flags.FLAGS + +flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") +flags.DEFINE_string("branch_method", "fractal", "placeholder") +flags.DEFINE_string("split_method", "test", "placeholder") +flags.DEFINE_integer("workspace_width", 50, "workspace width in centimeters") +flags.DEFINE_integer("depth", 4, "Total layers of depth") +flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None +flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") +flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") + +def main(_): + + env = gym.make("PandaPickCube-v0") + env = gym.wrappers.FlattenObservation(env) + observation_space=env.observation_space + action_space=env.action_space + + buffer = FractalSymmetryReplayBuffer( + observation_space=observation_space, + action_space=action_space, + capacity=FLAGS.replay_buffer_capacity, + split_method=FLAGS.split_method, + branch_method=FLAGS.branch_method, + workspace_width=FLAGS.workspace_width, + depth=FLAGS.depth, + dendrites=FLAGS.dendrites, + timesplit_freq=FLAGS.timesplit_freq, + branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, + ) + + obs = env.observation_space.sample() + actions = env.action_space.sample() + + actions = np.array([1, 2, 3, 4], dtype=np.float32) + + for idx in range(8): + obs = np.array([50, 0, 8-idx, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) + next_obs = obs - np.array([50, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) + rewards = np.float32(0) + # truncateds = False + masks=True + dones = False + if idx == 7: + rewards += 1 + masks = False + dones = True + transition = dict( + observations=obs, + next_observations=next_obs, + actions = actions, + rewards=rewards, + # truncateds=truncateds, + masks=masks, + dones=dones, + ) + buffer.insert(transition) + + print("finished\n") + + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file From 7338bcb159203cd952b828895784045e0694e113 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 10 Jul 2025 15:55:30 -0500 Subject: [PATCH 018/141] Updated test file and removed various bugs --- .../data/fractal_symmetry_replay_buffer.py | 54 +++----- serl_launcher/serl_launcher/data/fsrb_test.py | 127 ++++++++++++++---- 2 files changed, 120 insertions(+), 61 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index e6cae446..5b2834bb 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -26,7 +26,6 @@ def __init__( self.workspace_width = workspace_width # Total 1D distance for transform calculations # Potentially temporary variables - self.transition_num = 0 self.current_branch_count = 1 # Current number of branches self.current_depth = 0 # Current depth self.depth=depth # Total number of layers @@ -57,34 +56,36 @@ def __init__( capacity=capacity, ) - # REQUIRES TESTING def transform(self, data_dict: DatasetDict, translation: np.array): # return data_dict with positional arguments += translation data_dict["observations"] += translation data_dict["next_observations"] += translation + + # FOR TESTING ONLY + def test_branch(self): + self.current_branch_count = np.random.randint(2500) + while(not self.current_branch_count % 2): + self.current_branch_count = np.random.randint(2500) + return self.current_branch_count - # REQUIRES TESTING - def fractal_branch(self, data_dict: DatasetDict): + def fractal_branch(self): # return a new number of branches = dendrites ^ depth self.current_depth += 1 return self.dendrites ** (self.current_depth) # REQUIRES TESTING - def constant_branch(self, data_dict: DatasetDict): + def constant_branch(self): # return current number of branches return self.current_branch_count # REQUIRES TESTING - def linear_branch(self, data_dict: DatasetDict): + def linear_branch(self): # return a new number of branches = branches_count + n return self.current_branch_count + self.branch_count_rate_of_change # FOR TESTING ONLY def test_split(self, data_dict: DatasetDict): - split = False - if self.transition_num % 2: - split = True - return split + return True # UNFINISHED def time_split(self, data_dict: DatasetDict): @@ -135,53 +136,40 @@ def insert(self, data_dict: DatasetDict): branch = self.linear_branch case "constant": branch = self.constant_branch + case "test": + branch = self.test_branch case _: raise ValueError("incorrect value passed to branch_method") - self.transition_num += 1 # Update number of branches if needed if split(data_dict): - self.current_branch_count = branch(data_dict) + self.current_branch_count = branch() #---------------------TEMPORARY SECTION TO BE IMPROVED------------------------# # Save positions - rx = data_dict["observations"][0] + # rx = data_dict["observations"][0] # ry = data_dict["observations"]["state"][1] - bx = data_dict["observations"][7] + # bx = data_dict["observations"][7] # by = data_dict["observations"]["state"][8] - rx2 = data_dict["next_observations"][0] + # rx2 = data_dict["next_observations"][0] # ry2 = data_dict["next_observations"]["state"][1] - bx2 = data_dict["next_observations"][7] + # bx2 = data_dict["next_observations"][7] # by2 = data_dict["next_observations"]["state"][8] # OR set to extreme for iterations - x = self.workspace_width / 2 - data_dict["observations"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) - data_dict["next_observations"] -= np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0]) + x = -self.workspace_width / 2 + self.transform(data_dict, np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0])) #-----------------------------------------------------------------------------# - from_left = [] - from_center = [] - # Initial insert of most-extreme branch dx = self.workspace_width / (self.current_branch_count * 2) self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) super().insert(data_dict) dx = dx * 2 - from_left.append(data_dict["observations"][0] - self.workspace_width / 2) - from_center.append(data_dict["observations"][0] - rx) - # Insert of rest of branches for t in range(0, self.current_branch_count - 1): self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) super().insert(data_dict) - from_left.append(data_dict["observations"][0] - self.workspace_width / 2) - from_center.append(data_dict["observations"][0] - rx) - - print("\ntransition:\t", self.transition_num,"\ndepth:\t\t", self.current_depth, sep="\t") - print("workspace width:", self.workspace_width, sep="\t") - print("transforms from '0':", np.round(from_left, 2), sep="\t") - print("transforms from center:", np.round(from_center, 2), "\n", sep="\t") - return + diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 53820dcd..6940e5ce 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -9,7 +9,7 @@ FLAGS = flags.FLAGS flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "fractal", "placeholder") +flags.DEFINE_string("branch_method", "test", "placeholder") flags.DEFINE_string("split_method", "test", "placeholder") flags.DEFINE_integer("workspace_width", 50, "workspace width in centimeters") flags.DEFINE_integer("depth", 4, "Total layers of depth") @@ -19,6 +19,7 @@ def main(_): + # Initialize replay buffer env = gym.make("PandaPickCube-v0") env = gym.wrappers.FlattenObservation(env) observation_space=env.observation_space @@ -37,34 +38,104 @@ def main(_): branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, ) - obs = env.observation_space.sample() - actions = env.action_space.sample() - - actions = np.array([1, 2, 3, 4], dtype=np.float32) - - for idx in range(8): - obs = np.array([50, 0, 8-idx, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) - next_obs = obs - np.array([50, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) - rewards = np.float32(0) - # truncateds = False - masks=True - dones = False - if idx == 7: - rewards += 1 - masks = False - dones = True - transition = dict( - observations=obs, - next_observations=next_obs, - actions = actions, - rewards=rewards, - # truncateds=truncateds, - masks=masks, - dones=dones, - ) - buffer.insert(transition) + observation, info = env.reset() + action = env.action_space.sample() + next_observation, reward, terminated, truncated, info = env.step(action) + data_dict = dict( + observations=observation, + next_observations=next_observation, + actions=action, + rewards=reward, + masks=False, + dones=truncated or terminated, + ) + + del env, observation, next_observation, action, reward, truncated, terminated, action_space, observation_space, info, _ + + # transform() test + + expected = np.copy(data_dict["observations"]) * 2 + buffer.transform(data_dict, np.copy(data_dict["observations"])) + result = data_dict["observations"] + assert np.array_equal(result, expected), f"\033[31mTEST FAILED\033[0m transform() test failed (expected {expected} but got {result})" + + del result, expected + + print("\033[32mTEST PASSED \033[0m transform() tests passed") + # branch() tests + + ## fractal + + for dendrites in range(1, 10, 2): + buffer.dendrites=dendrites + for depth in range(1, 6): + buffer.current_depth=depth - 1 + result = buffer.fractal_branch() + expected = dendrites ** depth + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + del dendrites, depth, result, expected + + ## constant + + ## linear + + ## exponential + + print("\033[32mTEST PASSED \033[0m _branch() tests passed") + + # split() tests + + ## time + + ## height + + ## rel_pos + + ## velocity + + print("\033[32mTEST PASSED \033[0m _split() tests passed") + + # insert() tests + data_dict["observations"][0] = 0 + start = data_dict["observations"][0] - FLAGS.workspace_width / 2 + + for i in range(0,10): + buffer.insert(data_dict) + for idx in range(buffer.current_branch_count): + expected = FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count) + result = buffer.dataset_dict["observations"][idx][0] - start + assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" + print(f"{result} is correct!!!") + buffer._insert_index = 0 + + del i, idx, result, expected + + print("\033[32mTEST PASSED \033[0m insert() tests passed") + + # for idx in range(8): + # obs = np.array([50, 0, 8-idx, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) + # next_obs = obs - np.array([50, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) + # rewards = np.float32(0) + # # truncateds = False + # masks=True + # dones = False + # if idx == 7: + # rewards += 1 + # masks = False + # dones = True + # transition = dict( + # observations=obs, + # next_observations=next_obs, + # actions = actions, + # rewards=rewards, + # # truncateds=truncateds, + # masks=masks, + # dones=dones, + # ) + # buffer.insert(transition) - print("finished\n") + print("\nfinished!\n") From 836e40489c542cb9025700e3ff1ea9d383391dd8 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 10 Jul 2025 19:12:47 -0500 Subject: [PATCH 019/141] Tackling how bad rounding errors are tomorrow --- .../data/fractal_symmetry_replay_buffer.py | 40 ++++++------------- serl_launcher/serl_launcher/data/fsrb_test.py | 4 +- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 5b2834bb..363a8e28 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -63,10 +63,11 @@ def transform(self, data_dict: DatasetDict, translation: np.array): # FOR TESTING ONLY def test_branch(self): - self.current_branch_count = np.random.randint(2500) - while(not self.current_branch_count % 2): - self.current_branch_count = np.random.randint(2500) - return self.current_branch_count + # self.current_branch_count = np.random.randint(2500) + # while(not self.current_branch_count % 2): + # self.current_branch_count = np.random.randint(2500) + # return self.current_branch_count + return 81 def fractal_branch(self): # return a new number of branches = dendrites ^ depth @@ -145,31 +146,14 @@ def insert(self, data_dict: DatasetDict): if split(data_dict): self.current_branch_count = branch() - #---------------------TEMPORARY SECTION TO BE IMPROVED------------------------# - # Save positions - # rx = data_dict["observations"][0] - # ry = data_dict["observations"]["state"][1] - # bx = data_dict["observations"][7] - # by = data_dict["observations"]["state"][8] - - # rx2 = data_dict["next_observations"][0] - # ry2 = data_dict["next_observations"]["state"][1] - # bx2 = data_dict["next_observations"][7] - # by2 = data_dict["next_observations"]["state"][8] - - # OR set to extreme for iterations - x = -self.workspace_width / 2 + # Initial transform of first branch + dx = self.workspace_width/self.current_branch_count + x = (-self.workspace_width + dx)/2 self.transform(data_dict, np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0])) - #-----------------------------------------------------------------------------# - - # Initial insert of most-extreme branch - dx = self.workspace_width / (self.current_branch_count * 2) - self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) - super().insert(data_dict) - dx = dx * 2 - # Insert of rest of branches - for t in range(0, self.current_branch_count - 1): - self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) + # Insert of transformed branches + for t in range(0, self.current_branch_count): super().insert(data_dict) + self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) + diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 6940e5ce..0059e156 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -9,9 +9,9 @@ FLAGS = flags.FLAGS flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "test", "placeholder") +flags.DEFINE_string("branch_method", "constant", "placeholder") flags.DEFINE_string("split_method", "test", "placeholder") -flags.DEFINE_integer("workspace_width", 50, "workspace width in centimeters") +flags.DEFINE_float("workspace_width", 1, "workspace width in centimeters") flags.DEFINE_integer("depth", 4, "Total layers of depth") flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") From 457f707fb9c885e491b75ab3339a1472776586c2 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Mon, 14 Jul 2025 12:53:31 -0500 Subject: [PATCH 020/141] Finished basic test file --- serl_launcher/serl_launcher/data/fsrb_test.py | 91 +++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 0059e156..a5f05254 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -5,13 +5,14 @@ from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer from absl import app, flags import franka_sim +# import pandas as pd FLAGS = flags.FLAGS -flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "constant", "placeholder") +flags.DEFINE_integer("replay_buffer_capacity", 10000000, "Replay buffer capacity.") +flags.DEFINE_string("branch_method", "test", "placeholder") flags.DEFINE_string("split_method", "test", "placeholder") -flags.DEFINE_float("workspace_width", 1, "workspace width in centimeters") +flags.DEFINE_float("workspace_width", 0.5, "workspace width in centimeters") flags.DEFINE_integer("depth", 4, "Total layers of depth") flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") @@ -61,7 +62,7 @@ def main(_): del result, expected - print("\033[32mTEST PASSED \033[0m transform() tests passed") + print("\n\033[32mTEST PASSED \033[0m transform() tests passed") # branch() tests ## fractal @@ -97,43 +98,61 @@ def main(_): print("\033[32mTEST PASSED \033[0m _split() tests passed") # insert() tests + + # important_indices = np.array([0, 7]) + + # num_branches = 5000 + # workspace_options = 20 + + # data = dict( + # workspace_width = np.empty(shape=(num_branches//2 * workspace_options), dtype=np.float32), + # num_branches = np.empty(shape=(num_branches//2 * workspace_options), dtype=int), + # density = np.empty(shape=(num_branches//2 * workspace_options), dtype=np.float32) + # ) + + # for w in range(0, workspace_options): + # total_success = 0.0 + # for i in range(1,num_branches, 2): + # finished = False + # buffer.current_branch_count = i + # data_dict["observations"][important_indices] = 0 + # start = data_dict["observations"][important_indices[0]] - FLAGS.workspace_width / 2 + # buffer.insert(data_dict) + # for idx in range(buffer.current_branch_count): + # expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) + # result = round(buffer.dataset_dict["observations"][buffer._insert_index - buffer.current_branch_count + idx][important_indices[0]] - start, 3) + # # assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" + # if result != expected: + # # print(f"At ww = {FLAGS.workspace_width}, maximum transforms = {i - 2}") + # # print(f"{i} failed at {idx}") + # finished = True + + # break + # if not finished: + # print(f"SUCCESS at {i} branches") + # total_success += 1 + # data["workspace_width"][i//2 + (num_branches//2) * w] = 0.05 * (w + 1) + # data["num_branches"][i//2 + (num_branches//2) * w] = i + # data["density"][i//2 + (num_branches//2) * w] = total_success/(i//2 + 1) + + # FLAGS.workspace_width = FLAGS.workspace_width + 0.05 + # buffer.workspace_width = FLAGS.workspace_width + # del i, idx, result, expected + + # df = pd.DataFrame(data) + # df.to_excel('output.xlsx', index=False) + data_dict["observations"][0] = 0 start = data_dict["observations"][0] - FLAGS.workspace_width / 2 - for i in range(0,10): - buffer.insert(data_dict) - for idx in range(buffer.current_branch_count): - expected = FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count) - result = buffer.dataset_dict["observations"][idx][0] - start - assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" - print(f"{result} is correct!!!") - buffer._insert_index = 0 - - del i, idx, result, expected + + buffer.insert(data_dict) + for idx in range(buffer.current_branch_count): + expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) + result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) + assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" print("\033[32mTEST PASSED \033[0m insert() tests passed") - - # for idx in range(8): - # obs = np.array([50, 0, 8-idx, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) - # next_obs = obs - np.array([50, 0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) - # rewards = np.float32(0) - # # truncateds = False - # masks=True - # dones = False - # if idx == 7: - # rewards += 1 - # masks = False - # dones = True - # transition = dict( - # observations=obs, - # next_observations=next_obs, - # actions = actions, - # rewards=rewards, - # # truncateds=truncateds, - # masks=masks, - # dones=dones, - # ) - # buffer.insert(transition) print("\nfinished!\n") From ed940dd712df19b3443ac1db4d8184f512173a0f Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 14 Jul 2025 16:17:31 -0500 Subject: [PATCH 021/141] very basic spacing addition --- docs/real_franka.md | 0 docs/sim_quick_start.md | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 docs/real_franka.md mode change 100644 => 100755 docs/sim_quick_start.md diff --git a/docs/real_franka.md b/docs/real_franka.md old mode 100644 new mode 100755 diff --git a/docs/sim_quick_start.md b/docs/sim_quick_start.md old mode 100644 new mode 100755 From 234ec1e8f65640038cc3048bf96db0da51fc274a Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Mon, 14 Jul 2025 17:24:40 -0500 Subject: [PATCH 022/141] prevent latest version of opencv-python for compatibility reasons --- serl_launcher/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index 86f92e95..efb248aa 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -1,3 +1,4 @@ +opencv-python==4.11.0.86 gym >= 0.26 numpy>=1.24.3 flax>=0.8.0 From efece64641c13e257ae1e2ad35a094ec84fc976d Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 15 Jul 2025 12:33:23 -0500 Subject: [PATCH 023/141] hooked up fractal_symmetry_replay_buffer in data_store.py and launcher.py as well as improving how variables are passed in --- .../serl_launcher/data/data_store.py | 65 +++++++- .../data/fractal_symmetry_replay_buffer.py | 143 ++++++++++-------- serl_launcher/serl_launcher/data/fsrb_test.py | 40 +++-- serl_launcher/serl_launcher/utils/launcher.py | 17 +++ 4 files changed, 185 insertions(+), 80 deletions(-) diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index 20e65813..54430dc6 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -7,6 +7,9 @@ from serl_launcher.data.memory_efficient_replay_buffer import ( MemoryEfficientReplayBuffer, ) +from serl_launcher.data.fractal_symmetry_replay_buffer import ( + FractalSymmetryReplayBuffer +) from agentlace.data.data_store import DataStoreBase @@ -112,8 +115,6 @@ def insert(self, data): RLDSStepType.TRUNCATION, }: self.step_type = RLDSStepType.RESTART - elif self.step_type == RLDSStepType.TRUNCATION: - self.step_type = RLDSStepType.RESTART elif not data["masks"]: # 0 is done, 1 is not done self.step_type = RLDSStepType.TERMINATION elif data["dones"]: @@ -143,6 +144,66 @@ def latest_data_id(self): def get_latest_data(self, from_id: int): raise NotImplementedError # TODO +class FractalSymmetryReplayBufferDataStore(FractalSymmetryReplayBuffer, DataStoreBase): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + branch_method: str, + split_method: str, + workspace_width: int, + rlds_logger: Optional[RLDSLogger] = None, + **kwargs: dict, + ): + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width **kwargs) + DataStoreBase.__init__(self, capacity) + self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger + + # ensure thread safety + def insert(self, data): + with self._lock: + super(FractalSymmetryReplayBufferDataStore, self).insert(data) + + # TODO: Data logging currently does NOT WORK as shown if we want to log our transformed transitions + # add data to the rlds logger + if self._logger: + if self.step_type in { + RLDSStepType.TERMINATION, + RLDSStepType.TRUNCATION, + }: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["dones"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: check if this is correct + reward=data["rewards"], + step_type=self.step_type, + ) + + # ensure thread safety + def sample(self, *args, **kwargs): + with self._lock: + return super(FractalSymmetryReplayBufferDataStore, self).sample(*args, **kwargs) + + # NOTE: method for DataStoreBase + def latest_data_id(self): + return self._insert_index + + # NOTE: method for DataStoreBase + def get_latest_data(self, from_id: int): + raise NotImplementedError # TODO def populate_data_store( data_store: DataStoreBase, diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 363a8e28..1aa962b7 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -9,43 +9,83 @@ def __init__( observation_space: gym.Space, action_space: gym.Space, capacity: int, - - split_method: str, branch_method: str, + split_method: str, workspace_width: int, - - #Potentially temporary variables - depth: int, - dendrites: int, - timesplit_freq: int, - branch_count_rate_of_change: int, + **kwargs: dict ): + self.current_branch_count=1 - self.split_method=split_method # Determines method for when to change transformation number - self.branch_method=branch_method # Determines method for how many transforms to make for a given transition - self.workspace_width = workspace_width # Total 1D distance for transform calculations + match split_method: + case "time": + assert "timesplit_freq" in kwargs.keys(), "\033[31mERROR \033[0mtimesplit_freq must be defined for split_method \"time\"" + self.timesplit_freq=kwargs["timesplit_freq"] + del kwargs["timesplit_freq"] + self.split = self.time_split + case "rel_pos": + print("NOT IMPLEMENTED") + self.split = self.rel_pos_split + case "height": + print("NOT IMPLEMENTED") + self.split = self.height_split + case "velocity": + print("NOT IMPLEMENTED") + self.split = self.velocity_split + case "test": + self.split = self.test_split + case _: + raise ValueError("incorrect value passed to split_method") + + match branch_method: + case "fractal": + assert "depth" in kwargs.keys(), "\033[31mERROR \033[0mdepth must be defined for branch_method \"fractal\"" + self.depth=kwargs["depth"] + del kwargs["depth"] + + assert "dendrites" in kwargs.keys(), "\033[31mERROR \033[0mdendrites must be defined for branch_method \"fractal\"" + self.dendrites=kwargs["dendrites"] + del kwargs["dendrites"] + + self.branch = self.fractal_branch + + case "linear": + assert "branch_count_rate_of_change" in kwargs.keys(), "\033[31mERROR \033[0mbranch_count_rate_of_change must be defined for branch_method \"fractal\"" + self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] + del kwargs["branch_count_rate_of_change"] + + assert "starting_branch_count" in kwargs.keys(), "\033[31mERROR \033[0mstarting_branch_count must be defined for branch_method \"constant\"" + self.current_branch_count = kwargs["starting_branch_count"] + del kwargs["starting_branch_count"] - # Potentially temporary variables - self.current_branch_count = 1 # Current number of branches - self.current_depth = 0 # Current depth - self.depth=depth # Total number of layers - self.dendrites=dendrites # For fractal only - self.timesplit_freq=timesplit_freq # For time-based only - self.branch_count_rate_of_change=branch_count_rate_of_change # For linear only + self.branch = self.linear_branch + + case "constant": + assert "starting_branch_count" in kwargs.keys(), "\033[31mERROR \033[0mstarting_branch_count must be defined for branch_method \"constant\"" + self.current_branch_count = kwargs["starting_branch_count"] + del kwargs["starting_branch_count"] + + self.branch = self.constant_branch + + case "test": + self.branch = self.test_branch + case _: + raise ValueError("incorrect value passed to branch_method") + + for i in kwargs.keys(): + print(f"\033[33mWARNING \033[0m {i} argument not used") + + + + self.workspace_width = workspace_width + self.current_depth = 0 + # TODO - # Initialize passed-in variables as needed - # Temporary variables that are optional should be part of a *args or **kwargs # Add flags for variables passed to - # Add datastore class in datastore.py - # Create transform function - # Create at least one functional branch_method - # Create at least one functional split_method - # Create automated test.py with tests for current methods # Create more methods with tests # # Future considerations: - # get_pos_dict_paths to return a list of paths to alter for more complex environments + # consider different strategy for better accuracy # when dealing with x and y for transforms, consider two versions: # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 dendrites (original plus 8) in order to conform @@ -88,70 +128,41 @@ def linear_branch(self): def test_split(self, data_dict: DatasetDict): return True - # UNFINISHED + # NOT IMPLEMENTED def time_split(self, data_dict: DatasetDict): # return True when a set time has passed - return False + raise NotImplementedError - # UNFINISHED + # NOT IMPLEMENTED def rel_pos_split(self, data_dict: DatasetDict): # return True if there is a change of depth determined by relative position - return False + raise NotImplementedError - # UNFINISHED + # NOT IMPLEMENTED def height_split(self, data_dict: DatasetDict): # return True if there is a change of depth determined by height - return False + raise NotImplementedError - # UNFINISHED + # NOT IMPLEMENTED def velocity_split(self, data_dict: DatasetDict): # return True if there is a change of depth determined by velocity - return False + raise NotImplementedError def constant_split(self, data_dict: DatasetDict): # return True every transition return True - # REQUIRES TESTING def insert(self, data_dict: DatasetDict): - # Choose when to change number of branches - match self.split_method: - case "time": - split = self.time_split - case "rel_pos": - split = self.rel_pos_split - case "height": - split = self.height_split - case "velocity": - split = self.velocity_split - case "test": - split = self.test_split - case _: - raise ValueError("incorrect value passed to split_method") - - # Choose how number of branches changes - match self.branch_method: - case "fractal": - branch = self.fractal_branch - case "linear": - branch = self.linear_branch - case "constant": - branch = self.constant_branch - case "test": - branch = self.test_branch - case _: - raise ValueError("incorrect value passed to branch_method") # Update number of branches if needed - if split(data_dict): - self.current_branch_count = branch() + if self.split(data_dict): + self.current_branch_count = self.branch() - # Initial transform of first branch + # Transform and insert branches dx = self.workspace_width/self.current_branch_count x = (-self.workspace_width + dx)/2 self.transform(data_dict, np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0])) - # Insert of transformed branches for t in range(0, self.current_branch_count): super().insert(data_dict) self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index a5f05254..02d2a026 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -1,7 +1,7 @@ import gym.wrappers import numpy as np import gym -from serl_launcher.data.dataset import DatasetDict +from serl_launcher.utils.launcher import make_replay_buffer from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer from absl import app, flags import franka_sim @@ -9,7 +9,7 @@ FLAGS = flags.FLAGS -flags.DEFINE_integer("replay_buffer_capacity", 10000000, "Replay buffer capacity.") +flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") flags.DEFINE_string("branch_method", "test", "placeholder") flags.DEFINE_string("split_method", "test", "placeholder") flags.DEFINE_float("workspace_width", 0.5, "workspace width in centimeters") @@ -23,22 +23,38 @@ def main(_): # Initialize replay buffer env = gym.make("PandaPickCube-v0") env = gym.wrappers.FlattenObservation(env) - observation_space=env.observation_space - action_space=env.action_space buffer = FractalSymmetryReplayBuffer( - observation_space=observation_space, - action_space=action_space, - capacity=FLAGS.replay_buffer_capacity, + observation_space=env.observation_space, + action_space=env.action_space, + capacity=FLAGS.capacity, split_method=FLAGS.split_method, branch_method=FLAGS.branch_method, workspace_width=FLAGS.workspace_width, - depth=FLAGS.depth, - dendrites=FLAGS.dendrites, - timesplit_freq=FLAGS.timesplit_freq, - branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, + kwargs={ + "depth": FLAGS.depth, + "dendrites": FLAGS.dendrites, + "timesplit_freq": FLAGS.timesplit_freq, + "branch_count_rate_of_change": FLAGS.branch_count_rate_of_change, + } ) + replay_buffer = make_replay_buffer( + env, + capacity=FLAGS.capacity, + split_method=FLAGS.split_method, + branch_method=FLAGS.branch_method, + workspace_width=FLAGS.workspace_width, + kwargs={ + "depth": FLAGS.depth, + "dendrites": FLAGS.dendrites, + "timesplit_freq": FLAGS.timesplit_freq, + "branch_count_rate_of_change": FLAGS.branch_count_rate_of_change, + } + ) + + + observation, info = env.reset() action = env.action_space.sample() next_observation, reward, terminated, truncated, info = env.step(action) @@ -51,7 +67,7 @@ def main(_): dones=truncated or terminated, ) - del env, observation, next_observation, action, reward, truncated, terminated, action_space, observation_space, info, _ + del env, observation, next_observation, action, reward, truncated, terminated, info, _ # transform() test diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 782221eb..ad4326eb 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -18,6 +18,7 @@ from serl_launcher.data.data_store import ( MemoryEfficientReplayBufferDataStore, ReplayBufferDataStore, + FractalSymmetryReplayBufferDataStore, ) ############################################################################## @@ -206,6 +207,11 @@ def make_replay_buffer( image_keys: list = [], # used only type=="memory_efficient_replay_buffer" preload_rlds_path: Optional[str] = None, preload_data_transform: Optional[callable] = None, + branch_method: str = None, # used only type=="fractal_symmetry_replay_buffer" + split_method : str = None, # used only type=="fractal_symmetry_replay_buffer" + workspace_width : float = None, # used only type=="fractal_symmetry_replay_buffer" + + **kwargs: dict # used only type=="fractal_symmetry_replay_buffer" ): """ This is the high-level helper function to @@ -254,6 +260,17 @@ def make_replay_buffer( rlds_logger=rlds_logger, image_keys=image_keys, ) + elif type == "fractal_symmetry_replay_buffer": + replay_buffer = FractalSymmetryReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + branch_method=branch_method, + split_method=split_method, + workspace_width=workspace_width, + rlds_logger=rlds_logger, + kwargs=kwargs, + ) else: raise ValueError(f"Unsupported replay_buffer_type: {type}") From a23330fd0c1d29a4cc14bad50f92c2c65b35ca24 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 15 Jul 2025 14:56:54 -0500 Subject: [PATCH 024/141] code clean up --- .../serl_launcher/data/data_store.py | 2 +- .../data/fractal_symmetry_replay_buffer.py | 24 +++++++++------ serl_launcher/serl_launcher/data/fsrb_test.py | 29 +++++++++++-------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index 54430dc6..1e1ff8e6 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -156,7 +156,7 @@ def __init__( rlds_logger: Optional[RLDSLogger] = None, **kwargs: dict, ): - FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width **kwargs) + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width, **kwargs) DataStoreBase.__init__(self, capacity) self._lock = Lock() self._logger = None diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 1aa962b7..5ab843e9 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -16,9 +16,10 @@ def __init__( ): self.current_branch_count=1 + method_check = "split_method" match split_method: case "time": - assert "timesplit_freq" in kwargs.keys(), "\033[31mERROR \033[0mtimesplit_freq must be defined for split_method \"time\"" + assert "timesplit_freq" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") self.timesplit_freq=kwargs["timesplit_freq"] del kwargs["timesplit_freq"] self.split = self.time_split @@ -35,32 +36,33 @@ def __init__( self.split = self.test_split case _: raise ValueError("incorrect value passed to split_method") - + + method_check = "branch_method" match branch_method: case "fractal": - assert "depth" in kwargs.keys(), "\033[31mERROR \033[0mdepth must be defined for branch_method \"fractal\"" + assert "depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") self.depth=kwargs["depth"] del kwargs["depth"] - assert "dendrites" in kwargs.keys(), "\033[31mERROR \033[0mdendrites must be defined for branch_method \"fractal\"" + assert "dendrites" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "dendrites") self.dendrites=kwargs["dendrites"] del kwargs["dendrites"] self.branch = self.fractal_branch case "linear": - assert "branch_count_rate_of_change" in kwargs.keys(), "\033[31mERROR \033[0mbranch_count_rate_of_change must be defined for branch_method \"fractal\"" + assert "branch_count_rate_of_change" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branch_count_rate_of_change") self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] del kwargs["branch_count_rate_of_change"] - assert "starting_branch_count" in kwargs.keys(), "\033[31mERROR \033[0mstarting_branch_count must be defined for branch_method \"constant\"" + assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "starting_branch_count") self.current_branch_count = kwargs["starting_branch_count"] del kwargs["starting_branch_count"] self.branch = self.linear_branch case "constant": - assert "starting_branch_count" in kwargs.keys(), "\033[31mERROR \033[0mstarting_branch_count must be defined for branch_method \"constant\"" + assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") self.current_branch_count = kwargs["starting_branch_count"] del kwargs["starting_branch_count"] @@ -72,8 +74,8 @@ def __init__( case _: raise ValueError("incorrect value passed to branch_method") - for i in kwargs.keys(): - print(f"\033[33mWARNING \033[0m {i} argument not used") + for k in kwargs.keys(): + print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") @@ -96,6 +98,10 @@ def __init__( capacity=capacity, ) + def _handle_bad_args_(type: str, method: str, arg: str) : + return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" + + def transform(self, data_dict: DatasetDict, translation: np.array): # return data_dict with positional arguments += translation data_dict["observations"] += translation diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 02d2a026..6bce6d24 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -17,6 +17,7 @@ flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches(sort of, TBD)") def main(_): @@ -31,12 +32,11 @@ def main(_): split_method=FLAGS.split_method, branch_method=FLAGS.branch_method, workspace_width=FLAGS.workspace_width, - kwargs={ - "depth": FLAGS.depth, - "dendrites": FLAGS.dendrites, - "timesplit_freq": FLAGS.timesplit_freq, - "branch_count_rate_of_change": FLAGS.branch_count_rate_of_change, - } + depth=FLAGS.depth, + dendrites=FLAGS.dendrites, + timesplit_freq=FLAGS.timesplit_freq, + branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, + starting_branch_count=FLAGS.starting_branch_count, ) replay_buffer = make_replay_buffer( @@ -45,12 +45,11 @@ def main(_): split_method=FLAGS.split_method, branch_method=FLAGS.branch_method, workspace_width=FLAGS.workspace_width, - kwargs={ - "depth": FLAGS.depth, - "dendrites": FLAGS.dendrites, - "timesplit_freq": FLAGS.timesplit_freq, - "branch_count_rate_of_change": FLAGS.branch_count_rate_of_change, - } + depth=FLAGS.depth, + dendrites=FLAGS.dendrites, + timesplit_freq=FLAGS.timesplit_freq, + branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, + starting_branch_count=FLAGS.starting_branch_count, ) @@ -168,6 +167,12 @@ def main(_): result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" + replay_buffer.insert(data_dict) + for idx in range(buffer.current_branch_count): + expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) + result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) + assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" + print("\033[32mTEST PASSED \033[0m insert() tests passed") print("\nfinished!\n") From a108f6bac9ced36000af334e922d0cf89eb9adc5 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 15 Jul 2025 16:44:33 -0500 Subject: [PATCH 025/141] ready for testing with real environment --- .../data/fractal_symmetry_replay_buffer.py | 51 ++++++++++------ serl_launcher/serl_launcher/data/fsrb_test.py | 59 ++++++++----------- serl_launcher/serl_launcher/utils/launcher.py | 2 +- 3 files changed, 59 insertions(+), 53 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 5ab843e9..4b44180d 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -12,16 +12,16 @@ def __init__( branch_method: str, split_method: str, workspace_width: int, - **kwargs: dict + kwargs: dict ): self.current_branch_count=1 method_check = "split_method" match split_method: case "time": - assert "timesplit_freq" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") - self.timesplit_freq=kwargs["timesplit_freq"] - del kwargs["timesplit_freq"] + # assert "timesplit_freq" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") + # self.timesplit_freq=kwargs["timesplit_freq"] + # del kwargs["timesplit_freq"] self.split = self.time_split case "rel_pos": print("NOT IMPLEMENTED") @@ -77,10 +77,11 @@ def __init__( for k in kwargs.keys(): print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") - - + self.x_obs_idx = kwargs["x_obs_idx"] + # self.y_obs_idx = kwargs["y_obs_idx"] self.workspace_width = workspace_width - self.current_depth = 0 + self.current_depth = 1 + self.counter = 1 # TODO # Add flags for variables passed to @@ -102,10 +103,12 @@ def _handle_bad_args_(type: str, method: str, arg: str) : return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" - def transform(self, data_dict: DatasetDict, translation: np.array): + def transform(self, data_dict: DatasetDict, transform: np.array): # return data_dict with positional arguments += translation - data_dict["observations"] += translation - data_dict["next_observations"] += translation + assert data_dict["observations"].shape == transform.shape, "transform broke at observations" + assert data_dict["observations"].shape == transform.shape, "transform broke at next_observations" + data_dict["observations"] += transform + data_dict["next_observations"] += transform # FOR TESTING ONLY def test_branch(self): @@ -117,8 +120,11 @@ def test_branch(self): def fractal_branch(self): # return a new number of branches = dendrites ^ depth + temp = self.dendrites ** self.current_depth self.current_depth += 1 - return self.dendrites ** (self.current_depth) + if self.current_depth > self.depth: + return self.dendrites ** self.depth + return temp # REQUIRES TESTING def constant_branch(self): @@ -134,10 +140,14 @@ def linear_branch(self): def test_split(self, data_dict: DatasetDict): return True - # NOT IMPLEMENTED + # NOT IMPLEMENTED (HARDCODED) def time_split(self, data_dict: DatasetDict): # return True when a set time has passed - raise NotImplementedError + match self.counter % 12: + case 0: + return True + case _: + return False # NOT IMPLEMENTED def rel_pos_split(self, data_dict: DatasetDict): @@ -167,10 +177,17 @@ def insert(self, data_dict: DatasetDict): # Transform and insert branches dx = self.workspace_width/self.current_branch_count x = (-self.workspace_width + dx)/2 - self.transform(data_dict, np.array([x, 0, 0, 0, 0, 0, 0, x, 0, 0])) + transform = np.zeros_like(data_dict["observations"]) + transform[self.x_obs_idx] = x + self.transform(data_dict, transform) + transform[self.x_obs_idx] = dx for t in range(0, self.current_branch_count): super().insert(data_dict) - self.transform(data_dict, np.array([dx, 0, 0, 0, 0, 0, 0, dx, 0, 0])) - - + self.transform(data_dict, transform) + + self.counter += 1 + + if data_dict["dones"]: + self.counter = 1 + self.current_depth = 1 \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 6bce6d24..2e71bf01 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -10,8 +10,8 @@ FLAGS = flags.FLAGS flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "test", "placeholder") -flags.DEFINE_string("split_method", "test", "placeholder") +flags.DEFINE_string("branch_method", "fractal", "placeholder") +flags.DEFINE_string("split_method", "time", "placeholder") flags.DEFINE_float("workspace_width", 0.5, "workspace width in centimeters") flags.DEFINE_integer("depth", 4, "Total layers of depth") flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None @@ -21,30 +21,21 @@ def main(_): + x_obs_idx = np.array([0, 7]) + y_obs_idx = np.array([1, 8]) + # Initialize replay buffer env = gym.make("PandaPickCube-v0") env = gym.wrappers.FlattenObservation(env) - buffer = FractalSymmetryReplayBuffer( - observation_space=env.observation_space, - action_space=env.action_space, - capacity=FLAGS.capacity, - split_method=FLAGS.split_method, - branch_method=FLAGS.branch_method, - workspace_width=FLAGS.workspace_width, - depth=FLAGS.depth, - dendrites=FLAGS.dendrites, - timesplit_freq=FLAGS.timesplit_freq, - branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, - starting_branch_count=FLAGS.starting_branch_count, - ) - replay_buffer = make_replay_buffer( env, + type="fractal_symmetry_replay_buffer", capacity=FLAGS.capacity, split_method=FLAGS.split_method, branch_method=FLAGS.branch_method, workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, depth=FLAGS.depth, dendrites=FLAGS.dendrites, timesplit_freq=FLAGS.timesplit_freq, @@ -52,8 +43,6 @@ def main(_): starting_branch_count=FLAGS.starting_branch_count, ) - - observation, info = env.reset() action = env.action_space.sample() next_observation, reward, terminated, truncated, info = env.step(action) @@ -71,7 +60,7 @@ def main(_): # transform() test expected = np.copy(data_dict["observations"]) * 2 - buffer.transform(data_dict, np.copy(data_dict["observations"])) + replay_buffer.transform(data_dict, np.copy(data_dict["observations"])) result = data_dict["observations"] assert np.array_equal(result, expected), f"\033[31mTEST FAILED\033[0m transform() test failed (expected {expected} but got {result})" @@ -82,15 +71,15 @@ def main(_): ## fractal - for dendrites in range(1, 10, 2): - buffer.dendrites=dendrites - for depth in range(1, 6): - buffer.current_depth=depth - 1 - result = buffer.fractal_branch() - expected = dendrites ** depth - assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + # for dendrites in range(1, 10, 2): + # replay_buffer.dendrites=dendrites + # for depth in range(1, 6): + # replay_buffer.current_depth=depth - 1 + # result = replay_buffer.fractal_branch() + # expected = dendrites ** depth + # assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" - del dendrites, depth, result, expected + # del dendrites, depth, result, expected ## constant @@ -161,16 +150,16 @@ def main(_): start = data_dict["observations"][0] - FLAGS.workspace_width / 2 - buffer.insert(data_dict) - for idx in range(buffer.current_branch_count): - expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) - result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) - assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" + # buffer.insert(data_dict) + # for idx in range(buffer.current_branch_count): + # expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) + # result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) + # assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" replay_buffer.insert(data_dict) - for idx in range(buffer.current_branch_count): - expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) - result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) + for idx in range(replay_buffer.current_branch_count): + expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * replay_buffer.current_branch_count), 3) + result = round(replay_buffer.dataset_dict["observations"][idx][0] - start, 3) assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" print("\033[32mTEST PASSED \033[0m insert() tests passed") diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index ad4326eb..5eaab48b 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -221,7 +221,7 @@ def make_replay_buffer( - env: gym or gymasium environment - capacity: capacity of the replay buffer - rlds_logger_path: path to save RLDS logs - - type: support only for "replay_buffer" and "memory_efficient_replay_buffer" + - type: support only for "replay_buffer", "memory_efficient_replay_buffer", and "fractal_symmetry_replay_buffer" - image_keys: list of image keys, used only "memory_efficient_replay_buffer" - preload_rlds_path: path to preloaded RLDS trajectories - preload_data_transform: data transformation function for preloaded RLDS data From 0e6d65bca38914bdd4d3341b094f367e55d689b4 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 15 Jul 2025 16:52:04 -0500 Subject: [PATCH 026/141] reach task using serl --- .../async_sac_state_sim.py | 10 ++++----- examples/async_sac_state_sim/run_actor.sh | 19 +++++++++++++---- examples/async_sac_state_sim/run_learner.sh | 21 +++++++++++++++---- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 27d43c3c..8121d249 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -92,8 +92,8 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - eval_env = gym.wrappers.FlattenObservation(eval_env) + #if FLAGS.env == "PandaPickCube-v0": + eval_env = gym.wrappers.FlattenObservation(eval_env) eval_env = RecordEpisodeStatistics(eval_env) obs, _ = env.reset() @@ -174,7 +174,7 @@ def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project="serl_reach", description=FLAGS.exp_name or FLAGS.env, debug=FLAGS.debug, ) @@ -266,8 +266,8 @@ def main(_): else: env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) + #if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) rng, sampling_rng = jax.random.split(rng) agent: SACAgent = make_sac_agent( diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index b8ba3c8d..48c64354 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,7 +1,18 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ export ENV_NAME="PandaReachCube-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/$ENV_NAME_checkpoints/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi python async_sac_state_sim.py "$@" \ --actor \ --render \ @@ -9,12 +20,12 @@ python async_sac_state_sim.py "$@" \ --exp_name=serl-reach \ --seed 0 \ --random_steps 1000 \ - --max_steps 1000000 \ + --max_steps 100000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ + --replay_buffer_capacity 100000 \ --save_model True \ --checkpoint_period 10000 \ - --checkpoint_path "$SCRIPT_DIR/$ENV_NAME/checkpoints" \ + --checkpoint_path "$CHECKPOINT_DIR" \ #--debug # wandb is disabled when debug - diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 9aead4a2..ed3295d3 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,17 +1,30 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ export ENV_NAME="PandaReachCube-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/$ENV_NAME_checkpoints/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + python async_sac_state_sim.py "$@" \ --learner \ --env $ENV_NAME \ --exp_name=serl-reach \ --seed 0 \ - --max_steps 1000000 \ + --max_steps 100000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ + --replay_buffer_capacity 100000 \ --save_model True \ --checkpoint_period 10000 \ - --checkpoint_path "$SCRIPT_DIR/$ENV_NAME/checkpoints" \ - #--debug # wandb is disabled when debug + --checkpoint_path "$CHECKPOINT_DIR" \ + #--debug # wandb is disabled when debug \ No newline at end of file From 610f00c04c3a739201b8615763e5f9032ed3f982 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 15 Jul 2025 17:01:10 -0500 Subject: [PATCH 027/141] peg insert changes --- examples/async_peg_insert_drq/README.md | 11 ------- examples/async_peg_insert_drq/record_demo.py | 2 +- examples/async_peg_insert_drq/run_actor.sh | 29 ++++++++++++++++--- examples/async_peg_insert_drq/run_learner.sh | 23 ++++++++++++--- .../async_sac_state_sim/.vscode/launch.json | 4 +-- .../franka_sim/envs/panda_reach_gym_env.py | 2 +- serl_launcher/requirements.txt | 8 +++-- .../franka_env/envs/peg_env/config.py | 12 ++++---- 8 files changed, 60 insertions(+), 31 deletions(-) delete mode 100644 examples/async_peg_insert_drq/README.md diff --git a/examples/async_peg_insert_drq/README.md b/examples/async_peg_insert_drq/README.md deleted file mode 100644 index 068ec121..00000000 --- a/examples/async_peg_insert_drq/README.md +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - - - -for record demo -make sure to connect to RealSense camera diff --git a/examples/async_peg_insert_drq/record_demo.py b/examples/async_peg_insert_drq/record_demo.py index 5a0fd02b..1840a5d3 100644 --- a/examples/async_peg_insert_drq/record_demo.py +++ b/examples/async_peg_insert_drq/record_demo.py @@ -32,7 +32,7 @@ transitions = [] success_count = 0 - success_needed = 20 + success_needed = 30 total_count = 0 pbar = tqdm(total=success_needed) diff --git a/examples/async_peg_insert_drq/run_actor.sh b/examples/async_peg_insert_drq/run_actor.sh index a21431bf..d11b8429 100644 --- a/examples/async_peg_insert_drq/run_actor.sh +++ b/examples/async_peg_insert_drq/run_actor.sh @@ -1,13 +1,34 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="FrankaPegInsert-Vision-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ +export CHECKPOINT_EVAL="/home/student/code/cleiver/serl/examples/async_peg_insert_drq/checkpoints/checkpoints-07-14-2025-23-15-59" && \ + + +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + python async_drq_randomized.py "$@" \ --actor \ --render \ - --env FrankaPegInsert-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet \ + --env $ENV_NAME \ + --exp_name=serl-peg-insert \ --max_steps 25000 \ --seed 0 \ --random_steps 0 \ --training_starts 200 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2025-04-14_15-36-51.pkl \ + --demo_path peg_insert_30_demos_2025-07-14_22-57-59.pkl \ + --checkpoint_period 1000 \ + --checkpoint_path "$CHECKPOINT_DIR" \ + # --eval_checkpoint_step=5000 \ + # --eval_n_trajs=5 \ + #--debug # wandb is disabled when debug diff --git a/examples/async_peg_insert_drq/run_learner.sh b/examples/async_peg_insert_drq/run_learner.sh index 0495db45..63435ada 100644 --- a/examples/async_peg_insert_drq/run_learner.sh +++ b/examples/async_peg_insert_drq/run_learner.sh @@ -1,9 +1,23 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.6 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="FrankaPegInsert-Vision-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + python async_drq_randomized.py "$@" \ --learner \ - --env FrankaPegInsert-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet_097 \ + --env $ENV_NAME \ + --exp_name=serl-peg-insert \ --seed 0 \ --max_steps 25000 \ --random_steps 1000 \ @@ -12,6 +26,7 @@ python async_drq_randomized.py "$@" \ --batch_size 128 \ --eval_period 2000 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2025-04-14_15-36-51.pkl\ + --demo_path peg_insert_30_demos_2025-07-14_22-57-59.pkl\ --checkpoint_period 1000 \ - --checkpoint_path /home/student/robot/robot_ws/src/serl/examples/async_peg_insert_drq/checkpoints \ No newline at end of file + --checkpoint_path "$CHECKPOINT_DIR" \ + #--debug # wandb is disabled when debug \ No newline at end of file diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 4010fde4..c71ff95e 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -12,7 +12,7 @@ // Command-line arguments matching run_learner.sh "args": [ "--learner", // REQUIRED: Indicates this is a learner instance - "--env", "PandaPickCube-v0", // Environment to use + "--env", "PandaReachCube-v0", // Environment to use "--exp_name=serl_dev_sim_test", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--max_steps", "1000000", // Maximum training steps @@ -47,7 +47,7 @@ "args": [ "--actor", // REQUIRED: Indicates this is an actor instance "--render", // Enable rendering the environment - "--env", "PandaPickCube-v0", // Environment to use + "--env", "PandaReachCube-v0", // Environment to use "--exp_name=serl_dev_sim_test", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1000", // Number of random steps at beginning diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py index 8d1ec251..03828dd8 100644 --- a/franka_sim/franka_sim/envs/panda_reach_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -266,7 +266,7 @@ def _compute_reward(self) -> float: # Distance-based reward r_close = np.exp(-20 * dist) - print(r_close) + r_close = np.clip(r_close, 0.0, 1.0) return r_close diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index 86f92e95..f9d530ef 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -1,6 +1,7 @@ +opencv-python<4.12.0.88 gym >= 0.26 numpy>=1.24.3 -flax>=0.8.0 +flax>=0.8.0, < 0.10.6 distrax>=0.1.2 ml_collections >= 0.1.0 tqdm >= 4.60.0 @@ -8,7 +9,7 @@ chex>=0.1.85 optax>=0.1.5 orbax-checkpoint>=0.5.10 absl-py >= 0.12.0 -scipy==1.11.4 +scipy>=1.11.4 wandb >= 0.12.14 tensorflow>=2.16.0 tensorflow_probability>=0.24.0 @@ -18,3 +19,6 @@ imageio >= 2.31.1 moviepy >= 1.0.3 pre-commit == 3.3.3 tensorflow_datasets >= 4.4.0 +jax==0.4.35 +jaxlib==0.4.34 +jaxlib==0.4.34 diff --git a/serl_robot_infra/franka_env/envs/peg_env/config.py b/serl_robot_infra/franka_env/envs/peg_env/config.py index 97f263c5..c027b5a9 100644 --- a/serl_robot_infra/franka_env/envs/peg_env/config.py +++ b/serl_robot_infra/franka_env/envs/peg_env/config.py @@ -11,12 +11,12 @@ class PegEnvConfig(DefaultEnvConfig): "wrist_2": "218622271526", } TARGET_POSE = np.array( - [0.5130075190601081, - -0.23745299669341197, - 0.07808493744892302, - 3.104175639241274, - -0.024134743891730315, - -3.1304483269500967, + [0.6179721091801964, + -0.08069386463219706, + 0.07628962570607248, + -3.1161359527499712, + 0.04124456532930543, + 1.5939026317635385, ] ) From a80b8187f78b441c55c91c0e7fbb95958c13f795 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 15 Jul 2025 17:36:03 -0500 Subject: [PATCH 028/141] async_sac_state_sim uses fractal_replay_buffer --- examples/async_sac_state_sim/async_sac_state_sim.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 0892dde3..42453bed 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -288,7 +288,11 @@ def main(_): env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="replay_buffer", + type="fractal_symmetry_replay_buffer", + branch_method="test", + split_method="test", + workspace_width=0.5, + x_obs_idx=np.array([0,7]), preload_rlds_path=FLAGS.preload_rlds_path, ) replay_iterator = replay_buffer.get_iterator( From 729a6403a6e8ba6dff1c55172d616c94d3a23416 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 15 Jul 2025 18:56:24 -0500 Subject: [PATCH 029/141] launch json file for debugging reach --- .gitignore | 2 +- .../async_sac_state_sim/.vscode/launch.json | 28 +++++++++---------- .../async_sac_state_sim.py | 4 +-- serl_launcher/serl_launcher/common/wandb.py | 2 +- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index bf70074e..672039d1 100644 --- a/.gitignore +++ b/.gitignore @@ -174,5 +174,5 @@ checkpoint wandb/ # VS Code settings -.vscode/ *.code-workspace + diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index c71ff95e..60496b7a 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -11,25 +11,22 @@ "program": "${workspaceFolder}/async_sac_state_sim.py", // Command-line arguments matching run_learner.sh "args": [ - "--learner", // REQUIRED: Indicates this is a learner instance - "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=serl_dev_sim_test", // Experiment name for wandb logging + "--learner", // REQUIRED: Indicates this is a learner instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name=SERL-ReachTask", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility - "--max_steps", "1000000", // Maximum training steps + "--max_steps", "100000", // Maximum training steps "--training_starts", "1000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size - "--save_model", "True", // Whether to save model checkpoints - "--checkpoint_period", "10000", // How often to save checkpoints - "--checkpoint_path", "./checkpoints", // Where to save checkpoints - "--debug" // Debug mode (disables wandb) + "--replay_buffer_capacity", "100000", // Replay buffer capacity ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries // Environment variables from run_learner.sh "env": { "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory - "XLA_PYTHON_CLIENT_MEM_FRACTION": ".05" // Limit JAX memory usage to 5% + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% }, // Additional helpful debugging options "showReturnValue": true, // Show function return values @@ -46,19 +43,22 @@ // Command-line arguments matching run_actor.sh "args": [ "--actor", // REQUIRED: Indicates this is an actor instance - "--render", // Enable rendering the environment - "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=serl_dev_sim_test", // Experiment name for wandb logging + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name=SERL-ReachTask", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1000", // Number of random steps at beginning - "--debug" // Debug mode (disables wandb) + "--max_steps", "100000", // Maximum training steps + "--training_starts", "1000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "100000", // Replay buffer capacity ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries // Environment variables from run_actor.sh "env": { "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory - "XLA_PYTHON_CLIENT_MEM_FRACTION": ".05" // Limit JAX memory usage to 5% + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% }, "showReturnValue": true, // Show function return values "purpose": ["debug-in-terminal"], // Run in terminal for better output diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 8121d249..1e289bac 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -167,14 +167,14 @@ def update_params(params): ############################################################################## - + def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): """ The learner loop, which runs when "--learner" is set to True. """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_reach", + project=FLAGS.exp_name, description=FLAGS.exp_name or FLAGS.env, debug=FLAGS.debug, ) diff --git a/serl_launcher/serl_launcher/common/wandb.py b/serl_launcher/serl_launcher/common/wandb.py index 2ef341aa..5bc0f432 100644 --- a/serl_launcher/serl_launcher/common/wandb.py +++ b/serl_launcher/serl_launcher/common/wandb.py @@ -45,7 +45,7 @@ def __init__( self.config = wandb_config if self.config.unique_identifier == "": self.config.unique_identifier = datetime.datetime.now().strftime( - "%Y%m%d_%H%M%S" + "%m-%d-%Y-%H-%M-%S" ) self.config.experiment_id = ( From 45e7ee228bf297ae4b7d9aec5e76a6ae8c4f90b5 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 17 Jul 2025 14:09:36 -0500 Subject: [PATCH 030/141] fixed jax and jaxlib version errors (now should use 0.6.2) --- .vscode/launch.json | 15 +++++++++++++++ .../async_drq_randomized.py | 4 ++-- .../async_cable_route_drq/async_drq_randomized.py | 2 +- examples/async_drq_sim/async_drq_sim.py | 2 +- .../async_pcb_insert_drq/async_drq_randomized.py | 2 +- .../async_peg_insert_drq/async_drq_randomized.py | 2 +- .../async_sac_state_sim/async_sac_state_sim.py | 2 +- serl_launcher/requirements.txt | 3 --- .../serl_launcher/agents/continuous/sac.py | 4 ++-- serl_launcher/serl_launcher/common/common.py | 14 +++++++------- serl_launcher/serl_launcher/data/dataset.py | 2 +- .../serl_launcher/vision/data_augmentations.py | 8 ++++---- serl_launcher/serl_launcher/wrappers/chunking.py | 2 +- serl_launcher/serl_launcher/wrappers/remap.py | 2 +- 14 files changed, 38 insertions(+), 26 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..6b76b4fa --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py index 78a8eee4..7669559f 100644 --- a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py +++ b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py @@ -485,7 +485,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) agents[v] = agent else: @@ -500,7 +500,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_cable_route_drq/async_drq_randomized.py b/examples/async_cable_route_drq/async_drq_randomized.py index 90770fc4..3c15a441 100644 --- a/examples/async_cable_route_drq/async_drq_randomized.py +++ b/examples/async_cable_route_drq/async_drq_randomized.py @@ -370,7 +370,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index a84bc580..15a197d0 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -345,7 +345,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_pcb_insert_drq/async_drq_randomized.py b/examples/async_pcb_insert_drq/async_drq_randomized.py index 8248379e..5a1770ff 100644 --- a/examples/async_pcb_insert_drq/async_drq_randomized.py +++ b/examples/async_pcb_insert_drq/async_drq_randomized.py @@ -440,7 +440,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_peg_insert_drq/async_drq_randomized.py b/examples/async_peg_insert_drq/async_drq_randomized.py index 7b8ac779..2d089c96 100644 --- a/examples/async_peg_insert_drq/async_drq_randomized.py +++ b/examples/async_peg_insert_drq/async_drq_randomized.py @@ -348,7 +348,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index b6b0d728..314dcfdc 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -279,7 +279,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: SACAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index f9d530ef..9b92279f 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -19,6 +19,3 @@ imageio >= 2.31.1 moviepy >= 1.0.3 pre-commit == 3.3.3 tensorflow_datasets >= 4.4.0 -jax==0.4.35 -jaxlib==0.4.34 -jaxlib==0.4.34 diff --git a/serl_launcher/serl_launcher/agents/continuous/sac.py b/serl_launcher/serl_launcher/agents/continuous/sac.py index ca933db4..dd723924 100644 --- a/serl_launcher/serl_launcher/agents/continuous/sac.py +++ b/serl_launcher/serl_launcher/agents/continuous/sac.py @@ -575,11 +575,11 @@ def scan_body(carry: Tuple[SACAgent], data: Tuple[Batch]): def make_minibatch(data: jnp.ndarray): return jnp.reshape(data, (utd_ratio, minibatch_size) + data.shape[1:]) - minibatches = jax.tree_map(make_minibatch, batch) + minibatches = jax.tree.map(make_minibatch, batch) (agent,), critic_infos = jax.lax.scan(scan_body, (self,), (minibatches,)) - critic_infos = jax.tree_map(lambda x: jnp.mean(x, axis=0), critic_infos) + critic_infos = jax.tree.map(lambda x: jnp.mean(x, axis=0), critic_infos) del critic_infos["actor"] del critic_infos["temperature"] diff --git a/serl_launcher/serl_launcher/common/common.py b/serl_launcher/serl_launcher/common/common.py index a3a53e74..1e80fa06 100644 --- a/serl_launcher/serl_launcher/common/common.py +++ b/serl_launcher/serl_launcher/common/common.py @@ -22,7 +22,7 @@ def shard_batch(batch, sharding): batch: A pytree of arrays. sharding: A jax Sharding object with shape (num_devices,). """ - return jax.tree_map( + return jax.tree.map( lambda x: jax.device_put( x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) ), @@ -115,7 +115,7 @@ class JaxRLTrainState(struct.PyTreeNode): @staticmethod def _tx_tree_map(*args, **kwargs): - return jax.tree_map( + return jax.tree.map( *args, is_leaf=lambda x: isinstance(x, optax.GradientTransformation), **kwargs, @@ -128,7 +128,7 @@ def target_update(self, tau: float) -> "JaxRLTrainState": new_target_params = tau * params + (1 - tau) * target_params """ - new_target_params = jax.tree_map( + new_target_params = jax.tree.map( lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params ) return self.replace(target_params=new_target_params) @@ -158,7 +158,7 @@ def apply_gradients(self, *, grads: Any) -> "JaxRLTrainState": ) # apply all the updates additively - updates_acc = jax.tree_map( + updates_acc = jax.tree.map( lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat ) new_params = optax.apply_updates(self.params, updates_acc) @@ -200,7 +200,7 @@ def apply_loss_fns( rngs = jax.tree_util.tree_unflatten(treedef, rngs) # compute gradients - grads_and_aux = jax.tree_map( + grads_and_aux = jax.tree.map( lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), loss_fns, rngs, @@ -214,8 +214,8 @@ def apply_loss_fns( grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) if has_aux: - grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) - aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) + grads = jax.tree.map(lambda _, x: x[0], loss_fns, grads_and_aux) + aux = jax.tree.map(lambda _, x: x[1], loss_fns, grads_and_aux) return self.apply_gradients(grads=grads), aux else: return self.apply_gradients(grads=grads_and_aux) diff --git a/serl_launcher/serl_launcher/data/dataset.py b/serl_launcher/serl_launcher/data/dataset.py index f47eae9a..2daba136 100644 --- a/serl_launcher/serl_launcher/data/dataset.py +++ b/serl_launcher/serl_launcher/data/dataset.py @@ -210,7 +210,7 @@ def _sample_jax(rng, src, max_indx: int): return ( rng, indx.max(), - jax.tree_map(lambda d: jnp.take(d, indx, axis=0), src), + jax.tree.map(lambda d: jnp.take(d, indx, axis=0), src), ) self._sample_jax = _sample_jax diff --git a/serl_launcher/serl_launcher/vision/data_augmentations.py b/serl_launcher/serl_launcher/vision/data_augmentations.py index 2c2440fa..f455bbf1 100644 --- a/serl_launcher/serl_launcher/vision/data_augmentations.py +++ b/serl_launcher/serl_launcher/vision/data_augmentations.py @@ -169,7 +169,7 @@ def hsv_to_rgb(h, s, v): def adjust_brightness(rgb_tuple, delta): - return jax.tree_map(lambda x: x + delta, rgb_tuple) + return jax.tree.map(lambda x: x + delta, rgb_tuple) def adjust_contrast(image, factor): @@ -177,7 +177,7 @@ def _adjust_contrast_channel(channel): mean = jnp.mean(channel, axis=(-2, -1), keepdims=True) return factor * (channel - mean) + mean - return jax.tree_map(_adjust_contrast_channel, image) + return jax.tree.map(_adjust_contrast_channel, image) def adjust_saturation(h, s, v, factor): @@ -256,7 +256,7 @@ def identity_fn(x, unused_rng, unused_param): def cond_fn(args, i): def clip(args): - return jax.tree_map(lambda arg: jnp.clip(arg, 0.0, 1.0), args) + return jax.tree.map(lambda arg: jnp.clip(arg, 0.0, 1.0), args) out = jax.lax.cond( should_apply & should_apply_color & (i == idx), @@ -275,7 +275,7 @@ def clip(args): random_hue_cond = _make_cond(_random_hue, idx=3) def _color_jitter(x): - rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1))) + rgb_tuple = tuple(jax.tree.map(jnp.squeeze, jnp.split(x, 3, axis=-1))) if shuffle: order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32)) else: diff --git a/serl_launcher/serl_launcher/wrappers/chunking.py b/serl_launcher/serl_launcher/wrappers/chunking.py index 175c9a5f..404a31c5 100644 --- a/serl_launcher/serl_launcher/wrappers/chunking.py +++ b/serl_launcher/serl_launcher/wrappers/chunking.py @@ -9,7 +9,7 @@ def stack_obs(obs): dict_list = {k: [dic[k] for dic in obs] for k in obs[0]} - return jax.tree_map( + return jax.tree.map( lambda x: np.stack(x), dict_list, is_leaf=lambda x: isinstance(x, list) ) diff --git a/serl_launcher/serl_launcher/wrappers/remap.py b/serl_launcher/serl_launcher/wrappers/remap.py index 7acb2d93..1d724dc1 100644 --- a/serl_launcher/serl_launcher/wrappers/remap.py +++ b/serl_launcher/serl_launcher/wrappers/remap.py @@ -31,4 +31,4 @@ def __init__(self, env: gym.Env, new_structure: Any): raise TypeError(f"Unsupported type {type(new_structure)}") def observation(self, observation): - return jax.tree_map(lambda x: observation[x], self.new_structure) + return jax.tree.map(lambda x: observation[x], self.new_structure) From 5098d78467fdbb068f5e13a1df85e1419950cf16 Mon Sep 17 00:00:00 2001 From: Ryan Vander Stelt <87216059+ryanvanderstelt@users.noreply.github.com> Date: Fri, 18 Jul 2025 15:10:24 -0500 Subject: [PATCH 031/141] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7f8db41f..c319e01b 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ We fixed a major issue in the intervention action frame. See release [v0.1.1](ht - For GPU: ```bash - pip install --upgrade "jax[cuda12_pip]==0.4.35" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install --upgrade "jax[cuda12]" ``` - For TPU From 653ed82c4eda0a322559e8ca5af6bb3185938245 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 18 Jul 2025 15:13:51 -0500 Subject: [PATCH 032/141] phone files --- demos/franka_pick_place_demo_script.py | 407 +++++++++++++++++++++++++ demos/franka_reach_demo_script.py | 407 +++++++++++++++++++++++++ 2 files changed, 814 insertions(+) create mode 100755 demos/franka_pick_place_demo_script.py create mode 100755 demos/franka_reach_demo_script.py diff --git a/demos/franka_pick_place_demo_script.py b/demos/franka_pick_place_demo_script.py new file mode 100755 index 00000000..e5abbc11 --- /dev/null +++ b/demos/franka_pick_place_demo_script.py @@ -0,0 +1,407 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with stable-baselines3. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences +""" +import os +import numpy as np +import gymnasium as gym +from gymnasium.wrappers import TimeLimit +from time import sleep +from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the pick-and-place task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ file for use with stable-baselines3. + """ + # Initialize Fetch pick-and-place environment + env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") + env = TimeLimit(env, max_episode_steps=50) + + # Adjust physical settings + # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + num_demos = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() + print("Reset!") + + # Generate demonstration episodes + while len(actions) < attempted_demos: + obs,_ = env.reset() # Reset environment for new episode + print(f"We will run a total of: {attempted_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = pick_and_place_demo(env, obs) + + # Print success message + if res: + num_demos += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + + ## Write data to demos folder + # 1. Get the absolute path of this script + #script_path = os.path.abspath(__file__) + + # 2. Extract its directory + #script_dir = os.path.dirname(script_path) + script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. + + # 3. Create output filename with configuration details + fileName = "data_" + robot + fileName += "_" + initStateSpace + fileName += "_" + str(attempted_demos) + fileName += ".npz" + + # 3. Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones) + + print(f"Data saved to {fileName}.") + +def pick_and_place_demo(env, lastObs): + """ + Executes a scripted pick-and-place sequence using a hierarchical approach. + + Implements 4-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + 2. Grasp: Move to object and close gripper + 3. Transport: Move grasped object to goal position + 4. Maintain: Hold position until episode ends + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Args: + env: Gymnasium environment instance + lastObs: Last observation containing goal and object state information + - desired_goal: Target position for object placement + - observations: + ee_position[0:3], + ee_velocity[3:6], + fingers_width[6], + object_position[7:10], + object_rotation[10:13], + object_velp[13:16], + object_velr[16:19], + """ + + ## Init goal, current_pos, and object position from last observation + goal = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + object_pos = np.zeros(3, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + fgr_pos = np.zeros(1, dtype=np.float32) + + # Initialize episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Proportional control gain for action scaling -- empirically tuned + Kp = 8.0 + + # pre_pick_offset + pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) + + # Error thresholds + error_threshold = 0.011 # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + # Extract desired position from desired_goal dict + goal = lastObs["desired_goal"][0:3] + + # Current robot end-effector position from observation dict + current_pos = lastObs["observation"][0:3] + + # Current object position from observation dict: + object_pos = lastObs["observation"][7:10] + + # Relative position between end-effector and object + object_rel_pos = object_pos - current_pos + + ## Phase 1: Approach Object (Above) + # Create target position 3cm above the object. Use copy() method. + error = object_rel_pos.copy() + error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. + + timeStep = 0 # Track total timesteps in episode + episodeObs.append(lastObs) + + # Phase 1: Move gripper to position above object + # Terminate when distance to above-object position < 5mm + print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: + env.render() # Visual feedback + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Proportional control with gain of 6 + # action = Kp * error + action[:3] = error * Kp + + # Open gripper for approach + action[ len(action)-1 ] = 0.05 + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeAcs.append(action) + episodeInfo.append(info) + episodeRews.append(reward) + episodeObs.append(new_obs) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + # Phase 2: Descend Grasp Object + # Move gripper directly to object and close gripper + # Terminate when relative distance to object < 5mm + print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") + error = object_pos - current_pos # remove offset + while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm + env.render() + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Direct proportional control to object position + action[:3] = error * Kp + + # Close gripper to grasp object + action[len(action)-1] = -finger_delta_fast * 2 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + #sleep(0.5) # Optional: Slow down for better visualization + + # Phase 3: Transport to Goal + # Move grasped object to desired goal position + # Terminate when distance between object and goal < 1cm + print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") + + # Weld activation + if weld_flag: + activate_weld(env, constraint_name="grasp_weld") + + # Set error between goal and hand assuming the object is grasped + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: + env.render() + + action = np.array([0., 0., 0., 0.]) + + # Proportional control toward goal position + action[:3] = gh_error[:3] * Kp + + # Maintain grip on object + #action[len(action)-1] = 0 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + + # Print debug information + print( + f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"goal_pos: {np.array2string(goal, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(gh_error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + sleep(0.5) # Optional: Slow down for better visualization + + ## Check for success and store episode data + gh_norm = np.linalg.norm(gh_error) + ho_nomr = np.linalg.norm(ho_error) + if gh_norm < error_threshold and ho_nomr < error_threshold: + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + actions.append(episodeAcs) + observations.append(episodeObs) + infos.append(episodeInfo) + rewards.append(episodeRews) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episodeTerminated) + truncateds.append(episodeTruncated) + dones.append(episodeDones) + + # Deactivate weld constraint after successful pick + if weld_flag: + deactivate_weld(env, constraint_name="grasp_weld") + + # Close mujoco viewer + env.close() + + # Break out of the loop to start a new episode + return True + + # If we reach here, the episode was not successful + if weld_flag: + print("Failed to transport object to goal position. Deactivating weld.") + deactivate_weld(env, constraint_name="grasp_weld") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py new file mode 100755 index 00000000..e5abbc11 --- /dev/null +++ b/demos/franka_reach_demo_script.py @@ -0,0 +1,407 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with stable-baselines3. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences +""" +import os +import numpy as np +import gymnasium as gym +from gymnasium.wrappers import TimeLimit +from time import sleep +from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the pick-and-place task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ file for use with stable-baselines3. + """ + # Initialize Fetch pick-and-place environment + env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") + env = TimeLimit(env, max_episode_steps=50) + + # Adjust physical settings + # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + num_demos = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() + print("Reset!") + + # Generate demonstration episodes + while len(actions) < attempted_demos: + obs,_ = env.reset() # Reset environment for new episode + print(f"We will run a total of: {attempted_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = pick_and_place_demo(env, obs) + + # Print success message + if res: + num_demos += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + + ## Write data to demos folder + # 1. Get the absolute path of this script + #script_path = os.path.abspath(__file__) + + # 2. Extract its directory + #script_dir = os.path.dirname(script_path) + script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. + + # 3. Create output filename with configuration details + fileName = "data_" + robot + fileName += "_" + initStateSpace + fileName += "_" + str(attempted_demos) + fileName += ".npz" + + # 3. Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones) + + print(f"Data saved to {fileName}.") + +def pick_and_place_demo(env, lastObs): + """ + Executes a scripted pick-and-place sequence using a hierarchical approach. + + Implements 4-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + 2. Grasp: Move to object and close gripper + 3. Transport: Move grasped object to goal position + 4. Maintain: Hold position until episode ends + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Args: + env: Gymnasium environment instance + lastObs: Last observation containing goal and object state information + - desired_goal: Target position for object placement + - observations: + ee_position[0:3], + ee_velocity[3:6], + fingers_width[6], + object_position[7:10], + object_rotation[10:13], + object_velp[13:16], + object_velr[16:19], + """ + + ## Init goal, current_pos, and object position from last observation + goal = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + object_pos = np.zeros(3, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + fgr_pos = np.zeros(1, dtype=np.float32) + + # Initialize episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Proportional control gain for action scaling -- empirically tuned + Kp = 8.0 + + # pre_pick_offset + pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) + + # Error thresholds + error_threshold = 0.011 # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + # Extract desired position from desired_goal dict + goal = lastObs["desired_goal"][0:3] + + # Current robot end-effector position from observation dict + current_pos = lastObs["observation"][0:3] + + # Current object position from observation dict: + object_pos = lastObs["observation"][7:10] + + # Relative position between end-effector and object + object_rel_pos = object_pos - current_pos + + ## Phase 1: Approach Object (Above) + # Create target position 3cm above the object. Use copy() method. + error = object_rel_pos.copy() + error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. + + timeStep = 0 # Track total timesteps in episode + episodeObs.append(lastObs) + + # Phase 1: Move gripper to position above object + # Terminate when distance to above-object position < 5mm + print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: + env.render() # Visual feedback + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Proportional control with gain of 6 + # action = Kp * error + action[:3] = error * Kp + + # Open gripper for approach + action[ len(action)-1 ] = 0.05 + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeAcs.append(action) + episodeInfo.append(info) + episodeRews.append(reward) + episodeObs.append(new_obs) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + # Phase 2: Descend Grasp Object + # Move gripper directly to object and close gripper + # Terminate when relative distance to object < 5mm + print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") + error = object_pos - current_pos # remove offset + while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm + env.render() + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Direct proportional control to object position + action[:3] = error * Kp + + # Close gripper to grasp object + action[len(action)-1] = -finger_delta_fast * 2 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + #sleep(0.5) # Optional: Slow down for better visualization + + # Phase 3: Transport to Goal + # Move grasped object to desired goal position + # Terminate when distance between object and goal < 1cm + print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") + + # Weld activation + if weld_flag: + activate_weld(env, constraint_name="grasp_weld") + + # Set error between goal and hand assuming the object is grasped + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: + env.render() + + action = np.array([0., 0., 0., 0.]) + + # Proportional control toward goal position + action[:3] = gh_error[:3] * Kp + + # Maintain grip on object + #action[len(action)-1] = 0 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + + # Print debug information + print( + f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"goal_pos: {np.array2string(goal, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(gh_error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + sleep(0.5) # Optional: Slow down for better visualization + + ## Check for success and store episode data + gh_norm = np.linalg.norm(gh_error) + ho_nomr = np.linalg.norm(ho_error) + if gh_norm < error_threshold and ho_nomr < error_threshold: + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + actions.append(episodeAcs) + observations.append(episodeObs) + infos.append(episodeInfo) + rewards.append(episodeRews) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episodeTerminated) + truncateds.append(episodeTruncated) + dones.append(episodeDones) + + # Deactivate weld constraint after successful pick + if weld_flag: + deactivate_weld(env, constraint_name="grasp_weld") + + # Close mujoco viewer + env.close() + + # Break out of the loop to start a new episode + return True + + # If we reach here, the episode was not successful + if weld_flag: + print("Failed to transport object to goal position. Deactivating weld.") + deactivate_weld(env, constraint_name="grasp_weld") + +if __name__ == "__main__": + main() \ No newline at end of file From f669788fda50ca797def8b6ae98a351de32e9eaf Mon Sep 17 00:00:00 2001 From: Ryan Vander Stelt <87216059+ryanvanderstelt@users.noreply.github.com> Date: Fri, 18 Jul 2025 15:15:11 -0500 Subject: [PATCH 033/141] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c319e01b..9803ab6c 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ We fixed a major issue in the intervention action frame. See release [v0.1.1](ht - For GPU: ```bash - pip install --upgrade "jax[cuda12]" + pip install --upgrade "jax[cuda12]==0.6.2" ``` - For TPU From cf135b75eb769386f3427f9f5ff4ae078176e008 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 18 Jul 2025 15:25:34 -0500 Subject: [PATCH 034/141] Added y-dimension --- .../async_sac_state_sim.py | 1 + .../data/fractal_symmetry_replay_buffer.py | 72 ++++++++++++++++--- serl_launcher/serl_launcher/data/fsrb_test.py | 6 +- 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 314dcfdc..ef20e8d0 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -293,6 +293,7 @@ def main(_): split_method="test", workspace_width=0.5, x_obs_idx=np.array([0,7]), + y_obs_idx=np.array([1,8]), preload_rlds_path=FLAGS.preload_rlds_path, ) replay_iterator = replay_buffer.get_iterator( diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 4b44180d..860ab1b0 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -2,6 +2,7 @@ import numpy as np from serl_launcher.data.dataset import DatasetDict from serl_launcher.data.replay_buffer import ReplayBuffer +import copy class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( @@ -78,10 +79,11 @@ def __init__( print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") self.x_obs_idx = kwargs["x_obs_idx"] - # self.y_obs_idx = kwargs["y_obs_idx"] + self.y_obs_idx = kwargs["y_obs_idx"] self.workspace_width = workspace_width self.current_depth = 1 self.counter = 1 + self.branch_index = None # TODO # Add flags for variables passed to @@ -116,7 +118,7 @@ def test_branch(self): # while(not self.current_branch_count % 2): # self.current_branch_count = np.random.randint(2500) # return self.current_branch_count - return 81 + return 3 def fractal_branch(self): # return a new number of branches = dendrites ^ depth @@ -172,22 +174,74 @@ def insert(self, data_dict: DatasetDict): # Update number of branches if needed if self.split(data_dict): + temp = self.current_branch_count self.current_branch_count = self.branch() + if temp != self.current_branch_count: + self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) + constant = self.workspace_width/(2 * self.current_branch_count) + for i in range(0, self.current_branch_count): + self.branch_index[i] = (2 * i + 1) * constant - # Transform and insert branches - dx = self.workspace_width/self.current_branch_count - x = (-self.workspace_width + dx)/2 + # rb_origin = [data_dict["observations"][self.x_obs_idx[0]], data_dict["observations"][self.y_obs_idx[0]]] + # block_origin = [data_dict["observations"][self.x_obs_idx[1]], data_dict["observations"][self.y_obs_idx[1]]] + # rb_next_origin = [data_dict["next_observations"][self.x_obs_idx[0]], data_dict["next_observations"][self.y_obs_idx[0]]] + # block_next_origin = [data_dict["next_observations"][self.x_obs_idx[1]], data_dict["next_observations"][self.y_obs_idx[1]]] + + # Initialize to extreme x and y + x = -self.workspace_width/2 transform = np.zeros_like(data_dict["observations"]) transform[self.x_obs_idx] = x + transform[self.y_obs_idx] = x self.transform(data_dict, transform) - transform[self.x_obs_idx] = dx - for t in range(0, self.current_branch_count): - super().insert(data_dict) - self.transform(data_dict, transform) + # rb_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) + # block_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) + # rb_next_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) + # block_next_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) + # Transform and insert transitions (multiprocessing in the future) + for x in range(0, self.current_branch_count): + x_diff = self.branch_index[x] + transform[self.x_obs_idx] = x_diff + for y in range(0, self.current_branch_count): + new_data_dict = copy.deepcopy(data_dict) + y_diff = self.branch_index[y] + transform[self.y_obs_idx] = y_diff + self.transform(new_data_dict, transform) + + # keep info for debug + # rb_pos_print[y][x][0] = round(new_data_dict["observations"][self.x_obs_idx[0]], 2) + # rb_pos_print[y][x][1] = round(new_data_dict["observations"][self.y_obs_idx[0]], 2) + # block_pos_print[y][x][0] = round(new_data_dict["observations"][self.x_obs_idx[1]], 2) + # block_pos_print[y][x][1] = round(new_data_dict["observations"][self.y_obs_idx[1]], 2) + # rb_next_pos_print[y][x][0] = round(new_data_dict["next_observations"][self.x_obs_idx[0]], 2) + # rb_next_pos_print[y][x][1] = round(new_data_dict["next_observations"][self.y_obs_idx[0]], 2) + # block_next_pos_print[y][x][0] = round(new_data_dict["next_observations"][self.x_obs_idx[1]], 2) + # block_next_pos_print[y][x][1] = round(new_data_dict["next_observations"][self.y_obs_idx[1]], 2) + + # absolutely make sure nothing wrong is being changed + # assert data_dict["rewards"] == new_data_dict["rewards"] + # assert (data_dict["actions"] == new_data_dict["actions"]).all() + # assert data_dict["dones"] == new_data_dict["dones"] + # assert data_dict["masks"] == new_data_dict["masks"] + # for i in range(0, data_dict["observations"].size): + # if i in self.x_obs_idx or i in self.y_obs_idx: + # continue + # assert data_dict["observations"][i] == new_data_dict["observations"][i] + # assert data_dict["next_observations"][i] == new_data_dict["next_observations"][i] + + super().insert(new_data_dict) self.counter += 1 + # print(f"original x,y of hand before transition: {rb_origin}") + # print(f"transforms: \n{rb_pos_print}") + # print(f"original x,y of hand after transition: {rb_next_origin}") + # print(f"transforms: \n{rb_next_pos_print}") + # print(f"original x,y of block before transition: {block_origin}") + # print(f"transforms: \n{block_pos_print}") + # print(f"original x,y of block after transition: {block_next_origin}") + # print(f"transforms: \n{block_next_pos_print}") + if data_dict["dones"]: self.counter = 1 self.current_depth = 1 \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 2e71bf01..8ec2ad41 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -10,8 +10,8 @@ FLAGS = flags.FLAGS flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "fractal", "placeholder") -flags.DEFINE_string("split_method", "time", "placeholder") +flags.DEFINE_string("branch_method", "test", "placeholder") +flags.DEFINE_string("split_method", "test", "placeholder") flags.DEFINE_float("workspace_width", 0.5, "workspace width in centimeters") flags.DEFINE_integer("depth", 4, "Total layers of depth") flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None @@ -36,6 +36,7 @@ def main(_): branch_method=FLAGS.branch_method, workspace_width=FLAGS.workspace_width, x_obs_idx=x_obs_idx, + y_obs_idx= y_obs_idx, depth=FLAGS.depth, dendrites=FLAGS.dendrites, timesplit_freq=FLAGS.timesplit_freq, @@ -147,6 +148,7 @@ def main(_): # df.to_excel('output.xlsx', index=False) data_dict["observations"][0] = 0 + data_dict["observations"][1] = 0 start = data_dict["observations"][0] - FLAGS.workspace_width / 2 From 05893d9124b54f10092ad486f3033c74c7529e27 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 18 Jul 2025 17:14:06 -0500 Subject: [PATCH 035/141] created equivalence tests for fsrb (1 by 1) and rb --- .../fsrb_equivalence_test.py | 54 +++++++++++++++++++ .../rb_equivalence_test.py | 49 +++++++++++++++++ examples/async_sac_state_sim/run_actor.sh | 2 +- examples/async_sac_state_sim/run_learner.sh | 4 +- .../data/fractal_symmetry_replay_buffer.py | 4 +- serl_launcher/serl_launcher/data/fsrb_test.py | 2 +- 6 files changed, 109 insertions(+), 6 deletions(-) create mode 100644 examples/async_sac_state_sim/fsrb_equivalence_test.py create mode 100644 examples/async_sac_state_sim/rb_equivalence_test.py diff --git a/examples/async_sac_state_sim/fsrb_equivalence_test.py b/examples/async_sac_state_sim/fsrb_equivalence_test.py new file mode 100644 index 00000000..a039228f --- /dev/null +++ b/examples/async_sac_state_sim/fsrb_equivalence_test.py @@ -0,0 +1,54 @@ +import gym.wrappers +import numpy as np +import gym +import jax +from serl_launcher.utils.launcher import ( + make_replay_buffer, + make_trainer_config +) +from agentlace.trainer import TrainerServer +from serl_launcher.data.replay_buffer import ReplayBuffer +from absl import app, flags +import franka_sim + +def main(_): + devices = jax.local_devices() + sharding = jax.sharding.PositionalSharding(devices) + env = gym.make("PandaReachCube-v0") + env = gym.wrappers.FlattenObservation(env) + + replay_buffer = make_replay_buffer( + env=env, + observation_space=env.observation_space, + action_space=env.action_space, + capacity=100000, + branch_method="test", + split_method="test", + workspace_width=0.5, + x_obs_idx=np.array([0, 7]), + y_obs_idx=np.array([1, 8]), + ) + + replay_iterator = replay_buffer.get_iterator( + sample_args={ + "batch_size": 256, + }, + device=sharding.replicate(), + ) + + def stats_callback(type: str, payload: dict) -> dict: + """Callback for when server receives stats request.""" + assert type == "send-stats", f"Invalid request type: {type}" + + return {} # not expecting a response + + server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server.register_data_store("actor_env", replay_buffer) + server.start(threaded=True) + + server.register_data_store("actor_env", replay_buffer) + + size = len(replay_buffer) + + batch = next(replay_iterator) + diff --git a/examples/async_sac_state_sim/rb_equivalence_test.py b/examples/async_sac_state_sim/rb_equivalence_test.py new file mode 100644 index 00000000..2dc425ff --- /dev/null +++ b/examples/async_sac_state_sim/rb_equivalence_test.py @@ -0,0 +1,49 @@ +import gym.wrappers +import numpy as np +import gym +import jax +from serl_launcher.utils.launcher import ( + make_replay_buffer, + make_trainer_config +) +from agentlace.trainer import TrainerServer +from serl_launcher.data.replay_buffer import ReplayBuffer +from absl import app, flags +import franka_sim + +def main(_): + devices = jax.local_devices() + sharding = jax.sharding.PositionalSharding(devices) + env = gym.make("PandaReachCube-v0") + env = gym.wrappers.FlattenObservation(env) + + replay_buffer = make_replay_buffer( + env=env, + observation_space=env.observation_space, + action_space=env.action_space, + capacity=100000, + ) + + replay_iterator = replay_buffer.get_iterator( + sample_args={ + "batch_size": 256, + }, + device=sharding.replicate(), + ) + + def stats_callback(type: str, payload: dict) -> dict: + """Callback for when server receives stats request.""" + assert type == "send-stats", f"Invalid request type: {type}" + + return {} # not expecting a response + + server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server.register_data_store("actor_env", replay_buffer) + server.start(threaded=True) + + server.register_data_store("actor_env", replay_buffer) + + size = len(replay_buffer) + + batch = next(replay_iterator) + diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 48c64354..610374c0 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -17,7 +17,7 @@ python async_sac_state_sim.py "$@" \ --actor \ --render \ --env $ENV_NAME \ - --exp_name=serl-reach \ + --exp_name=serl-reach-testing \ --seed 0 \ --random_steps 1000 \ --max_steps 100000 \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index ed3295d3..fc9b48c8 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -17,7 +17,7 @@ fi python async_sac_state_sim.py "$@" \ --learner \ --env $ENV_NAME \ - --exp_name=serl-reach \ + --exp_name=serl-reach-testing \ --seed 0 \ --max_steps 100000 \ --training_starts 1000 \ @@ -27,4 +27,4 @@ python async_sac_state_sim.py "$@" \ --save_model True \ --checkpoint_period 10000 \ --checkpoint_path "$CHECKPOINT_DIR" \ - #--debug # wandb is disabled when debug \ No newline at end of file + #--debug # wandb is disabled when debug diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 860ab1b0..9732377d 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -83,7 +83,7 @@ def __init__( self.workspace_width = workspace_width self.current_depth = 1 self.counter = 1 - self.branch_index = None + self.branch_index = [workspace_width/2] # TODO # Add flags for variables passed to @@ -118,7 +118,7 @@ def test_branch(self): # while(not self.current_branch_count % 2): # self.current_branch_count = np.random.randint(2500) # return self.current_branch_count - return 3 + return 1 def fractal_branch(self): # return a new number of branches = dendrites ^ depth diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 8ec2ad41..c6a00e68 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -25,7 +25,7 @@ def main(_): y_obs_idx = np.array([1, 8]) # Initialize replay buffer - env = gym.make("PandaPickCube-v0") + env = gym.make("PandaReachCube-v0") env = gym.wrappers.FlattenObservation(env) replay_buffer = make_replay_buffer( From dc31956ce997a2a93433b10c4e6b6737581a6991 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 18 Jul 2025 17:21:59 -0500 Subject: [PATCH 036/141] update test file --- examples/async_sac_state_sim/fsrb_equivalence_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/async_sac_state_sim/fsrb_equivalence_test.py b/examples/async_sac_state_sim/fsrb_equivalence_test.py index a039228f..f752860a 100644 --- a/examples/async_sac_state_sim/fsrb_equivalence_test.py +++ b/examples/async_sac_state_sim/fsrb_equivalence_test.py @@ -52,3 +52,5 @@ def stats_callback(type: str, payload: dict) -> dict: batch = next(replay_iterator) +if __name__ == "__main__": + app.run(main) \ No newline at end of file From 002c0e0b2dc333079c5b8bf77afd6d7f5b5c2b0c Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 18 Jul 2025 22:07:54 -0500 Subject: [PATCH 037/141] minor comments. --- examples/async_sac_state_sim/async_sac_state_sim.py | 5 +++-- serl_launcher/serl_launcher/data/replay_buffer.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index ef20e8d0..17161e16 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -93,7 +93,7 @@ def update_params(params): eval_env = gym.make(FLAGS.env) #if FLAGS.env == "PandaPickCube-v0": - eval_env = gym.wrappers.FlattenObservation(eval_env) + eval_env = gym.wrappers.FlattenObservation(eval_env) ## Note!! eval_env = RecordEpisodeStatistics(eval_env) obs, _ = env.reset() @@ -125,7 +125,8 @@ def update_params(params): reward = np.asarray(reward, dtype=np.float32) running_return += reward - + + # Insert dict into actor's two deques... learner will read from these data_store.insert( dict( observations=obs, diff --git a/serl_launcher/serl_launcher/data/replay_buffer.py b/serl_launcher/serl_launcher/data/replay_buffer.py index f7d798a2..6dd29c73 100644 --- a/serl_launcher/serl_launcher/data/replay_buffer.py +++ b/serl_launcher/serl_launcher/data/replay_buffer.py @@ -76,7 +76,8 @@ def insert(self, data_dict: DatasetDict): def get_iterator(self, queue_size: int = 2, sample_args: dict = {}, device=None): # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device - # queue_size = 2 should be ok for one GPU. + + # queue_size = 2 should be ok for one GPU. See more at https://chatgpt.com/share/687af063-d6b0-8004-92b6-0e88b9c5f1e8 queue = collections.deque() def enqueue(n): From be72ea7c8fce93cda9c13d0a582edbe7d13d3fb4 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Sat, 19 Jul 2025 13:37:33 -0500 Subject: [PATCH 038/141] Successful base functionality achieved for fsrb --- .../async_sac_state_sim.py | 4 +- .../fsrb_equivalence_test.py | 56 ------------------- .../rb_equivalence_test.py | 49 ---------------- examples/async_sac_state_sim/run_actor.sh | 2 +- .../data/fractal_symmetry_replay_buffer.py | 26 ++++++--- 5 files changed, 20 insertions(+), 117 deletions(-) delete mode 100644 examples/async_sac_state_sim/fsrb_equivalence_test.py delete mode 100644 examples/async_sac_state_sim/rb_equivalence_test.py diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index ef20e8d0..970af9a7 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -292,8 +292,8 @@ def main(_): branch_method="test", split_method="test", workspace_width=0.5, - x_obs_idx=np.array([0,7]), - y_obs_idx=np.array([1,8]), + x_obs_idx=np.array([0,4]), + y_obs_idx=np.array([1,5]), preload_rlds_path=FLAGS.preload_rlds_path, ) replay_iterator = replay_buffer.get_iterator( diff --git a/examples/async_sac_state_sim/fsrb_equivalence_test.py b/examples/async_sac_state_sim/fsrb_equivalence_test.py deleted file mode 100644 index f752860a..00000000 --- a/examples/async_sac_state_sim/fsrb_equivalence_test.py +++ /dev/null @@ -1,56 +0,0 @@ -import gym.wrappers -import numpy as np -import gym -import jax -from serl_launcher.utils.launcher import ( - make_replay_buffer, - make_trainer_config -) -from agentlace.trainer import TrainerServer -from serl_launcher.data.replay_buffer import ReplayBuffer -from absl import app, flags -import franka_sim - -def main(_): - devices = jax.local_devices() - sharding = jax.sharding.PositionalSharding(devices) - env = gym.make("PandaReachCube-v0") - env = gym.wrappers.FlattenObservation(env) - - replay_buffer = make_replay_buffer( - env=env, - observation_space=env.observation_space, - action_space=env.action_space, - capacity=100000, - branch_method="test", - split_method="test", - workspace_width=0.5, - x_obs_idx=np.array([0, 7]), - y_obs_idx=np.array([1, 8]), - ) - - replay_iterator = replay_buffer.get_iterator( - sample_args={ - "batch_size": 256, - }, - device=sharding.replicate(), - ) - - def stats_callback(type: str, payload: dict) -> dict: - """Callback for when server receives stats request.""" - assert type == "send-stats", f"Invalid request type: {type}" - - return {} # not expecting a response - - server = TrainerServer(make_trainer_config(), request_callback=stats_callback) - server.register_data_store("actor_env", replay_buffer) - server.start(threaded=True) - - server.register_data_store("actor_env", replay_buffer) - - size = len(replay_buffer) - - batch = next(replay_iterator) - -if __name__ == "__main__": - app.run(main) \ No newline at end of file diff --git a/examples/async_sac_state_sim/rb_equivalence_test.py b/examples/async_sac_state_sim/rb_equivalence_test.py deleted file mode 100644 index 2dc425ff..00000000 --- a/examples/async_sac_state_sim/rb_equivalence_test.py +++ /dev/null @@ -1,49 +0,0 @@ -import gym.wrappers -import numpy as np -import gym -import jax -from serl_launcher.utils.launcher import ( - make_replay_buffer, - make_trainer_config -) -from agentlace.trainer import TrainerServer -from serl_launcher.data.replay_buffer import ReplayBuffer -from absl import app, flags -import franka_sim - -def main(_): - devices = jax.local_devices() - sharding = jax.sharding.PositionalSharding(devices) - env = gym.make("PandaReachCube-v0") - env = gym.wrappers.FlattenObservation(env) - - replay_buffer = make_replay_buffer( - env=env, - observation_space=env.observation_space, - action_space=env.action_space, - capacity=100000, - ) - - replay_iterator = replay_buffer.get_iterator( - sample_args={ - "batch_size": 256, - }, - device=sharding.replicate(), - ) - - def stats_callback(type: str, payload: dict) -> dict: - """Callback for when server receives stats request.""" - assert type == "send-stats", f"Invalid request type: {type}" - - return {} # not expecting a response - - server = TrainerServer(make_trainer_config(), request_callback=stats_callback) - server.register_data_store("actor_env", replay_buffer) - server.start(threaded=True) - - server.register_data_store("actor_env", replay_buffer) - - size = len(replay_buffer) - - batch = next(replay_iterator) - diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 610374c0..3dab8d7d 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -15,7 +15,6 @@ if [ ! -d "$CHECKPOINT_DIR" ]; then fi python async_sac_state_sim.py "$@" \ --actor \ - --render \ --env $ENV_NAME \ --exp_name=serl-reach-testing \ --seed 0 \ @@ -28,4 +27,5 @@ python async_sac_state_sim.py "$@" \ --save_model True \ --checkpoint_period 10000 \ --checkpoint_path "$CHECKPOINT_DIR" \ + #--render \ #--debug # wandb is disabled when debug diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 9732377d..dd452c86 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -4,6 +4,8 @@ from serl_launcher.data.replay_buffer import ReplayBuffer import copy +import time + class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( self, @@ -118,7 +120,7 @@ def test_branch(self): # while(not self.current_branch_count % 2): # self.current_branch_count = np.random.randint(2500) # return self.current_branch_count - return 1 + return 9 def fractal_branch(self): # return a new number of branches = dendrites ^ depth @@ -170,8 +172,10 @@ def constant_split(self, data_dict: DatasetDict): # return True every transition return True - def insert(self, data_dict: DatasetDict): + def insert(self, data_dict_not: DatasetDict): + data_dict = copy.deepcopy(data_dict_not) + # start_time = time.time() # Update number of branches if needed if self.split(data_dict): temp = self.current_branch_count @@ -181,7 +185,6 @@ def insert(self, data_dict: DatasetDict): constant = self.workspace_width/(2 * self.current_branch_count) for i in range(0, self.current_branch_count): self.branch_index[i] = (2 * i + 1) * constant - # rb_origin = [data_dict["observations"][self.x_obs_idx[0]], data_dict["observations"][self.y_obs_idx[0]]] # block_origin = [data_dict["observations"][self.x_obs_idx[1]], data_dict["observations"][self.y_obs_idx[1]]] # rb_next_origin = [data_dict["next_observations"][self.x_obs_idx[0]], data_dict["next_observations"][self.y_obs_idx[0]]] @@ -224,13 +227,13 @@ def insert(self, data_dict: DatasetDict): # assert data_dict["dones"] == new_data_dict["dones"] # assert data_dict["masks"] == new_data_dict["masks"] # for i in range(0, data_dict["observations"].size): - # if i in self.x_obs_idx or i in self.y_obs_idx: - # continue - # assert data_dict["observations"][i] == new_data_dict["observations"][i] - # assert data_dict["next_observations"][i] == new_data_dict["next_observations"][i] + # if i in self.x_obs_idx or i in self.y_obs_idx: + # continue + # assert data_dict["observations"][i] == new_data_dict["observations"][i] + # assert data_dict["next_observations"][i] == new_data_dict["next_observations"][i] super().insert(new_data_dict) - + self.counter += 1 # print(f"original x,y of hand before transition: {rb_origin}") @@ -244,4 +247,9 @@ def insert(self, data_dict: DatasetDict): if data_dict["dones"]: self.counter = 1 - self.current_depth = 1 \ No newline at end of file + self.current_depth = 1 + # end_time = time.time() + # duration = end_time - start_time + # print(f"#{self._insert_index - 1} Duration: {duration:.4f}") + + \ No newline at end of file From 0ddcbd39fb84dcf1e9078083d048de15e87639aa Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 21 Jul 2025 17:55:40 -0500 Subject: [PATCH 039/141] Reach demo motion stable. - Modularized methods - Introduced PD-controller vs just P-controller. - Kp, Kv gains set at 16,20. --- demos/franka_reach_demo_script.py | 612 +++++++++++++++++------------- 1 file changed, 349 insertions(+), 263 deletions(-) diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py index e5abbc11..ee4ef401 100755 --- a/demos/franka_reach_demo_script.py +++ b/demos/franka_reach_demo_script.py @@ -16,10 +16,11 @@ """ import os import numpy as np -import gymnasium as gym -from gymnasium.wrappers import TimeLimit -from time import sleep -from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv +import gym +from time import sleep, perf_counter + +import franka_sim +import franka_sim.envs.panda_reach_gym_env as panda_reach_env # Global variables to store episode data across all iterations observations = [] # List storing observation sequences for each episode @@ -30,8 +31,13 @@ truncateds = [] # List storing truncated flags for each episode dones = [] # List storing done flags (terminated or truncated) for each episode +# Proportional and derivative control gain for action scaling -- empirically tuned +Kp = 16.0 +Kv = 20.0 + # Robot configuration robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' +task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' # Weld constraint flag weld_flag = True # Flag to activate weld constraint during pick-and-place @@ -74,113 +80,118 @@ def deactivate_weld(env, constraint_name="grasp_weld"): print(f"Warning: Constraint '{constraint_name}' not found") return False -def main(): +def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): """ - Orchestrates the data generation process by running multiple episodes - of the pick-and-place task. - - Creates environment, runs scripted episodes, and saves demonstration data - to compressed NPZ file for use with stable-baselines3. + Store transition data in the episode dictionary. """ - # Initialize Fetch pick-and-place environment - env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") - env = TimeLimit(env, max_episode_steps=50) - - # Adjust physical settings - # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. - # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. - - # Configuration parameters - initStateSpace = "random" # Initial state space configuration + episode_dict["observations"].append(new_obs) + episode_dict["rewards"].append(rewards) + episode_dict["actions"].append(action) + episode_dict["infos"].append(info) + episode_dict["terminateds"].append(terminated) + episode_dict["truncateds"].append(truncated) + episode_dict["dones"].append(done) + +def store_episode_data(episode_data): + """ + Store complete episode data in global lists only if we succeeded (avoid bad demos). + """ + actions.append(episode_data["actions"]) + observations.append(episode_data["observations"]) + infos.append(episode_data["infos"]) + rewards.append(episode_data["rewards"]) - # Demos configs - attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episode_data["terminateds"]) + truncateds.append(episode_data["truncateds"]) + dones.append(episode_data["dones"]) - num_demos = 0 # Counter for successful demonstration episodes - - # Reset environment to initial state - render for the first time. - obs, _ = env.reset() - print("Reset!") +def update_state_info(episode_data, time_step, dt, error): + """ + Update and return the current state information. Always get the latest entry with [-1] - # Generate demonstration episodes - while len(actions) < attempted_demos: - obs,_ = env.reset() # Reset environment for new episode - print(f"We will run a total of: {attempted_demos} demos!!") - print("Demo: #", len(actions)+1) - - # Execute pick-and-place task - res = pick_and_place_demo(env, obs) - - # Print success message - if res: - num_demos += 1 - print("Episode completed successfully!") - print(f"Total successful demos: {num_demos}/{attempted_demos}") + Args: + new_obs (dict): New observation dictionary containing the current state. + time_step (int): Current time step in the episode. + dt (float): Current time step in the episode. + error (np.ndarray): Current error vector between object and end-effector positions. + + Returns: + object_pos (np.ndarray): Current position of the object in the environment. + gripper_pos (float): Current position of the gripper. + current_pos (np.ndarray): Current position of the end-effector. + current_vel (np.ndarray): Current velocity of the end-effector. + """ + object_pos = episode_data["observations"][-1][0:3] # Block position + gripper_pos = episode_data["observations"][-1][3] # Gripper position + current_pos = episode_data["observations"][-1][4:7] # Panda/tcp position + current_vel = episode_data["observations"][-1][7:10] # Panda/t + + + # Print debug information + print( + f"Time Step: {time_step}, Error: {np.linalg.norm(error):.4f}, " + f"bot_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(gripper_pos, precision=2)}, " + f"err: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(episode_data['actions'][-1], precision=3)}, " + f"dt: {dt: .4f}" + ) - ## Write data to demos folder - # 1. Get the absolute path of this script - #script_path = os.path.abspath(__file__) - - # 2. Extract its directory - #script_dir = os.path.dirname(script_path) - script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. - - # 3. Create output filename with configuration details - fileName = "data_" + robot - fileName += "_" + initStateSpace - fileName += "_" + str(attempted_demos) - fileName += ".npz" - - # 3. Build a filename in that same directory - out_path = os.path.join(script_dir, fileName) - # Save collected data to compressed numpy NPZ file - # Set acs,obs,info as keys in dict - np.savez_compressed(out_path, - acs = actions, - obs = observations, - rewards = rewards, - info = infos, - terminateds = terminateds, - truncateds = truncateds, - dones = dones) + return object_pos, gripper_pos, current_pos, current_vel - print(f"Data saved to {fileName}.") +def compute_error(object_pos, current_pos, prev_error, dt): + """Compute the error and its derivative between the object position and the current end-effector position. + Args: + object_pos (np.ndarray): The position of the object in the environment. + current_pos (np.ndarray): The current position of the end-effector. + prev_error (np.ndarray): The previous error value for derivative calculation. + dt (float): Time step for derivative calculation. + + Returns: + error (np.ndarray): The current error vector between the object and end-effector positions. + derror (np.ndarray): The derivative of the error vector. + """ + error = object_pos - current_pos # Calculate the error vector + derror = (error - prev_error) / dt # Calculate the derivative of the error vector -def pick_and_place_demo(env, lastObs): + prev_error = error.copy() # Update previous error for next iteration + return error, derror + +def demo(env, lastObs): """ - Executes a scripted pick-and-place sequence using a hierarchical approach. + Executes a scripted reach sequence using a hierarchical approach. - Implements 4-phase control strategy: + Implements 1-phase control strategy: 1. Approach: Move gripper above object (3cm offset) - 2. Grasp: Move to object and close gripper - 3. Transport: Move grasped object to goal position - 4. Maintain: Hold position until episode ends Store observations, actions, and info in global lists for later replay buffer inclusion. + + Gripper: + - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. Args: env: Gymnasium environment instance - lastObs: Last observation containing goal and object state information - - desired_goal: Target position for object placement + lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel - observations: - ee_position[0:3], - ee_velocity[3:6], - fingers_width[6], - object_position[7:10], - object_rotation[10:13], - object_velp[13:16], - object_velr[16:19], + object_pos[0:3], + gripper_pos[3] + panda/tcp_pos[4:7], + panda/tcp_vel[7:10] + + Returns: """ ## Init goal, current_pos, and object position from last observation - goal = np.zeros(3, dtype=np.float32) - current_pos = np.zeros(3, dtype=np.float32) object_pos = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + gripper_pos = np.zeros(1, dtype=np.float32) object_rel_pos = np.zeros(3, dtype=np.float32) - fgr_pos = np.zeros(1, dtype=np.float32) - # Initialize episode data collection + + # Initialize (single) episode data collection episodeObs = [] # Observations for this episode episodeAcs = [] # Actions for this episode episodeRews = [] # Rewards for this episode @@ -188,220 +199,295 @@ def pick_and_place_demo(env, lastObs): episodeTerminated = [] # Terminated flags for this episode episodeTruncated = [] # Truncated flags for this episode episodeDones = [] # Done flags (terminated or truncated) for this episode - - # Proportional control gain for action scaling -- empirically tuned - Kp = 8.0 - - # pre_pick_offset - pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) + # Dictionary to store episode data + episode_data = { + "observations": episodeObs, + "actions": episodeAcs, + "rewards": episodeRews, + "infos": episodeInfo, + "terminateds": episodeTerminated, + "truncateds": episodeTruncated, + "dones": episodeDones + } + + # close gripper + fgr_pos = 0 + # Error thresholds - error_threshold = 0.011 # Threshold for stopping condition (Xmm) + error_threshold = 0.04 # Threshold for stopping condition (Xmm) finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger - ## Extract data - # Extract desired position from desired_goal dict - goal = lastObs["desired_goal"][0:3] - - # Current robot end-effector position from observation dict - current_pos = lastObs["observation"][0:3] - - # Current object position from observation dict: - object_pos = lastObs["observation"][7:10] - + ## Extract data + object_pos = lastObs[0:3] # block pos + current_pos = lastObs[4:7] # panda/tcp_pos + # Relative position between end-effector and object - object_rel_pos = object_pos - current_pos - - ## Phase 1: Approach Object (Above) - # Create target position 3cm above the object. Use copy() method. - error = object_rel_pos.copy() - error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. - - timeStep = 0 # Track total timesteps in episode - episodeObs.append(lastObs) + dt = env.unwrapped.model.opt.timestep # Mujoco time step + prev_error = np.zeros_like(object_pos) + error, derror = compute_error(object_pos, current_pos, prev_error, dt) + + time_step = 0 # Track total time_steps in episode + episodeObs.append(lastObs) # Store initial observation + + # Initialize previous time for dt calculation + prev_time = perf_counter() # Start time for dt calculation - # Phase 1: Move gripper to position above object - # Terminate when distance to above-object position < 5mm - print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") - while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: + # Phase 1: Reach + # Terminate when distance to above-object position < error_threshold + print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: env.render() # Visual feedback - # Initialize action vector [x, y, z, gripper] - action = np.array([0., 0., 0., 0.]) + # Record current time and compute dt + curr_time = perf_counter() + dt = curr_time - prev_time + prev_time = curr_time + + # Initialize action vector [x, y, z] + action = np.array([0., 0., 0.]) # Proportional control with gain of 6 - # action = Kp * error - action[:3] = error * Kp + action[:3] = error * Kp + derror * Kv + prev_error = error.copy() # Update previous error for next iteration - # Open gripper for approach - action[ len(action)-1 ] = 0.05 + # Keep gripper closed -- no need. only 3 dimensions of control + #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. # Unpack new Gymnasium step API new_obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated - timeStep += 1 # Store episode data - episodeAcs.append(action) - episodeInfo.append(info) - episodeRews.append(reward) - episodeObs.append(new_obs) - episodeTerminated.append(terminated) - episodeTruncated.append(truncated) - episodeDones.append(done) - - # Update state information - fgr_pos = new_obs["observation"][6] - current_pos = new_obs["observation"][0:3] - object_pos = new_obs['observation'][7:10] - error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position - - # Print debug information - print( - f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " - f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - f"obj_pos: {np.array2string(object_pos, precision=3)}, " - f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " - f"Error: {np.array2string(error, precision=3)}, " - f"Action: {np.array2string(action, precision=3)}" - ) - - # Phase 2: Descend Grasp Object - # Move gripper directly to object and close gripper - # Terminate when relative distance to object < 5mm - print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") - error = object_pos - current_pos # remove offset - while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm - env.render() + store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) + + # Update and print state information + object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error) + + # Update error for next iteration + error, derror = compute_error(object_pos, cur_pos, prev_error, dt) + + # Update time step + time_step += 1 + + # Sleep + sleep(0.5) # Optional: Slow down for better visualization + + + # # Phase 2: Descend Grasp Object + # # Move gripper directly to object and close gripper + # # Terminate when relative distance to object < 5mm + # print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") + # error = object_pos - current_pos # remove offset + # while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and time_step <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm + # env.render() - # Initialize action vector [x, y, z, gripper] - action = np.array([0., 0., 0., 0.]) + # # Initialize action vector [x, y, z, gripper] + # action = np.array([0., 0., 0., 0.]) - # Direct proportional control to object position - action[:3] = error * Kp + # # Direct proportional control to object position + # action[:3] = error * Kp - # Close gripper to grasp object - action[len(action)-1] = -finger_delta_fast * 2 + # # Close gripper to grasp object + # action[len(action)-1] = -finger_delta_fast * 2 - # Execute action and collect data - new_obs, reward, terminated, truncated, info = env.step(action) - done = terminated or truncated - timeStep += 1 + # # Execute action and collect data + # new_obs, reward, terminated, truncated, info = env.step(action) + # done = terminated or truncated + # time_step += 1 - # Store episode data - episodeObs.append(new_obs) - episodeRews.append(reward) - episodeAcs.append(action) - episodeInfo.append(info) - episodeTerminated.append(terminated) - episodeTruncated.append(truncated) - episodeDones.append(done) - - # Update state information - fgr_pos = new_obs["observation"][6] - current_pos = new_obs["observation"][0:3] - object_pos = new_obs['observation'][7:10] - error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower - - # Print debug information - print( - f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " - f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - f"obj_pos: {np.array2string(object_pos, precision=3)}, " - f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " - f"Error: {np.array2string(error, precision=3)}, " - f"Action: {np.array2string(action, precision=3)}" - ) - #sleep(0.5) # Optional: Slow down for better visualization + # # Store episode data + # episodeObs.append(new_obs) + # episodeRews.append(reward) + # episodeAcs.append(action) + # episodeInfo.append(info) + # episodeTerminated.append(terminated) + # episodeTruncated.append(truncated) + # episodeDones.append(done) + + # # Update state information + # object_pos = new_obs['observation'][0:3] + # fgr_pos = new_obs["observation"][3] + # current_pos = new_obs["observation"][4:7] + # current_vel = new_obs["observation"][7:10] + + # error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower + + # # Print debug information + # print( + # f"Time Step: {time_step}, Error: {np.linalg.norm(error):.4f}, " + # f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + # f"obj_pos: {np.array2string(object_pos, precision=3)}, " + # f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " + # f"Error: {np.array2string(error, precision=3)}, " + # f"Action: {np.array2string(action, precision=3)}" + # ) + # #sleep(0.5) # Optional: Slow down for better visualization - # Phase 3: Transport to Goal - # Move grasped object to desired goal position - # Terminate when distance between object and goal < 1cm - print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") + # # Phase 3: Transport to Goal + # # Move grasped object to desired goal position + # # Terminate when distance between object and goal < 1cm + # print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") - # Weld activation - if weld_flag: - activate_weld(env, constraint_name="grasp_weld") + # # Weld activation + # if weld_flag: + # activate_weld(env, constraint_name="grasp_weld") + + # # Set error between goal and hand assuming the object is grasped - # Set error between goal and hand assuming the object is grasped - gh_error = goal - current_pos # Error between goal and hand position - ho_error = object_pos - current_pos # Error between object and hand position - while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: - env.render() + # ho_error = object_pos - current_pos # Error between object and hand position + # while np.linalg.norm(gh_error) >= 0.01 and time_step <= env._max_episode_steps: + # env.render() - action = np.array([0., 0., 0., 0.]) + # action = np.array([0., 0., 0., 0.]) - # Proportional control toward goal position - action[:3] = gh_error[:3] * Kp + # # Proportional control toward goal position + # action[:3] = gh_error[:3] * Kp - # Maintain grip on object - #action[len(action)-1] = 0 + # # Maintain grip on object + # #action[len(action)-1] = 0 - # Execute action and collect data - new_obs, reward, terminated, truncated, info = env.step(action) - done = terminated or truncated - timeStep += 1 + # # Execute action and collect data + # new_obs, reward, terminated, truncated, info = env.step(action) + # done = terminated or truncated + # time_step += 1 - # Store episode data - episodeObs.append(new_obs) - episodeRews.append(reward) - episodeAcs.append(action) - episodeInfo.append(info) - episodeTerminated.append(terminated) - episodeTruncated.append(truncated) - episodeDones.append(done) + # # Store episode data + # store_transition_data(new_obs, rewards, action, info, terminated, truncated, done) + - # Update state information - fgr_pos = new_obs["observation"][6] - current_pos = new_obs["observation"][0:3] - object_pos = new_obs['observation'][7:10] - gh_error = goal - current_pos # Error between goal and hand position - ho_error = object_pos - current_pos # Error between object and hand position - - # Print debug information - print( - f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " - f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - f"goal_pos: {np.array2string(goal, precision=3)}, " - f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " - f"Error: {np.array2string(gh_error, precision=3)}, " - f"Action: {np.array2string(action, precision=3)}" - ) + # # Update state information + # fgr_pos = new_obs["observation"][6] + # current_pos = new_obs["observation"][0:3] + # object_pos = new_obs['observation'][7:10] + # gh_error = goal - current_pos # Error between goal and hand position + # ho_error = object_pos - current_pos # Error between object and hand position + + # # Print debug information + # print( + # f"Time Step: {time_step}, Error Norm: {np.linalg.norm(gh_error):.4f}, " + # f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + # f"goal_pos: {np.array2string(goal, precision=3)}, " + # f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + # f"Error: {np.array2string(gh_error, precision=3)}, " + # f"Action: {np.array2string(action, precision=3)}" + # ) - sleep(0.5) # Optional: Slow down for better visualization + # sleep(0.5) # Optional: Slow down for better visualization - ## Check for success and store episode data - gh_norm = np.linalg.norm(gh_error) - ho_nomr = np.linalg.norm(ho_error) - if gh_norm < error_threshold and ho_nomr < error_threshold: - - # Store complete episode data in global lists only if we succeeded (avoid bad demos) - actions.append(episodeAcs) - observations.append(episodeObs) - infos.append(episodeInfo) - rewards.append(episodeRews) - - # Optionally, also store the done/terminated/truncated flags globally if needed: - terminateds.append(episodeTerminated) - truncateds.append(episodeTruncated) - dones.append(episodeDones) - - # Deactivate weld constraint after successful pick - if weld_flag: - deactivate_weld(env, constraint_name="grasp_weld") - - # Close mujoco viewer - env.close() - - # Break out of the loop to start a new episode - return True + # ## Check for success and store episode data + # gh_norm = np.linalg.norm(gh_error) + # ho_nomr = np.linalg.norm(ho_error) + # if gh_norm < error_threshold and ho_nomr < error_threshold: - # If we reach here, the episode was not successful + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + store_episode_data(episode_data) + + # Deactivate weld constraint after successful pick if weld_flag: - print("Failed to transport object to goal position. Deactivating weld.") deactivate_weld(env, constraint_name="grasp_weld") + + # Close mujoco viewer + env.close() + + # Break out of the loop to start a new episode + return True + + # # If we reach here, the episode was not successful + # if weld_flag: + # print("Failed to transport object to goal position. Deactivating weld.") + # deactivate_weld(env, constraint_name="grasp_weld") + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ. + + Arguments that can be configured with flags: + - env + - render + - num_demos + + """ + # Initialize the Panda environment. + env = gym.make("PandaReachCube-v0", render_mode="human") + env = gym.wrappers.FlattenObservation(env) + + # Adjust physical settings + # env.model.opt.time_step = 0.001 # Smaller time_step for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + num_demos = 0 # Counter for successful demonstration episodes + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() # For reach environment expect 10 observations: r_pos, r_vel, finger, object_pos. + + # Adjust camera view for better visualization + viewer = env.unwrapped._viewer.viewer + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 3.0 # Camera distance + viewer.cam.azimuth = 180 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + print("Reset!") + + # Generate demonstration episodes + while len(actions) < attempted_demos: + obs,_ = env.reset() # Reset environment for new episode + + print(f"We will run a total of: {attempted_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = demo(env, obs) + + # Print success message + if res: + num_demos += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + + ## Write data to demos folder + script_dir = '/data/serl/demos' # Assumes mounted /data folder. + + # Create output filename with configuration details + fileName = "data_" + robot + "_" + task + fileName += "_" + initStateSpace + fileName += "_" + str(attempted_demos) + fileName += ".npz" + + # Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Ensure the directory exists + os.makedirs(script_dir, exist_ok=True) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones) + + print(f"Data saved to {fileName}.") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + if __name__ == "__main__": main() \ No newline at end of file From d39f9bf0de3d925e2830c5c7f998e07ec245f5eb Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 21 Jul 2025 18:10:17 -0500 Subject: [PATCH 040/141] Re-adjusted gains slightly. Removed weld operations they are not present currently in franka_sim. --- demos/franka_reach_demo_script.py | 191 +++++++----------------------- 1 file changed, 42 insertions(+), 149 deletions(-) diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py index ee4ef401..14767eec 100755 --- a/demos/franka_reach_demo_script.py +++ b/demos/franka_reach_demo_script.py @@ -32,8 +32,8 @@ dones = [] # List storing done flags (terminated or truncated) for each episode # Proportional and derivative control gain for action scaling -- empirically tuned -Kp = 16.0 -Kv = 20.0 +Kp = 20.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 +Kv = 24.0 # Robot configuration robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' @@ -43,42 +43,43 @@ weld_flag = True # Flag to activate weld constraint during pick-and-place -def activate_weld(env, constraint_name="grasp_weld"): - """ - Activate a weld constraint during pick portion of a demo - :param env: The environment containing the model - :param constraint_name: The name of the weld constraint to activate - :return: True if the weld was successfully activated, False if the constraint was not found - """ - - try: - # Activate the weld constraint - env.unwrapped.model.eq(constraint_name).active = 1 - print("Activated weld") - return True +# Franka sim environments do not have weld constraints like the franka_mujoco environments. +# def activate_weld(env, constraint_name="grasp_weld"): +# """ +# Activate a weld constraint during pick portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to activate +# :return: True if the weld was successfully activated, False if the constraint was not found +# """ + +# try: +# # Activate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 1 +# print("Activated weld") +# return True - except KeyError: - print(f"Warning: Constraint '{constraint_name}' not found") - return False - -def deactivate_weld(env, constraint_name="grasp_weld"): - """ - Deactivate a weld constraint during place portion of a demo - :param env: The environment containing the model - :param constraint_name: The name of the weld constraint to deactivate - :return: True if the weld was successfully deactivated, False if the constraint was not - found - """ +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +# def deactivate_weld(env, constraint_name="grasp_weld"): +# """ +# Deactivate a weld constraint during place portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to deactivate +# :return: True if the weld was successfully deactivated, False if the constraint was not +# found +# """ - try: - # Deactivate the weld constraint - env.unwrapped.model.eq(constraint_name).active = 0 - print("Deactivated weld") - return True +# try: +# # Deactivate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 0 +# print("Deactivated weld") +# return True - except KeyError: - print(f"Warning: Constraint '{constraint_name}' not found") - return False +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): """ @@ -273,122 +274,14 @@ def demo(env, lastObs): time_step += 1 # Sleep - sleep(0.5) # Optional: Slow down for better visualization - - - # # Phase 2: Descend Grasp Object - # # Move gripper directly to object and close gripper - # # Terminate when relative distance to object < 5mm - # print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") - # error = object_pos - current_pos # remove offset - # while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and time_step <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm - # env.render() - - # # Initialize action vector [x, y, z, gripper] - # action = np.array([0., 0., 0., 0.]) - - # # Direct proportional control to object position - # action[:3] = error * Kp - - # # Close gripper to grasp object - # action[len(action)-1] = -finger_delta_fast * 2 - - # # Execute action and collect data - # new_obs, reward, terminated, truncated, info = env.step(action) - # done = terminated or truncated - # time_step += 1 - - # # Store episode data - # episodeObs.append(new_obs) - # episodeRews.append(reward) - # episodeAcs.append(action) - # episodeInfo.append(info) - # episodeTerminated.append(terminated) - # episodeTruncated.append(truncated) - # episodeDones.append(done) - - # # Update state information - # object_pos = new_obs['observation'][0:3] - # fgr_pos = new_obs["observation"][3] - # current_pos = new_obs["observation"][4:7] - # current_vel = new_obs["observation"][7:10] - - # error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower - - # # Print debug information - # print( - # f"Time Step: {time_step}, Error: {np.linalg.norm(error):.4f}, " - # f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - # f"obj_pos: {np.array2string(object_pos, precision=3)}, " - # f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " - # f"Error: {np.array2string(error, precision=3)}, " - # f"Action: {np.array2string(action, precision=3)}" - # ) - # #sleep(0.5) # Optional: Slow down for better visualization - - # # Phase 3: Transport to Goal - # # Move grasped object to desired goal position - # # Terminate when distance between object and goal < 1cm - # print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") - - # # Weld activation - # if weld_flag: - # activate_weld(env, constraint_name="grasp_weld") - - # # Set error between goal and hand assuming the object is grasped - - # ho_error = object_pos - current_pos # Error between object and hand position - # while np.linalg.norm(gh_error) >= 0.01 and time_step <= env._max_episode_steps: - # env.render() - - # action = np.array([0., 0., 0., 0.]) - - # # Proportional control toward goal position - # action[:3] = gh_error[:3] * Kp - - # # Maintain grip on object - # #action[len(action)-1] = 0 - - # # Execute action and collect data - # new_obs, reward, terminated, truncated, info = env.step(action) - # done = terminated or truncated - # time_step += 1 - - # # Store episode data - # store_transition_data(new_obs, rewards, action, info, terminated, truncated, done) - - - # # Update state information - # fgr_pos = new_obs["observation"][6] - # current_pos = new_obs["observation"][0:3] - # object_pos = new_obs['observation'][7:10] - # gh_error = goal - current_pos # Error between goal and hand position - # ho_error = object_pos - current_pos # Error between object and hand position - - # # Print debug information - # print( - # f"Time Step: {time_step}, Error Norm: {np.linalg.norm(gh_error):.4f}, " - # f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - # f"goal_pos: {np.array2string(goal, precision=3)}, " - # f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " - # f"Error: {np.array2string(gh_error, precision=3)}, " - # f"Action: {np.array2string(action, precision=3)}" - # ) - - # sleep(0.5) # Optional: Slow down for better visualization - - # ## Check for success and store episode data - # gh_norm = np.linalg.norm(gh_error) - # ho_nomr = np.linalg.norm(ho_error) - # if gh_norm < error_threshold and ho_nomr < error_threshold: - + sleep(0.25) # Optional: Slow down for better visualization # Store complete episode data in global lists only if we succeeded (avoid bad demos) store_episode_data(episode_data) - # Deactivate weld constraint after successful pick - if weld_flag: - deactivate_weld(env, constraint_name="grasp_weld") + # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. + # if weld_flag: + # deactivate_weld(env, constraint_name="grasp_weld") # Close mujoco viewer env.close() @@ -460,8 +353,8 @@ def main(): print("Episode completed successfully!") print(f"Total successful demos: {num_demos}/{attempted_demos}") - ## Write data to demos folder - script_dir = '/data/serl/demos' # Assumes mounted /data folder. + ## Write data to demos folder. Assumes mounted /data folder and internal data folder. + script_dir = '/data/data/serl/demos' # Create output filename with configuration details fileName = "data_" + robot + "_" + task From 0552bf449df40d516e672f642a1e0783506166a0 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 22 Jul 2025 11:45:36 -0500 Subject: [PATCH 041/141] New testing scripts for fractal symmetry benchmarks. - testing.sh and tmux_script_launch.sh should be used together - should be made executable and called ./ - still incomplete and with a bug showing up in actor/learner py scripts. but close --- examples/async_sac_state_sim/testing.sh | 42 +++++++++++++++++++ .../async_sac_state_sim/tmux_script_launch.sh | 19 +++++++++ 2 files changed, 61 insertions(+) create mode 100755 examples/async_sac_state_sim/testing.sh create mode 100755 examples/async_sac_state_sim/tmux_script_launch.sh diff --git a/examples/async_sac_state_sim/testing.sh b/examples/async_sac_state_sim/testing.sh new file mode 100755 index 00000000..fb2072c9 --- /dev/null +++ b/examples/async_sac_state_sim/testing.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Tmux options +tmux set-option -g history-limit 100000 + +# Exit on error: prevents silent failures and makes bugs easier to catch: +# -e: exit immediately on command failure +# -u: error on use of undefined variables +# -o pipefail: error if any command in a pipeline fails +set -euo pipefail + +# Safe IFS for word splitting. Protects against word-splitting bugs due to unexpected whitespace in filenames or variables. +IFS=$'\n\t' + +# === CONFIGURATION === +env="PandaReachCube-v0" +session="test" +window="0" + +# === CLEANUP ON EXIT === +# Kill only background jobs started by the script +trap 'jobs -p | xargs -r kill' EXIT INT TERM + +# === START TMUX SESSION === +if ! tmux has-session -t "$session" 2>/dev/null; then # suppress error output if session does not exist + tmux new-session -d -s "$session" -n main # start new detached tmux session named $session with init window named main +else + echo "Session $session already exists. Reusing..." +fi + +# === HELPER FUNCTION === +# Centralizes the logic for sending commands into tmux. Improves readability and makes it easy to add logging or future logic. +send_tmux_command() { + local cmd="$1" # Take 1st arg in func & store in var cmd. + tmux send-keys -t "$session:$window" "$cmd" C-m # send command to specific tmux window +} + +# === BASELINE RUNS === +# tmux_script_launch.sh must have ./ in front to tell tmux to execute the script in the current directory. +for seed in {1}; do # Loop + send_tmux_command "./tmux_script_launch.sh --env $env --name ${env}-baseline --type replay_buffer --seed $seed" +done diff --git a/examples/async_sac_state_sim/tmux_script_launch.sh b/examples/async_sac_state_sim/tmux_script_launch.sh new file mode 100755 index 00000000..eadc9921 --- /dev/null +++ b/examples/async_sac_state_sim/tmux_script_launch.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +EXAMPLE_DIR=${EXAMPLE_DIR:-"async_sac_state_sim"} +CONDA_ENV=${CONDA_ENV:-"serl"} + +echo "Running from $(pwd)" + +# This script assumes it is being run inside an existing tmux session from testing.sh +# Split the current window vertically +tmux split-window -v + +# Run actor in the top pane +tmux send-keys -t "$(tmux display-message -p '#S')":0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m + +# Run learner in the bottom pane +tmux send-keys -t "$(tmux display-message -p '#S')":0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m + +# Optionally, focus the upper pane again +tmux select-pane -t 0 From 0f146a6fe4d5ada94e62edcd403b92c7d60b144c Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 22 Jul 2025 11:48:24 -0500 Subject: [PATCH 042/141] simple comment --- franka_sim/franka_sim/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index dab351ab..28b97c12 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -5,6 +5,7 @@ "GymRenderingSpec", ] +# Register environments from franka_sim.envs where specific classes are found in the PandaXXX.py scripts from gym.envs.registration import register register( @@ -20,7 +21,7 @@ ) register( id="PandaReachCube-v0", - entry_point="franka_sim.envs:PandaReachCubeGymEnv", + entry_point="franka_sim.envs:PandaReachCubeGymEnv", max_episode_steps=100, ) # register( From f0e40bab47cc5c06e9fe97792866b56f38e87a88 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 22 Jul 2025 16:30:49 -0500 Subject: [PATCH 043/141] Removed undesired features, fixed bugs, and created --- .../async_sac_state_sim.py | 21 +++- examples/async_sac_state_sim/run_actor.sh | 115 +++++++++++++++-- examples/async_sac_state_sim/run_learner.sh | 116 ++++++++++++++++-- examples/async_sac_state_sim/run_tests.sh | 30 +++++ examples/async_sac_state_sim/tmux_launch.sh | 100 ++++++++++++++- .../franka_sim/envs/panda_reach_gym_env.py | 5 +- .../data/fractal_symmetry_replay_buffer.py | 51 +++----- 7 files changed, 373 insertions(+), 65 deletions(-) create mode 100644 examples/async_sac_state_sim/run_tests.sh diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index e6bcb165..45311ab8 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -57,6 +57,18 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which type of replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", "test", "Method for when to change number of branches generated") # Remember to default None +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in centimeters") # Remember to default None +flags.DEFINE_integer("depth", None, "Total layers of depth") +flags.DEFINE_integer("dendrites", None, "Dendrites for fractal branching") +flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") +flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") + + flags.DEFINE_boolean( "debug", False, "Debug mode." ) # debug mode will disable wandb logging @@ -289,10 +301,11 @@ def main(_): env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="fractal_symmetry_replay_buffer", - branch_method="test", - split_method="test", - workspace_width=0.5, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, x_obs_idx=np.array([0,4]), y_obs_idx=np.array([1,5]), preload_rlds_path=FLAGS.preload_rlds_path, diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 3dab8d7d..e97cca8c 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,9 +1,102 @@ +#!/bin/bash + +# Default variable values +seed=0 +env="PandaReachCube-v0" +branch="test" +type="fractal_symmetry_replay_buffer" +exp_name="untitled_experiment" +starting_branch_count=1 + +has_argument() { + [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; +} + +extract_argument() { + echo "${2:-${1#*=}}" +} + +# Function to handle options and arguments +handle_options() { + while [ $# -gt 0 ]; do + case $1 in + -e | --env*) + if ! has_argument $@; then + echo "Environment not specified" >&2 + exit 1 + fi + + env=$(extract_argument $@) + + shift + ;; + -s | --seed*) + if ! has_argument $@; then + echo "Seed needs a positive integer value" >&2 + exit 1 + fi + + seed=$(extract_argument $@) + + shift + ;; + -t | --type*) + if ! has_argument $@; then + echo "Type needs a string" >&2 + exit 1 + fi + + type=$(extract_argument $@) + + shift + ;; + -n | --name*) + if ! has_argument $@; then + echo "Name needs a string" >&2 + exit 1 + fi + + exp_name=$(extract_argument $@) + + shift + ;; + --starting_branch*) + if ! has_argument $@; then + echo "Starting_branch needs a positive odd integer" >&2 + exit 1 + fi + + starting_branch_count=$(extract_argument $@) + + shift + ;; + -b | --branch*) + if ! has_argument $@; then + echo "Branch needs a string" >&2 + exit 1 + fi + + branch=$(extract_argument $@) + + shift + ;; + *) + echo "Invalid option: $1" >&2 + exit 1 + ;; + esac + shift + done +} + +# Main script execution +handle_options "$@" + export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ -export ENV_NAME="PandaReachCube-v0" && \ export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ -export CHECKPOINT_DIR="$SCRIPT_DIR/$ENV_NAME_checkpoints/checkpoints-$TIMESTAMP" && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ # Create checkpoint directory if it doesn't exist if [ ! -d "$CHECKPOINT_DIR" ]; then @@ -13,19 +106,23 @@ if [ ! -d "$CHECKPOINT_DIR" ]; then exit 1 } fi -python async_sac_state_sim.py "$@" \ + +python async_sac_state_sim.py \ --actor \ - --env $ENV_NAME \ - --exp_name=serl-reach-testing \ - --seed 0 \ + --env $env \ + --exp_name=$exp_name \ + --seed $seed \ + --replay_buffer_type $type \ + --branch_method $branch \ + --starting_branch_count $starting_branch_count \ --random_steps 1000 \ --max_steps 100000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --replay_buffer_capacity 100000 \ + --replay_buffer_capacity 1000000 \ --save_model True \ - --checkpoint_period 10000 \ - --checkpoint_path "$CHECKPOINT_DIR" \ + # --checkpoint_period 10000 \ + # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ #--debug # wandb is disabled when debug diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index fc9b48c8..09809593 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,9 +1,102 @@ +#!/bin/bash + +# Default variable values +seed=0 +env="PandaReachCube-v0" +branch="test" +type="fractal_symmetry_replay_buffer" +exp_name="untitled_experiment" +starting_branch_count=1 + +has_argument() { + [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; +} + +extract_argument() { + echo "${2:-${1#*=}}" +} + +# Function to handle options and arguments +handle_options() { + while [ $# -gt 0 ]; do + case $1 in + -e | --env*) + if ! has_argument $@; then + echo "Environment not specified" >&2 + exit 1 + fi + + env=$(extract_argument $@) + + shift + ;; + -s | --seed*) + if ! has_argument $@; then + echo "Seed needs a positive integer value" >&2 + exit 1 + fi + + seed=$(extract_argument $@) + + shift + ;; + -t | --type*) + if ! has_argument $@; then + echo "Type needs a string" >&2 + exit 1 + fi + + type=$(extract_argument $@) + + shift + ;; + -n | --name*) + if ! has_argument $@; then + echo "Name needs a string" >&2 + exit 1 + fi + + exp_name=$(extract_argument $@) + + shift + ;; + --starting_branch*) + if ! has_argument $@; then + echo "Starting_branch needs a positive odd integer" >&2 + exit 1 + fi + + starting_branch_count=$(extract_argument $@) + + shift + ;; + -b | --branch*) + if ! has_argument $@; then + echo "Branch needs a string" >&2 + exit 1 + fi + + branch=$(extract_argument $@) + + shift + ;; + *) + echo "Invalid option: $1" >&2 + exit 1 + ;; + esac + shift + done +} + +# Main script execution +handle_options "$@" + export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ -export ENV_NAME="PandaReachCube-v0" && \ export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ -export CHECKPOINT_DIR="$SCRIPT_DIR/$ENV_NAME_checkpoints/checkpoints-$TIMESTAMP" && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ # Create checkpoint directory if it doesn't exist if [ ! -d "$CHECKPOINT_DIR" ]; then @@ -14,17 +107,20 @@ if [ ! -d "$CHECKPOINT_DIR" ]; then } fi -python async_sac_state_sim.py "$@" \ +python async_sac_state_sim.py \ --learner \ - --env $ENV_NAME \ - --exp_name=serl-reach-testing \ - --seed 0 \ - --max_steps 100000 \ + --env $env \ + --exp_name=$exp_name \ + --seed $seed \ + --replay_buffer_type $type \ + --branch_method $branch \ + --starting_branch_count $starting_branch_count \ + --max_steps 1000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --replay_buffer_capacity 100000 \ + --replay_buffer_capacity 1000000 \ --save_model True \ - --checkpoint_period 10000 \ - --checkpoint_path "$CHECKPOINT_DIR" \ + # --checkpoint_period 10000 \ + # --checkpoint_path "$CHECKPOINT_DIR" \ #--debug # wandb is disabled when debug diff --git a/examples/async_sac_state_sim/run_tests.sh b/examples/async_sac_state_sim/run_tests.sh new file mode 100644 index 00000000..9db3bb6d --- /dev/null +++ b/examples/async_sac_state_sim/run_tests.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +env="PandaReachCube-v0" + +trap "kill 0" EXIT INT TERM + +# Baseline + +cmd="" +for i in $(seq 1 5); do + cmd+="bash -c 'tmux_launch.sh --env $env --name \'$env-baseline\' --type \'replay_buffer\' --seed $i'; " +done + + + +branch=("constant" "linear" "fractal" "OIO" "IOI" "inv_linear" "inv_fractal") +for b in "${branch[@]}"; do + # constant branching + if [ $b == "constant" ]; then + branch_factor=(1 3 9 27) + for bf in "${branch_factor[@]}"; do + b_num=$(( $bf*$bf )) + for seed in $(seq 1 5); do + cmd+="bash -c 'tmux_launch.sh --env $env --name \'$env-$branch-$b_num\' --type \'fractal_symmetry_replay_buffer\' --branch $b --starting_branch $bf --seed $seed'; " + done + done + fi +done +tmux new-session -d -s testing_session +tmux send-keys -t testing_session:0.0 $cmd C-m diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 78ff94a8..6145f712 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -1,9 +1,101 @@ #!/bin/bash -EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} +# Default variable values +seed=0 +env="PandaReachCube-v0" +branch="test" +type="fractal_symmetry_replay_buffer" +name="untitled_experiment" +starting_branch=1 + +has_argument() { + [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; +} + +extract_argument() { + echo "${2:-${1#*=}}" +} + +# Function to handle options and arguments +handle_options() { + while [ $# -gt 0 ]; do + case $1 in + -e | --env*) + if ! has_argument $@; then + echo "Environment not specified" >&2 + exit 1 + fi + + env=$(extract_argument $@) + + shift + ;; + -s | --seed*) + if ! has_argument $@; then + echo "Seed needs a positive integer value" >&2 + exit 1 + fi + + seed=$(extract_argument $@) + + shift + ;; + -t | --type*) + if ! has_argument $@; then + echo "Type needs a string" >&2 + exit 1 + fi + + type=$(extract_argument $@) + + shift + ;; + -n | --name*) + if ! has_argument $@; then + echo "Name needs a string" >&2 + exit 1 + fi + + name=$(extract_argument $@) + + shift + ;; + --starting_branch*) + if ! has_argument $@; then + echo "Starting_branch needs a positive odd integer" >&2 + exit 1 + fi + + starting_branch=$(extract_argument $@) + + shift + ;; + -b | --branch*) + if ! has_argument $@; then + echo "Branch needs a string" >&2 + exit 1 + fi + + branch=$(extract_argument $@) + + shift + ;; + *) + echo "Invalid option: $1" >&2 + exit 1 + ;; + esac + shift + done +} + +# Main script execution +handle_options "$@" + +# EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} CONDA_ENV=${CONDA_ENV:-"serl"} -cd $EXAMPLE_DIR +# cd $EXAMPLE_DIR echo "Running from $(pwd)" # Create a new tmux session @@ -13,10 +105,10 @@ tmux new-session -d -s serl_session tmux split-window -v # Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m +tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh --seed $seed --env $env --name $name --type $type --branch $branch --starting_branch $starting_branch" C-m # Navigate to the activate the conda environment in the second pane -tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m +tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh --seed $seed --env $env --name $name --type $type --branch $branch --starting_branch $starting_branch && tmux kill-session -t serl_session:0.0" C-m # Attach to the tmux session tmux attach-session -t serl_session diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py index 03828dd8..c38e8d99 100644 --- a/franka_sim/franka_sim/envs/panda_reach_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -35,7 +35,7 @@ def __init__( physics_dt: float = 0.002, time_limit: float = 10.0, render_spec: GymRenderingSpec = GymRenderingSpec(), - render_mode: Literal["rgb_array", "human"] = "rgb_array", + render_mode: Literal["rgb_array", "human"] = None, image_obs: bool = False, ): self._action_scale = action_scale @@ -145,7 +145,8 @@ def __init__( height=960, camera_id=0 ) - self._viewer.render(self.render_mode) + if self.render_mode: + self._viewer.render(self.render_mode) def reset( self, seed=None, **kwargs diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index dd452c86..f12856a5 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -22,19 +22,13 @@ def __init__( method_check = "split_method" match split_method: case "time": - # assert "timesplit_freq" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") - # self.timesplit_freq=kwargs["timesplit_freq"] - # del kwargs["timesplit_freq"] - self.split = self.time_split - case "rel_pos": - print("NOT IMPLEMENTED") - self.split = self.rel_pos_split - case "height": - print("NOT IMPLEMENTED") - self.split = self.height_split - case "velocity": - print("NOT IMPLEMENTED") - self.split = self.velocity_split + assert "depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "depth") + self.max_depth=kwargs["depth"] + del kwargs["depth"] + + assert "max_steps" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_steps") + self.split = self.time_split + case "test": self.split = self.test_split case _: @@ -44,7 +38,7 @@ def __init__( match branch_method: case "fractal": assert "depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") - self.depth=kwargs["depth"] + self.max_depth=kwargs["depth"] del kwargs["depth"] assert "dendrites" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "dendrites") @@ -76,9 +70,6 @@ def __init__( case _: raise ValueError("incorrect value passed to branch_method") - - for k in kwargs.keys(): - print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") self.x_obs_idx = kwargs["x_obs_idx"] self.y_obs_idx = kwargs["y_obs_idx"] @@ -86,6 +77,9 @@ def __init__( self.current_depth = 1 self.counter = 1 self.branch_index = [workspace_width/2] + + for k in kwargs.keys(): + print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") # TODO # Add flags for variables passed to @@ -103,7 +97,7 @@ def __init__( capacity=capacity, ) - def _handle_bad_args_(type: str, method: str, arg: str) : + def _handle_bad_args_(self, type: str, method: str, arg: str) : return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" @@ -120,14 +114,14 @@ def test_branch(self): # while(not self.current_branch_count % 2): # self.current_branch_count = np.random.randint(2500) # return self.current_branch_count - return 9 + return 1 def fractal_branch(self): # return a new number of branches = dendrites ^ depth temp = self.dendrites ** self.current_depth self.current_depth += 1 - if self.current_depth > self.depth: - return self.dendrites ** self.depth + if self.current_depth > self.max_depth: + return self.dendrites ** self.max_depth return temp # REQUIRES TESTING @@ -153,21 +147,6 @@ def time_split(self, data_dict: DatasetDict): case _: return False - # NOT IMPLEMENTED - def rel_pos_split(self, data_dict: DatasetDict): - # return True if there is a change of depth determined by relative position - raise NotImplementedError - - # NOT IMPLEMENTED - def height_split(self, data_dict: DatasetDict): - # return True if there is a change of depth determined by height - raise NotImplementedError - - # NOT IMPLEMENTED - def velocity_split(self, data_dict: DatasetDict): - # return True if there is a change of depth determined by velocity - raise NotImplementedError - def constant_split(self, data_dict: DatasetDict): # return True every transition return True From a03c6f6b068ed592483c7788f4159dd1f4c2a885 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 23 Jul 2025 12:54:29 -0500 Subject: [PATCH 044/141] General enhancements: - Improved index handling for environments - Updated how env.step() returns done based on the reward definition in this reach env. Used to always be False. - Path setting improvements to save .npz file. - Created a load_demo_test.py to assert written data can be read and parsed by episode and time step. --- demos/franka_reach_demo_script.py | 95 +++++++++---- demos/load_demo_test.py | 130 ++++++++++++++++++ .../franka_sim/envs/panda_reach_gym_env.py | 13 +- 3 files changed, 207 insertions(+), 31 deletions(-) create mode 100644 demos/load_demo_test.py diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py index 14767eec..e632e00f 100755 --- a/demos/franka_reach_demo_script.py +++ b/demos/franka_reach_demo_script.py @@ -13,6 +13,8 @@ 4. Maintain Position (hold final position) Output: Compressed NPZ file containing action, observation, and info sequences + +TODO: Convert to a class-based structure for better modularity and reusability. """ import os import numpy as np @@ -32,13 +34,32 @@ dones = [] # List storing done flags (terminated or truncated) for each episode # Proportional and derivative control gain for action scaling -- empirically tuned -Kp = 20.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 -Kv = 24.0 +Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 +Kv = 10.0 + +# Number of demonstration episodes to generate +NUM_DEMOS = 20 # Robot configuration robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' +DEBUG = False + +if DEBUG: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' +else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + +# Indices for franka_sim reach environment observations +if robot == 'franka' and task == 'reach': + + opi = np.array([0, 3]) # Indices for object position in observation + gpi = np.array([3]) # Indices for gripper position in observation + rpi = np.array([4, 7]) # Indices for robot position in observation + rvi = np.array([7, 10]) # Indices for robot velocity in observation + # Weld constraint flag weld_flag = True # Flag to activate weld constraint during pick-and-place @@ -81,6 +102,26 @@ # print(f"Warning: Constraint '{constraint_name}' not found") # return False +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 3.0 # Camera distance + viewer.cam.azimuth = 180 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + return viewer + def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): """ Store transition data in the episode dictionary. @@ -123,10 +164,10 @@ def update_state_info(episode_data, time_step, dt, error): current_pos (np.ndarray): Current position of the end-effector. current_vel (np.ndarray): Current velocity of the end-effector. """ - object_pos = episode_data["observations"][-1][0:3] # Block position - gripper_pos = episode_data["observations"][-1][3] # Gripper position - current_pos = episode_data["observations"][-1][4:7] # Panda/tcp position - current_vel = episode_data["observations"][-1][7:10] # Panda/t + object_pos = episode_data["observations"][-1][ opi[0]:opi[1] ] # Block position + gripper_pos = episode_data["observations"][-1][ gpi[0] ] # Gripper position + current_pos = episode_data["observations"][-1][ rpi[0]:rpi[1] ] # Panda/tcp position + current_vel = episode_data["observations"][-1][ rvi[0]:rvi[1] ] # Panda/t # Print debug information @@ -216,14 +257,14 @@ def demo(env, lastObs): fgr_pos = 0 # Error thresholds - error_threshold = 0.04 # Threshold for stopping condition (Xmm) + error_threshold = 0.03 # Threshold for stopping condition (Xmm) finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger ## Extract data - object_pos = lastObs[0:3] # block pos - current_pos = lastObs[4:7] # panda/tcp_pos + object_pos = lastObs[opi[0]:opi[1]] # block pos + current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos # Relative position between end-effector and object dt = env.unwrapped.model.opt.timestep # Mujoco time step @@ -274,7 +315,8 @@ def demo(env, lastObs): time_step += 1 # Sleep - sleep(0.25) # Optional: Slow down for better visualization + if DEBUG: + sleep(0.25) # Activated when DEBUG is True for better visualization. # Store complete episode data in global lists only if we succeeded (avoid bad demos) store_episode_data(episode_data) @@ -283,9 +325,6 @@ def demo(env, lastObs): # if weld_flag: # deactivate_weld(env, constraint_name="grasp_weld") - # Close mujoco viewer - env.close() - # Break out of the loop to start a new episode return True @@ -305,11 +344,11 @@ def main(): Arguments that can be configured with flags: - env - render - - num_demos + - demo_ctr """ # Initialize the Panda environment. - env = gym.make("PandaReachCube-v0", render_mode="human") + env = gym.make("PandaReachCube-v0", render_mode=_render_mode) env = gym.wrappers.FlattenObservation(env) # Adjust physical settings @@ -320,28 +359,23 @@ def main(): initStateSpace = "random" # Initial state space configuration # Demos configs - attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + num_demos = NUM_DEMOS # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** - num_demos = 0 # Counter for successful demonstration episodes + demo_ctr = 0 # Counter for successful demonstration episodes # Reset environment to initial state - render for the first time. obs, _ = env.reset() # For reach environment expect 10 observations: r_pos, r_vel, finger, object_pos. # Adjust camera view for better visualization - viewer = env.unwrapped._viewer.viewer - if hasattr(viewer, 'cam'): - viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) - viewer.cam.distance = 3.0 # Camera distance - viewer.cam.azimuth = 180 # 0 = right, 90 = front, 180 = left - viewer.cam.elevation = -30 # Negative = above, positive = below + viewer = set_front_cam_view(env) print("Reset!") # Generate demonstration episodes - while len(actions) < attempted_demos: + while len(actions) < num_demos: obs,_ = env.reset() # Reset environment for new episode - print(f"We will run a total of: {attempted_demos} demos!!") + print(f"We will run a total of: {num_demos} demos!!") print("Demo: #", len(actions)+1) # Execute pick-and-place task @@ -349,9 +383,12 @@ def main(): # Print success message if res: - num_demos += 1 + demo_ctr += 1 print("Episode completed successfully!") - print(f"Total successful demos: {num_demos}/{attempted_demos}") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + + # Close the environment after all episodes are done + env.close() ## Write data to demos folder. Assumes mounted /data folder and internal data folder. script_dir = '/data/data/serl/demos' @@ -359,7 +396,7 @@ def main(): # Create output filename with configuration details fileName = "data_" + robot + "_" + task fileName += "_" + initStateSpace - fileName += "_" + str(attempted_demos) + fileName += "_" + str(num_demos) fileName += ".npz" # Build a filename in that same directory @@ -380,7 +417,7 @@ def main(): dones = dones) print(f"Data saved to {fileName}.") - print(f"Total successful demos: {num_demos}/{attempted_demos}") + print(f"Total successful demos: {demo_ctr}/{num_demos}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/demos/load_demo_test.py b/demos/load_demo_test.py new file mode 100644 index 00000000..adac50eb --- /dev/null +++ b/demos/load_demo_test.py @@ -0,0 +1,130 @@ +# Updated and advanced train.py that includes logging, vectorized environments, and periodic recorded evaluations +import os +from pathlib import Path +import numpy as np + +def load_demos_to_her_buffer_gymnasium(demo_npz_path: str, combine_done: bool = True): + """ + Load a raw Gymnasium-style .npz of expert episodes into model.replay_buffer. + + demo_npz_path must contain at least these arrays: + - 'episodeObs' : shape (T+1, *obs_shape*), list of observations + - 'episodeAcs' : shape (T, *act_shape*), list of actions + - 'episodeRews' : shape (T,), list of rewards + - 'episodeTerminated' : shape (T,), list of terminated flags + - 'episodeTruncated' : shape (T,), list of truncated flags + - 'episodeInfo' : shape (T,), list of info dicts + + Parameters + ---------- + model : SAC + Your SAC+HER model, already instantiated (and with the correct HerReplayBuffer). + demo_npz_path : str + Path to the .npz file you saved from your demo collector. + combine_done : bool, default=True + If True, `done = terminated or truncated`. If False, `done = terminated` only. + """ + + # Load all demo data. Structure: var_name[num_demo][time_step][key if dict] = value + data = np.load(demo_npz_path, allow_pickle=True) + + obs_buffer = data['obs'] # length T+1 + act_buffer = data['acs'] # length T + rew_buffer = data['rewards'] # length T + term_buffer = data['terminateds'] # length T + trunc_buffer = data['truncateds'] # length T + info_buffer = data['info'] # length T + done_buffer = data['dones'] # length T, if available + + # Extract number of demonstrations + num_demos = obs_buffer.shape[0] + + # Extract rollout data for a single episode + for ep in range(num_demos): + ep_obs = obs_buffer[ep] # this is a length‐(T+1) array of dicts + ep_acts = act_buffer[ep] # length‐T array of actions + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] # length‐T array of dicts + + # Length of episode: + T = len(ep_acts) + + # Extract single transitions from the episode data + for t in range(T): + # raw single‐step data: + obs_t = ep_obs[t] # dict[str, np.ndarray] (obs_dim,) + next_obs_t = ep_obs[t+1] + a_t = ep_acts[t] # np.ndarray (action_dim,) + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + + # Rehydrate info dict and inject the timeout flag + raw_info = ep_info[t] # dict[str,Any] + if isinstance(raw_info, str): + import ast + info_t = ast.literal_eval(raw_info) + else: + info_t = raw_info.copy() + # Append truncated information to info_t + info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) + + # **Add the required batch‐dimension** for n_envs=1 (necessary for defualt DummyVecEnv) + # obs_batch = {k: v[None, ...] for k, v in obs_t.items()} + # next_obs_batch = {k: v[None, ...] for k, v in next_obs_t.items()} + # action_batch = a_t[None, ...] # shape (1, action_dim) + # reward_batch = np.array([r_t]) # shape (1,) + # done_batch = np.array([done_t]) # shape (1,) + # infos_batch = [info_t] # length‐1 list + + # model.replay_buffer.add( + # obs = obs_batch, + # next_obs = next_obs_batch, + # action = action_batch, + # reward = reward_batch, + # done = done_batch, + # infos = infos_batch, + # ) + + print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." + f"(combine_done={combine_done}).") + +def get_demo_path(relative_path: str) -> str: + """ + Given a path relative to this script file, return + the absolute, normalized path as a string. + + Example: + # If your demos live at ../../../demos/data.npz + demo_file = get_demo_path("../../../demos/data_franka_random_10.npz") + """ + # 1) Resolve this script’s directory + script_dir = Path(__file__).resolve().parent + + # 2) Join with the user-supplied relative path and normalize + full_path = (script_dir / relative_path).resolve() + + return str(full_path) + + +def main(): + """ + Load demos + """ + + # Get abs path to demo file + script_dir = '/data/data/serl/demos' + default_file = 'data_franka_reach_random_20.npz' + + prompt = f"Please input the name of the file to load [{default_file}]: " + + file_name = input(prompt) or default_file + demo_file = os.path.join(script_dir, file_name) + + # Load the demo file into the HER buffer: mutated model.replay_buffer will persist. + load_demos_to_her_buffer_gymnasium(demo_file, combine_done=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py index 03828dd8..44ad2667 100644 --- a/franka_sim/franka_sim/envs/panda_reach_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -215,7 +215,16 @@ def step( obs = self._compute_observation() rew = self._compute_reward() - terminated = self.time_limit_exceeded() + + # For reach environment, finger + if rew >= 0.55: + terminated = True + else: + # Check if the time limit is exceeded. + if self._time_limit is not None: + terminated = self.time_limit_exceeded() + else: + terminated = False return obs, rew, terminated, False, {} @@ -264,7 +273,7 @@ def _compute_reward(self) -> float: # Calculate distance dist = np.linalg.norm(block_pos - tcp_pos) - # Distance-based reward + # Distance-based reward. Note at norm of 0.015, reward will be 0.5 r_close = np.exp(-20 * dist) r_close = np.clip(r_close, 0.0, 1.0) From 680762347e5dcb3e76dfd850a8bd2e0450a983dd Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 23 Jul 2025 15:36:29 -0500 Subject: [PATCH 045/141] Imitate an insertion into data_store or queued data store --- demos/load_demo_test.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/demos/load_demo_test.py b/demos/load_demo_test.py index adac50eb..240a9dfa 100644 --- a/demos/load_demo_test.py +++ b/demos/load_demo_test.py @@ -3,7 +3,7 @@ from pathlib import Path import numpy as np -def load_demos_to_her_buffer_gymnasium(demo_npz_path: str, combine_done: bool = True): +def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_done: bool = True): """ Load a raw Gymnasium-style .npz of expert episodes into model.replay_buffer. @@ -17,8 +17,8 @@ def load_demos_to_her_buffer_gymnasium(demo_npz_path: str, combine_done: bool = Parameters ---------- - model : SAC - Your SAC+HER model, already instantiated (and with the correct HerReplayBuffer). + data_store : DataStore + The data store to which the demo transitions will be added. demo_npz_path : str Path to the .npz file you saved from your demo collector. combine_done : bool, default=True @@ -71,22 +71,17 @@ def load_demos_to_her_buffer_gymnasium(demo_npz_path: str, combine_done: bool = # Append truncated information to info_t info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) - # **Add the required batch‐dimension** for n_envs=1 (necessary for defualt DummyVecEnv) - # obs_batch = {k: v[None, ...] for k, v in obs_t.items()} - # next_obs_batch = {k: v[None, ...] for k, v in next_obs_t.items()} - # action_batch = a_t[None, ...] # shape (1, action_dim) - # reward_batch = np.array([r_t]) # shape (1,) - # done_batch = np.array([done_t]) # shape (1,) - # infos_batch = [info_t] # length‐1 list - - # model.replay_buffer.add( - # obs = obs_batch, - # next_obs = next_obs_batch, - # action = action_batch, - # reward = reward_batch, - # done = done_batch, - # infos = infos_batch, - # ) + + data_store.insert( + dict( + observations=obs_t, + actions=a_t, + next_observations=next_obs_t, + rewards=r_t, + masks=1.0 - done_t, + dones=done_t + ) + ) print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." f"(combine_done={combine_done}).") @@ -124,7 +119,9 @@ def main(): demo_file = os.path.join(script_dir, file_name) # Load the demo file into the HER buffer: mutated model.replay_buffer will persist. - load_demos_to_her_buffer_gymnasium(demo_file, combine_done=True) + from agentlace.data.data_store import QueuedDataStore + data_store = QueuedDataStore(2000) + load_demos_to_her_buffer_gymnasium(data_store, demo_file, combine_done=True) if __name__ == "__main__": main() \ No newline at end of file From 2aeb8ccfcfe5b1691d04dfece5e0516517372143 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 23 Jul 2025 15:53:56 -0500 Subject: [PATCH 046/141] Converted code into a class that can be imported for easy running in serl. --- demos/__init__.py | 0 demos/demoHandling.py | 121 ++++++++++++++++++++++++++++++++++++++++ demos/load_demo_test.py | 6 +- 3 files changed, 124 insertions(+), 3 deletions(-) create mode 100644 demos/__init__.py create mode 100644 demos/demoHandling.py diff --git a/demos/__init__.py b/demos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demos/demoHandling.py b/demos/demoHandling.py new file mode 100644 index 00000000..1d564e5c --- /dev/null +++ b/demos/demoHandling.py @@ -0,0 +1,121 @@ +import os +from pathlib import Path +import numpy as np +from agentlace.data.data_store import QueuedDataStore + +class DemoHandling: + """ + Encapsulates loading of Gymnasium-style demo .npz files + into a DataStore (e.g., QueuedDataStore) for use with HER/RL agents. + """ + def __init__( + self, + data_store: QueuedDataStore, + demo_dir: str = '/data/data/serl/demos' + ): + """ + Parameters + ---------- + data_store : QueuedDataStore + The replay buffer or datastore to insert transitions into. + demo_dir : str + Directory where demo .npz files live by default. + """ + self.data_store = data_store + self.demo_dir = demo_dir + + @staticmethod + def load_demos_to_her_buffer_gymnasium( + data_store, + demo_npz_path: str, + combine_done: bool = True + ): + """ + Load a raw Gymnasium-style .npz of expert episodes into data_store. + + demo_npz_path must contain arrays named 'obs', 'acs', 'rewards', + 'terminateds', 'truncateds', 'info', and optionally 'dones'. + + If combine_done=True, done = terminated OR truncated OR done_buffer. + """ + data = np.load(demo_npz_path, allow_pickle=True) + + obs_buffer = data['obs'] # shape (N, T+1, ...) + act_buffer = data['acs'] # shape (N, T, ...) + rew_buffer = data['rewards'] # shape (N, T) + term_buffer = data['terminateds'] # shape (N, T) + trunc_buffer = data['truncateds'] # shape (N, T) + info_buffer = data['info'] # shape (N, T) + done_buffer = data.get('dones', term_buffer | trunc_buffer) + + num_demos = obs_buffer.shape[0] + + for ep in range(num_demos): + ep_obs = obs_buffer[ep] + ep_acts = act_buffer[ep] + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] + + T = len(ep_acts) + for t in range(T): + obs_t = ep_obs[t] + next_obs_t = ep_obs[t+1] + a_t = ep_acts[t] + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + + raw_info = ep_info[t] + if isinstance(raw_info, str): + import ast + info_t = ast.literal_eval(raw_info) + else: + info_t = raw_info.copy() + info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) + + data_store.insert( + dict( + observations=obs_t, + actions=a_t, + next_observations=next_obs_t, + rewards=r_t, + masks=1.0 - done_t, + dones=done_t + ) + ) + + print(f"Loaded {num_demos} episodes from '{demo_npz_path}' " + f"(combine_done={combine_done}).") + + @staticmethod + def get_demo_path(relative_path: str) -> str: + """ + Resolve a path relative to this file's location. + """ + script_dir = Path(__file__).resolve().parent + return str((script_dir / relative_path).resolve()) + + + def run(self, default_file: str = 'data_franka_reach_random_20.npz'): + """ + Prompt for a file name (with a sensible default), then load it. + """ + prompt = f"Please input the name of the file to load [{default_file}]: " + + file_name = input(prompt).strip() or default_file + demo_file = os.path.join(self.demo_dir, file_name) + + self.load_demos_to_her_buffer_gymnasium( + self.data_store, + demo_file, + combine_done=True # dones are truncated or terminated in this case. + ) + + +# if __name__ == "__main__": +# # create your datastore; here we use a QueuedDataStore with capacity 2000 +# ds = QueuedDataStore(2000) +# handler = DemoHandling(ds) +# handler.run() diff --git a/demos/load_demo_test.py b/demos/load_demo_test.py index 240a9dfa..8fd33aad 100644 --- a/demos/load_demo_test.py +++ b/demos/load_demo_test.py @@ -71,7 +71,7 @@ def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_d # Append truncated information to info_t info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) - + # Enter transition into data_store or QueuedDataStore data_store.insert( dict( observations=obs_t, @@ -81,7 +81,7 @@ def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_d masks=1.0 - done_t, dones=done_t ) - ) + ) print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." f"(combine_done={combine_done}).") @@ -118,7 +118,7 @@ def main(): file_name = input(prompt) or default_file demo_file = os.path.join(script_dir, file_name) - # Load the demo file into the HER buffer: mutated model.replay_buffer will persist. + # Load the demo file into data_store as in async_sac_state.py from agentlace.data.data_store import QueuedDataStore data_store = QueuedDataStore(2000) load_demos_to_her_buffer_gymnasium(data_store, demo_file, combine_done=True) From 2c71b365494868cd3caffeece4bfec7510350cb5 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 23 Jul 2025 17:51:30 -0500 Subject: [PATCH 047/141] Reformated the class to better fit with the functionality we want in async_sac_state. - franka_reach_demo_script.py: now keeps count of total transitions - demoHandling.py: constructor checks for directories/files and loads all data to member - method insert_data_to_buffer pushes releveant transition data to QueuedDataStore for actor --- demos/demoHandling.py | 120 ++++++++++-------- demos/franka_reach_demo_script.py | 9 +- .../async_sac_state_sim.py | 89 ++++++++----- 3 files changed, 126 insertions(+), 92 deletions(-) diff --git a/demos/demoHandling.py b/demos/demoHandling.py index 1d564e5c..99b75c6f 100644 --- a/demos/demoHandling.py +++ b/demos/demoHandling.py @@ -5,13 +5,33 @@ class DemoHandling: """ - Encapsulates loading of Gymnasium-style demo .npz files - into a DataStore (e.g., QueuedDataStore) for use with HER/RL agents. + Koads an .npz file containing demonstration data into a data object. + This class is designed to work with Gymnasium-style demonstration data + and is intended to be used with a QueuedDataStore or similar data store. + + The .npz file should contain the following arrays: + - 'obs' : shape (N, T+1, *obs_shape*), list of observations + - 'acs' : shape (N, T, *act_shape*), list of actions + - 'rewards' : shape (N, T), list of rewards + - 'terminateds' : shape (N, T), list of terminated flags + - 'truncateds' : shape (N, T), list of truncated flags + - 'info' : shape (N, T), list of info dicts + - 'dones' : shape (N, T), list of done flags (if available) + + Parameters + ---------- + data_store : QueuedDataStore + The data store to which the demo transitions will be added. + demo_dir : str + Directory where demo .npz files live by default. + file_name : str + Name of the demo file to load. If not provided, a default will be used. """ def __init__( self, data_store: QueuedDataStore, - demo_dir: str = '/data/data/serl/demos' + demo_dir: str = '/data/data/serl/demos', + file_name: str = 'data_franka_reach_random_20.npz' ): """ Parameters @@ -23,30 +43,36 @@ def __init__( """ self.data_store = data_store self.demo_dir = demo_dir + self.transition_ctr = 0 # Global counter for transitions across all episodes - @staticmethod - def load_demos_to_her_buffer_gymnasium( - data_store, - demo_npz_path: str, - combine_done: bool = True - ): + # Load the demo data from the .npz file + + # Check if the demo directory exists + if not os.path.exists(self.demo_dir): + raise FileNotFoundError(f"Demo directory '{self.demo_dir}' does not exist.") + + # Construct the full path to the demo file + self.demo_npz_path = os.path.join(self.demo_dir, file_name) + if not os.path.isfile(self.demo_npz_path): + raise FileNotFoundError(f"Demo file '{self.demo_npz_path}' does not exist.") + + # Load the .npz file + self.data = np.load(self.demo_npz_path, allow_pickle=True) + + def insert_data_to_buffer(self): """ Load a raw Gymnasium-style .npz of expert episodes into data_store. - - demo_npz_path must contain arrays named 'obs', 'acs', 'rewards', + The .npz file must contain arrays named 'obs', 'acs', 'rewards', 'terminateds', 'truncateds', 'info', and optionally 'dones'. - - If combine_done=True, done = terminated OR truncated OR done_buffer. """ - data = np.load(demo_npz_path, allow_pickle=True) - - obs_buffer = data['obs'] # shape (N, T+1, ...) - act_buffer = data['acs'] # shape (N, T, ...) - rew_buffer = data['rewards'] # shape (N, T) - term_buffer = data['terminateds'] # shape (N, T) - trunc_buffer = data['truncateds'] # shape (N, T) - info_buffer = data['info'] # shape (N, T) - done_buffer = data.get('dones', term_buffer | trunc_buffer) + + obs_buffer = self.data['obs'] # shape (N, T+1, ...) + act_buffer = self.data['acs'] # shape (N, T, ...) + rew_buffer = self.data['rewards'] # shape (N, T) + term_buffer = self.data['terminateds'] # shape (N, T) + trunc_buffer = self.data['truncateds'] # shape (N, T) + info_buffer = self.data['info'] # shape (N, T) + done_buffer = self.data.get('dones', term_buffer | trunc_buffer) num_demos = obs_buffer.shape[0] @@ -75,7 +101,7 @@ def load_demos_to_her_buffer_gymnasium( info_t = raw_info.copy() info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) - data_store.insert( + self.data_store.insert( dict( observations=obs_t, actions=a_t, @@ -86,36 +112,18 @@ def load_demos_to_her_buffer_gymnasium( ) ) - print(f"Loaded {num_demos} episodes from '{demo_npz_path}' " - f"(combine_done={combine_done}).") - - @staticmethod - def get_demo_path(relative_path: str) -> str: - """ - Resolve a path relative to this file's location. - """ - script_dir = Path(__file__).resolve().parent - return str((script_dir / relative_path).resolve()) - - - def run(self, default_file: str = 'data_franka_reach_random_20.npz'): - """ - Prompt for a file name (with a sensible default), then load it. - """ - prompt = f"Please input the name of the file to load [{default_file}]: " - - file_name = input(prompt).strip() or default_file - demo_file = os.path.join(self.demo_dir, file_name) - - self.load_demos_to_her_buffer_gymnasium( - self.data_store, - demo_file, - combine_done=True # dones are truncated or terminated in this case. - ) - - -# if __name__ == "__main__": -# # create your datastore; here we use a QueuedDataStore with capacity 2000 -# ds = QueuedDataStore(2000) -# handler = DemoHandling(ds) -# handler.run() + print(f"Loaded {num_demos} episodes from '{self.demo_npz_path}' ") + + +if __name__ == "__main__": + # create your datastore; here we use a QueuedDataStore with capacity 2000 + ds = QueuedDataStore(2000) + handler = DemoHandling(ds, + demo_dir='/data/data/serl/demos', + file_name='data_franka_reach_random_20.npz') + + # Idenitfy the total number of transitions in the datastore + print(f'We have {handler.transition_ctr} transitions in the datastore.') + + # Load the demo data into the data_store + handler.insert_data_to_buffer() diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py index e632e00f..6d68bb9d 100755 --- a/demos/franka_reach_demo_script.py +++ b/demos/franka_reach_demo_script.py @@ -32,6 +32,7 @@ terminateds = [] # List storing terminated flags for each episode truncateds = [] # List storing truncated flags for each episode dones = [] # List storing done flags (terminated or truncated) for each episode +transition_ctr = 0 # Global counter for transitions across all episodes # Proportional and derivative control gain for action scaling -- empirically tuned Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 @@ -124,8 +125,11 @@ def set_front_cam_view(env): def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): """ - Store transition data in the episode dictionary. + Store transition data in the episode dictionary and update global counter. """ + global transition_ctr + transition_ctr += 1 + episode_dict["observations"].append(new_obs) episode_dict["rewards"].append(rewards) episode_dict["actions"].append(action) @@ -414,7 +418,8 @@ def main(): info = infos, terminateds = terminateds, truncateds = truncateds, - dones = dones) + dones = dones, + transition_ctr = transition_ctr) print(f"Data saved to {fileName}.") print(f"Total successful demos: {demo_ctr}/{num_demos}") diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index ef20e8d0..258f7f0b 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -27,6 +27,8 @@ import franka_sim +from demos.demoHandling import DemoHandling + FLAGS = flags.FLAGS flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") @@ -65,6 +67,11 @@ flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +# Load demonstation data +flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") +flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") +flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") + def print_green(x): return print("\033[92m {}\033[00m".format(x)) @@ -105,42 +112,56 @@ def update_params(params): for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): timer.tick("total") - with timer.context("sample_actions"): - if step < FLAGS.random_steps: - actions = env.action_space.sample() - else: - sampling_rng, key = jax.random.split(sampling_rng) - actions = agent.sample_actions( - observations=jax.device_put(obs), - seed=key, - deterministic=False, - ) - actions = np.asarray(jax.device_get(actions)) - - # Step environment - with timer.context("step_env"): - - next_obs, reward, done, truncated, info = env.step(actions) - next_obs = np.asarray(next_obs, dtype=np.float32) - reward = np.asarray(reward, dtype=np.float32) - - running_return += reward - - data_store.insert( - dict( - observations=obs, - actions=actions, - next_observations=next_obs, - rewards=reward, - masks=1.0 - done, - dones=done or truncated, + # Load demos: handler.run will insert all transition demo data into the data store. + if FLAGS.load_demos: + with timer.context("sample and step into env"): + + # handler loads data + handler = DemoHandling(data_store, + demo_dir=FLAGS.demo_dir, + file_name=FLAGS.demo_dir) + + # Insert complete demonstration into the data store + print(f"Inserting {handler.data["transition_ctr"]} transitions into the data store.") + handler.insert_data_to_buffer() + + else: + with timer.context("sample_actions"): + if step < FLAGS.random_steps: + actions = env.action_space.sample() + else: + sampling_rng, key = jax.random.split(sampling_rng) + actions = agent.sample_actions( + observations=jax.device_put(obs), + seed=key, + deterministic=False, + ) + actions = np.asarray(jax.device_get(actions)) + + # Step environment + with timer.context("step_env"): + + next_obs, reward, done, truncated, info = env.step(actions) + next_obs = np.asarray(next_obs, dtype=np.float32) + reward = np.asarray(reward, dtype=np.float32) + + running_return += reward + + data_store.insert( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done or truncated, + ) ) - ) - obs = next_obs - if done or truncated: - running_return = 0.0 - obs, _ = env.reset() + obs = next_obs + if done or truncated: + running_return = 0.0 + obs, _ = env.reset() if FLAGS.render: env.render() From dcaf5a56184fd603879ddf127d918bfa6657a4fe Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 23 Jul 2025 17:57:24 -0500 Subject: [PATCH 048/141] commented out main --- demos/demoHandling.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/demos/demoHandling.py b/demos/demoHandling.py index 99b75c6f..be984192 100644 --- a/demos/demoHandling.py +++ b/demos/demoHandling.py @@ -92,38 +92,31 @@ def insert_data_to_buffer(self): a_t = ep_acts[t] r_t = float(ep_rews[t]) done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) - - raw_info = ep_info[t] - if isinstance(raw_info, str): - import ast - info_t = ast.literal_eval(raw_info) - else: - info_t = raw_info.copy() - info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) + # masks will be created right before insert below self.data_store.insert( dict( - observations=obs_t, - actions=a_t, + observations =obs_t, + actions =a_t, next_observations=next_obs_t, - rewards=r_t, - masks=1.0 - done_t, - dones=done_t + rewards =r_t, + masks =1.0 - done_t, + dones =done_t ) ) print(f"Loaded {num_demos} episodes from '{self.demo_npz_path}' ") -if __name__ == "__main__": - # create your datastore; here we use a QueuedDataStore with capacity 2000 - ds = QueuedDataStore(2000) - handler = DemoHandling(ds, - demo_dir='/data/data/serl/demos', - file_name='data_franka_reach_random_20.npz') +# if __name__ == "__main__": +# # create your datastore; here we use a QueuedDataStore with capacity 2000 +# ds = QueuedDataStore(2000) +# handler = DemoHandling(ds, +# demo_dir='/data/data/serl/demos', +# file_name='data_franka_reach_random_20.npz') - # Idenitfy the total number of transitions in the datastore - print(f'We have {handler.transition_ctr} transitions in the datastore.') +# # Idenitfy the total number of transitions in the datastore +# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') - # Load the demo data into the data_store - handler.insert_data_to_buffer() +# # Load the demo data into the data_store +# handler.insert_data_to_buffer() From 11f7c0004f15566afb2b454294a06d97e08979e7 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 23 Jul 2025 20:10:57 -0500 Subject: [PATCH 049/141] Added temporal awareness to fsrb with time_split function --- .../data/fractal_symmetry_replay_buffer.py | 147 ++++++------------ serl_launcher/serl_launcher/data/fsrb_test.py | 126 ++++----------- 2 files changed, 73 insertions(+), 200 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index f12856a5..ef5056a0 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -22,11 +22,15 @@ def __init__( method_check = "split_method" match split_method: case "time": - assert "depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "depth") - self.max_depth=kwargs["depth"] - del kwargs["depth"] + assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_depth") + self.max_depth=kwargs["max_depth"] + del kwargs["max_depth"] assert "max_steps" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_steps") + self.max_steps=kwargs["max_steps"] + + assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") + self.alpha=kwargs["alpha"] self.split = self.time_split case "test": @@ -37,33 +41,33 @@ def __init__( method_check = "branch_method" match branch_method: case "fractal": - assert "depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "depth") - self.max_depth=kwargs["depth"] - del kwargs["depth"] + assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_depth") + self.max_depth=kwargs["max_depth"] + del kwargs["max_depth"] - assert "dendrites" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "dendrites") - self.dendrites=kwargs["dendrites"] - del kwargs["dendrites"] + assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") + self.branching_factor=kwargs["branching_factor"] + del kwargs["branching_factor"] self.branch = self.fractal_branch - case "linear": - assert "branch_count_rate_of_change" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branch_count_rate_of_change") - self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] - del kwargs["branch_count_rate_of_change"] + # case "linear": + # assert "branch_count_rate_of_change" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branch_count_rate_of_change") + # self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] + # del kwargs["branch_count_rate_of_change"] - assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "starting_branch_count") - self.current_branch_count = kwargs["starting_branch_count"] - del kwargs["starting_branch_count"] + # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "starting_branch_count") + # self.current_branch_count = kwargs["starting_branch_count"] + # del kwargs["starting_branch_count"] - self.branch = self.linear_branch + # self.branch = self.linear_branch - case "constant": - assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") - self.current_branch_count = kwargs["starting_branch_count"] - del kwargs["starting_branch_count"] + # case "constant": + # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") + # self.current_branch_count = kwargs["starting_branch_count"] + # del kwargs["starting_branch_count"] - self.branch = self.constant_branch + # self.branch = self.constant_branch case "test": self.branch = self.test_branch @@ -74,22 +78,20 @@ def __init__( self.x_obs_idx = kwargs["x_obs_idx"] self.y_obs_idx = kwargs["y_obs_idx"] self.workspace_width = workspace_width + self.timestep = 0 self.current_depth = 1 - self.counter = 1 self.branch_index = [workspace_width/2] for k in kwargs.keys(): print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") # TODO - # Add flags for variables passed to # Create more methods with tests # # Future considerations: - # consider different strategy for better accuracy # when dealing with x and y for transforms, consider two versions: # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms - # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 dendrites (original plus 8) in order to conform + # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform super().__init__( observation_space=observation_space, @@ -103,58 +105,48 @@ def _handle_bad_args_(self, type: str, method: str, arg: str) : def transform(self, data_dict: DatasetDict, transform: np.array): # return data_dict with positional arguments += translation - assert data_dict["observations"].shape == transform.shape, "transform broke at observations" - assert data_dict["observations"].shape == transform.shape, "transform broke at next_observations" data_dict["observations"] += transform data_dict["next_observations"] += transform # FOR TESTING ONLY def test_branch(self): - # self.current_branch_count = np.random.randint(2500) - # while(not self.current_branch_count % 2): - # self.current_branch_count = np.random.randint(2500) - # return self.current_branch_count return 1 def fractal_branch(self): - # return a new number of branches = dendrites ^ depth - temp = self.dendrites ** self.current_depth + # return a new number of branches = branching_factor ^ depth + temp = self.branching_factor ** self.current_depth self.current_depth += 1 if self.current_depth > self.max_depth: - return self.dendrites ** self.max_depth + return self.branching_factor ** self.max_depth return temp # REQUIRES TESTING - def constant_branch(self): - # return current number of branches - return self.current_branch_count + # def constant_branch(self): + # # return current number of branches + # return self.current_branch_count # REQUIRES TESTING - def linear_branch(self): - # return a new number of branches = branches_count + n - return self.current_branch_count + self.branch_count_rate_of_change + # def linear_branch(self): + # # return a new number of branches = branches_count + n + # return self.current_branch_count + self.branch_count_rate_of_change # FOR TESTING ONLY def test_split(self, data_dict: DatasetDict): return True - # NOT IMPLEMENTED (HARDCODED) + # REQUIRES TESTING def time_split(self, data_dict: DatasetDict): - # return True when a set time has passed - match self.counter % 12: - case 0: - return True - case _: - return False + if self.timestep % (self.max_steps/self.max_depth) == 0: + return True + return False def constant_split(self, data_dict: DatasetDict): - # return True every transition return True def insert(self, data_dict_not: DatasetDict): data_dict = copy.deepcopy(data_dict_not) - # start_time = time.time() + # Update number of branches if needed if self.split(data_dict): temp = self.current_branch_count @@ -164,11 +156,7 @@ def insert(self, data_dict_not: DatasetDict): constant = self.workspace_width/(2 * self.current_branch_count) for i in range(0, self.current_branch_count): self.branch_index[i] = (2 * i + 1) * constant - # rb_origin = [data_dict["observations"][self.x_obs_idx[0]], data_dict["observations"][self.y_obs_idx[0]]] - # block_origin = [data_dict["observations"][self.x_obs_idx[1]], data_dict["observations"][self.y_obs_idx[1]]] - # rb_next_origin = [data_dict["next_observations"][self.x_obs_idx[0]], data_dict["next_observations"][self.y_obs_idx[0]]] - # block_next_origin = [data_dict["next_observations"][self.x_obs_idx[1]], data_dict["next_observations"][self.y_obs_idx[1]]] - + # Initialize to extreme x and y x = -self.workspace_width/2 transform = np.zeros_like(data_dict["observations"]) @@ -176,59 +164,16 @@ def insert(self, data_dict_not: DatasetDict): transform[self.y_obs_idx] = x self.transform(data_dict, transform) - # rb_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) - # block_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) - # rb_next_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) - # block_next_pos_print = np.empty(shape=(self.current_branch_count, self.current_branch_count, 2), dtype=np.float32) # Transform and insert transitions (multiprocessing in the future) for x in range(0, self.current_branch_count): - x_diff = self.branch_index[x] - transform[self.x_obs_idx] = x_diff + transform[self.x_obs_idx] = self.branch_index[x] for y in range(0, self.current_branch_count): new_data_dict = copy.deepcopy(data_dict) - y_diff = self.branch_index[y] - transform[self.y_obs_idx] = y_diff + transform[self.y_obs_idx] = self.branch_index[y] self.transform(new_data_dict, transform) - - # keep info for debug - # rb_pos_print[y][x][0] = round(new_data_dict["observations"][self.x_obs_idx[0]], 2) - # rb_pos_print[y][x][1] = round(new_data_dict["observations"][self.y_obs_idx[0]], 2) - # block_pos_print[y][x][0] = round(new_data_dict["observations"][self.x_obs_idx[1]], 2) - # block_pos_print[y][x][1] = round(new_data_dict["observations"][self.y_obs_idx[1]], 2) - # rb_next_pos_print[y][x][0] = round(new_data_dict["next_observations"][self.x_obs_idx[0]], 2) - # rb_next_pos_print[y][x][1] = round(new_data_dict["next_observations"][self.y_obs_idx[0]], 2) - # block_next_pos_print[y][x][0] = round(new_data_dict["next_observations"][self.x_obs_idx[1]], 2) - # block_next_pos_print[y][x][1] = round(new_data_dict["next_observations"][self.y_obs_idx[1]], 2) - - # absolutely make sure nothing wrong is being changed - # assert data_dict["rewards"] == new_data_dict["rewards"] - # assert (data_dict["actions"] == new_data_dict["actions"]).all() - # assert data_dict["dones"] == new_data_dict["dones"] - # assert data_dict["masks"] == new_data_dict["masks"] - # for i in range(0, data_dict["observations"].size): - # if i in self.x_obs_idx or i in self.y_obs_idx: - # continue - # assert data_dict["observations"][i] == new_data_dict["observations"][i] - # assert data_dict["next_observations"][i] == new_data_dict["next_observations"][i] - super().insert(new_data_dict) - self.counter += 1 - - # print(f"original x,y of hand before transition: {rb_origin}") - # print(f"transforms: \n{rb_pos_print}") - # print(f"original x,y of hand after transition: {rb_next_origin}") - # print(f"transforms: \n{rb_next_pos_print}") - # print(f"original x,y of block before transition: {block_origin}") - # print(f"transforms: \n{block_pos_print}") - # print(f"original x,y of block after transition: {block_next_origin}") - # print(f"transforms: \n{block_next_pos_print}") - if data_dict["dones"]: - self.counter = 1 self.current_depth = 1 - # end_time = time.time() - # duration = end_time - start_time - # print(f"#{self._insert_index - 1} Duration: {duration:.4f}") - + self.max_steps = int(self.timestep * self.alpha + self.max_steps * (1 - self.alpha)) \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index c6a00e68..b48b0a91 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -10,19 +10,17 @@ FLAGS = flags.FLAGS flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "test", "placeholder") -flags.DEFINE_string("split_method", "test", "placeholder") -flags.DEFINE_float("workspace_width", 0.5, "workspace width in centimeters") -flags.DEFINE_integer("depth", 4, "Total layers of depth") -flags.DEFINE_integer("dendrites", 3, "Dendrites for fractal branching") # Remember to set default to None -flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") -flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") -flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches(sort of, TBD)") +flags.DEFINE_string("branch_method", "test", "Method for determining the number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", "test", "Method for determining whether to change the number of transforms per dimension (x,y)") +flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") +flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only +flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only def main(_): - x_obs_idx = np.array([0, 7]) - y_obs_idx = np.array([1, 8]) + x_obs_idx = np.array([0, 4]) + y_obs_idx = np.array([1, 5]) # Initialize replay buffer env = gym.make("PandaReachCube-v0") @@ -37,11 +35,8 @@ def main(_): workspace_width=FLAGS.workspace_width, x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, - depth=FLAGS.depth, - dendrites=FLAGS.dendrites, - timesplit_freq=FLAGS.timesplit_freq, - branch_count_rate_of_change=FLAGS.branch_count_rate_of_change, - starting_branch_count=FLAGS.starting_branch_count, + max_depth=FLAGS.max_depth, + branching_factor=FLAGS.branching_factor, ) observation, info = env.reset() @@ -52,11 +47,11 @@ def main(_): next_observations=next_observation, actions=action, rewards=reward, - masks=False, + masks=not truncated and not terminated, dones=truncated or terminated, ) - del env, observation, next_observation, action, reward, truncated, terminated, info, _ + del env, observation, next_observation, action, reward, truncated, terminated, info, y_obs_idx, x_obs_idx, _ # transform() test @@ -67,103 +62,36 @@ def main(_): del result, expected - print("\n\033[32mTEST PASSED \033[0m transform() tests passed") + print("\n\033[32mTEST PASSED \033[0m transform() test passed") + # branch() tests ## fractal - # for dendrites in range(1, 10, 2): - # replay_buffer.dendrites=dendrites - # for depth in range(1, 6): - # replay_buffer.current_depth=depth - 1 - # result = replay_buffer.fractal_branch() - # expected = dendrites ** depth - # assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" - - # del dendrites, depth, result, expected - - ## constant + replay_buffer.max_depth = 4 + replay_buffer.branching_factor = 3 - ## linear + branches = replay_buffer.fractal_branch() + assert branches == 3 + branches = replay_buffer.fractal_branch() + assert branches == 9 + branches = replay_buffer.fractal_branch() + assert branches == 27 + branches = replay_buffer.fractal_branch() + assert branches == 81 + branches = replay_buffer.fractal_branch() + assert branches == 81 - ## exponential + del branches print("\033[32mTEST PASSED \033[0m _branch() tests passed") # split() tests - ## time - - ## height - - ## rel_pos - - ## velocity - print("\033[32mTEST PASSED \033[0m _split() tests passed") # insert() tests - # important_indices = np.array([0, 7]) - - # num_branches = 5000 - # workspace_options = 20 - - # data = dict( - # workspace_width = np.empty(shape=(num_branches//2 * workspace_options), dtype=np.float32), - # num_branches = np.empty(shape=(num_branches//2 * workspace_options), dtype=int), - # density = np.empty(shape=(num_branches//2 * workspace_options), dtype=np.float32) - # ) - - # for w in range(0, workspace_options): - # total_success = 0.0 - # for i in range(1,num_branches, 2): - # finished = False - # buffer.current_branch_count = i - # data_dict["observations"][important_indices] = 0 - # start = data_dict["observations"][important_indices[0]] - FLAGS.workspace_width / 2 - # buffer.insert(data_dict) - # for idx in range(buffer.current_branch_count): - # expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) - # result = round(buffer.dataset_dict["observations"][buffer._insert_index - buffer.current_branch_count + idx][important_indices[0]] - start, 3) - # # assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" - # if result != expected: - # # print(f"At ww = {FLAGS.workspace_width}, maximum transforms = {i - 2}") - # # print(f"{i} failed at {idx}") - # finished = True - - # break - # if not finished: - # print(f"SUCCESS at {i} branches") - # total_success += 1 - # data["workspace_width"][i//2 + (num_branches//2) * w] = 0.05 * (w + 1) - # data["num_branches"][i//2 + (num_branches//2) * w] = i - # data["density"][i//2 + (num_branches//2) * w] = total_success/(i//2 + 1) - - # FLAGS.workspace_width = FLAGS.workspace_width + 0.05 - # buffer.workspace_width = FLAGS.workspace_width - # del i, idx, result, expected - - # df = pd.DataFrame(data) - # df.to_excel('output.xlsx', index=False) - - data_dict["observations"][0] = 0 - data_dict["observations"][1] = 0 - start = data_dict["observations"][0] - FLAGS.workspace_width / 2 - - - # buffer.insert(data_dict) - # for idx in range(buffer.current_branch_count): - # expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * buffer.current_branch_count), 3) - # result = round(buffer.dataset_dict["observations"][idx][0] - start, 3) - # assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" - - replay_buffer.insert(data_dict) - for idx in range(replay_buffer.current_branch_count): - expected = round(FLAGS.workspace_width * (2 * idx + 1)/(2 * replay_buffer.current_branch_count), 3) - result = round(replay_buffer.dataset_dict["observations"][idx][0] - start, 3) - assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected {expected} but got {result})" - print("\033[32mTEST PASSED \033[0m insert() tests passed") print("\nfinished!\n") From 9f0701f5120815de650ebded9443727a3f13bcb0 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 23 Jul 2025 20:18:06 -0500 Subject: [PATCH 050/141] fixed fractal test in fsrb_test.py --- serl_launcher/serl_launcher/data/fsrb_test.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index b48b0a91..558162c9 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -71,18 +71,23 @@ def main(_): replay_buffer.max_depth = 4 replay_buffer.branching_factor = 3 - branches = replay_buffer.fractal_branch() - assert branches == 3 - branches = replay_buffer.fractal_branch() - assert branches == 9 - branches = replay_buffer.fractal_branch() - assert branches == 27 - branches = replay_buffer.fractal_branch() - assert branches == 81 - branches = replay_buffer.fractal_branch() - assert branches == 81 - - del branches + result = replay_buffer.fractal_branch() + expected = 3 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + result = replay_buffer.fractal_branch() + expected = 9 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + result = replay_buffer.fractal_branch() + expected = 27 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + result = replay_buffer.fractal_branch() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + result = replay_buffer.fractal_branch() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + del result, expected print("\033[32mTEST PASSED \033[0m _branch() tests passed") From 27360ca4ffcbde3d6bbc4978cf5d1120a4cfd901 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 23 Jul 2025 21:33:26 -0500 Subject: [PATCH 051/141] Refactor current_depth handling in FractalSymmetryReplayBuffer and enhance tests for fractal_branch and time_split methods --- .../data/fractal_symmetry_replay_buffer.py | 18 ++++--- serl_launcher/serl_launcher/data/fsrb_test.py | 47 ++++++++++++++++--- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index ef5056a0..55188efd 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -79,7 +79,7 @@ def __init__( self.y_obs_idx = kwargs["y_obs_idx"] self.workspace_width = workspace_width self.timestep = 0 - self.current_depth = 1 + self.current_depth = 0 self.branch_index = [workspace_width/2] for k in kwargs.keys(): @@ -110,15 +110,12 @@ def transform(self, data_dict: DatasetDict, transform: np.array): # FOR TESTING ONLY def test_branch(self): + self.current_depth += 1 return 1 def fractal_branch(self): # return a new number of branches = branching_factor ^ depth - temp = self.branching_factor ** self.current_depth - self.current_depth += 1 - if self.current_depth > self.max_depth: - return self.branching_factor ** self.max_depth - return temp + return self.branching_factor ** self.current_depth # REQUIRES TESTING # def constant_branch(self): @@ -136,9 +133,10 @@ def test_split(self, data_dict: DatasetDict): # REQUIRES TESTING def time_split(self, data_dict: DatasetDict): - if self.timestep % (self.max_steps/self.max_depth) == 0: - return True - return False + if self.timestep % (self.max_steps/self.max_depth) or self.current_depth >= self.max_depth: + return False + self.current_depth += 1 + return True def constant_split(self, data_dict: DatasetDict): return True @@ -174,6 +172,6 @@ def insert(self, data_dict_not: DatasetDict): super().insert(new_data_dict) if data_dict["dones"]: - self.current_depth = 1 + self.current_depth = 0 self.max_steps = int(self.timestep * self.alpha + self.max_steps * (1 - self.alpha)) \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 558162c9..cff3e474 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -68,32 +68,67 @@ def main(_): ## fractal - replay_buffer.max_depth = 4 replay_buffer.branching_factor = 3 + replay_buffer.current_depth = 1 result = replay_buffer.fractal_branch() expected = 3 assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 2 result = replay_buffer.fractal_branch() expected = 9 assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 3 result = replay_buffer.fractal_branch() expected = 27 assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 4 result = replay_buffer.fractal_branch() expected = 81 assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" - result = replay_buffer.fractal_branch() - expected = 81 - assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 0 del result, expected - print("\033[32mTEST PASSED \033[0m _branch() tests passed") + print("\033[32mTEST PASSED \033[0m fractal_branch() tests passed") # split() tests + ## time + replay_buffer.max_steps = 100 + replay_buffer.max_depth = 4 + + replay_buffer.timestep = 0 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 25 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 50 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 75 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + + replay_buffer.timestep = 100 + result = replay_buffer.time_split(data_dict) + expected = False + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + - print("\033[32mTEST PASSED \033[0m _split() tests passed") + print("\033[32mTEST PASSED \033[0m time_split() test passed") # insert() tests From a28a959acd38df5758115e00d64b4a5fb387ebf1 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 11:59:13 -0500 Subject: [PATCH 052/141] Scaled actions in reach script and changed async_sac_state to immediately use demo data in the replay buffer. - Script was producing actions that had dimensions in the 100s. After clipping, values were always \pm1. This would not allow the neural network to learn well from these numbers. An ACTION_MAX value was set to 10 to produce more realistic actions. TODO: analyze natural action dimensions in SERL and approximate as closely as possible. - In the franka_reach_demo_script, updated error_thrshold used to get gripper closer to object. Currently set at 0.01m. - This changes the reward values produced. Updated panda_reach_gym_env step function such that when the rewards is greater than some threshold we return a done that is True. Updated async_sac_state - increase size of actor QueuedDataStore based on the total number of transitions available in the demos. - set training_starts to 0 since we have good demo data to start with. --- demos/demoHandling.py | 27 ++++++----- demos/franka_reach_demo_script.py | 23 +++++++--- .../async_sac_state_sim.py | 46 ++++++++++++++----- .../franka_sim/envs/panda_reach_gym_env.py | 2 +- 4 files changed, 65 insertions(+), 33 deletions(-) diff --git a/demos/demoHandling.py b/demos/demoHandling.py index be984192..90c901c1 100644 --- a/demos/demoHandling.py +++ b/demos/demoHandling.py @@ -20,8 +20,6 @@ class DemoHandling: Parameters ---------- - data_store : QueuedDataStore - The data store to which the demo transitions will be added. demo_dir : str Directory where demo .npz files live by default. file_name : str @@ -29,19 +27,11 @@ class DemoHandling: """ def __init__( self, - data_store: QueuedDataStore, demo_dir: str = '/data/data/serl/demos', file_name: str = 'data_franka_reach_random_20.npz' ): - """ - Parameters - ---------- - data_store : QueuedDataStore - The replay buffer or datastore to insert transitions into. - demo_dir : str - Directory where demo .npz files live by default. - """ - self.data_store = data_store + + self.debug = False # Set to True for debugging purposes self.demo_dir = demo_dir self.transition_ctr = 0 # Global counter for transitions across all episodes @@ -59,7 +49,13 @@ def __init__( # Load the .npz file self.data = np.load(self.demo_npz_path, allow_pickle=True) - def insert_data_to_buffer(self): + def get_num_transitions(self): + """ + Returns the total number of transitions counted in the demo data. + """ + return self.data["transition_ctr"] if "transition_ctr" in self.data else 0 + + def insert_data_to_buffer(self,data_store: QueuedDataStore): """ Load a raw Gymnasium-style .npz of expert episodes into data_store. The .npz file must contain arrays named 'obs', 'acs', 'rewards', @@ -94,7 +90,10 @@ def insert_data_to_buffer(self): done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) # masks will be created right before insert below - self.data_store.insert( + if self.debug: + print(f"Demo {ep}, Step {t}: Obs={obs_t}, Action={a_t}, Reward={r_t}, Done={done_t}") + + data_store.insert( dict( observations =obs_t, actions =a_t, diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py index 6d68bb9d..4404a1d5 100755 --- a/demos/franka_reach_demo_script.py +++ b/demos/franka_reach_demo_script.py @@ -34,10 +34,16 @@ dones = [] # List storing done flags (terminated or truncated) for each episode transition_ctr = 0 # Global counter for transitions across all episodes +#------------------------------------------------------------------------------------------- +## Key Config Variables +#------------------------------------------------------------------------------------------- # Proportional and derivative control gain for action scaling -- empirically tuned Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 Kv = 10.0 +ACTION_MAX = 10 # Maximum action value for clipping actions +ERROR_THRESHOLD = 0.01 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. + # Number of demonstration episodes to generate NUM_DEMOS = 20 @@ -45,14 +51,14 @@ robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' -DEBUG = False +# Debug mode for rendering and visualization +DEBUG = True if DEBUG: _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' else: _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI - # Indices for franka_sim reach environment observations if robot == 'franka' and task == 'reach': @@ -64,7 +70,7 @@ # Weld constraint flag weld_flag = True # Flag to activate weld constraint during pick-and-place - +#------------------------------------------------------------------------------------------- # Franka sim environments do not have weld constraints like the franka_mujoco environments. # def activate_weld(env, constraint_name="grasp_weld"): # """ @@ -118,7 +124,7 @@ def set_front_cam_view(env): if hasattr(viewer, 'cam'): viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) viewer.cam.distance = 3.0 # Camera distance - viewer.cam.azimuth = 180 # 0 = right, 90 = front, 180 = left + viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left viewer.cam.elevation = -30 # Negative = above, positive = below return viewer @@ -261,7 +267,7 @@ def demo(env, lastObs): fgr_pos = 0 # Error thresholds - error_threshold = 0.03 # Threshold for stopping condition (Xmm) + error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger @@ -299,6 +305,9 @@ def demo(env, lastObs): action[:3] = error * Kp + derror * Kv prev_error = error.copy() # Update previous error for next iteration + # Clip action to prevent excessive movements + action = np.clip(action/ACTION_MAX, -0.1, 0.1) # + # Keep gripper closed -- no need. only 3 dimensions of control #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. @@ -419,7 +428,9 @@ def main(): terminateds = terminateds, truncateds = truncateds, dones = dones, - transition_ctr = transition_ctr) + transition_ctr = transition_ctr, + num_demos = num_demos + ) print(f"Data saved to {fileName}.") print(f"Total successful demos: {demo_ctr}/{num_demos}") diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 258f7f0b..d2469825 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -79,7 +79,7 @@ def print_green(x): ############################################################################## -def actor(agent: SACAgent, data_store, env, sampling_rng): +def actor(agent: SACAgent, data_store, env, sampling_rng, demos_handler=None): """ This is the actor loop, which runs when "--actor" is set to True. """ @@ -114,16 +114,11 @@ def update_params(params): # Load demos: handler.run will insert all transition demo data into the data store. if FLAGS.load_demos: - with timer.context("sample and step into env"): - - # handler loads data - handler = DemoHandling(data_store, - demo_dir=FLAGS.demo_dir, - file_name=FLAGS.demo_dir) - + with timer.context("sample and step into env with loaded demos"): + # Insert complete demonstration into the data store - print(f"Inserting {handler.data["transition_ctr"]} transitions into the data store.") - handler.insert_data_to_buffer() + print(f"Inserting {demos_handler.data["transition_ctr"]} transitions into the data store.") + demos_handler.insert_data_to_buffer(data_store) else: with timer.context("sample_actions"): @@ -303,6 +298,24 @@ def main(_): jax.tree.map(jnp.array, agent), sharding.replicate() ) + # Demo Data + if FLAGS.load_demos: + print_green("loading demo data") + # Create a handler for the demo data + demos_handler = DemoHandling( + demo_dir=FLAGS.demo_dir, + file_name=FLAGS.file_name, + ) + + # 1. Modify actor data_store size + # Extract number of demo transitions + demo_transitions = demos_handler.get_num_transitions() + qds_size = demo_transitions + 1000 # the queue size on the actor + + # 2. Modify training starts (since we have good data) + FLAGS.training_starts = 0 + + if FLAGS.learner: sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) replay_buffer = make_replay_buffer( @@ -334,11 +347,20 @@ def main(_): elif FLAGS.actor: sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) - data_store = QueuedDataStore(2000) # the queue size on the actor + + if FLAGS.load_demos: + print_green("loading demo data") + + # Create a data store for the actor + data_store = QueuedDataStore(qds_size) # the queue size on the actor + else: + print_green("no demo data, using empty data store") + # Create a data store for the actor + data_store = QueuedDataStore(2000) # the queue size on the actor # actor loop print_green("starting actor loop") - actor(agent, data_store, env, sampling_rng) + actor(agent, data_store, env, sampling_rng,demos_handler) else: raise NotImplementedError("Must be either a learner or an actor") diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py index 44ad2667..f9b6baad 100644 --- a/franka_sim/franka_sim/envs/panda_reach_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -217,7 +217,7 @@ def step( rew = self._compute_reward() # For reach environment, finger - if rew >= 0.55: + if rew >= 0.82: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 terminated = True else: # Check if the time limit is exceeded. From c499cd6210db62a36a7c333fa254c3f96afa103b Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 12:11:08 -0500 Subject: [PATCH 053/141] include demo flags --- examples/async_sac_state_sim/.vscode/launch.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 60496b7a..1faedff0 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -20,6 +20,11 @@ "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity + + // Demonstration data loading options + "--load_demos", // Load demo dataset + "--demo_dir", "/data/data/serl/demos/data_franka_reach_random", + "--file_name", "data_franka_reach_random_20.npz" // Name of the demo file to load ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries From da91b853fed5ee89ef9a009c31bbf7f663cb911b Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 12:12:31 -0500 Subject: [PATCH 054/141] created more robust and successful bash script from which more advanced testing can be done --- examples/async_sac_state_sim/testing.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/async_sac_state_sim/testing.sh b/examples/async_sac_state_sim/testing.sh index fb2072c9..43c7d512 100755 --- a/examples/async_sac_state_sim/testing.sh +++ b/examples/async_sac_state_sim/testing.sh @@ -37,6 +37,6 @@ send_tmux_command() { # === BASELINE RUNS === # tmux_script_launch.sh must have ./ in front to tell tmux to execute the script in the current directory. -for seed in {1}; do # Loop +for seed in {1..2}; do # Loop send_tmux_command "./tmux_script_launch.sh --env $env --name ${env}-baseline --type replay_buffer --seed $seed" done From 1e89d4c7b56e1700fcd1f9871524c972dcf2a7d7 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 14:16:07 -0500 Subject: [PATCH 055/141] hide viewer --- demos/franka_reach_demo_script.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py index 4404a1d5..4336752c 100755 --- a/demos/franka_reach_demo_script.py +++ b/demos/franka_reach_demo_script.py @@ -126,6 +126,9 @@ def set_front_cam_view(env): viewer.cam.distance = 3.0 # Camera distance viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True return viewer From 863ada5174791770b6239184d424e1c36aea4483 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 24 Jul 2025 14:40:05 -0500 Subject: [PATCH 056/141] progress on automated testing --- examples/async_sac_state_sim/run_tests.py | 85 +++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 examples/async_sac_state_sim/run_tests.py diff --git a/examples/async_sac_state_sim/run_tests.py b/examples/async_sac_state_sim/run_tests.py new file mode 100644 index 00000000..46a73723 --- /dev/null +++ b/examples/async_sac_state_sim/run_tests.py @@ -0,0 +1,85 @@ + +import subprocess +from absl import app + +def actor(kwargs:dict): + script = "python async_sac_state_sim.py --actor" + for k, v in kwargs.items(): + script += f" --{k}" + script += f" {v}" + return script + +def learner(kwargs:dict): + script = "python async_sac_state_sim.py --learner" + for k, v in kwargs.items(): + script += f" --{k}" + script += f" {v}" + return script + +def main(_): + env = "PandaReachCube-v0" + random_steps = 1000 + max_steps_learner = 500 + max_steps_actor = max_steps_learner * 2 + training_starts = 1000 + critic_actor_ratio = 8 + batch_size = 256 + replay_buffer_capacity = 1000000 + save_model = True + + rb_args = { + "env": env, + "random_steps": random_steps, + "max_steps": max_steps_actor, + "training_starts": training_starts, + "critic_actor_ratio": critic_actor_ratio, + "batch_size": batch_size, + "replay_buffer_capacity": replay_buffer_capacity, + "save_model": save_model, + } + + subprocess.Popen(['tmux', 'new-session', '-s', '-d', 'serl_session'], shell=True) + + # baseline + rb_args["replay_buffer_type"] = "replay_buffer" + rb_args["exp_name"] = f"{env}-baseline" + for seed in range(0, 5): + rb_args["seed"] = seed + rb_args["max_steps"] = max_steps_actor + actor_process = subprocess.Popen(['tmux', 'send-keys', '-t', 'serl_session:0.0', actor(kwargs=rb_args), 'C-m'], shell=True) + + + actor_process.wait() + subprocess.run("tmux split-window -v", shell=True, check=True) + rb_args["max_steps"] = max_steps_learner + learner_process = subprocess.Popen(['tmux', 'send-keys', '-t', 'serl_session:0.1', learner(kwargs=rb_args), 'C-m'], shell=True) + subprocess.run("tmux attach-session -t serl_session", shell=True) + + learner_process.wait() + actor_process.terminate() + actor_process.wait() + + + # constant + rb_args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" + rb_args["branch_method"] = "constant" + rb_args["split_method"] = "constant" + for b in [1, 3, 9, 27]: + rb_args["starting_branch_count"] = b + rb_args["exp_name"] = f"{env}-{b}x{b}" + for seed in range(0, 5): + rb_args["seed"] = seed + rb_args["max_steps"] = max_steps_actor + actor_process = subprocess.Popen(["tmux", "send-keys", "-t", "serl_session:0.0", actor(kwargs=rb_args)]) + + rb_args["max_steps"] = max_steps_learner + learner_process = subprocess.Popen(learner(kwargs=rb_args), shell=True) + + learner_process.wait() + actor_process.terminate() + actor_process.wait() + + +if __name__ == "__main__": + app.run(main) + \ No newline at end of file From fe292d97213dd8238d4c45d36bd6bc66a300d9e2 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 16:19:46 -0500 Subject: [PATCH 057/141] Modified suprocess script using functions and logging --- .../async_sac_state_sim/run_tests_funcs.py | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 examples/async_sac_state_sim/run_tests_funcs.py diff --git a/examples/async_sac_state_sim/run_tests_funcs.py b/examples/async_sac_state_sim/run_tests_funcs.py new file mode 100644 index 00000000..5d814261 --- /dev/null +++ b/examples/async_sac_state_sim/run_tests_funcs.py @@ -0,0 +1,117 @@ +import subprocess +import os +from absl import app + +LOG_DIR = "serl_logs" +os.makedirs(LOG_DIR, exist_ok=True) + +def actor_cmd(kwargs: dict) -> str: + cmd = "python async_sac_state_sim.py --actor" + for k, v in kwargs.items(): + cmd += f" --{k} {v}" + return cmd + +def learner_cmd(kwargs: dict) -> str: + cmd = "python async_sac_state_sim.py --learner" + for k, v in kwargs.items(): + cmd += f" --{k} {v}" + return cmd + +def run_tmux_command(cmd: str): + return subprocess.run(cmd, shell=True, check=True) + +def send_tmux_command(target: str, command: str, env: str = "serl", workdir: str = "./"): + + full_command = ( + f"conda activate {env} && " + f"cd {workdir} && " + f"{command}" + ) + + subprocess.Popen( + f'tmux send-keys -t {target} "{full_command}" C-m', + shell=True + # stdout=subprocess.DEVNULL, + # stderr=subprocess.DEVNULL + ) + +def ensure_tmux_session(session_name: str): + result = subprocess.run( + ["tmux", "has-session", "-t", session_name] + # stdout=subprocess.DEVNULL, + # stderr=subprocess.DEVNULL + ) + if result.returncode != 0: + subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) + subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) + +def main(_): + session = "serl_session" + ensure_tmux_session(session) + + env = "PandaReachCube-v0" + random_steps = 1000 + max_steps_learner = 10000 + max_steps_actor = max_steps_learner * 2 + training_starts = 1000 + critic_actor_ratio = 8 + batch_size = 256 + replay_buffer_capacity = 1000000 + save_model = True + + base_args = { + "env": env, + "random_steps": random_steps, + "training_starts": training_starts, + "critic_actor_ratio": critic_actor_ratio, + "batch_size": batch_size, + "replay_buffer_capacity": replay_buffer_capacity, + "save_model": save_model, + } + + # --- Baseline --- + print("Running baseline...") + args = base_args.copy() + args["replay_buffer_type"] = "replay_buffer" + args["exp_name"] = f"{env}-baseline" + + for seed in range(1): + args["seed"] = seed + args["max_steps"] = max_steps_actor + actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") + learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") + + send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files + args["max_steps"] = max_steps_learner + send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)}")# > {learner_log} 2>&1") + + # Optional: Wait before next round + print(f"Launched baseline with seed={seed}") + + # --- Fractal Symmetry Replay Buffer --- + # print("Running fractal symmetry buffer variations...") + # args = base_args.copy() + # args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" + # args["branch_method"] = "constant" + # args["split_method"] = "constant" + + # for b in [1, 3, 9, 27]: + # args["starting_branch_count"] = b + # args["exp_name"] = f"{env}-{b}x{b}" + # for seed in range(5): + # args["seed"] = seed + # args["max_steps"] = max_steps_actor + # actor_log = os.path.join(LOG_DIR, f"actor_b{b}_s{seed}.log") + # learner_log = os.path.join(LOG_DIR, f"learner_b{b}_s{seed}.log") + + # send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)} > {actor_log} 2>&1") + # args["max_steps"] = max_steps_learner + # send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)} > {learner_log} 2>&1") + + # print(f"Launched branch={b}, seed={seed}") + + # Optional: Attach manually to see progress + # run_tmux_command(f"tmux attach-session -t {session}") + +if __name__ == "__main__": + app.run(main) From f97c2917fec0f3ab4e81bb4deb4e74266ad287f2 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 24 Jul 2025 16:31:17 -0500 Subject: [PATCH 058/141] update for run_tests.py --- examples/async_sac_state_sim/run_tests.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/examples/async_sac_state_sim/run_tests.py b/examples/async_sac_state_sim/run_tests.py index 46a73723..dfdd7773 100644 --- a/examples/async_sac_state_sim/run_tests.py +++ b/examples/async_sac_state_sim/run_tests.py @@ -1,16 +1,17 @@ import subprocess from absl import app +import time def actor(kwargs:dict): - script = "python async_sac_state_sim.py --actor" + script = "python3 /home/ryan_vanderstelt/serl/examples/async_sac_state_sim/async_sac_state_sim.py --actor" for k, v in kwargs.items(): script += f" --{k}" script += f" {v}" return script def learner(kwargs:dict): - script = "python async_sac_state_sim.py --learner" + script = "python3 /home/ryan_vanderstelt/serl/examples/async_sac_state_sim/async_sac_state_sim.py --learner" for k, v in kwargs.items(): script += f" --{k}" script += f" {v}" @@ -19,7 +20,7 @@ def learner(kwargs:dict): def main(_): env = "PandaReachCube-v0" random_steps = 1000 - max_steps_learner = 500 + max_steps_learner = 10000 max_steps_actor = max_steps_learner * 2 training_starts = 1000 critic_actor_ratio = 8 @@ -38,7 +39,9 @@ def main(_): "save_model": save_model, } - subprocess.Popen(['tmux', 'new-session', '-s', '-d', 'serl_session'], shell=True) + subprocess.Popen(['tmux', 'new-session', '-d', '-s', 'serl_session'], shell=True) + time.sleep(10) + subprocess.run("tmux split-window -v", shell=True) # baseline rb_args["replay_buffer_type"] = "replay_buffer" @@ -46,38 +49,35 @@ def main(_): for seed in range(0, 5): rb_args["seed"] = seed rb_args["max_steps"] = max_steps_actor - actor_process = subprocess.Popen(['tmux', 'send-keys', '-t', 'serl_session:0.0', actor(kwargs=rb_args), 'C-m'], shell=True) + actor_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.0 \"conda activate serl && {actor(kwargs=rb_args)}\" C-m", shell=True) - actor_process.wait() - subprocess.run("tmux split-window -v", shell=True, check=True) rb_args["max_steps"] = max_steps_learner - learner_process = subprocess.Popen(['tmux', 'send-keys', '-t', 'serl_session:0.1', learner(kwargs=rb_args), 'C-m'], shell=True) + learner_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.1 \"conda activate serl && {learner(kwargs=rb_args)}\" C-m", shell=True) subprocess.run("tmux attach-session -t serl_session", shell=True) learner_process.wait() actor_process.terminate() actor_process.wait() - # constant - rb_args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - rb_args["branch_method"] = "constant" - rb_args["split_method"] = "constant" - for b in [1, 3, 9, 27]: - rb_args["starting_branch_count"] = b - rb_args["exp_name"] = f"{env}-{b}x{b}" - for seed in range(0, 5): - rb_args["seed"] = seed - rb_args["max_steps"] = max_steps_actor - actor_process = subprocess.Popen(["tmux", "send-keys", "-t", "serl_session:0.0", actor(kwargs=rb_args)]) - - rb_args["max_steps"] = max_steps_learner - learner_process = subprocess.Popen(learner(kwargs=rb_args), shell=True) - - learner_process.wait() - actor_process.terminate() - actor_process.wait() + # rb_args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" + # rb_args["branch_method"] = "constant" + # rb_args["split_method"] = "constant" + # for b in [1, 3, 9, 27]: + # rb_args["starting_branch_count"] = b + # rb_args["exp_name"] = f"{env}-{b}x{b}" + # for seed in range(0, 5): + # rb_args["seed"] = seed + # rb_args["max_steps"] = max_steps_actor + # actor_process = subprocess.Popen(["tmux", "send-keys", "-t", "serl_session:0.0", actor(kwargs=rb_args)]) + + # rb_args["max_steps"] = max_steps_learner + # learner_process = subprocess.Popen(learner(kwargs=rb_args), shell=True) + + # learner_process.wait() + # actor_process.terminate() + # actor_process.wait() if __name__ == "__main__": From 159051dad54c74af8c666829a9d58dee8d33f857 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 24 Jul 2025 16:44:49 -0500 Subject: [PATCH 059/141] reverted run scripts --- examples/async_sac_state_sim/run_actor.sh | 108 ++------------------ examples/async_sac_state_sim/run_learner.sh | 105 ++----------------- examples/async_sac_state_sim/tmux_launch.sh | 96 +---------------- 3 files changed, 17 insertions(+), 292 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index e97cca8c..af173dba 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,97 +1,5 @@ #!/bin/bash -# Default variable values -seed=0 -env="PandaReachCube-v0" -branch="test" -type="fractal_symmetry_replay_buffer" -exp_name="untitled_experiment" -starting_branch_count=1 - -has_argument() { - [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; -} - -extract_argument() { - echo "${2:-${1#*=}}" -} - -# Function to handle options and arguments -handle_options() { - while [ $# -gt 0 ]; do - case $1 in - -e | --env*) - if ! has_argument $@; then - echo "Environment not specified" >&2 - exit 1 - fi - - env=$(extract_argument $@) - - shift - ;; - -s | --seed*) - if ! has_argument $@; then - echo "Seed needs a positive integer value" >&2 - exit 1 - fi - - seed=$(extract_argument $@) - - shift - ;; - -t | --type*) - if ! has_argument $@; then - echo "Type needs a string" >&2 - exit 1 - fi - - type=$(extract_argument $@) - - shift - ;; - -n | --name*) - if ! has_argument $@; then - echo "Name needs a string" >&2 - exit 1 - fi - - exp_name=$(extract_argument $@) - - shift - ;; - --starting_branch*) - if ! has_argument $@; then - echo "Starting_branch needs a positive odd integer" >&2 - exit 1 - fi - - starting_branch_count=$(extract_argument $@) - - shift - ;; - -b | --branch*) - if ! has_argument $@; then - echo "Branch needs a string" >&2 - exit 1 - fi - - branch=$(extract_argument $@) - - shift - ;; - *) - echo "Invalid option: $1" >&2 - exit 1 - ;; - esac - shift - done -} - -# Main script execution -handle_options "$@" - export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ @@ -109,14 +17,14 @@ fi python async_sac_state_sim.py \ --actor \ - --env $env \ - --exp_name=$exp_name \ - --seed $seed \ - --replay_buffer_type $type \ - --branch_method $branch \ - --starting_branch_count $starting_branch_count \ - --random_steps 1000 \ - --max_steps 100000 \ + --env PandaReachCube-v0 \ + --exp_name serl-reach \ + --seed 0 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --branch_method test \ + --split_method test \ + --starting_branch_count 1 \ + --max_steps 1000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 09809593..ee9d39a9 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,97 +1,5 @@ #!/bin/bash -# Default variable values -seed=0 -env="PandaReachCube-v0" -branch="test" -type="fractal_symmetry_replay_buffer" -exp_name="untitled_experiment" -starting_branch_count=1 - -has_argument() { - [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; -} - -extract_argument() { - echo "${2:-${1#*=}}" -} - -# Function to handle options and arguments -handle_options() { - while [ $# -gt 0 ]; do - case $1 in - -e | --env*) - if ! has_argument $@; then - echo "Environment not specified" >&2 - exit 1 - fi - - env=$(extract_argument $@) - - shift - ;; - -s | --seed*) - if ! has_argument $@; then - echo "Seed needs a positive integer value" >&2 - exit 1 - fi - - seed=$(extract_argument $@) - - shift - ;; - -t | --type*) - if ! has_argument $@; then - echo "Type needs a string" >&2 - exit 1 - fi - - type=$(extract_argument $@) - - shift - ;; - -n | --name*) - if ! has_argument $@; then - echo "Name needs a string" >&2 - exit 1 - fi - - exp_name=$(extract_argument $@) - - shift - ;; - --starting_branch*) - if ! has_argument $@; then - echo "Starting_branch needs a positive odd integer" >&2 - exit 1 - fi - - starting_branch_count=$(extract_argument $@) - - shift - ;; - -b | --branch*) - if ! has_argument $@; then - echo "Branch needs a string" >&2 - exit 1 - fi - - branch=$(extract_argument $@) - - shift - ;; - *) - echo "Invalid option: $1" >&2 - exit 1 - ;; - esac - shift - done -} - -# Main script execution -handle_options "$@" - export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ @@ -109,12 +17,13 @@ fi python async_sac_state_sim.py \ --learner \ - --env $env \ - --exp_name=$exp_name \ - --seed $seed \ - --replay_buffer_type $type \ - --branch_method $branch \ - --starting_branch_count $starting_branch_count \ + --env PandaReachCube-v0 \ + --exp_name serl-reach \ + --seed 0 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --branch_method test \ + --split_method test \ + --starting_branch_count 1 \ --max_steps 1000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 6145f712..6064bd29 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -1,97 +1,5 @@ #!/bin/bash -# Default variable values -seed=0 -env="PandaReachCube-v0" -branch="test" -type="fractal_symmetry_replay_buffer" -name="untitled_experiment" -starting_branch=1 - -has_argument() { - [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; -} - -extract_argument() { - echo "${2:-${1#*=}}" -} - -# Function to handle options and arguments -handle_options() { - while [ $# -gt 0 ]; do - case $1 in - -e | --env*) - if ! has_argument $@; then - echo "Environment not specified" >&2 - exit 1 - fi - - env=$(extract_argument $@) - - shift - ;; - -s | --seed*) - if ! has_argument $@; then - echo "Seed needs a positive integer value" >&2 - exit 1 - fi - - seed=$(extract_argument $@) - - shift - ;; - -t | --type*) - if ! has_argument $@; then - echo "Type needs a string" >&2 - exit 1 - fi - - type=$(extract_argument $@) - - shift - ;; - -n | --name*) - if ! has_argument $@; then - echo "Name needs a string" >&2 - exit 1 - fi - - name=$(extract_argument $@) - - shift - ;; - --starting_branch*) - if ! has_argument $@; then - echo "Starting_branch needs a positive odd integer" >&2 - exit 1 - fi - - starting_branch=$(extract_argument $@) - - shift - ;; - -b | --branch*) - if ! has_argument $@; then - echo "Branch needs a string" >&2 - exit 1 - fi - - branch=$(extract_argument $@) - - shift - ;; - *) - echo "Invalid option: $1" >&2 - exit 1 - ;; - esac - shift - done -} - -# Main script execution -handle_options "$@" - # EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} CONDA_ENV=${CONDA_ENV:-"serl"} @@ -105,10 +13,10 @@ tmux new-session -d -s serl_session tmux split-window -v # Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh --seed $seed --env $env --name $name --type $type --branch $branch --starting_branch $starting_branch" C-m +tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m # Navigate to the activate the conda environment in the second pane -tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh --seed $seed --env $env --name $name --type $type --branch $branch --starting_branch $starting_branch && tmux kill-session -t serl_session:0.0" C-m +tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m # Attach to the tmux session tmux attach-session -t serl_session From ec55fcb511e4d313110a5604b6f0fd612a77b54d Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 16:59:58 -0500 Subject: [PATCH 060/141] test if relying on child actor/learner .sh files for conda might solve VRAM memory problems we are facing --- examples/async_sac_state_sim/run_tests_funcs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/async_sac_state_sim/run_tests_funcs.py b/examples/async_sac_state_sim/run_tests_funcs.py index 5d814261..2cc251fa 100644 --- a/examples/async_sac_state_sim/run_tests_funcs.py +++ b/examples/async_sac_state_sim/run_tests_funcs.py @@ -81,9 +81,9 @@ def main(_): actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") - send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files + send_tmux_command(f"{session}:0.0", f"./run_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files args["max_steps"] = max_steps_learner - send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)}")# > {learner_log} 2>&1") + send_tmux_command(f"{session}:0.1", f"./run_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") # Optional: Wait before next round print(f"Launched baseline with seed={seed}") From 53e0612ecead98a7271efd9e36a52a7f43cbf731 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 24 Jul 2025 18:01:57 -0500 Subject: [PATCH 061/141] commit for consistency --- examples/async_sac_state_sim/run_actor.sh | 6 ++--- examples/async_sac_state_sim/run_learner.sh | 6 ++--- examples/async_sac_state_sim/run_tests.py | 8 +++--- examples/async_sac_state_sim/run_tests.sh | 30 +++------------------ 4 files changed, 14 insertions(+), 36 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index af173dba..237a8022 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -20,15 +20,15 @@ python async_sac_state_sim.py \ --env PandaReachCube-v0 \ --exp_name serl-reach \ --seed 0 \ - --replay_buffer_type fractal_symmetry_replay_buffer \ + --replay_buffer_type replay_buffer \ --branch_method test \ --split_method test \ --starting_branch_count 1 \ - --max_steps 1000 \ + --max_steps 40000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --replay_buffer_capacity 1000000 \ + --replay_buffer_capacity 100000 \ --save_model True \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index ee9d39a9..a4a2210e 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -20,15 +20,15 @@ python async_sac_state_sim.py \ --env PandaReachCube-v0 \ --exp_name serl-reach \ --seed 0 \ - --replay_buffer_type fractal_symmetry_replay_buffer \ + --replay_buffer_type replay_buffer \ --branch_method test \ --split_method test \ --starting_branch_count 1 \ - --max_steps 1000 \ + --max_steps 20000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --replay_buffer_capacity 1000000 \ + --replay_buffer_capacity 100000 \ --save_model True \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ diff --git a/examples/async_sac_state_sim/run_tests.py b/examples/async_sac_state_sim/run_tests.py index dfdd7773..4167eebc 100644 --- a/examples/async_sac_state_sim/run_tests.py +++ b/examples/async_sac_state_sim/run_tests.py @@ -39,8 +39,8 @@ def main(_): "save_model": save_model, } - subprocess.Popen(['tmux', 'new-session', '-d', '-s', 'serl_session'], shell=True) - time.sleep(10) + subprocess.run(['tmux', 'new-session', '-d', '-s', 'serl_session']) + # time.sleep(10) subprocess.run("tmux split-window -v", shell=True) # baseline @@ -49,11 +49,11 @@ def main(_): for seed in range(0, 5): rb_args["seed"] = seed rb_args["max_steps"] = max_steps_actor - actor_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.0 \"conda activate serl && {actor(kwargs=rb_args)}\" C-m", shell=True) + actor_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.0 \"conda run -vvv -n serl {actor(kwargs=rb_args)}\" C-m", shell=True) rb_args["max_steps"] = max_steps_learner - learner_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.1 \"conda activate serl && {learner(kwargs=rb_args)}\" C-m", shell=True) + learner_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.1 \"conda run -vvv -n serl '{learner(kwargs=rb_args)}'\" C-m", shell=True) subprocess.run("tmux attach-session -t serl_session", shell=True) learner_process.wait() diff --git a/examples/async_sac_state_sim/run_tests.sh b/examples/async_sac_state_sim/run_tests.sh index 9db3bb6d..0bd80620 100644 --- a/examples/async_sac_state_sim/run_tests.sh +++ b/examples/async_sac_state_sim/run_tests.sh @@ -1,30 +1,8 @@ #!/bin/bash -env="PandaReachCube-v0" +tmux new-session -d -s serl_session "python run_tests.py" -trap "kill 0" EXIT INT TERM +tmux attach-session -t serl_session -# Baseline - -cmd="" -for i in $(seq 1 5); do - cmd+="bash -c 'tmux_launch.sh --env $env --name \'$env-baseline\' --type \'replay_buffer\' --seed $i'; " -done - - - -branch=("constant" "linear" "fractal" "OIO" "IOI" "inv_linear" "inv_fractal") -for b in "${branch[@]}"; do - # constant branching - if [ $b == "constant" ]; then - branch_factor=(1 3 9 27) - for bf in "${branch_factor[@]}"; do - b_num=$(( $bf*$bf )) - for seed in $(seq 1 5); do - cmd+="bash -c 'tmux_launch.sh --env $env --name \'$env-$branch-$b_num\' --type \'fractal_symmetry_replay_buffer\' --branch $b --starting_branch $bf --seed $seed'; " - done - done - fi -done -tmux new-session -d -s testing_session -tmux send-keys -t testing_session:0.0 $cmd C-m +# kill the tmux session by running the following command +# tmux kill-session -t serl_session From 8e281734d54fc2d0781778c4c2db15c686f01371 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 18:02:59 -0500 Subject: [PATCH 062/141] Working automated script --- .../async_sac_state_sim/automate_actor.sh | 6 + .../async_sac_state_sim/automate_funcs.py | 117 ++++++++++++++++++ .../async_sac_state_sim/automate_learner.sh | 6 + 3 files changed, 129 insertions(+) create mode 100644 examples/async_sac_state_sim/automate_actor.sh create mode 100644 examples/async_sac_state_sim/automate_funcs.py create mode 100644 examples/async_sac_state_sim/automate_learner.sh diff --git a/examples/async_sac_state_sim/automate_actor.sh b/examples/async_sac_state_sim/automate_actor.sh new file mode 100644 index 00000000..d0ade5fc --- /dev/null +++ b/examples/async_sac_state_sim/automate_actor.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ + +python async_sac_state_sim.py --actor "$@" \ No newline at end of file diff --git a/examples/async_sac_state_sim/automate_funcs.py b/examples/async_sac_state_sim/automate_funcs.py new file mode 100644 index 00000000..ecc1a11f --- /dev/null +++ b/examples/async_sac_state_sim/automate_funcs.py @@ -0,0 +1,117 @@ +import subprocess +import os +from absl import app + +LOG_DIR = "serl_logs" +os.makedirs(LOG_DIR, exist_ok=True) + +def actor_cmd(kwargs: dict) -> str: + cmd = "python async_sac_state_sim.py --actor" + for k, v in kwargs.items(): + cmd += f" --{k} {v}" + return cmd + +def learner_cmd(kwargs: dict) -> str: + cmd = "python async_sac_state_sim.py --learner" + for k, v in kwargs.items(): + cmd += f" --{k} {v}" + return cmd + +def run_tmux_command(cmd: str): + return subprocess.run(cmd, shell=True, check=True) + +def send_tmux_command(target: str, command: str, env: str = "serl", workdir: str = "./"): + + full_command = ( + f"conda activate {env} && " + f"cd {workdir} && " + f"{command}" + ) + + subprocess.Popen( + f'tmux send-keys -t {target} "{full_command}" C-m', + shell=True + # stdout=subprocess.DEVNULL, + # stderr=subprocess.DEVNULL + ) + +def ensure_tmux_session(session_name: str): + result = subprocess.run( + ["tmux", "has-session", "-t", session_name] + # stdout=subprocess.DEVNULL, + # stderr=subprocess.DEVNULL + ) + if result.returncode != 0: + subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) + subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) + +def main(_): + session = "serl_session" + ensure_tmux_session(session) + + env = "PandaReachCube-v0" + random_steps = 500 + max_steps = 30000 + training_starts = 500 + critic_actor_ratio = 8 + batch_size = 256 + replay_buffer_capacity = 100000 + save_model = True + + base_args = { + "env": env, + "random_steps": random_steps, + "training_starts": training_starts, + "critic_actor_ratio": critic_actor_ratio, + "batch_size": batch_size, + "replay_buffer_capacity": replay_buffer_capacity, + "save_model": save_model, + "max_steps": max_steps + } + + # --- Baseline --- + print("Running baseline...") + args = base_args.copy() + args["replay_buffer_type"] = "replay_buffer" + args["exp_name"] = f"{env}-baseline" + + for seed in range(1): + args["seed"] = seed + # args["max_steps"] = max_steps_actor + actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") + learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") + + send_tmux_command(f"{session}:0.0", f"./run_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files + # args["max_steps"] = max_steps_learner + send_tmux_command(f"{session}:0.1", f"./run_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") + + # Optional: Wait before next round + print(f"Launched baseline with seed={seed}") + + # --- Fractal Symmetry Replay Buffer --- + # print("Running fractal symmetry buffer variations...") + # args = base_args.copy() + # args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" + # args["branch_method"] = "constant" + # args["split_method"] = "constant" + + # for b in [1, 3, 9, 27]: + # args["starting_branch_count"] = b + # args["exp_name"] = f"{env}-{b}x{b}" + # for seed in range(5): + # args["seed"] = seed + # args["max_steps"] = max_steps_actor + # actor_log = os.path.join(LOG_DIR, f"actor_b{b}_s{seed}.log") + # learner_log = os.path.join(LOG_DIR, f"learner_b{b}_s{seed}.log") + + # send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)} > {actor_log} 2>&1") + # args["max_steps"] = max_steps_learner + # send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)} > {learner_log} 2>&1") + + # print(f"Launched branch={b}, seed={seed}") + + # Optional: Attach manually to see progress + # run_tmux_command(f"tmux attach-session -t {session}") + +if __name__ == "__main__": + app.run(main) diff --git a/examples/async_sac_state_sim/automate_learner.sh b/examples/async_sac_state_sim/automate_learner.sh new file mode 100644 index 00000000..0e0b4be4 --- /dev/null +++ b/examples/async_sac_state_sim/automate_learner.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ + +python async_sac_state_sim.py --learner "$@" \ No newline at end of file From b8674e9fe9fefd7fe5f77cfd0e2daec35246cf35 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 18:06:57 -0500 Subject: [PATCH 063/141] changed for automate_funds.py --- .../async_sac_state_sim/run_tests_funcs.py | 117 ------------------ 1 file changed, 117 deletions(-) delete mode 100644 examples/async_sac_state_sim/run_tests_funcs.py diff --git a/examples/async_sac_state_sim/run_tests_funcs.py b/examples/async_sac_state_sim/run_tests_funcs.py deleted file mode 100644 index 2cc251fa..00000000 --- a/examples/async_sac_state_sim/run_tests_funcs.py +++ /dev/null @@ -1,117 +0,0 @@ -import subprocess -import os -from absl import app - -LOG_DIR = "serl_logs" -os.makedirs(LOG_DIR, exist_ok=True) - -def actor_cmd(kwargs: dict) -> str: - cmd = "python async_sac_state_sim.py --actor" - for k, v in kwargs.items(): - cmd += f" --{k} {v}" - return cmd - -def learner_cmd(kwargs: dict) -> str: - cmd = "python async_sac_state_sim.py --learner" - for k, v in kwargs.items(): - cmd += f" --{k} {v}" - return cmd - -def run_tmux_command(cmd: str): - return subprocess.run(cmd, shell=True, check=True) - -def send_tmux_command(target: str, command: str, env: str = "serl", workdir: str = "./"): - - full_command = ( - f"conda activate {env} && " - f"cd {workdir} && " - f"{command}" - ) - - subprocess.Popen( - f'tmux send-keys -t {target} "{full_command}" C-m', - shell=True - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - -def ensure_tmux_session(session_name: str): - result = subprocess.run( - ["tmux", "has-session", "-t", session_name] - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - if result.returncode != 0: - subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) - subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) - -def main(_): - session = "serl_session" - ensure_tmux_session(session) - - env = "PandaReachCube-v0" - random_steps = 1000 - max_steps_learner = 10000 - max_steps_actor = max_steps_learner * 2 - training_starts = 1000 - critic_actor_ratio = 8 - batch_size = 256 - replay_buffer_capacity = 1000000 - save_model = True - - base_args = { - "env": env, - "random_steps": random_steps, - "training_starts": training_starts, - "critic_actor_ratio": critic_actor_ratio, - "batch_size": batch_size, - "replay_buffer_capacity": replay_buffer_capacity, - "save_model": save_model, - } - - # --- Baseline --- - print("Running baseline...") - args = base_args.copy() - args["replay_buffer_type"] = "replay_buffer" - args["exp_name"] = f"{env}-baseline" - - for seed in range(1): - args["seed"] = seed - args["max_steps"] = max_steps_actor - actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") - learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") - - send_tmux_command(f"{session}:0.0", f"./run_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files - args["max_steps"] = max_steps_learner - send_tmux_command(f"{session}:0.1", f"./run_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") - - # Optional: Wait before next round - print(f"Launched baseline with seed={seed}") - - # --- Fractal Symmetry Replay Buffer --- - # print("Running fractal symmetry buffer variations...") - # args = base_args.copy() - # args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - # args["branch_method"] = "constant" - # args["split_method"] = "constant" - - # for b in [1, 3, 9, 27]: - # args["starting_branch_count"] = b - # args["exp_name"] = f"{env}-{b}x{b}" - # for seed in range(5): - # args["seed"] = seed - # args["max_steps"] = max_steps_actor - # actor_log = os.path.join(LOG_DIR, f"actor_b{b}_s{seed}.log") - # learner_log = os.path.join(LOG_DIR, f"learner_b{b}_s{seed}.log") - - # send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)} > {actor_log} 2>&1") - # args["max_steps"] = max_steps_learner - # send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)} > {learner_log} 2>&1") - - # print(f"Launched branch={b}, seed={seed}") - - # Optional: Attach manually to see progress - # run_tmux_command(f"tmux attach-session -t {session}") - -if __name__ == "__main__": - app.run(main) From 30424ed9519e831b4cf346015e235c10cac76018 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 20:46:38 -0500 Subject: [PATCH 064/141] Add tests for insert() functionality in replay buffer --- serl_launcher/serl_launcher/data/fsrb_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index cff3e474..df7f09b0 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -127,10 +127,29 @@ def main(_): expected = False assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + del result, expected print("\033[32mTEST PASSED \033[0m time_split() test passed") # insert() tests + # insert() tests + initial_size = len(replay_buffer.dataset_dict['observations'][0]) * replay_buffer._insert_index % len(replay_buffer.dataset_dict['observations']) + + replay_buffer.insert(data_dict) + final_size = len(replay_buffer.dataset_dict['observations'][0]) * replay_buffer._insert_index % len(replay_buffer.dataset_dict['observations']) + + result = final_size > initial_size + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected buffer size to increase from {initial_size} to {final_size})" + del result, expected, initial_size, final_size + + + + + print("\033[32mTEST PASSED \033[0m insert() tests passed") + + + print("\033[32mTEST PASSED \033[0m insert() tests passed") From 950162927066e1f63e07222ee301293c432b8881 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 24 Jul 2025 22:55:22 -0500 Subject: [PATCH 065/141] almost functional automate_func.py --- .../async_sac_state_sim.py | 14 +- .../async_sac_state_sim/automate_funcs.py | 155 ++++++++++-------- examples/async_sac_state_sim/run_tests.py | 2 +- .../async_sac_state_sim/run_tests_funcs.py | 117 ------------- .../data/fractal_symmetry_replay_buffer.py | 17 +- 5 files changed, 99 insertions(+), 206 deletions(-) delete mode 100644 examples/async_sac_state_sim/run_tests_funcs.py diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 45311ab8..2ad82311 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -59,14 +59,12 @@ # flags for replay buffer flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which type of replay buffer to use") -flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") -flags.DEFINE_string("split_method", "test", "Method for when to change number of branches generated") # Remember to default None -flags.DEFINE_float("workspace_width", 0.5, "Workspace width in centimeters") # Remember to default None -flags.DEFINE_integer("depth", None, "Total layers of depth") -flags.DEFINE_integer("dendrites", None, "Dendrites for fractal branching") -flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") -flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") -flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") +flags.DEFINE_string("branch_method", None, "Method for determining number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", None, "Method for determining whether to change the number of transforms per dimension (x,y)") # Remember to default None +flags.DEFINE_float("workspace_width", None, "Workspace width in meters") +flags.DEFINE_integer("max_depth", None, "Total layers of depth") +flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of transformations per dimension (x,y)") flags.DEFINE_boolean( diff --git a/examples/async_sac_state_sim/automate_funcs.py b/examples/async_sac_state_sim/automate_funcs.py index ecc1a11f..425e2d94 100644 --- a/examples/async_sac_state_sim/automate_funcs.py +++ b/examples/async_sac_state_sim/automate_funcs.py @@ -1,21 +1,13 @@ import subprocess import os +import psutil from absl import app +import copy LOG_DIR = "serl_logs" -os.makedirs(LOG_DIR, exist_ok=True) - -def actor_cmd(kwargs: dict) -> str: - cmd = "python async_sac_state_sim.py --actor" - for k, v in kwargs.items(): - cmd += f" --{k} {v}" - return cmd +SESSION = "serl_session" -def learner_cmd(kwargs: dict) -> str: - cmd = "python async_sac_state_sim.py --learner" - for k, v in kwargs.items(): - cmd += f" --{k} {v}" - return cmd +os.makedirs(LOG_DIR, exist_ok=True) def run_tmux_command(cmd: str): return subprocess.run(cmd, shell=True, check=True) @@ -28,7 +20,7 @@ def send_tmux_command(target: str, command: str, env: str = "serl", workdir: str f"{command}" ) - subprocess.Popen( + subprocess.run( f'tmux send-keys -t {target} "{full_command}" C-m', shell=True # stdout=subprocess.DEVNULL, @@ -45,73 +37,94 @@ def ensure_tmux_session(session_name: str): subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) +def getProcesses(name): + p1_pid = None + p2_pid = None + first = True + while True: + for proc in psutil.process_iter(["name"]): + if proc.info["name"] == name: + if first: + p1_pid = proc.pid + first = False + else: + p2_pid = proc.pid + break + if p1_pid and p2_pid: + if p1_pid > p2_pid: + temp = p2_pid + p2_pid = p1_pid + p1_pid = temp + p1 = psutil.Process(p1_pid) + p2 = psutil.Process(p2_pid) + return p1, p2 + +def run_test(args:dict): + + max_steps = args["max_steps"] + for seed in range(1): + args["seed"] = seed + actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") + learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") + + args["max_steps"] = max_steps * 10 + send_tmux_command(f"{SESSION}:0.0", f"bash ./automate_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files + args["max_steps"] = max_steps + send_tmux_command(f"{SESSION}:0.1", f"bash ./automate_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") + + actor, learner = getProcesses("async_sac_state") + + learner.wait() + actor.terminate() + actor.wait() + def main(_): - session = "serl_session" - ensure_tmux_session(session) + ensure_tmux_session(SESSION) env = "PandaReachCube-v0" - random_steps = 500 - max_steps = 30000 - training_starts = 500 - critic_actor_ratio = 8 - batch_size = 256 - replay_buffer_capacity = 100000 - save_model = True + base_args = { - "env": env, - "random_steps": random_steps, - "training_starts": training_starts, - "critic_actor_ratio": critic_actor_ratio, - "batch_size": batch_size, - "replay_buffer_capacity": replay_buffer_capacity, - "save_model": save_model, - "max_steps": max_steps + "env": "PandaReachCube-v0", + "random_steps": 1000, + "training_starts": 1000, + "critic_actor_ratio": 8, + "batch_size": 256, + "replay_buffer_capacity": 100000, + "save_model": True, + "max_steps": 2000 } # --- Baseline --- - print("Running baseline...") - args = base_args.copy() + args = copy.deepcopy(base_args) args["replay_buffer_type"] = "replay_buffer" - args["exp_name"] = f"{env}-baseline" - - for seed in range(1): - args["seed"] = seed - # args["max_steps"] = max_steps_actor - actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") - learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") - - send_tmux_command(f"{session}:0.0", f"./run_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files - # args["max_steps"] = max_steps_learner - send_tmux_command(f"{session}:0.1", f"./run_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") - - # Optional: Wait before next round - print(f"Launched baseline with seed={seed}") - - # --- Fractal Symmetry Replay Buffer --- - # print("Running fractal symmetry buffer variations...") - # args = base_args.copy() - # args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - # args["branch_method"] = "constant" - # args["split_method"] = "constant" - - # for b in [1, 3, 9, 27]: - # args["starting_branch_count"] = b - # args["exp_name"] = f"{env}-{b}x{b}" - # for seed in range(5): - # args["seed"] = seed - # args["max_steps"] = max_steps_actor - # actor_log = os.path.join(LOG_DIR, f"actor_b{b}_s{seed}.log") - # learner_log = os.path.join(LOG_DIR, f"learner_b{b}_s{seed}.log") - - # send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)} > {actor_log} 2>&1") - # args["max_steps"] = max_steps_learner - # send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)} > {learner_log} 2>&1") - - # print(f"Launched branch={b}, seed={seed}") - - # Optional: Attach manually to see progress - # run_tmux_command(f"tmux attach-session -t {session}") + args["exp_name"] = f"{args['env']}-baseline" + run_test(args) + + + # --- 1x1, 3x3, 9x9, 27x27 --- + args = copy.deepcopy(base_args) + args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" + args["branch_method"] = "constant" + args["split_method"] = "constant" + + for b in (1, 3, 9, 27): + args["starting_branch_count"] = b + args["exp_name"] = f"{args['env']}-{b}x{b}" + run_test(args) + + # --- Fractal Expansion 3^4, 9^2 --- + args = copy.deepcopy(base_args) + args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" + args["branch_method"] = "fractal" + args["split_method"] = "time" + args["alpha"] = 1 + + for b, d in ((3, 4), (9, 2)): + args["branching_factor"] = b + args["max_depth"] = d + args["exp_name"] = f"{args['env']}-{b}^{d}" + run_test(args) if __name__ == "__main__": app.run(main) diff --git a/examples/async_sac_state_sim/run_tests.py b/examples/async_sac_state_sim/run_tests.py index 4167eebc..5acbaba3 100644 --- a/examples/async_sac_state_sim/run_tests.py +++ b/examples/async_sac_state_sim/run_tests.py @@ -20,7 +20,7 @@ def learner(kwargs:dict): def main(_): env = "PandaReachCube-v0" random_steps = 1000 - max_steps_learner = 10000 + max_steps_learner = 5000 max_steps_actor = max_steps_learner * 2 training_starts = 1000 critic_actor_ratio = 8 diff --git a/examples/async_sac_state_sim/run_tests_funcs.py b/examples/async_sac_state_sim/run_tests_funcs.py deleted file mode 100644 index 2cc251fa..00000000 --- a/examples/async_sac_state_sim/run_tests_funcs.py +++ /dev/null @@ -1,117 +0,0 @@ -import subprocess -import os -from absl import app - -LOG_DIR = "serl_logs" -os.makedirs(LOG_DIR, exist_ok=True) - -def actor_cmd(kwargs: dict) -> str: - cmd = "python async_sac_state_sim.py --actor" - for k, v in kwargs.items(): - cmd += f" --{k} {v}" - return cmd - -def learner_cmd(kwargs: dict) -> str: - cmd = "python async_sac_state_sim.py --learner" - for k, v in kwargs.items(): - cmd += f" --{k} {v}" - return cmd - -def run_tmux_command(cmd: str): - return subprocess.run(cmd, shell=True, check=True) - -def send_tmux_command(target: str, command: str, env: str = "serl", workdir: str = "./"): - - full_command = ( - f"conda activate {env} && " - f"cd {workdir} && " - f"{command}" - ) - - subprocess.Popen( - f'tmux send-keys -t {target} "{full_command}" C-m', - shell=True - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - -def ensure_tmux_session(session_name: str): - result = subprocess.run( - ["tmux", "has-session", "-t", session_name] - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - if result.returncode != 0: - subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) - subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) - -def main(_): - session = "serl_session" - ensure_tmux_session(session) - - env = "PandaReachCube-v0" - random_steps = 1000 - max_steps_learner = 10000 - max_steps_actor = max_steps_learner * 2 - training_starts = 1000 - critic_actor_ratio = 8 - batch_size = 256 - replay_buffer_capacity = 1000000 - save_model = True - - base_args = { - "env": env, - "random_steps": random_steps, - "training_starts": training_starts, - "critic_actor_ratio": critic_actor_ratio, - "batch_size": batch_size, - "replay_buffer_capacity": replay_buffer_capacity, - "save_model": save_model, - } - - # --- Baseline --- - print("Running baseline...") - args = base_args.copy() - args["replay_buffer_type"] = "replay_buffer" - args["exp_name"] = f"{env}-baseline" - - for seed in range(1): - args["seed"] = seed - args["max_steps"] = max_steps_actor - actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") - learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") - - send_tmux_command(f"{session}:0.0", f"./run_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files - args["max_steps"] = max_steps_learner - send_tmux_command(f"{session}:0.1", f"./run_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") - - # Optional: Wait before next round - print(f"Launched baseline with seed={seed}") - - # --- Fractal Symmetry Replay Buffer --- - # print("Running fractal symmetry buffer variations...") - # args = base_args.copy() - # args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - # args["branch_method"] = "constant" - # args["split_method"] = "constant" - - # for b in [1, 3, 9, 27]: - # args["starting_branch_count"] = b - # args["exp_name"] = f"{env}-{b}x{b}" - # for seed in range(5): - # args["seed"] = seed - # args["max_steps"] = max_steps_actor - # actor_log = os.path.join(LOG_DIR, f"actor_b{b}_s{seed}.log") - # learner_log = os.path.join(LOG_DIR, f"learner_b{b}_s{seed}.log") - - # send_tmux_command(f"{session}:0.0", f"{actor_cmd(args)} > {actor_log} 2>&1") - # args["max_steps"] = max_steps_learner - # send_tmux_command(f"{session}:0.1", f"{learner_cmd(args)} > {learner_log} 2>&1") - - # print(f"Launched branch={b}, seed={seed}") - - # Optional: Attach manually to see progress - # run_tmux_command(f"tmux attach-session -t {session}") - -if __name__ == "__main__": - app.run(main) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 55188efd..7b386155 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -62,12 +62,12 @@ def __init__( # self.branch = self.linear_branch - # case "constant": - # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") - # self.current_branch_count = kwargs["starting_branch_count"] - # del kwargs["starting_branch_count"] + case "constant": + assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") + self.current_branch_count = kwargs["starting_branch_count"] + del kwargs["starting_branch_count"] - # self.branch = self.constant_branch + self.branch = self.constant_branch case "test": self.branch = self.test_branch @@ -117,10 +117,9 @@ def fractal_branch(self): # return a new number of branches = branching_factor ^ depth return self.branching_factor ** self.current_depth - # REQUIRES TESTING - # def constant_branch(self): - # # return current number of branches - # return self.current_branch_count + def constant_branch(self): + # return current number of branches + return self.current_branch_count # REQUIRES TESTING # def linear_branch(self): From 93928891dc4fb2d69e6d1959e1b6bc063d0bdbcd Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 22:58:13 -0500 Subject: [PATCH 066/141] Needed to create another demos folder as a child in order for setup.py to locate it and enable it to be imported throughout. --- demos/__init__.py | 0 demos/demoHandling.py | 121 ------- demos/franka_pick_place_demo_script.py | 407 ----------------------- demos/franka_reach_demo_script.py | 442 ------------------------- demos/load_demo_test.py | 127 ------- 5 files changed, 1097 deletions(-) delete mode 100644 demos/__init__.py delete mode 100644 demos/demoHandling.py delete mode 100755 demos/franka_pick_place_demo_script.py delete mode 100755 demos/franka_reach_demo_script.py delete mode 100644 demos/load_demo_test.py diff --git a/demos/__init__.py b/demos/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/demos/demoHandling.py b/demos/demoHandling.py deleted file mode 100644 index 90c901c1..00000000 --- a/demos/demoHandling.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -from pathlib import Path -import numpy as np -from agentlace.data.data_store import QueuedDataStore - -class DemoHandling: - """ - Koads an .npz file containing demonstration data into a data object. - This class is designed to work with Gymnasium-style demonstration data - and is intended to be used with a QueuedDataStore or similar data store. - - The .npz file should contain the following arrays: - - 'obs' : shape (N, T+1, *obs_shape*), list of observations - - 'acs' : shape (N, T, *act_shape*), list of actions - - 'rewards' : shape (N, T), list of rewards - - 'terminateds' : shape (N, T), list of terminated flags - - 'truncateds' : shape (N, T), list of truncated flags - - 'info' : shape (N, T), list of info dicts - - 'dones' : shape (N, T), list of done flags (if available) - - Parameters - ---------- - demo_dir : str - Directory where demo .npz files live by default. - file_name : str - Name of the demo file to load. If not provided, a default will be used. - """ - def __init__( - self, - demo_dir: str = '/data/data/serl/demos', - file_name: str = 'data_franka_reach_random_20.npz' - ): - - self.debug = False # Set to True for debugging purposes - self.demo_dir = demo_dir - self.transition_ctr = 0 # Global counter for transitions across all episodes - - # Load the demo data from the .npz file - - # Check if the demo directory exists - if not os.path.exists(self.demo_dir): - raise FileNotFoundError(f"Demo directory '{self.demo_dir}' does not exist.") - - # Construct the full path to the demo file - self.demo_npz_path = os.path.join(self.demo_dir, file_name) - if not os.path.isfile(self.demo_npz_path): - raise FileNotFoundError(f"Demo file '{self.demo_npz_path}' does not exist.") - - # Load the .npz file - self.data = np.load(self.demo_npz_path, allow_pickle=True) - - def get_num_transitions(self): - """ - Returns the total number of transitions counted in the demo data. - """ - return self.data["transition_ctr"] if "transition_ctr" in self.data else 0 - - def insert_data_to_buffer(self,data_store: QueuedDataStore): - """ - Load a raw Gymnasium-style .npz of expert episodes into data_store. - The .npz file must contain arrays named 'obs', 'acs', 'rewards', - 'terminateds', 'truncateds', 'info', and optionally 'dones'. - """ - - obs_buffer = self.data['obs'] # shape (N, T+1, ...) - act_buffer = self.data['acs'] # shape (N, T, ...) - rew_buffer = self.data['rewards'] # shape (N, T) - term_buffer = self.data['terminateds'] # shape (N, T) - trunc_buffer = self.data['truncateds'] # shape (N, T) - info_buffer = self.data['info'] # shape (N, T) - done_buffer = self.data.get('dones', term_buffer | trunc_buffer) - - num_demos = obs_buffer.shape[0] - - for ep in range(num_demos): - ep_obs = obs_buffer[ep] - ep_acts = act_buffer[ep] - ep_rews = rew_buffer[ep] - ep_terms = term_buffer[ep] - ep_trunc = trunc_buffer[ep] - ep_done = done_buffer[ep] - ep_info = info_buffer[ep] - - T = len(ep_acts) - for t in range(T): - obs_t = ep_obs[t] - next_obs_t = ep_obs[t+1] - a_t = ep_acts[t] - r_t = float(ep_rews[t]) - done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) - # masks will be created right before insert below - - if self.debug: - print(f"Demo {ep}, Step {t}: Obs={obs_t}, Action={a_t}, Reward={r_t}, Done={done_t}") - - data_store.insert( - dict( - observations =obs_t, - actions =a_t, - next_observations=next_obs_t, - rewards =r_t, - masks =1.0 - done_t, - dones =done_t - ) - ) - - print(f"Loaded {num_demos} episodes from '{self.demo_npz_path}' ") - - -# if __name__ == "__main__": -# # create your datastore; here we use a QueuedDataStore with capacity 2000 -# ds = QueuedDataStore(2000) -# handler = DemoHandling(ds, -# demo_dir='/data/data/serl/demos', -# file_name='data_franka_reach_random_20.npz') - -# # Idenitfy the total number of transitions in the datastore -# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') - -# # Load the demo data into the data_store -# handler.insert_data_to_buffer() diff --git a/demos/franka_pick_place_demo_script.py b/demos/franka_pick_place_demo_script.py deleted file mode 100755 index e5abbc11..00000000 --- a/demos/franka_pick_place_demo_script.py +++ /dev/null @@ -1,407 +0,0 @@ -""" -Scripted Controller for Franka FR3 Robot - Demonstration Data Generation - -This script implements a scripted controller for generating expert demonstration data -for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a -pick-and-place task. The generated data serves as bootstrapping demonstrations for -training RL agents with stable-baselines3. - -The controller uses a 4-phase hierarchical approach: -1. Approach Object (move gripper above object) -2. Grasp Object (move to object and close gripper) -3. Transport to Goal (move grasped object to target) -4. Maintain Position (hold final position) - -Output: Compressed NPZ file containing action, observation, and info sequences -""" -import os -import numpy as np -import gymnasium as gym -from gymnasium.wrappers import TimeLimit -from time import sleep -from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv - -# Global variables to store episode data across all iterations -observations = [] # List storing observation sequences for each episode -actions = [] # List storing action sequences for each episode -rewards = [] # List storing reward sequences for each episode -infos = [] # List storing info dictionaries for each episode -terminateds = [] # List storing terminated flags for each episode -truncateds = [] # List storing truncated flags for each episode -dones = [] # List storing done flags (terminated or truncated) for each episode - -# Robot configuration -robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' - -# Weld constraint flag -weld_flag = True # Flag to activate weld constraint during pick-and-place - - -def activate_weld(env, constraint_name="grasp_weld"): - """ - Activate a weld constraint during pick portion of a demo - :param env: The environment containing the model - :param constraint_name: The name of the weld constraint to activate - :return: True if the weld was successfully activated, False if the constraint was not found - """ - - try: - # Activate the weld constraint - env.unwrapped.model.eq(constraint_name).active = 1 - print("Activated weld") - return True - - except KeyError: - print(f"Warning: Constraint '{constraint_name}' not found") - return False - -def deactivate_weld(env, constraint_name="grasp_weld"): - """ - Deactivate a weld constraint during place portion of a demo - :param env: The environment containing the model - :param constraint_name: The name of the weld constraint to deactivate - :return: True if the weld was successfully deactivated, False if the constraint was not - found - """ - - try: - # Deactivate the weld constraint - env.unwrapped.model.eq(constraint_name).active = 0 - print("Deactivated weld") - return True - - except KeyError: - print(f"Warning: Constraint '{constraint_name}' not found") - return False - -def main(): - """ - Orchestrates the data generation process by running multiple episodes - of the pick-and-place task. - - Creates environment, runs scripted episodes, and saves demonstration data - to compressed NPZ file for use with stable-baselines3. - """ - # Initialize Fetch pick-and-place environment - env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") - env = TimeLimit(env, max_episode_steps=50) - - # Adjust physical settings - # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. - # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. - - # Configuration parameters - initStateSpace = "random" # Initial state space configuration - - # Demos configs - attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** - - num_demos = 0 # Counter for successful demonstration episodes - - # Reset environment to initial state - render for the first time. - obs, _ = env.reset() - print("Reset!") - - # Generate demonstration episodes - while len(actions) < attempted_demos: - obs,_ = env.reset() # Reset environment for new episode - print(f"We will run a total of: {attempted_demos} demos!!") - print("Demo: #", len(actions)+1) - - # Execute pick-and-place task - res = pick_and_place_demo(env, obs) - - # Print success message - if res: - num_demos += 1 - print("Episode completed successfully!") - print(f"Total successful demos: {num_demos}/{attempted_demos}") - - ## Write data to demos folder - # 1. Get the absolute path of this script - #script_path = os.path.abspath(__file__) - - # 2. Extract its directory - #script_dir = os.path.dirname(script_path) - script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. - - # 3. Create output filename with configuration details - fileName = "data_" + robot - fileName += "_" + initStateSpace - fileName += "_" + str(attempted_demos) - fileName += ".npz" - - # 3. Build a filename in that same directory - out_path = os.path.join(script_dir, fileName) - - # Save collected data to compressed numpy NPZ file - # Set acs,obs,info as keys in dict - np.savez_compressed(out_path, - acs = actions, - obs = observations, - rewards = rewards, - info = infos, - terminateds = terminateds, - truncateds = truncateds, - dones = dones) - - print(f"Data saved to {fileName}.") - -def pick_and_place_demo(env, lastObs): - """ - Executes a scripted pick-and-place sequence using a hierarchical approach. - - Implements 4-phase control strategy: - 1. Approach: Move gripper above object (3cm offset) - 2. Grasp: Move to object and close gripper - 3. Transport: Move grasped object to goal position - 4. Maintain: Hold position until episode ends - - Store observations, actions, and info in global lists for later replay buffer inclusion. - - Args: - env: Gymnasium environment instance - lastObs: Last observation containing goal and object state information - - desired_goal: Target position for object placement - - observations: - ee_position[0:3], - ee_velocity[3:6], - fingers_width[6], - object_position[7:10], - object_rotation[10:13], - object_velp[13:16], - object_velr[16:19], - """ - - ## Init goal, current_pos, and object position from last observation - goal = np.zeros(3, dtype=np.float32) - current_pos = np.zeros(3, dtype=np.float32) - object_pos = np.zeros(3, dtype=np.float32) - object_rel_pos = np.zeros(3, dtype=np.float32) - fgr_pos = np.zeros(1, dtype=np.float32) - - # Initialize episode data collection - episodeObs = [] # Observations for this episode - episodeAcs = [] # Actions for this episode - episodeRews = [] # Rewards for this episode - episodeInfo = [] # Info for this episode - episodeTerminated = [] # Terminated flags for this episode - episodeTruncated = [] # Truncated flags for this episode - episodeDones = [] # Done flags (terminated or truncated) for this episode - - # Proportional control gain for action scaling -- empirically tuned - Kp = 8.0 - - # pre_pick_offset - pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) - - # Error thresholds - error_threshold = 0.011 # Threshold for stopping condition (Xmm) - - finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. - finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger - - ## Extract data - # Extract desired position from desired_goal dict - goal = lastObs["desired_goal"][0:3] - - # Current robot end-effector position from observation dict - current_pos = lastObs["observation"][0:3] - - # Current object position from observation dict: - object_pos = lastObs["observation"][7:10] - - # Relative position between end-effector and object - object_rel_pos = object_pos - current_pos - - ## Phase 1: Approach Object (Above) - # Create target position 3cm above the object. Use copy() method. - error = object_rel_pos.copy() - error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. - - timeStep = 0 # Track total timesteps in episode - episodeObs.append(lastObs) - - # Phase 1: Move gripper to position above object - # Terminate when distance to above-object position < 5mm - print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") - while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: - env.render() # Visual feedback - - # Initialize action vector [x, y, z, gripper] - action = np.array([0., 0., 0., 0.]) - - # Proportional control with gain of 6 - # action = Kp * error - action[:3] = error * Kp - - # Open gripper for approach - action[ len(action)-1 ] = 0.05 - - # Unpack new Gymnasium step API - new_obs, reward, terminated, truncated, info = env.step(action) - done = terminated or truncated - timeStep += 1 - - # Store episode data - episodeAcs.append(action) - episodeInfo.append(info) - episodeRews.append(reward) - episodeObs.append(new_obs) - episodeTerminated.append(terminated) - episodeTruncated.append(truncated) - episodeDones.append(done) - - # Update state information - fgr_pos = new_obs["observation"][6] - current_pos = new_obs["observation"][0:3] - object_pos = new_obs['observation'][7:10] - error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position - - # Print debug information - print( - f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " - f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - f"obj_pos: {np.array2string(object_pos, precision=3)}, " - f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " - f"Error: {np.array2string(error, precision=3)}, " - f"Action: {np.array2string(action, precision=3)}" - ) - - # Phase 2: Descend Grasp Object - # Move gripper directly to object and close gripper - # Terminate when relative distance to object < 5mm - print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") - error = object_pos - current_pos # remove offset - while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm - env.render() - - # Initialize action vector [x, y, z, gripper] - action = np.array([0., 0., 0., 0.]) - - # Direct proportional control to object position - action[:3] = error * Kp - - # Close gripper to grasp object - action[len(action)-1] = -finger_delta_fast * 2 - - # Execute action and collect data - new_obs, reward, terminated, truncated, info = env.step(action) - done = terminated or truncated - timeStep += 1 - - # Store episode data - episodeObs.append(new_obs) - episodeRews.append(reward) - episodeAcs.append(action) - episodeInfo.append(info) - episodeTerminated.append(terminated) - episodeTruncated.append(truncated) - episodeDones.append(done) - - # Update state information - fgr_pos = new_obs["observation"][6] - current_pos = new_obs["observation"][0:3] - object_pos = new_obs['observation'][7:10] - error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower - - # Print debug information - print( - f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " - f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - f"obj_pos: {np.array2string(object_pos, precision=3)}, " - f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " - f"Error: {np.array2string(error, precision=3)}, " - f"Action: {np.array2string(action, precision=3)}" - ) - #sleep(0.5) # Optional: Slow down for better visualization - - # Phase 3: Transport to Goal - # Move grasped object to desired goal position - # Terminate when distance between object and goal < 1cm - print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") - - # Weld activation - if weld_flag: - activate_weld(env, constraint_name="grasp_weld") - - # Set error between goal and hand assuming the object is grasped - gh_error = goal - current_pos # Error between goal and hand position - ho_error = object_pos - current_pos # Error between object and hand position - while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: - env.render() - - action = np.array([0., 0., 0., 0.]) - - # Proportional control toward goal position - action[:3] = gh_error[:3] * Kp - - # Maintain grip on object - #action[len(action)-1] = 0 - - # Execute action and collect data - new_obs, reward, terminated, truncated, info = env.step(action) - done = terminated or truncated - timeStep += 1 - - # Store episode data - episodeObs.append(new_obs) - episodeRews.append(reward) - episodeAcs.append(action) - episodeInfo.append(info) - episodeTerminated.append(terminated) - episodeTruncated.append(truncated) - episodeDones.append(done) - - # Update state information - fgr_pos = new_obs["observation"][6] - current_pos = new_obs["observation"][0:3] - object_pos = new_obs['observation'][7:10] - gh_error = goal - current_pos # Error between goal and hand position - ho_error = object_pos - current_pos # Error between object and hand position - - # Print debug information - print( - f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " - f"Eff_pos: {np.array2string(current_pos, precision=3)}, " - f"goal_pos: {np.array2string(goal, precision=3)}, " - f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " - f"Error: {np.array2string(gh_error, precision=3)}, " - f"Action: {np.array2string(action, precision=3)}" - ) - - sleep(0.5) # Optional: Slow down for better visualization - - ## Check for success and store episode data - gh_norm = np.linalg.norm(gh_error) - ho_nomr = np.linalg.norm(ho_error) - if gh_norm < error_threshold and ho_nomr < error_threshold: - - # Store complete episode data in global lists only if we succeeded (avoid bad demos) - actions.append(episodeAcs) - observations.append(episodeObs) - infos.append(episodeInfo) - rewards.append(episodeRews) - - # Optionally, also store the done/terminated/truncated flags globally if needed: - terminateds.append(episodeTerminated) - truncateds.append(episodeTruncated) - dones.append(episodeDones) - - # Deactivate weld constraint after successful pick - if weld_flag: - deactivate_weld(env, constraint_name="grasp_weld") - - # Close mujoco viewer - env.close() - - # Break out of the loop to start a new episode - return True - - # If we reach here, the episode was not successful - if weld_flag: - print("Failed to transport object to goal position. Deactivating weld.") - deactivate_weld(env, constraint_name="grasp_weld") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/demos/franka_reach_demo_script.py b/demos/franka_reach_demo_script.py deleted file mode 100755 index 4336752c..00000000 --- a/demos/franka_reach_demo_script.py +++ /dev/null @@ -1,442 +0,0 @@ -""" -Scripted Controller for Franka FR3 Robot - Demonstration Data Generation - -This script implements a scripted controller for generating expert demonstration data -for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a -pick-and-place task. The generated data serves as bootstrapping demonstrations for -training RL agents with stable-baselines3. - -The controller uses a 4-phase hierarchical approach: -1. Approach Object (move gripper above object) -2. Grasp Object (move to object and close gripper) -3. Transport to Goal (move grasped object to target) -4. Maintain Position (hold final position) - -Output: Compressed NPZ file containing action, observation, and info sequences - -TODO: Convert to a class-based structure for better modularity and reusability. -""" -import os -import numpy as np -import gym -from time import sleep, perf_counter - -import franka_sim -import franka_sim.envs.panda_reach_gym_env as panda_reach_env - -# Global variables to store episode data across all iterations -observations = [] # List storing observation sequences for each episode -actions = [] # List storing action sequences for each episode -rewards = [] # List storing reward sequences for each episode -infos = [] # List storing info dictionaries for each episode -terminateds = [] # List storing terminated flags for each episode -truncateds = [] # List storing truncated flags for each episode -dones = [] # List storing done flags (terminated or truncated) for each episode -transition_ctr = 0 # Global counter for transitions across all episodes - -#------------------------------------------------------------------------------------------- -## Key Config Variables -#------------------------------------------------------------------------------------------- -# Proportional and derivative control gain for action scaling -- empirically tuned -Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 -Kv = 10.0 - -ACTION_MAX = 10 # Maximum action value for clipping actions -ERROR_THRESHOLD = 0.01 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. - -# Number of demonstration episodes to generate -NUM_DEMOS = 20 - -# Robot configuration -robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' -task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' - -# Debug mode for rendering and visualization -DEBUG = True - -if DEBUG: - _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' -else: - _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI - -# Indices for franka_sim reach environment observations -if robot == 'franka' and task == 'reach': - - opi = np.array([0, 3]) # Indices for object position in observation - gpi = np.array([3]) # Indices for gripper position in observation - rpi = np.array([4, 7]) # Indices for robot position in observation - rvi = np.array([7, 10]) # Indices for robot velocity in observation - -# Weld constraint flag -weld_flag = True # Flag to activate weld constraint during pick-and-place - -#------------------------------------------------------------------------------------------- -# Franka sim environments do not have weld constraints like the franka_mujoco environments. -# def activate_weld(env, constraint_name="grasp_weld"): -# """ -# Activate a weld constraint during pick portion of a demo -# :param env: The environment containing the model -# :param constraint_name: The name of the weld constraint to activate -# :return: True if the weld was successfully activated, False if the constraint was not found -# """ - -# try: -# # Activate the weld constraint -# env.unwrapped.model.eq(constraint_name).active = 1 -# print("Activated weld") -# return True - -# except KeyError: -# print(f"Warning: Constraint '{constraint_name}' not found") -# return False - -# def deactivate_weld(env, constraint_name="grasp_weld"): -# """ -# Deactivate a weld constraint during place portion of a demo -# :param env: The environment containing the model -# :param constraint_name: The name of the weld constraint to deactivate -# :return: True if the weld was successfully deactivated, False if the constraint was not -# found -# """ - -# try: -# # Deactivate the weld constraint -# env.unwrapped.model.eq(constraint_name).active = 0 -# print("Deactivated weld") -# return True - -# except KeyError: -# print(f"Warning: Constraint '{constraint_name}' not found") -# return False - -def set_front_cam_view(env): - """ - Set the camera view to a front-facing perspective for better visualization. - - Args: - env: The environment instance containing the viewer. - - Returns: - viewer: The viewer with updated camera settings. - """ - viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment - - if hasattr(viewer, 'cam'): - viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) - viewer.cam.distance = 3.0 # Camera distance - viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left - viewer.cam.elevation = -30 # Negative = above, positive = below - - # Hide menu - viewer._hide_overlay = True - - return viewer - -def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): - """ - Store transition data in the episode dictionary and update global counter. - """ - global transition_ctr - transition_ctr += 1 - - episode_dict["observations"].append(new_obs) - episode_dict["rewards"].append(rewards) - episode_dict["actions"].append(action) - episode_dict["infos"].append(info) - episode_dict["terminateds"].append(terminated) - episode_dict["truncateds"].append(truncated) - episode_dict["dones"].append(done) - -def store_episode_data(episode_data): - """ - Store complete episode data in global lists only if we succeeded (avoid bad demos). - """ - actions.append(episode_data["actions"]) - observations.append(episode_data["observations"]) - infos.append(episode_data["infos"]) - rewards.append(episode_data["rewards"]) - - # Optionally, also store the done/terminated/truncated flags globally if needed: - terminateds.append(episode_data["terminateds"]) - truncateds.append(episode_data["truncateds"]) - dones.append(episode_data["dones"]) - -def update_state_info(episode_data, time_step, dt, error): - """ - Update and return the current state information. Always get the latest entry with [-1] - - Args: - new_obs (dict): New observation dictionary containing the current state. - time_step (int): Current time step in the episode. - dt (float): Current time step in the episode. - error (np.ndarray): Current error vector between object and end-effector positions. - - Returns: - object_pos (np.ndarray): Current position of the object in the environment. - gripper_pos (float): Current position of the gripper. - current_pos (np.ndarray): Current position of the end-effector. - current_vel (np.ndarray): Current velocity of the end-effector. - """ - object_pos = episode_data["observations"][-1][ opi[0]:opi[1] ] # Block position - gripper_pos = episode_data["observations"][-1][ gpi[0] ] # Gripper position - current_pos = episode_data["observations"][-1][ rpi[0]:rpi[1] ] # Panda/tcp position - current_vel = episode_data["observations"][-1][ rvi[0]:rvi[1] ] # Panda/t - - - # Print debug information - print( - f"Time Step: {time_step}, Error: {np.linalg.norm(error):.4f}, " - f"bot_pos: {np.array2string(current_pos, precision=3)}, " - f"obj_pos: {np.array2string(object_pos, precision=3)}, " - f"fgr_pos: {np.array2string(gripper_pos, precision=2)}, " - f"err: {np.array2string(error, precision=3)}, " - f"Action: {np.array2string(episode_data['actions'][-1], precision=3)}, " - f"dt: {dt: .4f}" - ) - - - return object_pos, gripper_pos, current_pos, current_vel - -def compute_error(object_pos, current_pos, prev_error, dt): - """Compute the error and its derivative between the object position and the current end-effector position. - Args: - object_pos (np.ndarray): The position of the object in the environment. - current_pos (np.ndarray): The current position of the end-effector. - prev_error (np.ndarray): The previous error value for derivative calculation. - dt (float): Time step for derivative calculation. - - Returns: - error (np.ndarray): The current error vector between the object and end-effector positions. - derror (np.ndarray): The derivative of the error vector. - """ - error = object_pos - current_pos # Calculate the error vector - derror = (error - prev_error) / dt # Calculate the derivative of the error vector - - prev_error = error.copy() # Update previous error for next iteration - return error, derror - -def demo(env, lastObs): - """ - Executes a scripted reach sequence using a hierarchical approach. - - Implements 1-phase control strategy: - 1. Approach: Move gripper above object (3cm offset) - - Store observations, actions, and info in global lists for later replay buffer inclusion. - - Gripper: - - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. - - Args: - env: Gymnasium environment instance - lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel - - observations: - object_pos[0:3], - gripper_pos[3] - panda/tcp_pos[4:7], - panda/tcp_vel[7:10] - - Returns: - """ - - ## Init goal, current_pos, and object position from last observation - object_pos = np.zeros(3, dtype=np.float32) - current_pos = np.zeros(3, dtype=np.float32) - gripper_pos = np.zeros(1, dtype=np.float32) - object_rel_pos = np.zeros(3, dtype=np.float32) - - - # Initialize (single) episode data collection - episodeObs = [] # Observations for this episode - episodeAcs = [] # Actions for this episode - episodeRews = [] # Rewards for this episode - episodeInfo = [] # Info for this episode - episodeTerminated = [] # Terminated flags for this episode - episodeTruncated = [] # Truncated flags for this episode - episodeDones = [] # Done flags (terminated or truncated) for this episode - - # Dictionary to store episode data - episode_data = { - "observations": episodeObs, - "actions": episodeAcs, - "rewards": episodeRews, - "infos": episodeInfo, - "terminateds": episodeTerminated, - "truncateds": episodeTruncated, - "dones": episodeDones - } - - # close gripper - fgr_pos = 0 - - # Error thresholds - error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) - - finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. - finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger - - ## Extract data - object_pos = lastObs[opi[0]:opi[1]] # block pos - current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos - - # Relative position between end-effector and object - dt = env.unwrapped.model.opt.timestep # Mujoco time step - prev_error = np.zeros_like(object_pos) - error, derror = compute_error(object_pos, current_pos, prev_error, dt) - - time_step = 0 # Track total time_steps in episode - episodeObs.append(lastObs) # Store initial observation - - # Initialize previous time for dt calculation - prev_time = perf_counter() # Start time for dt calculation - - # Phase 1: Reach - # Terminate when distance to above-object position < error_threshold - print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") - while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: - env.render() # Visual feedback - - # Record current time and compute dt - curr_time = perf_counter() - dt = curr_time - prev_time - prev_time = curr_time - - # Initialize action vector [x, y, z] - action = np.array([0., 0., 0.]) - - # Proportional control with gain of 6 - action[:3] = error * Kp + derror * Kv - prev_error = error.copy() # Update previous error for next iteration - - # Clip action to prevent excessive movements - action = np.clip(action/ACTION_MAX, -0.1, 0.1) # - - # Keep gripper closed -- no need. only 3 dimensions of control - #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. - - # Unpack new Gymnasium step API - new_obs, reward, terminated, truncated, info = env.step(action) - done = terminated or truncated - - # Store episode data - store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) - - # Update and print state information - object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error) - - # Update error for next iteration - error, derror = compute_error(object_pos, cur_pos, prev_error, dt) - - # Update time step - time_step += 1 - - # Sleep - if DEBUG: - sleep(0.25) # Activated when DEBUG is True for better visualization. - - # Store complete episode data in global lists only if we succeeded (avoid bad demos) - store_episode_data(episode_data) - - # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. - # if weld_flag: - # deactivate_weld(env, constraint_name="grasp_weld") - - # Break out of the loop to start a new episode - return True - - # # If we reach here, the episode was not successful - # if weld_flag: - # print("Failed to transport object to goal position. Deactivating weld.") - # deactivate_weld(env, constraint_name="grasp_weld") - -def main(): - """ - Orchestrates the data generation process by running multiple episodes - of the task. - - Creates environment, runs scripted episodes, and saves demonstration data - to compressed NPZ. - - Arguments that can be configured with flags: - - env - - render - - demo_ctr - - """ - # Initialize the Panda environment. - env = gym.make("PandaReachCube-v0", render_mode=_render_mode) - env = gym.wrappers.FlattenObservation(env) - - # Adjust physical settings - # env.model.opt.time_step = 0.001 # Smaller time_step for more accurate physics. Default is 0.002. - # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. - - # Configuration parameters - initStateSpace = "random" # Initial state space configuration - - # Demos configs - num_demos = NUM_DEMOS # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** - - demo_ctr = 0 # Counter for successful demonstration episodes - - # Reset environment to initial state - render for the first time. - obs, _ = env.reset() # For reach environment expect 10 observations: r_pos, r_vel, finger, object_pos. - - # Adjust camera view for better visualization - viewer = set_front_cam_view(env) - - print("Reset!") - - # Generate demonstration episodes - while len(actions) < num_demos: - obs,_ = env.reset() # Reset environment for new episode - - print(f"We will run a total of: {num_demos} demos!!") - print("Demo: #", len(actions)+1) - - # Execute pick-and-place task - res = demo(env, obs) - - # Print success message - if res: - demo_ctr += 1 - print("Episode completed successfully!") - print(f"Total successful demos: {demo_ctr}/{num_demos}") - - # Close the environment after all episodes are done - env.close() - - ## Write data to demos folder. Assumes mounted /data folder and internal data folder. - script_dir = '/data/data/serl/demos' - - # Create output filename with configuration details - fileName = "data_" + robot + "_" + task - fileName += "_" + initStateSpace - fileName += "_" + str(num_demos) - fileName += ".npz" - - # Build a filename in that same directory - out_path = os.path.join(script_dir, fileName) - - # Ensure the directory exists - os.makedirs(script_dir, exist_ok=True) - - # Save collected data to compressed numpy NPZ file - # Set acs,obs,info as keys in dict - np.savez_compressed(out_path, - acs = actions, - obs = observations, - rewards = rewards, - info = infos, - terminateds = terminateds, - truncateds = truncateds, - dones = dones, - transition_ctr = transition_ctr, - num_demos = num_demos - ) - - print(f"Data saved to {fileName}.") - print(f"Total successful demos: {demo_ctr}/{num_demos}") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/demos/load_demo_test.py b/demos/load_demo_test.py deleted file mode 100644 index 8fd33aad..00000000 --- a/demos/load_demo_test.py +++ /dev/null @@ -1,127 +0,0 @@ -# Updated and advanced train.py that includes logging, vectorized environments, and periodic recorded evaluations -import os -from pathlib import Path -import numpy as np - -def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_done: bool = True): - """ - Load a raw Gymnasium-style .npz of expert episodes into model.replay_buffer. - - demo_npz_path must contain at least these arrays: - - 'episodeObs' : shape (T+1, *obs_shape*), list of observations - - 'episodeAcs' : shape (T, *act_shape*), list of actions - - 'episodeRews' : shape (T,), list of rewards - - 'episodeTerminated' : shape (T,), list of terminated flags - - 'episodeTruncated' : shape (T,), list of truncated flags - - 'episodeInfo' : shape (T,), list of info dicts - - Parameters - ---------- - data_store : DataStore - The data store to which the demo transitions will be added. - demo_npz_path : str - Path to the .npz file you saved from your demo collector. - combine_done : bool, default=True - If True, `done = terminated or truncated`. If False, `done = terminated` only. - """ - - # Load all demo data. Structure: var_name[num_demo][time_step][key if dict] = value - data = np.load(demo_npz_path, allow_pickle=True) - - obs_buffer = data['obs'] # length T+1 - act_buffer = data['acs'] # length T - rew_buffer = data['rewards'] # length T - term_buffer = data['terminateds'] # length T - trunc_buffer = data['truncateds'] # length T - info_buffer = data['info'] # length T - done_buffer = data['dones'] # length T, if available - - # Extract number of demonstrations - num_demos = obs_buffer.shape[0] - - # Extract rollout data for a single episode - for ep in range(num_demos): - ep_obs = obs_buffer[ep] # this is a length‐(T+1) array of dicts - ep_acts = act_buffer[ep] # length‐T array of actions - ep_rews = rew_buffer[ep] - ep_terms = term_buffer[ep] - ep_trunc = trunc_buffer[ep] - ep_done = done_buffer[ep] - ep_info = info_buffer[ep] # length‐T array of dicts - - # Length of episode: - T = len(ep_acts) - - # Extract single transitions from the episode data - for t in range(T): - # raw single‐step data: - obs_t = ep_obs[t] # dict[str, np.ndarray] (obs_dim,) - next_obs_t = ep_obs[t+1] - a_t = ep_acts[t] # np.ndarray (action_dim,) - r_t = float(ep_rews[t]) - done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) - - # Rehydrate info dict and inject the timeout flag - raw_info = ep_info[t] # dict[str,Any] - if isinstance(raw_info, str): - import ast - info_t = ast.literal_eval(raw_info) - else: - info_t = raw_info.copy() - # Append truncated information to info_t - info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) - - # Enter transition into data_store or QueuedDataStore - data_store.insert( - dict( - observations=obs_t, - actions=a_t, - next_observations=next_obs_t, - rewards=r_t, - masks=1.0 - done_t, - dones=done_t - ) - ) - - print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." - f"(combine_done={combine_done}).") - -def get_demo_path(relative_path: str) -> str: - """ - Given a path relative to this script file, return - the absolute, normalized path as a string. - - Example: - # If your demos live at ../../../demos/data.npz - demo_file = get_demo_path("../../../demos/data_franka_random_10.npz") - """ - # 1) Resolve this script’s directory - script_dir = Path(__file__).resolve().parent - - # 2) Join with the user-supplied relative path and normalize - full_path = (script_dir / relative_path).resolve() - - return str(full_path) - - -def main(): - """ - Load demos - """ - - # Get abs path to demo file - script_dir = '/data/data/serl/demos' - default_file = 'data_franka_reach_random_20.npz' - - prompt = f"Please input the name of the file to load [{default_file}]: " - - file_name = input(prompt) or default_file - demo_file = os.path.join(script_dir, file_name) - - # Load the demo file into data_store as in async_sac_state.py - from agentlace.data.data_store import QueuedDataStore - data_store = QueuedDataStore(2000) - load_demos_to_her_buffer_gymnasium(data_store, demo_file, combine_done=True) - -if __name__ == "__main__": - main() \ No newline at end of file From 3af4401287520b9fff232c61ba1f715509700da9 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 22:59:22 -0500 Subject: [PATCH 067/141] Bug fix with dict key string. Use single quotes in f-string. --- examples/async_sac_state_sim/async_sac_state_sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index a9ecb1cf..7dd9d2a6 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -129,7 +129,7 @@ def update_params(params): with timer.context("sample and step into env with loaded demos"): # Insert complete demonstration into the data store - print(f"Inserting {demos_handler.data["transition_ctr"]} transitions into the data store.") + print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") demos_handler.insert_data_to_buffer(data_store) else: From a18e1d526a70652ea65215508dfaade1bf71ffff Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 22:59:56 -0500 Subject: [PATCH 068/141] missed including child demo folder --- demos/demos/__init__.py | 0 demos/demos/demoHandling.py | 121 +++++ demos/demos/franka_pick_place_demo_script.py | 407 +++++++++++++++++ demos/demos/franka_reach_demo_script.py | 442 +++++++++++++++++++ demos/demos/load_demo_test.py | 127 ++++++ demos/setup.py | 7 + 6 files changed, 1104 insertions(+) create mode 100644 demos/demos/__init__.py create mode 100644 demos/demos/demoHandling.py create mode 100755 demos/demos/franka_pick_place_demo_script.py create mode 100755 demos/demos/franka_reach_demo_script.py create mode 100644 demos/demos/load_demo_test.py create mode 100644 demos/setup.py diff --git a/demos/demos/__init__.py b/demos/demos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py new file mode 100644 index 00000000..90c901c1 --- /dev/null +++ b/demos/demos/demoHandling.py @@ -0,0 +1,121 @@ +import os +from pathlib import Path +import numpy as np +from agentlace.data.data_store import QueuedDataStore + +class DemoHandling: + """ + Koads an .npz file containing demonstration data into a data object. + This class is designed to work with Gymnasium-style demonstration data + and is intended to be used with a QueuedDataStore or similar data store. + + The .npz file should contain the following arrays: + - 'obs' : shape (N, T+1, *obs_shape*), list of observations + - 'acs' : shape (N, T, *act_shape*), list of actions + - 'rewards' : shape (N, T), list of rewards + - 'terminateds' : shape (N, T), list of terminated flags + - 'truncateds' : shape (N, T), list of truncated flags + - 'info' : shape (N, T), list of info dicts + - 'dones' : shape (N, T), list of done flags (if available) + + Parameters + ---------- + demo_dir : str + Directory where demo .npz files live by default. + file_name : str + Name of the demo file to load. If not provided, a default will be used. + """ + def __init__( + self, + demo_dir: str = '/data/data/serl/demos', + file_name: str = 'data_franka_reach_random_20.npz' + ): + + self.debug = False # Set to True for debugging purposes + self.demo_dir = demo_dir + self.transition_ctr = 0 # Global counter for transitions across all episodes + + # Load the demo data from the .npz file + + # Check if the demo directory exists + if not os.path.exists(self.demo_dir): + raise FileNotFoundError(f"Demo directory '{self.demo_dir}' does not exist.") + + # Construct the full path to the demo file + self.demo_npz_path = os.path.join(self.demo_dir, file_name) + if not os.path.isfile(self.demo_npz_path): + raise FileNotFoundError(f"Demo file '{self.demo_npz_path}' does not exist.") + + # Load the .npz file + self.data = np.load(self.demo_npz_path, allow_pickle=True) + + def get_num_transitions(self): + """ + Returns the total number of transitions counted in the demo data. + """ + return self.data["transition_ctr"] if "transition_ctr" in self.data else 0 + + def insert_data_to_buffer(self,data_store: QueuedDataStore): + """ + Load a raw Gymnasium-style .npz of expert episodes into data_store. + The .npz file must contain arrays named 'obs', 'acs', 'rewards', + 'terminateds', 'truncateds', 'info', and optionally 'dones'. + """ + + obs_buffer = self.data['obs'] # shape (N, T+1, ...) + act_buffer = self.data['acs'] # shape (N, T, ...) + rew_buffer = self.data['rewards'] # shape (N, T) + term_buffer = self.data['terminateds'] # shape (N, T) + trunc_buffer = self.data['truncateds'] # shape (N, T) + info_buffer = self.data['info'] # shape (N, T) + done_buffer = self.data.get('dones', term_buffer | trunc_buffer) + + num_demos = obs_buffer.shape[0] + + for ep in range(num_demos): + ep_obs = obs_buffer[ep] + ep_acts = act_buffer[ep] + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] + + T = len(ep_acts) + for t in range(T): + obs_t = ep_obs[t] + next_obs_t = ep_obs[t+1] + a_t = ep_acts[t] + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + # masks will be created right before insert below + + if self.debug: + print(f"Demo {ep}, Step {t}: Obs={obs_t}, Action={a_t}, Reward={r_t}, Done={done_t}") + + data_store.insert( + dict( + observations =obs_t, + actions =a_t, + next_observations=next_obs_t, + rewards =r_t, + masks =1.0 - done_t, + dones =done_t + ) + ) + + print(f"Loaded {num_demos} episodes from '{self.demo_npz_path}' ") + + +# if __name__ == "__main__": +# # create your datastore; here we use a QueuedDataStore with capacity 2000 +# ds = QueuedDataStore(2000) +# handler = DemoHandling(ds, +# demo_dir='/data/data/serl/demos', +# file_name='data_franka_reach_random_20.npz') + +# # Idenitfy the total number of transitions in the datastore +# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') + +# # Load the demo data into the data_store +# handler.insert_data_to_buffer() diff --git a/demos/demos/franka_pick_place_demo_script.py b/demos/demos/franka_pick_place_demo_script.py new file mode 100755 index 00000000..e5abbc11 --- /dev/null +++ b/demos/demos/franka_pick_place_demo_script.py @@ -0,0 +1,407 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with stable-baselines3. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences +""" +import os +import numpy as np +import gymnasium as gym +from gymnasium.wrappers import TimeLimit +from time import sleep +from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the pick-and-place task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ file for use with stable-baselines3. + """ + # Initialize Fetch pick-and-place environment + env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") + env = TimeLimit(env, max_episode_steps=50) + + # Adjust physical settings + # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + num_demos = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() + print("Reset!") + + # Generate demonstration episodes + while len(actions) < attempted_demos: + obs,_ = env.reset() # Reset environment for new episode + print(f"We will run a total of: {attempted_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = pick_and_place_demo(env, obs) + + # Print success message + if res: + num_demos += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + + ## Write data to demos folder + # 1. Get the absolute path of this script + #script_path = os.path.abspath(__file__) + + # 2. Extract its directory + #script_dir = os.path.dirname(script_path) + script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. + + # 3. Create output filename with configuration details + fileName = "data_" + robot + fileName += "_" + initStateSpace + fileName += "_" + str(attempted_demos) + fileName += ".npz" + + # 3. Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones) + + print(f"Data saved to {fileName}.") + +def pick_and_place_demo(env, lastObs): + """ + Executes a scripted pick-and-place sequence using a hierarchical approach. + + Implements 4-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + 2. Grasp: Move to object and close gripper + 3. Transport: Move grasped object to goal position + 4. Maintain: Hold position until episode ends + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Args: + env: Gymnasium environment instance + lastObs: Last observation containing goal and object state information + - desired_goal: Target position for object placement + - observations: + ee_position[0:3], + ee_velocity[3:6], + fingers_width[6], + object_position[7:10], + object_rotation[10:13], + object_velp[13:16], + object_velr[16:19], + """ + + ## Init goal, current_pos, and object position from last observation + goal = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + object_pos = np.zeros(3, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + fgr_pos = np.zeros(1, dtype=np.float32) + + # Initialize episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Proportional control gain for action scaling -- empirically tuned + Kp = 8.0 + + # pre_pick_offset + pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) + + # Error thresholds + error_threshold = 0.011 # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + # Extract desired position from desired_goal dict + goal = lastObs["desired_goal"][0:3] + + # Current robot end-effector position from observation dict + current_pos = lastObs["observation"][0:3] + + # Current object position from observation dict: + object_pos = lastObs["observation"][7:10] + + # Relative position between end-effector and object + object_rel_pos = object_pos - current_pos + + ## Phase 1: Approach Object (Above) + # Create target position 3cm above the object. Use copy() method. + error = object_rel_pos.copy() + error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. + + timeStep = 0 # Track total timesteps in episode + episodeObs.append(lastObs) + + # Phase 1: Move gripper to position above object + # Terminate when distance to above-object position < 5mm + print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: + env.render() # Visual feedback + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Proportional control with gain of 6 + # action = Kp * error + action[:3] = error * Kp + + # Open gripper for approach + action[ len(action)-1 ] = 0.05 + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeAcs.append(action) + episodeInfo.append(info) + episodeRews.append(reward) + episodeObs.append(new_obs) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + # Phase 2: Descend Grasp Object + # Move gripper directly to object and close gripper + # Terminate when relative distance to object < 5mm + print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") + error = object_pos - current_pos # remove offset + while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm + env.render() + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Direct proportional control to object position + action[:3] = error * Kp + + # Close gripper to grasp object + action[len(action)-1] = -finger_delta_fast * 2 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + #sleep(0.5) # Optional: Slow down for better visualization + + # Phase 3: Transport to Goal + # Move grasped object to desired goal position + # Terminate when distance between object and goal < 1cm + print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") + + # Weld activation + if weld_flag: + activate_weld(env, constraint_name="grasp_weld") + + # Set error between goal and hand assuming the object is grasped + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: + env.render() + + action = np.array([0., 0., 0., 0.]) + + # Proportional control toward goal position + action[:3] = gh_error[:3] * Kp + + # Maintain grip on object + #action[len(action)-1] = 0 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + + # Print debug information + print( + f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"goal_pos: {np.array2string(goal, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(gh_error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + sleep(0.5) # Optional: Slow down for better visualization + + ## Check for success and store episode data + gh_norm = np.linalg.norm(gh_error) + ho_nomr = np.linalg.norm(ho_error) + if gh_norm < error_threshold and ho_nomr < error_threshold: + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + actions.append(episodeAcs) + observations.append(episodeObs) + infos.append(episodeInfo) + rewards.append(episodeRews) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episodeTerminated) + truncateds.append(episodeTruncated) + dones.append(episodeDones) + + # Deactivate weld constraint after successful pick + if weld_flag: + deactivate_weld(env, constraint_name="grasp_weld") + + # Close mujoco viewer + env.close() + + # Break out of the loop to start a new episode + return True + + # If we reach here, the episode was not successful + if weld_flag: + print("Failed to transport object to goal position. Deactivating weld.") + deactivate_weld(env, constraint_name="grasp_weld") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/demos/franka_reach_demo_script.py b/demos/demos/franka_reach_demo_script.py new file mode 100755 index 00000000..43952329 --- /dev/null +++ b/demos/demos/franka_reach_demo_script.py @@ -0,0 +1,442 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with stable-baselines3. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences + +TODO: Convert to a class-based structure for better modularity and reusability. +""" +import os +import numpy as np +import gym +from time import sleep, perf_counter + +import franka_sim +import franka_sim.envs.panda_reach_gym_env as panda_reach_env + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode +transition_ctr = 0 # Global counter for transitions across all episodes + +#------------------------------------------------------------------------------------------- +## Key Config Variables +#------------------------------------------------------------------------------------------- +# Proportional and derivative control gain for action scaling -- empirically tuned +Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 +Kv = 10.0 + +ACTION_MAX = 10 # Maximum action value for clipping actions +ERROR_THRESHOLD = 0.01 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. + +# Number of demonstration episodes to generate +NUM_DEMOS = 20 + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' +task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' + +# Debug mode for rendering and visualization +DEBUG = False + +if DEBUG: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' +else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + +# Indices for franka_sim reach environment observations +if robot == 'franka' and task == 'reach': + + opi = np.array([0, 3]) # Indices for object position in observation + gpi = np.array([3]) # Indices for gripper position in observation + rpi = np.array([4, 7]) # Indices for robot position in observation + rvi = np.array([7, 10]) # Indices for robot velocity in observation + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + +#------------------------------------------------------------------------------------------- +# Franka sim environments do not have weld constraints like the franka_mujoco environments. +# def activate_weld(env, constraint_name="grasp_weld"): +# """ +# Activate a weld constraint during pick portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to activate +# :return: True if the weld was successfully activated, False if the constraint was not found +# """ + +# try: +# # Activate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 1 +# print("Activated weld") +# return True + +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +# def deactivate_weld(env, constraint_name="grasp_weld"): +# """ +# Deactivate a weld constraint during place portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to deactivate +# :return: True if the weld was successfully deactivated, False if the constraint was not +# found +# """ + +# try: +# # Deactivate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 0 +# print("Deactivated weld") +# return True + +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 3.0 # Camera distance + viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): + """ + Store transition data in the episode dictionary and update global counter. + """ + global transition_ctr + transition_ctr += 1 + + episode_dict["observations"].append(new_obs) + episode_dict["rewards"].append(rewards) + episode_dict["actions"].append(action) + episode_dict["infos"].append(info) + episode_dict["terminateds"].append(terminated) + episode_dict["truncateds"].append(truncated) + episode_dict["dones"].append(done) + +def store_episode_data(episode_data): + """ + Store complete episode data in global lists only if we succeeded (avoid bad demos). + """ + actions.append(episode_data["actions"]) + observations.append(episode_data["observations"]) + infos.append(episode_data["infos"]) + rewards.append(episode_data["rewards"]) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episode_data["terminateds"]) + truncateds.append(episode_data["truncateds"]) + dones.append(episode_data["dones"]) + +def update_state_info(episode_data, time_step, dt, error): + """ + Update and return the current state information. Always get the latest entry with [-1] + + Args: + new_obs (dict): New observation dictionary containing the current state. + time_step (int): Current time step in the episode. + dt (float): Current time step in the episode. + error (np.ndarray): Current error vector between object and end-effector positions. + + Returns: + object_pos (np.ndarray): Current position of the object in the environment. + gripper_pos (float): Current position of the gripper. + current_pos (np.ndarray): Current position of the end-effector. + current_vel (np.ndarray): Current velocity of the end-effector. + """ + object_pos = episode_data["observations"][-1][ opi[0]:opi[1] ] # Block position + gripper_pos = episode_data["observations"][-1][ gpi[0] ] # Gripper position + current_pos = episode_data["observations"][-1][ rpi[0]:rpi[1] ] # Panda/tcp position + current_vel = episode_data["observations"][-1][ rvi[0]:rvi[1] ] # Panda/t + + + # Print debug information + print( + f"Time Step: {time_step}, Error: {np.linalg.norm(error):.4f}, " + f"bot_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(gripper_pos, precision=2)}, " + f"err: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(episode_data['actions'][-1], precision=3)}, " + f"dt: {dt: .4f}" + ) + + + return object_pos, gripper_pos, current_pos, current_vel + +def compute_error(object_pos, current_pos, prev_error, dt): + """Compute the error and its derivative between the object position and the current end-effector position. + Args: + object_pos (np.ndarray): The position of the object in the environment. + current_pos (np.ndarray): The current position of the end-effector. + prev_error (np.ndarray): The previous error value for derivative calculation. + dt (float): Time step for derivative calculation. + + Returns: + error (np.ndarray): The current error vector between the object and end-effector positions. + derror (np.ndarray): The derivative of the error vector. + """ + error = object_pos - current_pos # Calculate the error vector + derror = (error - prev_error) / dt # Calculate the derivative of the error vector + + prev_error = error.copy() # Update previous error for next iteration + return error, derror + +def demo(env, lastObs): + """ + Executes a scripted reach sequence using a hierarchical approach. + + Implements 1-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Gripper: + - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. + + Args: + env: Gymnasium environment instance + lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel + - observations: + object_pos[0:3], + gripper_pos[3] + panda/tcp_pos[4:7], + panda/tcp_vel[7:10] + + Returns: + """ + + ## Init goal, current_pos, and object position from last observation + object_pos = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + gripper_pos = np.zeros(1, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + + + # Initialize (single) episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Dictionary to store episode data + episode_data = { + "observations": episodeObs, + "actions": episodeAcs, + "rewards": episodeRews, + "infos": episodeInfo, + "terminateds": episodeTerminated, + "truncateds": episodeTruncated, + "dones": episodeDones + } + + # close gripper + fgr_pos = 0 + + # Error thresholds + error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + object_pos = lastObs[opi[0]:opi[1]] # block pos + current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos + + # Relative position between end-effector and object + dt = env.unwrapped.model.opt.timestep # Mujoco time step + prev_error = np.zeros_like(object_pos) + error, derror = compute_error(object_pos, current_pos, prev_error, dt) + + time_step = 0 # Track total time_steps in episode + episodeObs.append(lastObs) # Store initial observation + + # Initialize previous time for dt calculation + prev_time = perf_counter() # Start time for dt calculation + + # Phase 1: Reach + # Terminate when distance to above-object position < error_threshold + print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: + env.render() # Visual feedback + + # Record current time and compute dt + curr_time = perf_counter() + dt = curr_time - prev_time + prev_time = curr_time + + # Initialize action vector [x, y, z] + action = np.array([0., 0., 0.]) + + # Proportional control with gain of 6 + action[:3] = error * Kp + derror * Kv + prev_error = error.copy() # Update previous error for next iteration + + # Clip action to prevent excessive movements + action = np.clip(action/ACTION_MAX, -0.1, 0.1) # + + # Keep gripper closed -- no need. only 3 dimensions of control + #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + + # Store episode data + store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) + + # Update and print state information + object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error) + + # Update error for next iteration + error, derror = compute_error(object_pos, cur_pos, prev_error, dt) + + # Update time step + time_step += 1 + + # Sleep + if DEBUG: + sleep(0.25) # Activated when DEBUG is True for better visualization. + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + store_episode_data(episode_data) + + # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. + # if weld_flag: + # deactivate_weld(env, constraint_name="grasp_weld") + + # Break out of the loop to start a new episode + return True + + # # If we reach here, the episode was not successful + # if weld_flag: + # print("Failed to transport object to goal position. Deactivating weld.") + # deactivate_weld(env, constraint_name="grasp_weld") + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ. + + Arguments that can be configured with flags: + - env + - render + - demo_ctr + + """ + # Initialize the Panda environment. + env = gym.make("PandaReachCube-v0", render_mode=_render_mode) + env = gym.wrappers.FlattenObservation(env) + + # Adjust physical settings + # env.model.opt.time_step = 0.001 # Smaller time_step for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + num_demos = NUM_DEMOS # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + demo_ctr = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() # For reach environment expect 10 observations: r_pos, r_vel, finger, object_pos. + + # Adjust camera view for better visualization + viewer = set_front_cam_view(env) + + print("Reset!") + + # Generate demonstration episodes + while len(actions) < num_demos: + obs,_ = env.reset() # Reset environment for new episode + + print(f"We will run a total of: {num_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = demo(env, obs) + + # Print success message + if res: + demo_ctr += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + + # Close the environment after all episodes are done + env.close() + + ## Write data to demos folder. Assumes mounted /data folder and internal data folder. + script_dir = '/data/data/serl/demos' + + # Create output filename with configuration details + fileName = "data_" + robot + "_" + task + fileName += "_" + initStateSpace + fileName += "_" + str(num_demos) + fileName += ".npz" + + # Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Ensure the directory exists + os.makedirs(script_dir, exist_ok=True) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones, + transition_ctr = transition_ctr, + num_demos = num_demos + ) + + print(f"Data saved to {fileName}.") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + +if __name__ == "__main__": + main() diff --git a/demos/demos/load_demo_test.py b/demos/demos/load_demo_test.py new file mode 100644 index 00000000..8fd33aad --- /dev/null +++ b/demos/demos/load_demo_test.py @@ -0,0 +1,127 @@ +# Updated and advanced train.py that includes logging, vectorized environments, and periodic recorded evaluations +import os +from pathlib import Path +import numpy as np + +def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_done: bool = True): + """ + Load a raw Gymnasium-style .npz of expert episodes into model.replay_buffer. + + demo_npz_path must contain at least these arrays: + - 'episodeObs' : shape (T+1, *obs_shape*), list of observations + - 'episodeAcs' : shape (T, *act_shape*), list of actions + - 'episodeRews' : shape (T,), list of rewards + - 'episodeTerminated' : shape (T,), list of terminated flags + - 'episodeTruncated' : shape (T,), list of truncated flags + - 'episodeInfo' : shape (T,), list of info dicts + + Parameters + ---------- + data_store : DataStore + The data store to which the demo transitions will be added. + demo_npz_path : str + Path to the .npz file you saved from your demo collector. + combine_done : bool, default=True + If True, `done = terminated or truncated`. If False, `done = terminated` only. + """ + + # Load all demo data. Structure: var_name[num_demo][time_step][key if dict] = value + data = np.load(demo_npz_path, allow_pickle=True) + + obs_buffer = data['obs'] # length T+1 + act_buffer = data['acs'] # length T + rew_buffer = data['rewards'] # length T + term_buffer = data['terminateds'] # length T + trunc_buffer = data['truncateds'] # length T + info_buffer = data['info'] # length T + done_buffer = data['dones'] # length T, if available + + # Extract number of demonstrations + num_demos = obs_buffer.shape[0] + + # Extract rollout data for a single episode + for ep in range(num_demos): + ep_obs = obs_buffer[ep] # this is a length‐(T+1) array of dicts + ep_acts = act_buffer[ep] # length‐T array of actions + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] # length‐T array of dicts + + # Length of episode: + T = len(ep_acts) + + # Extract single transitions from the episode data + for t in range(T): + # raw single‐step data: + obs_t = ep_obs[t] # dict[str, np.ndarray] (obs_dim,) + next_obs_t = ep_obs[t+1] + a_t = ep_acts[t] # np.ndarray (action_dim,) + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + + # Rehydrate info dict and inject the timeout flag + raw_info = ep_info[t] # dict[str,Any] + if isinstance(raw_info, str): + import ast + info_t = ast.literal_eval(raw_info) + else: + info_t = raw_info.copy() + # Append truncated information to info_t + info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) + + # Enter transition into data_store or QueuedDataStore + data_store.insert( + dict( + observations=obs_t, + actions=a_t, + next_observations=next_obs_t, + rewards=r_t, + masks=1.0 - done_t, + dones=done_t + ) + ) + + print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." + f"(combine_done={combine_done}).") + +def get_demo_path(relative_path: str) -> str: + """ + Given a path relative to this script file, return + the absolute, normalized path as a string. + + Example: + # If your demos live at ../../../demos/data.npz + demo_file = get_demo_path("../../../demos/data_franka_random_10.npz") + """ + # 1) Resolve this script’s directory + script_dir = Path(__file__).resolve().parent + + # 2) Join with the user-supplied relative path and normalize + full_path = (script_dir / relative_path).resolve() + + return str(full_path) + + +def main(): + """ + Load demos + """ + + # Get abs path to demo file + script_dir = '/data/data/serl/demos' + default_file = 'data_franka_reach_random_20.npz' + + prompt = f"Please input the name of the file to load [{default_file}]: " + + file_name = input(prompt) or default_file + demo_file = os.path.join(script_dir, file_name) + + # Load the demo file into data_store as in async_sac_state.py + from agentlace.data.data_store import QueuedDataStore + data_store = QueuedDataStore(2000) + load_demos_to_her_buffer_gymnasium(data_store, demo_file, combine_done=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/setup.py b/demos/setup.py new file mode 100644 index 00000000..d6259a7e --- /dev/null +++ b/demos/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="demos", + version="0.1", + packages=find_packages(), +) From fd6111c213fec5ad032b668b8783d68ab656898d Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 24 Jul 2025 23:00:41 -0500 Subject: [PATCH 069/141] Reverted these to standard files with demo flags inside --- examples/async_sac_state_sim/run_actor.sh | 129 ++------------------ examples/async_sac_state_sim/run_learner.sh | 129 ++------------------ examples/async_sac_state_sim/tmux_launch.sh | 100 +-------------- 3 files changed, 21 insertions(+), 337 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index e97cca8c..806a71d1 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,128 +1,15 @@ #!/bin/bash - -# Default variable values -seed=0 -env="PandaReachCube-v0" -branch="test" -type="fractal_symmetry_replay_buffer" -exp_name="untitled_experiment" -starting_branch_count=1 - -has_argument() { - [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; -} - -extract_argument() { - echo "${2:-${1#*=}}" -} - -# Function to handle options and arguments -handle_options() { - while [ $# -gt 0 ]; do - case $1 in - -e | --env*) - if ! has_argument $@; then - echo "Environment not specified" >&2 - exit 1 - fi - - env=$(extract_argument $@) - - shift - ;; - -s | --seed*) - if ! has_argument $@; then - echo "Seed needs a positive integer value" >&2 - exit 1 - fi - - seed=$(extract_argument $@) - - shift - ;; - -t | --type*) - if ! has_argument $@; then - echo "Type needs a string" >&2 - exit 1 - fi - - type=$(extract_argument $@) - - shift - ;; - -n | --name*) - if ! has_argument $@; then - echo "Name needs a string" >&2 - exit 1 - fi - - exp_name=$(extract_argument $@) - - shift - ;; - --starting_branch*) - if ! has_argument $@; then - echo "Starting_branch needs a positive odd integer" >&2 - exit 1 - fi - - starting_branch_count=$(extract_argument $@) - - shift - ;; - -b | --branch*) - if ! has_argument $@; then - echo "Branch needs a string" >&2 - exit 1 - fi - - branch=$(extract_argument $@) - - shift - ;; - *) - echo "Invalid option: $1" >&2 - exit 1 - ;; - esac - shift - done -} - -# Main script execution -handle_options "$@" - export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ -export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ -export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ - -# Create checkpoint directory if it doesn't exist -if [ ! -d "$CHECKPOINT_DIR" ]; then - echo "Creating checkpoint directory: $CHECKPOINT_DIR" - mkdir -p "$CHECKPOINT_DIR" || { - echo "Failed to create checkpoint directory!" >&2 - exit 1 - } -fi -python async_sac_state_sim.py \ +python async_sac_state_sim.py "$@" \ --actor \ - --env $env \ - --exp_name=$exp_name \ - --seed $seed \ - --replay_buffer_type $type \ - --branch_method $branch \ - --starting_branch_count $starting_branch_count \ + --env PandaPickCube-v0 \ + --exp_name=serl_reach \ + --seed 0 \ --random_steps 1000 \ - --max_steps 100000 \ - --training_starts 1000 \ - --critic_actor_ratio 8 \ - --batch_size 256 \ - --replay_buffer_capacity 1000000 \ - --save_model True \ - # --checkpoint_period 10000 \ - # --checkpoint_path "$CHECKPOINT_DIR" \ - #--render \ + --load_demos \ + --demo_dir "/data/data/serl/demos" \ + --file_name "data_franka_reach_random_20.npz" #--debug # wandb is disabled when debug + #--render diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 09809593..b75a0841 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,126 +1,15 @@ #!/bin/bash - -# Default variable values -seed=0 -env="PandaReachCube-v0" -branch="test" -type="fractal_symmetry_replay_buffer" -exp_name="untitled_experiment" -starting_branch_count=1 - -has_argument() { - [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; -} - -extract_argument() { - echo "${2:-${1#*=}}" -} - -# Function to handle options and arguments -handle_options() { - while [ $# -gt 0 ]; do - case $1 in - -e | --env*) - if ! has_argument $@; then - echo "Environment not specified" >&2 - exit 1 - fi - - env=$(extract_argument $@) - - shift - ;; - -s | --seed*) - if ! has_argument $@; then - echo "Seed needs a positive integer value" >&2 - exit 1 - fi - - seed=$(extract_argument $@) - - shift - ;; - -t | --type*) - if ! has_argument $@; then - echo "Type needs a string" >&2 - exit 1 - fi - - type=$(extract_argument $@) - - shift - ;; - -n | --name*) - if ! has_argument $@; then - echo "Name needs a string" >&2 - exit 1 - fi - - exp_name=$(extract_argument $@) - - shift - ;; - --starting_branch*) - if ! has_argument $@; then - echo "Starting_branch needs a positive odd integer" >&2 - exit 1 - fi - - starting_branch_count=$(extract_argument $@) - - shift - ;; - -b | --branch*) - if ! has_argument $@; then - echo "Branch needs a string" >&2 - exit 1 - fi - - branch=$(extract_argument $@) - - shift - ;; - *) - echo "Invalid option: $1" >&2 - exit 1 - ;; - esac - shift - done -} - -# Main script execution -handle_options "$@" - export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ -export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ -export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ - -# Create checkpoint directory if it doesn't exist -if [ ! -d "$CHECKPOINT_DIR" ]; then - echo "Creating checkpoint directory: $CHECKPOINT_DIR" - mkdir -p "$CHECKPOINT_DIR" || { - echo "Failed to create checkpoint directory!" >&2 - exit 1 - } -fi -python async_sac_state_sim.py \ +python async_sac_state_sim.py "$@" \ --learner \ - --env $env \ - --exp_name=$exp_name \ - --seed $seed \ - --replay_buffer_type $type \ - --branch_method $branch \ - --starting_branch_count $starting_branch_count \ - --max_steps 1000 \ - --training_starts 1000 \ - --critic_actor_ratio 8 \ - --batch_size 256 \ - --replay_buffer_capacity 1000000 \ - --save_model True \ - # --checkpoint_period 10000 \ - # --checkpoint_path "$CHECKPOINT_DIR" \ + --env PandaPickCube-v0 \ + --exp_name=serl_reach \ + --seed 0 \ + --random_steps 1000 \ + --load_demos \ + --demo_dir "/data/data/serl/demos" \ + --file_name "data_franka_reach_random_20.npz" #--debug # wandb is disabled when debug + #--render diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 6145f712..78ff94a8 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -1,101 +1,9 @@ #!/bin/bash -# Default variable values -seed=0 -env="PandaReachCube-v0" -branch="test" -type="fractal_symmetry_replay_buffer" -name="untitled_experiment" -starting_branch=1 - -has_argument() { - [[ ("$1" == *=* && -n ${1#*=}) || ( ! -z "$2" && "$2" != -*) ]]; -} - -extract_argument() { - echo "${2:-${1#*=}}" -} - -# Function to handle options and arguments -handle_options() { - while [ $# -gt 0 ]; do - case $1 in - -e | --env*) - if ! has_argument $@; then - echo "Environment not specified" >&2 - exit 1 - fi - - env=$(extract_argument $@) - - shift - ;; - -s | --seed*) - if ! has_argument $@; then - echo "Seed needs a positive integer value" >&2 - exit 1 - fi - - seed=$(extract_argument $@) - - shift - ;; - -t | --type*) - if ! has_argument $@; then - echo "Type needs a string" >&2 - exit 1 - fi - - type=$(extract_argument $@) - - shift - ;; - -n | --name*) - if ! has_argument $@; then - echo "Name needs a string" >&2 - exit 1 - fi - - name=$(extract_argument $@) - - shift - ;; - --starting_branch*) - if ! has_argument $@; then - echo "Starting_branch needs a positive odd integer" >&2 - exit 1 - fi - - starting_branch=$(extract_argument $@) - - shift - ;; - -b | --branch*) - if ! has_argument $@; then - echo "Branch needs a string" >&2 - exit 1 - fi - - branch=$(extract_argument $@) - - shift - ;; - *) - echo "Invalid option: $1" >&2 - exit 1 - ;; - esac - shift - done -} - -# Main script execution -handle_options "$@" - -# EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} +EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} CONDA_ENV=${CONDA_ENV:-"serl"} -# cd $EXAMPLE_DIR +cd $EXAMPLE_DIR echo "Running from $(pwd)" # Create a new tmux session @@ -105,10 +13,10 @@ tmux new-session -d -s serl_session tmux split-window -v # Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh --seed $seed --env $env --name $name --type $type --branch $branch --starting_branch $starting_branch" C-m +tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m # Navigate to the activate the conda environment in the second pane -tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh --seed $seed --env $env --name $name --type $type --branch $branch --starting_branch $starting_branch && tmux kill-session -t serl_session:0.0" C-m +tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m # Attach to the tmux session tmux attach-session -t serl_session From b38b60e6ac9d1e5576c9a4bc7aa4603fe1cf63ed Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 25 Jul 2025 02:09:37 -0500 Subject: [PATCH 070/141] hardly functional automated testing --- .../async_sac_state_sim.py | 7 +++- .../async_sac_state_sim/automate_funcs.py | 34 ++++++++++++------- examples/async_sac_state_sim/run_actor.sh | 10 ++++-- examples/async_sac_state_sim/run_learner.sh | 12 ++++--- .../data/fractal_symmetry_replay_buffer.py | 29 +++++++++++----- 5 files changed, 63 insertions(+), 29 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 2ad82311..6748fc0f 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -62,9 +62,10 @@ flags.DEFINE_string("branch_method", None, "Method for determining number of transforms per dimension (x,y)") flags.DEFINE_string("split_method", None, "Method for determining whether to change the number of transforms per dimension (x,y)") # Remember to default None flags.DEFINE_float("workspace_width", None, "Workspace width in meters") -flags.DEFINE_integer("max_depth", None, "Total layers of depth") +flags.DEFINE_integer("max_depth", None, "Maximum layers of depth") flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") flags.DEFINE_integer("starting_branch_count", None, "Initial number of transformations per dimension (x,y)") +flags.DEFINE_integer("alpha", None, "placeholder") flags.DEFINE_boolean( @@ -307,6 +308,10 @@ def main(_): x_obs_idx=np.array([0,4]), y_obs_idx=np.array([1,5]), preload_rlds_path=FLAGS.preload_rlds_path, + alpha=FLAGS.alpha, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, ) replay_iterator = replay_buffer.get_iterator( sample_args={ diff --git a/examples/async_sac_state_sim/automate_funcs.py b/examples/async_sac_state_sim/automate_funcs.py index 425e2d94..09b0b106 100644 --- a/examples/async_sac_state_sim/automate_funcs.py +++ b/examples/async_sac_state_sim/automate_funcs.py @@ -50,11 +50,7 @@ def getProcesses(name): else: p2_pid = proc.pid break - if p1_pid and p2_pid: - if p1_pid > p2_pid: - temp = p2_pid - p2_pid = p1_pid - p1_pid = temp + if p1_pid and p2_pid and p1_pid != p2_pid: p1 = psutil.Process(p1_pid) p2 = psutil.Process(p2_pid) return p1, p2 @@ -67,16 +63,29 @@ def run_test(args:dict): actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") - args["max_steps"] = max_steps * 10 + args["max_steps"] = max_steps * 10000 send_tmux_command(f"{SESSION}:0.0", f"bash ./automate_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files args["max_steps"] = max_steps send_tmux_command(f"{SESSION}:0.1", f"bash ./automate_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") - actor, learner = getProcesses("async_sac_state") + p1, p2 = getProcesses("async_sac_state") - learner.wait() - actor.terminate() - actor.wait() + while (True): + if p1.is_running(): + if p2.is_running(): + continue + else: + p1.terminate() + p1.wait() + break + else: + if p2.is_running(): + p2.terminate() + p2.wait() + break + break + + def main(_): ensure_tmux_session(SESSION) @@ -90,9 +99,10 @@ def main(_): "training_starts": 1000, "critic_actor_ratio": 8, "batch_size": 256, - "replay_buffer_capacity": 100000, + "replay_buffer_capacity": 1000000, "save_model": True, - "max_steps": 2000 + "max_steps": 50000, + "workspace_width": 0.5 } # --- Baseline --- diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 237a8022..6572f6eb 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -20,15 +20,19 @@ python async_sac_state_sim.py \ --env PandaReachCube-v0 \ --exp_name serl-reach \ --seed 0 \ - --replay_buffer_type replay_buffer \ - --branch_method test \ - --split_method test \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --branch_method fractal \ + --split_method time \ --starting_branch_count 1 \ --max_steps 40000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ --replay_buffer_capacity 100000 \ + --alpha 1 \ + --max_depth 4 \ + --branching_factor 3 \ + --workspace_width 5 \ --save_model True \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index a4a2210e..7039e29e 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -20,15 +20,19 @@ python async_sac_state_sim.py \ --env PandaReachCube-v0 \ --exp_name serl-reach \ --seed 0 \ - --replay_buffer_type replay_buffer \ - --branch_method test \ - --split_method test \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --branch_method fractal \ + --split_method time \ --starting_branch_count 1 \ - --max_steps 20000 \ + --max_steps 40000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ --replay_buffer_capacity 100000 \ + --alpha 1 \ + --max_depth 4 \ + --branching_factor 3 \ + --workspace_width 0.5 \ --save_model True \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 7b386155..efc36050 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -3,6 +3,7 @@ from serl_launcher.data.dataset import DatasetDict from serl_launcher.data.replay_buffer import ReplayBuffer import copy +from datetime import datetime import time @@ -18,6 +19,7 @@ def __init__( kwargs: dict ): self.current_branch_count=1 + self.update_max_traj_length = False method_check = "split_method" match split_method: @@ -26,13 +28,17 @@ def __init__( self.max_depth=kwargs["max_depth"] del kwargs["max_depth"] - assert "max_steps" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_steps") - self.max_steps=kwargs["max_steps"] + assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") + self.max_traj_length=kwargs["max_traj_length"] assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") self.alpha=kwargs["alpha"] + update_max_traj_length = True self.split = self.time_split + case "constant": + self.split = self.constant_split + case "test": self.split = self.test_split case _: @@ -41,9 +47,6 @@ def __init__( method_check = "branch_method" match branch_method: case "fractal": - assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_depth") - self.max_depth=kwargs["max_depth"] - del kwargs["max_depth"] assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") self.branching_factor=kwargs["branching_factor"] @@ -80,7 +83,11 @@ def __init__( self.workspace_width = workspace_width self.timestep = 0 self.current_depth = 0 - self.branch_index = [workspace_width/2] + + self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) + constant = self.workspace_width/(2 * self.current_branch_count) + for i in range(0, self.current_branch_count): + self.branch_index[i] = (2 * i + 1) * constant for k in kwargs.keys(): print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") @@ -132,7 +139,7 @@ def test_split(self, data_dict: DatasetDict): # REQUIRES TESTING def time_split(self, data_dict: DatasetDict): - if self.timestep % (self.max_steps/self.max_depth) or self.current_depth >= self.max_depth: + if self.timestep % (self.max_traj_length//self.max_depth) or self.current_depth >= self.max_depth: return False self.current_depth += 1 return True @@ -141,6 +148,7 @@ def constant_split(self, data_dict: DatasetDict): return True def insert(self, data_dict_not: DatasetDict): + start = datetime.now() data_dict = copy.deepcopy(data_dict_not) @@ -170,7 +178,10 @@ def insert(self, data_dict_not: DatasetDict): self.transform(new_data_dict, transform) super().insert(new_data_dict) + self.timestep += 1 if data_dict["dones"]: self.current_depth = 0 - self.max_steps = int(self.timestep * self.alpha + self.max_steps * (1 - self.alpha)) - \ No newline at end of file + if self.update_max_traj_length: + self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + self.timestep = 0 + print(f"{datetime.now() - start}") \ No newline at end of file From 5400702c5a18709403d4f03b2f5913ae4fafd52a Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 25 Jul 2025 10:52:55 -0500 Subject: [PATCH 071/141] cleaning up test scripts as we have made advances there in fractal branch --- examples/async_sac_state_sim/run_actor.sh | 13 ------ examples/async_sac_state_sim/run_learner.sh | 13 ------ examples/async_sac_state_sim/run_tests.sh | 8 ---- examples/async_sac_state_sim/testing.sh | 42 ------------------- .../async_sac_state_sim/tmux_script_launch.sh | 19 --------- 5 files changed, 95 deletions(-) delete mode 100644 examples/async_sac_state_sim/run_tests.sh delete mode 100755 examples/async_sac_state_sim/testing.sh delete mode 100755 examples/async_sac_state_sim/tmux_script_launch.sh diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 237a8022..38f68399 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,19 +1,6 @@ #!/bin/bash - export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ -export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ -export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ - -# Create checkpoint directory if it doesn't exist -if [ ! -d "$CHECKPOINT_DIR" ]; then - echo "Creating checkpoint directory: $CHECKPOINT_DIR" - mkdir -p "$CHECKPOINT_DIR" || { - echo "Failed to create checkpoint directory!" >&2 - exit 1 - } -fi python async_sac_state_sim.py \ --actor \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index a4a2210e..0ea4efc6 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,19 +1,6 @@ #!/bin/bash - export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ -export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ -export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ - -# Create checkpoint directory if it doesn't exist -if [ ! -d "$CHECKPOINT_DIR" ]; then - echo "Creating checkpoint directory: $CHECKPOINT_DIR" - mkdir -p "$CHECKPOINT_DIR" || { - echo "Failed to create checkpoint directory!" >&2 - exit 1 - } -fi python async_sac_state_sim.py \ --learner \ diff --git a/examples/async_sac_state_sim/run_tests.sh b/examples/async_sac_state_sim/run_tests.sh deleted file mode 100644 index 0bd80620..00000000 --- a/examples/async_sac_state_sim/run_tests.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -tmux new-session -d -s serl_session "python run_tests.py" - -tmux attach-session -t serl_session - -# kill the tmux session by running the following command -# tmux kill-session -t serl_session diff --git a/examples/async_sac_state_sim/testing.sh b/examples/async_sac_state_sim/testing.sh deleted file mode 100755 index fb2072c9..00000000 --- a/examples/async_sac_state_sim/testing.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -# Tmux options -tmux set-option -g history-limit 100000 - -# Exit on error: prevents silent failures and makes bugs easier to catch: -# -e: exit immediately on command failure -# -u: error on use of undefined variables -# -o pipefail: error if any command in a pipeline fails -set -euo pipefail - -# Safe IFS for word splitting. Protects against word-splitting bugs due to unexpected whitespace in filenames or variables. -IFS=$'\n\t' - -# === CONFIGURATION === -env="PandaReachCube-v0" -session="test" -window="0" - -# === CLEANUP ON EXIT === -# Kill only background jobs started by the script -trap 'jobs -p | xargs -r kill' EXIT INT TERM - -# === START TMUX SESSION === -if ! tmux has-session -t "$session" 2>/dev/null; then # suppress error output if session does not exist - tmux new-session -d -s "$session" -n main # start new detached tmux session named $session with init window named main -else - echo "Session $session already exists. Reusing..." -fi - -# === HELPER FUNCTION === -# Centralizes the logic for sending commands into tmux. Improves readability and makes it easy to add logging or future logic. -send_tmux_command() { - local cmd="$1" # Take 1st arg in func & store in var cmd. - tmux send-keys -t "$session:$window" "$cmd" C-m # send command to specific tmux window -} - -# === BASELINE RUNS === -# tmux_script_launch.sh must have ./ in front to tell tmux to execute the script in the current directory. -for seed in {1}; do # Loop - send_tmux_command "./tmux_script_launch.sh --env $env --name ${env}-baseline --type replay_buffer --seed $seed" -done diff --git a/examples/async_sac_state_sim/tmux_script_launch.sh b/examples/async_sac_state_sim/tmux_script_launch.sh deleted file mode 100755 index eadc9921..00000000 --- a/examples/async_sac_state_sim/tmux_script_launch.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -EXAMPLE_DIR=${EXAMPLE_DIR:-"async_sac_state_sim"} -CONDA_ENV=${CONDA_ENV:-"serl"} - -echo "Running from $(pwd)" - -# This script assumes it is being run inside an existing tmux session from testing.sh -# Split the current window vertically -tmux split-window -v - -# Run actor in the top pane -tmux send-keys -t "$(tmux display-message -p '#S')":0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m - -# Run learner in the bottom pane -tmux send-keys -t "$(tmux display-message -p '#S')":0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m - -# Optionally, focus the upper pane again -tmux select-pane -t 0 From 9b62012da329627916c30563f3278123ddafa436 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 25 Jul 2025 13:26:31 -0500 Subject: [PATCH 072/141] improved testing, fixed fsrb bugs --- examples/async_sac_state_sim/automation.sh | 8 ++ examples/async_sac_state_sim/run_actor.sh | 23 +++-- examples/async_sac_state_sim/run_learner.sh | 23 +++-- examples/async_sac_state_sim/tmux_launch.sh | 4 +- .../data/fractal_symmetry_replay_buffer.py | 17 ++-- serl_launcher/serl_launcher/data/fsrb_test.py | 91 +++++++++++++------ 6 files changed, 106 insertions(+), 60 deletions(-) create mode 100644 examples/async_sac_state_sim/automation.sh diff --git a/examples/async_sac_state_sim/automation.sh b/examples/async_sac_state_sim/automation.sh new file mode 100644 index 00000000..39164dc0 --- /dev/null +++ b/examples/async_sac_state_sim/automation.sh @@ -0,0 +1,8 @@ + +tmux new-session -d -s serl_session + +tmux split-window -h -t serl_session + +tmux split-window -v -t serl_session:0.1 + +tmux send-keys -t serl_session:0.0 "conda activate serl && python automate_funcs.py" C-m diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 6572f6eb..f00797a7 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -15,25 +15,24 @@ if [ ! -d "$CHECKPOINT_DIR" ]; then } fi -python async_sac_state_sim.py \ +python async_sac_state_sim.py "$@"\ --actor \ --env PandaReachCube-v0 \ - --exp_name serl-reach \ - --seed 0 \ - --replay_buffer_type fractal_symmetry_replay_buffer \ - --branch_method fractal \ - --split_method time \ - --starting_branch_count 1 \ - --max_steps 40000 \ + --exp_name reach-baseline \ + --replay_buffer_type replay_buffer \ + --max_steps 80000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ --replay_buffer_capacity 100000 \ - --alpha 1 \ - --max_depth 4 \ - --branching_factor 3 \ - --workspace_width 5 \ --save_model True \ + # --branch_method fractal \ + # --split_method time \ + # --starting_branch_count 1 \ + # --alpha 1 \ + # --max_depth 4 \ + # --branching_factor 3 \ + # --workspace_width 5 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 7039e29e..6b85e282 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -15,25 +15,24 @@ if [ ! -d "$CHECKPOINT_DIR" ]; then } fi -python async_sac_state_sim.py \ +python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name serl-reach \ - --seed 0 \ - --replay_buffer_type fractal_symmetry_replay_buffer \ - --branch_method fractal \ - --split_method time \ - --starting_branch_count 1 \ - --max_steps 40000 \ + --exp_name reach-baseline \ + --replay_buffer_type replay_buffer \ + --max_steps 80000 \ --training_starts 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ --replay_buffer_capacity 100000 \ - --alpha 1 \ - --max_depth 4 \ - --branching_factor 3 \ - --workspace_width 0.5 \ --save_model True \ + # --branch_method fractal \ + # --split_method time \ + # --starting_branch_count 1 \ + # --alpha 1 \ + # --max_depth 4 \ + # --branching_factor 3 \ + # --workspace_width 5 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--debug # wandb is disabled when debug diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 6064bd29..dff31d54 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -13,10 +13,10 @@ tmux new-session -d -s serl_session tmux split-window -v # Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m +tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh '$@'" C-m # Navigate to the activate the conda environment in the second pane -tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m +tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh '$@'" C-m # Attach to the tmux session tmux attach-session -t serl_session diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index efc36050..ad1f3ee7 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -3,10 +3,8 @@ from serl_launcher.data.dataset import DatasetDict from serl_launcher.data.replay_buffer import ReplayBuffer import copy -from datetime import datetime - -import time +from datetime import datetime as dt class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( self, @@ -148,8 +146,8 @@ def constant_split(self, data_dict: DatasetDict): return True def insert(self, data_dict_not: DatasetDict): - start = datetime.now() - + + sector_1 = dt.now() data_dict = copy.deepcopy(data_dict_not) # Update number of branches if needed @@ -162,13 +160,14 @@ def insert(self, data_dict_not: DatasetDict): for i in range(0, self.current_branch_count): self.branch_index[i] = (2 * i + 1) * constant + sector_2 = dt.now() # Initialize to extreme x and y x = -self.workspace_width/2 transform = np.zeros_like(data_dict["observations"]) transform[self.x_obs_idx] = x transform[self.y_obs_idx] = x self.transform(data_dict, transform) - + sector_3 = dt.now() # Transform and insert transitions (multiprocessing in the future) for x in range(0, self.current_branch_count): transform[self.x_obs_idx] = self.branch_index[x] @@ -177,11 +176,13 @@ def insert(self, data_dict_not: DatasetDict): transform[self.y_obs_idx] = self.branch_index[y] self.transform(new_data_dict, transform) super().insert(new_data_dict) - + sector_4 = dt.now() self.timestep += 1 if data_dict["dones"]: self.current_depth = 0 if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) self.timestep = 0 - print(f"{datetime.now() - start}") \ No newline at end of file + + finish = dt.now() + print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index df7f09b0..78a868ed 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -2,26 +2,13 @@ import numpy as np import gym from serl_launcher.utils.launcher import make_replay_buffer -from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer -from absl import app, flags +from absl import app import franka_sim +from datetime import datetime # import pandas as pd -FLAGS = flags.FLAGS - -flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "test", "Method for determining the number of transforms per dimension (x,y)") -flags.DEFINE_string("split_method", "test", "Method for determining whether to change the number of transforms per dimension (x,y)") -flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") -flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only -flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only -flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only - def main(_): - x_obs_idx = np.array([0, 4]) - y_obs_idx = np.array([1, 5]) - # Initialize replay buffer env = gym.make("PandaReachCube-v0") env = gym.wrappers.FlattenObservation(env) @@ -29,16 +16,19 @@ def main(_): replay_buffer = make_replay_buffer( env, type="fractal_symmetry_replay_buffer", - capacity=FLAGS.capacity, - split_method=FLAGS.split_method, - branch_method=FLAGS.branch_method, - workspace_width=FLAGS.workspace_width, - x_obs_idx=x_obs_idx, - y_obs_idx= y_obs_idx, - max_depth=FLAGS.max_depth, - branching_factor=FLAGS.branching_factor, + capacity=1000000, + split_method="time", + branch_method="fractal", + workspace_width=0.5, + x_obs_idx=np.array([0, 4]), + y_obs_idx= np.array([1, 5]), + alpha=1, + max_traj_length=20, + max_depth=4, + branching_factor=3, + starting_branch_count=1, ) - + observation, info = env.reset() action = env.action_space.sample() next_observation, reward, terminated, truncated, info = env.step(action) @@ -51,8 +41,57 @@ def main(_): dones=truncated or terminated, ) - del env, observation, next_observation, action, reward, truncated, terminated, info, y_obs_idx, x_obs_idx, _ - + del env, observation, next_observation, action, reward, truncated, terminated, info, _ + + # temp testing + + # replay_buffer.branch =replay_buffer.constant_branch + # replay_buffer.split =replay_buffer.constant_split + + # for b in (1, 3, 9, 27): + # avg = 0 + # high = 0 + # low = 10 + # replay_buffer.starting_branch_count = b + # for _ in range(50): + # start=datetime.now() + # replay_buffer.insert(data_dict) + # duration=(datetime.now() - start).total_seconds() + # avg += duration + # if duration < low: + # low = duration + # if duration > high: + # high = duration + + # avg = avg/50 + # print(f"Test: {b}x{b}\n\nhigh: {high:.5f}\nlow: {low:.5f}\naverage: {avg:.5f}\n\n") + + # replay_buffer.branch =replay_buffer.fractal_branch + # replay_buffer.split =replay_buffer.time_split + + # avg = 0 + # high = 0 + # low = 10 + # for b, d in ((3,4),(9,2)): + # for a in (0.25, 0.5, 0.75, 1): + # avg = 0 + # high = 0 + # low = 10 + # replay_buffer.branching_factor=b + # replay_buffer.max_depth=d + # replay_buffer.current_depth=0 + # for i in range(20): + # start=datetime.now() + # replay_buffer.insert(data_dict) + # duration=(datetime.now() - start).total_seconds() + # avg += duration + # if duration < low: + # low = duration + # if duration > high: + # high = duration + + # avg = avg/50 + # print(f"Test: {b}^{d} at alpha={a}\n\nhigh: {high:.5f}\nlow: {low:.5f}\naverage: {avg:.5f}\n\n") # transform() test expected = np.copy(data_dict["observations"]) * 2 From 2cd298f0c2380a79b97014a977461904f59e0c3e Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 25 Jul 2025 13:46:06 -0500 Subject: [PATCH 073/141] Set conda once only at the beginning. - remove conda activattion from sending tmux commands --- .../async_sac_state_sim/automate_funcs.py | 60 +++++++++++++++---- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/examples/async_sac_state_sim/automate_funcs.py b/examples/async_sac_state_sim/automate_funcs.py index 425e2d94..0c4ab4d3 100644 --- a/examples/async_sac_state_sim/automate_funcs.py +++ b/examples/async_sac_state_sim/automate_funcs.py @@ -12,17 +12,46 @@ def run_tmux_command(cmd: str): return subprocess.run(cmd, shell=True, check=True) -def send_tmux_command(target: str, command: str, env: str = "serl", workdir: str = "./"): - - full_command = ( +def set_conda_env(session_name: str, + env: str = "serl", + workdir: str = "./"): + """ + Set the conda environment for the current shell in a tmux session. + Call as `set_conda_env(f"{SESSION}:0.0", workdir="./", env="serl")` + """ + + cmd = ( f"conda activate {env} && " - f"cd {workdir} && " + f"cd {workdir}" + ) + + subprocess.run( + f'tmux send-keys -t {session_name} "{cmd}" C-m', + shell=True + # stdout=subprocess.DEVNULL, + # stderr=subprocess.DEVNULL + ) + +def send_tmux_command(target: str, command: str): + """ + Send a command to a specific tmux target session as: + (f"{SESSION}:0.0", f"bash ./automate_actor.sh") + Args: + - target: str - The tmux target session (e.g., "serl_session:0.0") + - command: str - The command to run in the tmux session + + Returns: + - None + + """ + full_command = ( f"{command}" ) subprocess.run( f'tmux send-keys -t {target} "{full_command}" C-m', - shell=True + shell=True, + #check=True # stdout=subprocess.DEVNULL, # stderr=subprocess.DEVNULL ) @@ -34,9 +63,14 @@ def ensure_tmux_session(session_name: str): # stderr=subprocess.DEVNULL ) if result.returncode != 0: + # Session does not exist, create it and split the window subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) + # Set the conda environment in both panes once at the beginning + set_conda_env(f"{SESSION}:0.0", workdir="./", env="serl") + set_conda_env(f"{SESSION}:0.1", workdir="./", env="serl") + def getProcesses(name): p1_pid = None p2_pid = None @@ -62,7 +96,8 @@ def getProcesses(name): def run_test(args:dict): max_steps = args["max_steps"] - for seed in range(1): + for seed in range(2): + args["seed"] = seed actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") @@ -80,9 +115,9 @@ def run_test(args:dict): def main(_): ensure_tmux_session(SESSION) + set_conda_env env = "PandaReachCube-v0" - base_args = { "env": "PandaReachCube-v0", @@ -120,11 +155,12 @@ def main(_): args["split_method"] = "time" args["alpha"] = 1 - for b, d in ((3, 4), (9, 2)): - args["branching_factor"] = b - args["max_depth"] = d - args["exp_name"] = f"{args['env']}-{b}^{d}" - run_test(args) + for b, d in ((3, 2), (3, 3)): + for a in (0.25, 0.5, 0.75, 1): + args["branching_factor"] = b + args["max_depth"] = d + args["exp_name"] = f"{args['env']}-{b}^{d}-alpha-{a}" + run_test(args) if __name__ == "__main__": app.run(main) From c901bada581ac55ebb3ab2b48412025f832936a7 Mon Sep 17 00:00:00 2001 From: Cleiver Ivan Ruiz-Martinez Date: Fri, 25 Jul 2025 21:08:54 -0500 Subject: [PATCH 074/141] Wrong env used in training. --- franka_sim/franka_sim/envs/panda_reach_gym_env.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py index e930a85a..904ec277 100644 --- a/franka_sim/franka_sim/envs/panda_reach_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -216,17 +216,7 @@ def step( obs = self._compute_observation() rew = self._compute_reward() - - # For reach environment, finger - if rew >= 0.55: - terminated = True - else: - # Check if the time limit is exceeded. - if self._time_limit is not None: - terminated = self.time_limit_exceeded() - else: - terminated = False - + terminated = self.time_limit_exceeded() return obs, rew, terminated, False, {} def render(self): From bd5d65dd22e7ff1143c1644245ba0d80eaa2bdf0 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 28 Jul 2025 14:13:16 -0500 Subject: [PATCH 075/141] General Enhancements - More robust execution of data generation and data reading. - PandaReachGymEnv will only change done/rewards interactions if a franka_reach_demo string is passed on to a class member demo which is by default None. --- demos/demos/demoHandling.py | 61 ++++++++++++++----- demos/demos/franka_reach_demo_script.py | 51 ++++++++++------ .../franka_sim/envs/panda_reach_gym_env.py | 11 +++- 3 files changed, 85 insertions(+), 38 deletions(-) diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py index 90c901c1..c379d293 100644 --- a/demos/demos/demoHandling.py +++ b/demos/demos/demoHandling.py @@ -27,11 +27,12 @@ class DemoHandling: """ def __init__( self, + ds: QueuedDataStore, demo_dir: str = '/data/data/serl/demos', file_name: str = 'data_franka_reach_random_20.npz' ): - self.debug = False # Set to True for debugging purposes + self.debug = True # Set to True for debugging purposes self.demo_dir = demo_dir self.transition_ctr = 0 # Global counter for transitions across all episodes @@ -54,12 +55,28 @@ def get_num_transitions(self): Returns the total number of transitions counted in the demo data. """ return self.data["transition_ctr"] if "transition_ctr" in self.data else 0 + + def get_num_demos(self): + """ + Returns the total number of demonstrations in the demo data. + """ + return self.data["num_demos"] if "num_demos" in self.data else 0 def insert_data_to_buffer(self,data_store: QueuedDataStore): """ Load a raw Gymnasium-style .npz of expert episodes into data_store. The .npz file must contain arrays named 'obs', 'acs', 'rewards', 'terminateds', 'truncateds', 'info', and optionally 'dones'. + Each episode is processed, and transitions are inserted into the data_store. + Inserted transitions in data store will remain in the data_store as pointers + + Parameters + ---------- + data_store : QueuedDataStore + + Returns + ------- + None """ obs_buffer = self.data['obs'] # shape (N, T+1, ...) @@ -68,10 +85,18 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): term_buffer = self.data['terminateds'] # shape (N, T) trunc_buffer = self.data['truncateds'] # shape (N, T) info_buffer = self.data['info'] # shape (N, T) - done_buffer = self.data.get('dones', term_buffer | trunc_buffer) + done_buffer = self.data['dones'] # shape (N, T) #.get('dones', term_buffer | trunc_buffer) + + num_demos = self.get_num_demos() + if num_demos == 0: + raise ValueError("No demonstrations found in the provided .npz file.") + + num_transitions = self.get_num_transitions() + if num_transitions == 0: + raise ValueError("No transitions found in the provided .npz file.") - num_demos = obs_buffer.shape[0] + # Extract the number of episodes and transitions for ep in range(num_demos): ep_obs = obs_buffer[ep] ep_acts = act_buffer[ep] @@ -91,7 +116,13 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): # masks will be created right before insert below if self.debug: - print(f"Demo {ep}, Step {t}: Obs={obs_t}, Action={a_t}, Reward={r_t}, Done={done_t}") + np.set_printoptions(precision=3, suppress=True) + + print(f"Demo {ep:2}, Step {t:3} \n " + f"Obs: [{obs_t[0]:.2f} {obs_t[1]:.2f} {obs_t[2]:.2f}] \n " + f"Action: [{a_t[0]:.2f} {a_t[1]:.2f} {a_t[2]:.2f}] \n " + f"Reward: {r_t:.2f} \n " + f"Done: {done_t}") data_store.insert( dict( @@ -104,18 +135,18 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): ) ) - print(f"Loaded {num_demos} episodes from '{self.demo_npz_path}' ") + print(f"Loaded a total of {num_transitions} from {num_demos} episodes from '{self.demo_npz_path}' ") -# if __name__ == "__main__": -# # create your datastore; here we use a QueuedDataStore with capacity 2000 -# ds = QueuedDataStore(2000) -# handler = DemoHandling(ds, -# demo_dir='/data/data/serl/demos', -# file_name='data_franka_reach_random_20.npz') +if __name__ == "__main__": + # create your datastore; here we use a QueuedDataStore with capacity 2000 + ds = QueuedDataStore(2000) + handler = DemoHandling(ds, + demo_dir='/data/data/serl/demos', + file_name='data_franka_reach_random_20.npz') -# # Idenitfy the total number of transitions in the datastore -# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') + # Idenitfy the total number of transitions in the datastore + print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') -# # Load the demo data into the data_store -# handler.insert_data_to_buffer() + # Load the demo data into the data_store + handler.insert_data_to_buffer(ds) diff --git a/demos/demos/franka_reach_demo_script.py b/demos/demos/franka_reach_demo_script.py index 43952329..f97848a4 100755 --- a/demos/demos/franka_reach_demo_script.py +++ b/demos/demos/franka_reach_demo_script.py @@ -4,7 +4,11 @@ This script implements a scripted controller for generating expert demonstration data for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a pick-and-place task. The generated data serves as bootstrapping demonstrations for -training RL agents with stable-baselines3. +training RL agents with your desired algorithm. + +Note that different error thresholds lead to different rewards and in turn done values. +Adjust carefully. Depending on the the controller and clipping scaling (ACTION_MAX) settings, you may get very different behaviors. +The current program clips actions and leads to small increments that allows to more precise movements and close the error. The controller uses a 4-phase hierarchical approach: 1. Approach Object (move gripper above object) @@ -20,6 +24,7 @@ import numpy as np import gym from time import sleep, perf_counter +from datetime import datetime import franka_sim import franka_sim.envs.panda_reach_gym_env as panda_reach_env @@ -42,7 +47,7 @@ Kv = 10.0 ACTION_MAX = 10 # Maximum action value for clipping actions -ERROR_THRESHOLD = 0.01 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. +ERROR_THRESHOLD = 0.008 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. # Number of demonstration episodes to generate NUM_DEMOS = 20 @@ -161,7 +166,7 @@ def store_episode_data(episode_data): truncateds.append(episode_data["truncateds"]) dones.append(episode_data["dones"]) -def update_state_info(episode_data, time_step, dt, error): +def update_state_info(episode_data, time_step, dt, error, reward): """ Update and return the current state information. Always get the latest entry with [-1] @@ -170,6 +175,7 @@ def update_state_info(episode_data, time_step, dt, error): time_step (int): Current time step in the episode. dt (float): Current time step in the episode. error (np.ndarray): Current error vector between object and end-effector positions. + reward (float): Current reward value for the action taken. Returns: object_pos (np.ndarray): Current position of the object in the environment. @@ -185,13 +191,13 @@ def update_state_info(episode_data, time_step, dt, error): # Print debug information print( - f"Time Step: {time_step}, Error: {np.linalg.norm(error):.4f}, " + f"Step: {time_step}, ErrNorm: {np.linalg.norm(error):.4f}, " f"bot_pos: {np.array2string(current_pos, precision=3)}, " f"obj_pos: {np.array2string(object_pos, precision=3)}, " f"fgr_pos: {np.array2string(gripper_pos, precision=2)}, " f"err: {np.array2string(error, precision=3)}, " f"Action: {np.array2string(episode_data['actions'][-1], precision=3)}, " - f"dt: {dt: .4f}" + f"dt: {dt:.4f}, reward: {reward:.3f}" ) @@ -322,7 +328,7 @@ def demo(env, lastObs): store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) # Update and print state information - object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error) + object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) # Update error for next iteration error, derror = compute_error(object_pos, cur_pos, prev_error, dt) @@ -331,8 +337,8 @@ def demo(env, lastObs): time_step += 1 # Sleep - if DEBUG: - sleep(0.25) # Activated when DEBUG is True for better visualization. + #if DEBUG: + sleep(0.25) # Activated when DEBUG is True for better visualization. # Store complete episode data in global lists only if we succeeded (avoid bad demos) store_episode_data(episode_data) @@ -413,6 +419,10 @@ def main(): fileName = "data_" + robot + "_" + task fileName += "_" + initStateSpace fileName += "_" + str(num_demos) + + # Add timestamp to filename for uniqueness + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + fileName += "_" + timestamp fileName += ".npz" # Build a filename in that same directory @@ -422,18 +432,19 @@ def main(): os.makedirs(script_dir, exist_ok=True) # Save collected data to compressed numpy NPZ file - # Set acs,obs,info as keys in dict - np.savez_compressed(out_path, - acs = actions, - obs = observations, - rewards = rewards, - info = infos, - terminateds = terminateds, - truncateds = truncateds, - dones = dones, - transition_ctr = transition_ctr, - num_demos = num_demos - ) + # Set acs,obs,info as keys in dict and values as np.arrays of type objects. This allows you to handle different lengths. + np.savez_compressed( + out_path, + acs=np.array(actions, dtype=object), + obs=np.array(observations, dtype=object), + rewards=np.array(rewards, dtype=object), + info=np.array(infos, dtype=object), + terminateds=np.array(terminateds, dtype=object), + truncateds=np.array(truncateds, dtype=object), + dones=np.array(dones, dtype=object), + transition_ctr=transition_ctr, + num_demos=num_demos + ) print(f"Data saved to {fileName}.") print(f"Total successful demos: {demo_ctr}/{num_demos}") diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py index fc311b87..b1980d7c 100644 --- a/franka_sim/franka_sim/envs/panda_reach_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -37,6 +37,7 @@ def __init__( render_spec: GymRenderingSpec = GymRenderingSpec(), render_mode: Literal["rgb_array", "human"] = None, image_obs: bool = False, + demo: str = "None", ): self._action_scale = action_scale @@ -148,6 +149,8 @@ def __init__( if self.render_mode: self._viewer.render(self.render_mode) + self.demo = demo + def reset( self, seed=None, **kwargs ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: @@ -217,9 +220,11 @@ def step( obs = self._compute_observation() rew = self._compute_reward() - # For reach environment, finger - if rew >= 0.82: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 - terminated = True + # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. + # IF ACCIDENTALLY MERGED, IT WILL REDUCE PERFORMANCE OF THE AGENT. + if self.demo == "franka_reach_demo": + if rew >= 0.85: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 + terminated = True else: # Check if the time limit is exceeded. if self._time_limit is not None: From f0b4df0bbd04084059c298716e4c49b11c3d6b92 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 28 Jul 2025 14:55:40 -0500 Subject: [PATCH 076/141] Buf fix. Datastore should not be required at instantiation time. Only at insertion time. --- demos/demos/demoHandling.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py index c379d293..b0f36d36 100644 --- a/demos/demos/demoHandling.py +++ b/demos/demos/demoHandling.py @@ -27,7 +27,6 @@ class DemoHandling: """ def __init__( self, - ds: QueuedDataStore, demo_dir: str = '/data/data/serl/demos', file_name: str = 'data_franka_reach_random_20.npz' ): @@ -138,15 +137,16 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): print(f"Loaded a total of {num_transitions} from {num_demos} episodes from '{self.demo_npz_path}' ") -if __name__ == "__main__": - # create your datastore; here we use a QueuedDataStore with capacity 2000 - ds = QueuedDataStore(2000) - handler = DemoHandling(ds, - demo_dir='/data/data/serl/demos', - file_name='data_franka_reach_random_20.npz') +# if __name__ == "__main__": +# # Instantiate a DemoHandling object +# handler = DemoHandling(demo_dir='/data/data/serl/demos', +# file_name='data_franka_reach_random_20.npz') - # Idenitfy the total number of transitions in the datastore - print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') +# # Idenitfy the total number of transitions in the datastore +# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') - # Load the demo data into the data_store - handler.insert_data_to_buffer(ds) +# # Simulate SERL's datastore creation w/ capacity 2000 +# ds = QueuedDataStore(2000) + +# # Insert the demo data into the datastore +# handler.insert_data_to_buffer(ds) From 98657d82ecb4fcdb60782fbb32bf8aa8085675f9 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 28 Jul 2025 16:46:34 -0500 Subject: [PATCH 077/141] Fractal Contraction defined and tested. --- .../data/fractal_symmetry_replay_buffer.py | 13 +++++ serl_launcher/serl_launcher/data/fsrb_test.py | 47 ++++++++++++++++--- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 55188efd..93f955d1 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -49,6 +49,10 @@ def __init__( self.branching_factor=kwargs["branching_factor"] del kwargs["branching_factor"] + assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") + self.start_num=kwargs["start_num"] + del kwargs["start_num"] + self.branch = self.fractal_branch # case "linear": @@ -81,6 +85,7 @@ def __init__( self.timestep = 0 self.current_depth = 0 self.branch_index = [workspace_width/2] + self.start_num = kwargs["start_num"] for k in kwargs.keys(): print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") @@ -117,6 +122,14 @@ def fractal_branch(self): # return a new number of branches = branching_factor ^ depth return self.branching_factor ** self.current_depth + def fractal_contraction(self): + # return + b = self.branching_factor + d = self.current_depth + top_b = self.start_num + + return int( top_b/( b**(d-1) )) + # REQUIRES TESTING # def constant_branch(self): # # return current number of branches diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index df7f09b0..31e6cf1f 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -16,6 +16,7 @@ flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("start_num",81, "Initila number of branch on the first depth") # For Fractal Cntraction def main(_): @@ -37,6 +38,8 @@ def main(_): y_obs_idx= y_obs_idx, max_depth=FLAGS.max_depth, branching_factor=FLAGS.branching_factor, + start_num = FLAGS.start_num + ) observation, info = env.reset() @@ -66,7 +69,7 @@ def main(_): # branch() tests - ## fractal + ## fractal -Associative Expansions replay_buffer.branching_factor = 3 @@ -96,6 +99,42 @@ def main(_): print("\033[32mTEST PASSED \033[0m fractal_branch() tests passed") + + # fractal contraction + replay_buffer.branching_factor = 3 + + replay_buffer.current_depth = 1 + result = replay_buffer.fractal_contraction() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 2 + result = replay_buffer.fractal_contraction() + expected = 27 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 3 + result = replay_buffer.fractal_contraction() + expected = 9 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 4 + result = replay_buffer.fractal_contraction() + expected = 3 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 0 + + del result, expected + + print("\033[32mTEST PASSED \033[0m fractal_contraction() tests passed") + + + + + + + # split() tests ## time replay_buffer.max_steps = 100 @@ -144,14 +183,8 @@ def main(_): del result, expected, initial_size, final_size - - print("\033[32mTEST PASSED \033[0m insert() tests passed") - - - - print("\033[32mTEST PASSED \033[0m insert() tests passed") print("\nfinished!\n") From 6add31a647a4a237c41d7c78dd480caca16aff57 Mon Sep 17 00:00:00 2001 From: crosen77 Date: Mon, 28 Jul 2025 17:53:07 -0500 Subject: [PATCH 078/141] omit time measurement printouts --- .../serl_launcher/data/fractal_symmetry_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index ad1f3ee7..7e1ba93d 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -185,4 +185,4 @@ def insert(self, data_dict_not: DatasetDict): self.timestep = 0 finish = dt.now() - print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file + #print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") From fd7093b63a6dd9ad89caad70f56af680d01c6ba8 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 29 Jul 2025 14:00:18 -0500 Subject: [PATCH 079/141] Fixed bugs - Seamlessly run between demos and non demos - Properly consider and create the qds_size for QueuedDataStore in any case. --- demos/demos/demoHandling.py | 6 +- .../async_sac_state_sim.py | 57 +++++++++++-------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py index b0f36d36..d3f3ece0 100644 --- a/demos/demos/demoHandling.py +++ b/demos/demos/demoHandling.py @@ -31,7 +31,7 @@ def __init__( file_name: str = 'data_franka_reach_random_20.npz' ): - self.debug = True # Set to True for debugging purposes + self.debug = False # Set to True for debugging purposes self.demo_dir = demo_dir self.transition_ctr = 0 # Global counter for transitions across all episodes @@ -53,13 +53,13 @@ def get_num_transitions(self): """ Returns the total number of transitions counted in the demo data. """ - return self.data["transition_ctr"] if "transition_ctr" in self.data else 0 + return int(self.data["transition_ctr"]) if "transition_ctr" in self.data else 0 def get_num_demos(self): """ Returns the total number of demonstrations in the demo data. """ - return self.data["num_demos"] if "num_demos" in self.data else 0 + return int(self.data["num_demos"]) if "num_demos" in self.data else 0 def insert_data_to_buffer(self,data_store: QueuedDataStore): """ diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index b0edc9ac..274fc660 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -120,29 +120,30 @@ def update_params(params): # training loop timer = Timer() running_return = 0.0 + + # Load demos: handler.run will insert all transition demo data into the data store. + if FLAGS.load_demos: + with timer.context("sample and step into env with loaded demos"): + + # Insert complete demonstration into the data store + print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") + demos_handler.insert_data_to_buffer(data_store) + FLAGS.random_steps = 0 # Set random steps to 0 since we have demo data + # For subsequent steps, sample actions from the agent for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): timer.tick("total") - # Load demos: handler.run will insert all transition demo data into the data store. - if FLAGS.load_demos: - with timer.context("sample and step into env with loaded demos"): - - # Insert complete demonstration into the data store - print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") - demos_handler.insert_data_to_buffer(data_store) - - else: - with timer.context("sample_actions"): - if step < FLAGS.random_steps: - actions = env.action_space.sample() - else: - sampling_rng, key = jax.random.split(sampling_rng) - actions = agent.sample_actions( - observations=jax.device_put(obs), - seed=key, - deterministic=False, - ) - actions = np.asarray(jax.device_get(actions)) + with timer.context("sample_actions"): + if step < FLAGS.random_steps: + actions = env.action_space.sample() + else: + sampling_rng, key = jax.random.split(sampling_rng) + actions = agent.sample_actions( + observations=jax.device_put(obs), + seed=key, + deterministic=False, + ) + actions = np.asarray(jax.device_get(actions)) # Step environment with timer.context("step_env"): @@ -311,7 +312,7 @@ def main(_): # Demo Data if FLAGS.load_demos: - print_green("loading demo data") + print_green("Setting demo parameters") # Create a handler for the demo data demos_handler = DemoHandling( demo_dir=FLAGS.demo_dir, @@ -321,10 +322,18 @@ def main(_): # 1. Modify actor data_store size # Extract number of demo transitions demo_transitions = demos_handler.get_num_transitions() - qds_size = demo_transitions + 1000 # the queue size on the actor + + if demo_transitions > 2000: + qds_size = demo_transitions + 1000 # Increment the queue size on the actor + else: + qds_size = 2000 # the original queue size on the actor # 2. Modify training starts (since we have good data) - FLAGS.training_starts = 0 + FLAGS.training_starts = 1 + + else: + demos_handler = None + qds_size = 2000 # the original queue size on the actor if FLAGS.learner: @@ -376,7 +385,7 @@ def main(_): # actor loop print_green("starting actor loop") - actor(agent, data_store, env, sampling_rng,demos_handler) + actor(agent, data_store, env, sampling_rng, demos_handler) else: raise NotImplementedError("Must be either a learner or an actor") From c84f08b79b24651c5dd0593b11647700269ccebf Mon Sep 17 00:00:00 2001 From: CleiverRuiz-LU Date: Tue, 29 Jul 2025 21:31:20 -0500 Subject: [PATCH 080/141] Complete Fractal Contraction - new branch method for contraction - Completed Test with FSRB - Added FLAGs to Async_Sac for replayBuffer --- .../async_sac_state_sim.py | 29 +++++---- examples/async_sac_state_sim/run_actor.sh | 32 +++++----- examples/async_sac_state_sim/run_learner.sh | 30 ++++----- .../data/fractal_symmetry_replay_buffer.py | 62 +++++++++++++------ serl_launcher/serl_launcher/data/fsrb_test.py | 11 ++-- 5 files changed, 100 insertions(+), 64 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 6748fc0f..2358122a 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -59,13 +59,18 @@ # flags for replay buffer flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which type of replay buffer to use") -flags.DEFINE_string("branch_method", None, "Method for determining number of transforms per dimension (x,y)") -flags.DEFINE_string("split_method", None, "Method for determining whether to change the number of transforms per dimension (x,y)") # Remember to default None -flags.DEFINE_float("workspace_width", None, "Workspace width in meters") -flags.DEFINE_integer("max_depth", None, "Maximum layers of depth") -flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") -flags.DEFINE_integer("starting_branch_count", None, "Initial number of transformations per dimension (x,y)") -flags.DEFINE_integer("alpha", None, "placeholder") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") # Remember to default None +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in centimeters") # Remember to default None +flags.DEFINE_integer("max_depth",None,"max_depth") +flags.DEFINE_integer("depth", None, "Total layers of depth") +flags.DEFINE_integer("dendrites", None, "Dendrites for fractal branching") +flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") +flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_integer("start_num",None, "Initila number of branch on the first depth") # For Fractal Cntraction +flags.DEFINE_integer("alpha",None,"alpha value") flags.DEFINE_boolean( @@ -193,7 +198,6 @@ def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): # To track the step in the training loop update_steps = 0 - def stats_callback(type: str, payload: dict) -> dict: """Callback for when server receives stats request.""" assert type == "send-stats", f"Invalid request type: {type}" @@ -303,15 +307,16 @@ def main(_): type=FLAGS.replay_buffer_type, branch_method=FLAGS.branch_method, split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, starting_branch_count=FLAGS.starting_branch_count, workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, x_obs_idx=np.array([0,4]), y_obs_idx=np.array([1,5]), preload_rlds_path=FLAGS.preload_rlds_path, - alpha=FLAGS.alpha, - branching_factor=FLAGS.branching_factor, - max_depth=FLAGS.max_depth, - max_traj_length=FLAGS.max_traj_length, + start_num = FLAGS.start_num, + max_depth = FLAGS.max_depth, + alpha = FLAGS.alpha, ) replay_iterator = replay_buffer.get_iterator( sample_args={ diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index f00797a7..9b31bfaa 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -15,24 +15,28 @@ if [ ! -d "$CHECKPOINT_DIR" ]; then } fi -python async_sac_state_sim.py "$@"\ +python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name reach-baseline \ - --replay_buffer_type replay_buffer \ - --max_steps 80000 \ + --exp_name PandaReachCube-v0_state_sim_3D_con-81-3-batch_2048_replay_8M_utd_32 \ + --seed 0 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 300_000 \ --training_starts 1000 \ - --critic_actor_ratio 8 \ - --batch_size 256 \ - --replay_buffer_capacity 100000 \ + --critic_actor_ratio 32 \ + --batch_size 2048 \ + --replay_buffer_capacity 8_000_000 \ --save_model True \ - # --branch_method fractal \ - # --split_method time \ - # --starting_branch_count 1 \ - # --alpha 1 \ - # --max_depth 4 \ - # --branching_factor 3 \ - # --workspace_width 5 \ + --branch_method contraction \ + --split_method time \ + --max_traj_length 100 \ + --max_depth 4 \ + --start_num 81 \ + --alpha 1 \ + --max_depth 4 \ + --branching_factor 3 \ + --workspace_width 0.5 \ + --render # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 6b85e282..dc0047b1 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -18,21 +18,23 @@ fi python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name reach-baseline \ - --replay_buffer_type replay_buffer \ - --max_steps 80000 \ + --exp_name PandaReachCube-v0_state_sim_3D_con-81-3-batch_2048_replay_8M_utd_32 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 30_000 \ --training_starts 1000 \ - --critic_actor_ratio 8 \ - --batch_size 256 \ - --replay_buffer_capacity 100000 \ + --critic_actor_ratio 32 \ + --batch_size 2048 \ + --replay_buffer_capacity 8_000_000 \ --save_model True \ - # --branch_method fractal \ - # --split_method time \ - # --starting_branch_count 1 \ - # --alpha 1 \ - # --max_depth 4 \ - # --branching_factor 3 \ - # --workspace_width 5 \ + --branch_method contraction \ + --split_method time \ + --starting_branch_count 1 \ + --max_traj_length 100 \ + --start_num 81 \ + --alpha 1 \ + --max_depth 4 \ + --branching_factor 3 \ + --workspace_width 0.5 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ - #--debug # wandb is disabled when debug + #--debug # wandb is disabled when debug \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index acb3f3de..5bd0c2af 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -24,7 +24,6 @@ def __init__( case "time": assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_depth") self.max_depth=kwargs["max_depth"] - del kwargs["max_depth"] assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") self.max_traj_length=kwargs["max_traj_length"] @@ -45,34 +44,47 @@ def __init__( method_check = "branch_method" match branch_method: case "fractal": + assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_depth") + self.max_depth=kwargs["max_depth"] assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") self.branching_factor=kwargs["branching_factor"] - del kwargs["branching_factor"] + + assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") + self.start_num=kwargs["start_num"] self.branch = self.fractal_branch + case "contraction": + assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") + self.start_num=kwargs["start_num"] - # case "linear": - # assert "branch_count_rate_of_change" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branch_count_rate_of_change") - # self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] - # del kwargs["branch_count_rate_of_change"] + # Add branching_factor for contraction method + assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") + self.branching_factor = kwargs["branching_factor"] - # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "starting_branch_count") - # self.current_branch_count = kwargs["starting_branch_count"] - # del kwargs["starting_branch_count"] + self.branch = self.fractal_contraction + case "linear": + raise NotImplementedError("linear branch method is not yet implemented") + # TODO: Implement constant branch method + # assert "branch_count_rate_of_change" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branch_count_rate_of_change") + # self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] + # del kwargs["branch_count_rate_of_change"] - # self.branch = self.linear_branch + # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "starting_branch_count") + # self.current_branch_count = kwargs["starting_branch_count"] + # del kwargs["starting_branch_count"] + # self.branch = self.linear_branch case "constant": - assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") - self.current_branch_count = kwargs["starting_branch_count"] - del kwargs["starting_branch_count"] + raise NotImplementedError("constant branch method is not yet implemented") + # TODO: Implement constant branch method + # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") + # self.current_branch_count = kwargs["starting_branch_count"] + # del kwargs["starting_branch_count"] self.branch = self.constant_branch - case "test": self.branch = self.test_branch - case _: raise ValueError("incorrect value passed to branch_method") @@ -81,6 +93,7 @@ def __init__( self.workspace_width = workspace_width self.timestep = 0 self.current_depth = 0 + self.branch_index = [workspace_width/2] self.start_num = kwargs["start_num"] self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) @@ -123,9 +136,18 @@ def fractal_branch(self): # return a new number of branches = branching_factor ^ depth return self.branching_factor ** self.current_depth - def constant_branch(self): - # return current number of branches - return self.current_branch_count + def fractal_contraction(self): + # return + b = self.branching_factor + d = self.current_depth + top_b = self.start_num + + return int( top_b/( b**(d-1) )) + + # REQUIRES TESTING + # def constant_branch(self): + # # return current number of branches + # return self.current_branch_count # REQUIRES TESTING # def linear_branch(self): @@ -185,5 +207,5 @@ def insert(self, data_dict_not: DatasetDict): self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) self.timestep = 0 - finish = dt.now() - print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file + # finish = dt.now() + # print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 31e6cf1f..58518566 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -10,13 +10,15 @@ FLAGS = flags.FLAGS flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "test", "Method for determining the number of transforms per dimension (x,y)") -flags.DEFINE_string("split_method", "test", "Method for determining whether to change the number of transforms per dimension (x,y)") +flags.DEFINE_string("branch_method", "contraction", "Method for determining the number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", "time", "Method for determining whether to change the number of transforms per dimension (x,y)") flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only +flags.DEFINE_integer("max_steps",20000,"Maximum steps") flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only flags.DEFINE_integer("start_num",81, "Initila number of branch on the first depth") # For Fractal Cntraction +flags.DEFINE_integer("alpha",1,"alpha value") def main(_): @@ -37,9 +39,10 @@ def main(_): x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, max_depth=FLAGS.max_depth, + max_traj_length = 150, branching_factor=FLAGS.branching_factor, - start_num = FLAGS.start_num - + start_num = FLAGS.start_num, + alpha = FLAGS.alpha ) observation, info = env.reset() From 65987e17be53d9ff314283a003909ca84fe264eb Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 30 Jul 2025 13:18:24 -0500 Subject: [PATCH 081/141] some options for demos --- .../async_sac_state_sim/.vscode/launch.json | 21 ++++++++++++------- examples/async_sac_state_sim/run_actor.sh | 16 ++++++++------ examples/async_sac_state_sim/run_learner.sh | 16 ++++++++------ 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 1faedff0..30f2242a 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -13,18 +13,18 @@ "args": [ "--learner", // REQUIRED: Indicates this is a learner instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=SERL-ReachTask", // Experiment name for wandb logging + "--exp_name=PandaReachCube-v0_async_sac_state_demos", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility - "--max_steps", "100000", // Maximum training steps - "--training_starts", "1000", // Start training after buffer has this many samples + "--max_steps", "50000 ", // Maximum training steps + "--training_starts", "1000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity // Demonstration data loading options - "--load_demos", // Load demo dataset - "--demo_dir", "/data/data/serl/demos/data_franka_reach_random", - "--file_name", "data_franka_reach_random_20.npz" // Name of the demo file to load + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_20.npz" // Name of the demo file to load ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries @@ -49,14 +49,19 @@ "args": [ "--actor", // REQUIRED: Indicates this is an actor instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=SERL-ReachTask", // Experiment name for wandb logging + "--exp_name=PandaReachCube-v0_async_sac_state_demos", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1000", // Number of random steps at beginning - "--max_steps", "100000", // Maximum training steps + "--max_steps", "80000", // Maximum training steps "--training_starts", "1000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_20.npz" // Name of the demo file to load ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index b8ce3b1b..4b8e2dae 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,18 +1,22 @@ #!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.8 && \ python async_sac_state_sim.py "$@"\ --actor \ --env PandaReachCube-v0 \ - --exp_name reach-baseline \ + --exp_name PandaReachCube-v0_async_sac_state_demos_5_replay_1M_batch_2048_utd_32 \ --replay_buffer_type replay_buffer \ - --max_steps 80000 \ + --max_steps 500_000 \ --training_starts 1000 \ - --critic_actor_ratio 8 \ - --batch_size 256 \ - --replay_buffer_capacity 100000 \ + --random_steps 1000 \ + --critic_actor_ratio 32 \ + --batch_size 2048 \ + --replay_buffer_capacity 1_000_000 \ --save_model True \ + --load_demos \ + --demo_dir /data/data/serl/demos \ + --file_name data_franka_reach_random_5_2.npz \ # --branch_method fractal \ # --split_method time \ # --starting_branch_count 1 \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index bf34e3d1..c6786c8e 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,18 +1,22 @@ #!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name reach-baseline \ + --exp_name PandaReachCube-v0_async_sac_state_demos_5_replay_1M_batch_2048_utd_32 \ --replay_buffer_type replay_buffer \ - --max_steps 80000 \ + --max_steps 50_000 \ --training_starts 1000 \ - --critic_actor_ratio 8 \ - --batch_size 256 \ - --replay_buffer_capacity 100000 \ + --random_steps 1000 \ + --critic_actor_ratio 32 \ + --batch_size 2048 \ + --replay_buffer_capacity 1_000_000 \ --save_model True \ + --load_demos \ + --demo_dir /data/data/serl/demos \ + --file_name data_franka_reach_random_5_2.npz \ # --branch_method fractal \ # --split_method time \ # --starting_branch_count 1 \ From 0ea596b749a78795a6fe99b95e489d12d581a66e Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 30 Jul 2025 13:29:34 -0500 Subject: [PATCH 082/141] letting sh files to have demo info commented out --- examples/async_sac_state_sim/run_actor.sh | 51 +++++++++++++-------- examples/async_sac_state_sim/run_learner.sh | 6 +-- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 4b8e2dae..64caea08 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,31 +1,46 @@ #!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.8 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ -python async_sac_state_sim.py "$@"\ +# Create checkpoint directory if it doesn't exist +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + +python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_async_sac_state_demos_5_replay_1M_batch_2048_utd_32 \ - --replay_buffer_type replay_buffer \ - --max_steps 500_000 \ + --exp_name PandaReachCube-v0_state_sim_3D_con-81-3-batch_2048_replay_8M_utd_32 \ + --seed 0 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 300_000 \ --training_starts 1000 \ - --random_steps 1000 \ --critic_actor_ratio 32 \ --batch_size 2048 \ - --replay_buffer_capacity 1_000_000 \ + --replay_buffer_capacity 8_000_000 \ --save_model True \ - --load_demos \ - --demo_dir /data/data/serl/demos \ - --file_name data_franka_reach_random_5_2.npz \ - # --branch_method fractal \ - # --split_method time \ - # --starting_branch_count 1 \ - # --alpha 1 \ - # --max_depth 4 \ - # --branching_factor 3 \ - # --workspace_width 5 \ + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ + --branch_method contraction \ + --split_method time \ + --max_traj_length 100 \ + --max_depth 4 \ + --start_num 81 \ + --alpha 1 \ + --max_depth 4 \ + --branching_factor 3 \ + --workspace_width 0.5 \ + --render # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ #--debug # wandb is disabled when debug - #--render + #--render \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index e12bf016..8b302299 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -14,9 +14,9 @@ python async_sac_state_sim.py "$@"\ --batch_size 2048 \ --replay_buffer_capacity 1_000_000 \ --save_model True \ - --load_demos \ - --demo_dir /data/data/serl/demos \ - --file_name data_franka_reach_random_5_2.npz \ + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ --branch_method contraction \ --split_method time \ --starting_branch_count 1 \ From e0e06333232b0aeaaa11e21504e73fdbdc2c1aa7 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 30 Jul 2025 14:16:11 -0500 Subject: [PATCH 083/141] update sh files to work with fractal_symmetry_replay_buffer --- examples/async_sac_state_sim/run_actor.sh | 34 +++++++++--------- examples/async_sac_state_sim/run_learner.sh | 40 +++++++++++++-------- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 64caea08..35c3ba3e 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -6,18 +6,18 @@ export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ # Create checkpoint directory if it doesn't exist -if [ ! -d "$CHECKPOINT_DIR" ]; then - echo "Creating checkpoint directory: $CHECKPOINT_DIR" - mkdir -p "$CHECKPOINT_DIR" || { - echo "Failed to create checkpoint directory!" >&2 - exit 1 - } -fi +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_3D_con-81-3-batch_2048_replay_8M_utd_32 \ + --exp_name PandaReachCube-v0_state_sim_demos_27^1_batch_2048_replay_8M_utd_32 \ --seed 0 \ --replay_buffer_type fractal_symmetry_replay_buffer \ --max_steps 300_000 \ @@ -26,21 +26,19 @@ python async_sac_state_sim.py \ --batch_size 2048 \ --replay_buffer_capacity 8_000_000 \ --save_model True \ - # --load_demos \ - # --demo_dir /data/data/serl/demos \ - # --file_name data_franka_reach_random_5_2.npz \ - --branch_method contraction \ - --split_method time \ - --max_traj_length 100 \ + --load_demos \ + --demo_dir /data/data/serl/demos \ + --file_name data_franka_reach_random_5_2.npz \ + --branch_method constant \ + --split_method constant \ + # --max_traj_length 100 \ --max_depth 4 \ - --start_num 81 \ + # --start_num 81 \ --alpha 1 \ - --max_depth 4 \ + # --max_depth 4 \ --branching_factor 3 \ --workspace_width 0.5 \ - --render # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ #--debug # wandb is disabled when debug - #--render \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 8b302299..bfdad420 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,31 +1,43 @@ #!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_async_sac_state_demos_5_replay_1M_batch_2048_utd_32 \ + --exp_name PandaReachCube-v0_state_sim_demos_27^1_batch_2048_replay_8M_utd_32 \ --replay_buffer_type replay_buffer \ --max_steps 50_000 \ --training_starts 1000 \ --random_steps 1000 \ --critic_actor_ratio 32 \ --batch_size 2048 \ - --replay_buffer_capacity 1_000_000 \ + --replay_buffer_capacity 8_000_000 \ --save_model True \ - # --load_demos \ - # --demo_dir /data/data/serl/demos \ - # --file_name data_franka_reach_random_5_2.npz \ - --branch_method contraction \ - --split_method time \ - --starting_branch_count 1 \ - --max_traj_length 100 \ - --start_num 81 \ - --alpha 1 \ - --max_depth 4 \ - --branching_factor 3 \ + --load_demos \ + --demo_dir /data/data/serl/demos \ + --file_name data_franka_reach_random_5_2.npz \ + --branch_method constant \ + --split_method constant \ --workspace_width 0.5 \ + --starting_branch_count 1 \ + # --max_traj_length 100 \ + # --start_num 81 \ + # --alpha 1 \ + # --max_depth 4 \ + # --branching_factor 3 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--debug # wandb is disabled when debug From 794caf50e8861b76ad636fb1666091e8298556d5 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 30 Jul 2025 15:04:12 -0500 Subject: [PATCH 084/141] Undoing comments on constant methods for branch/split. --- .../data/fractal_symmetry_replay_buffer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 5bd0c2af..d98346d0 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -76,11 +76,9 @@ def __init__( # self.branch = self.linear_branch case "constant": - raise NotImplementedError("constant branch method is not yet implemented") - # TODO: Implement constant branch method - # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") - # self.current_branch_count = kwargs["starting_branch_count"] - # del kwargs["starting_branch_count"] + assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") + self.current_branch_count = kwargs["starting_branch_count"] + del kwargs["starting_branch_count"] self.branch = self.constant_branch case "test": @@ -208,4 +206,4 @@ def insert(self, data_dict_not: DatasetDict): self.timestep = 0 # finish = dt.now() - # print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file + # print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") From 692682cacefdb648263b5b82d326f2c6ff4afe1e Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 30 Jul 2025 16:05:00 -0500 Subject: [PATCH 085/141] Updating .sh files and launch.json to have fractal + demo options to run/debug --- .../async_sac_state_sim/.vscode/launch.json | 91 +++++++++++-------- examples/async_sac_state_sim/run_actor.sh | 20 ++-- examples/async_sac_state_sim/run_learner.sh | 16 ++-- 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 30f2242a..24546022 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -11,32 +11,39 @@ "program": "${workspaceFolder}/async_sac_state_sim.py", // Command-line arguments matching run_learner.sh "args": [ - "--learner", // REQUIRED: Indicates this is a learner instance - "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=PandaReachCube-v0_async_sac_state_demos", // Experiment name for wandb logging - "--seed", "0", // Random seed for reproducibility - "--max_steps", "50000 ", // Maximum training steps - "--training_starts", "1000", // Start training after buffer has this many samples - "--critic_actor_ratio", "8", // Critic-to-actor update ratio - "--batch_size", "256", // Training batch size - "--replay_buffer_capacity", "100000", // Replay buffer capacity + "--learner", // REQUIRED: Indicates this is a learner instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name=PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1_000", // Number of random steps at beginning + "--max_steps", "50_000 ", // Maximum training steps + "--training_starts", "1_000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--split_method", "constant", + "--workspace_width", "0.5", // Demonstration data loading options - // "--load_demos", // Load demo dataset - // "--demo_dir", "/data/data/serl/demos", - // "--file_name", "data_franka_reach_random_20.npz" // Name of the demo file to load + "--load_demos", // Load demo dataset + "--demo_dir", "/data/data/serl/demos", + "--file_name", "data_franka_reach_random_5_2.npz" // Name of the demo file to load ], - "console": "integratedTerminal", // Use integrated terminal to see output - "justMyCode": false, // Allow stepping into libraries + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries // Environment variables from run_learner.sh "env": { - "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory - "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% }, // Additional helpful debugging options - "showReturnValue": true, // Show function return values - "purpose": ["debug-in-terminal"], // Run in terminal for better output - "cwd": "${workspaceFolder}" // Set working directory explicitly + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly }, // ACTOR { @@ -47,32 +54,38 @@ "program": "${workspaceFolder}/async_sac_state_sim.py", // Command-line arguments matching run_actor.sh "args": [ - "--actor", // REQUIRED: Indicates this is an actor instance - "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=PandaReachCube-v0_async_sac_state_demos", // Experiment name for wandb logging - "--seed", "0", // Random seed for reproducibility - "--random_steps", "1000", // Number of random steps at beginning - "--max_steps", "80000", // Maximum training steps - "--training_starts", "1000", // Start training after buffer has this many samples - "--critic_actor_ratio", "8", // Critic-to-actor update ratio - "--batch_size", "256", // Training batch size - "--replay_buffer_capacity", "100000", // Replay buffer capacity + "--actor", // REQUIRED: Indicates this is an actor instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name=PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1_000", // Number of random steps at beginning + "--max_steps", "500_000", // Maximum training steps + "--training_starts", "1_000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--split_method", "constant", + "--workspace_width", "0.5", + // Demonstration data loading options - // "--load_demos", // Load demo dataset - // "--demo_dir", "/data/data/serl/demos", - // "--file_name", "data_franka_reach_random_20.npz" // Name of the demo file to load + "--load_demos", // Load demo dataset + "--demo_dir", "/data/data/serl/demos", + "--file_name", "data_franka_reach_random_5_2.npz" // Name of the demo file to load ], - "console": "integratedTerminal", // Use integrated terminal to see output - "justMyCode": false, // Allow stepping into libraries + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries // Environment variables from run_actor.sh "env": { - "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory - "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% }, - "showReturnValue": true, // Show function return values - "purpose": ["debug-in-terminal"], // Run in terminal for better output - "cwd": "${workspaceFolder}" // Set working directory explicitly + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly } ], // Compound configurations to launch multiple configurations together diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 35c3ba3e..816899bd 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -17,7 +17,7 @@ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_demos_27^1_batch_2048_replay_8M_utd_32 \ + --exp_name PandaReachCube-v0_state_sim_demos_5_const_27^1_batch_256_replay_1M_utd_8 \ --seed 0 \ --replay_buffer_type fractal_symmetry_replay_buffer \ --max_steps 300_000 \ @@ -26,19 +26,19 @@ python async_sac_state_sim.py \ --batch_size 2048 \ --replay_buffer_capacity 8_000_000 \ --save_model True \ - --load_demos \ - --demo_dir /data/data/serl/demos \ - --file_name data_franka_reach_random_5_2.npz \ --branch_method constant \ --split_method constant \ + --workspace_width 0.5 \ + --load_demos \ + --demo_dir /data/data/serl/demos \ + --file_name data_franka_reach_random_5_2.npz \ + # --starting_branch_count 1 \ # --max_traj_length 100 \ - --max_depth 4 \ - # --start_num 81 \ - --alpha 1 \ # --max_depth 4 \ - --branching_factor 3 \ - --workspace_width 0.5 \ + # --start_num 81 \ + # --alpha 1 \ + # --branching_factor 3 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ - #--debug # wandb is disabled when debug + #--debug # wandb is disabled when debug \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index bfdad420..43a65005 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -17,8 +17,8 @@ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_demos_27^1_batch_2048_replay_8M_utd_32 \ - --replay_buffer_type replay_buffer \ + --exp_name PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ --max_steps 50_000 \ --training_starts 1000 \ --random_steps 1000 \ @@ -26,19 +26,19 @@ python async_sac_state_sim.py "$@"\ --batch_size 2048 \ --replay_buffer_capacity 8_000_000 \ --save_model True \ - --load_demos \ - --demo_dir /data/data/serl/demos \ - --file_name data_franka_reach_random_5_2.npz \ --branch_method constant \ --split_method constant \ --workspace_width 0.5 \ - --starting_branch_count 1 \ + --load_demos \ + --demo_dir /data/data/serl/demos \ + --file_name data_franka_reach_random_5_2.npz \ + # --starting_branch_count 1 \ # --max_traj_length 100 \ + # --max_depth 4 \ # --start_num 81 \ # --alpha 1 \ - # --max_depth 4 \ # --branching_factor 3 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--debug # wandb is disabled when debug - #--render + #--render \ No newline at end of file From bdfb6f6f8623be88cf68c240995ba322dcf1e5c0 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Wed, 30 Jul 2025 16:53:02 -0500 Subject: [PATCH 086/141] Bug fixing for demos + fractals --- examples/async_sac_state_sim/.vscode/launch.json | 2 ++ examples/async_sac_state_sim/run_actor.sh | 2 +- examples/async_sac_state_sim/run_learner.sh | 2 +- .../serl_launcher/data/fractal_symmetry_replay_buffer.py | 6 +++--- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 24546022..0372d1e7 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -26,6 +26,7 @@ "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer "--branch_method", "constant", "--split_method", "constant", + "--starting_branch_count", "27", // Start with 27 branches "--workspace_width", "0.5", // Demonstration data loading options @@ -69,6 +70,7 @@ "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer "--branch_method", "constant", "--split_method", "constant", + "--starting_branch_count", "27", // Start with 27 branches "--workspace_width", "0.5", // Demonstration data loading options diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 816899bd..f352472b 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -28,11 +28,11 @@ python async_sac_state_sim.py \ --save_model True \ --branch_method constant \ --split_method constant \ + --starting_branch_count 27 \ --workspace_width 0.5 \ --load_demos \ --demo_dir /data/data/serl/demos \ --file_name data_franka_reach_random_5_2.npz \ - # --starting_branch_count 1 \ # --max_traj_length 100 \ # --max_depth 4 \ # --start_num 81 \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 43a65005..86421fda 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -28,11 +28,11 @@ python async_sac_state_sim.py "$@"\ --save_model True \ --branch_method constant \ --split_method constant \ + --starting_branch_count 27 \ --workspace_width 0.5 \ --load_demos \ --demo_dir /data/data/serl/demos \ --file_name data_franka_reach_random_5_2.npz \ - # --starting_branch_count 1 \ # --max_traj_length 100 \ # --max_depth 4 \ # --start_num 81 \ diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index d98346d0..0df57387 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -143,9 +143,9 @@ def fractal_contraction(self): return int( top_b/( b**(d-1) )) # REQUIRES TESTING - # def constant_branch(self): - # # return current number of branches - # return self.current_branch_count + def constant_branch(self): + # return current number of branches + return self.current_branch_count # REQUIRES TESTING # def linear_branch(self): From 592f80efbc38c042b4baed886ef6b32ea5f090b3 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 1 Aug 2025 09:34:09 -0500 Subject: [PATCH 087/141] Workspace Width Method - Modify density of workspace with depth. - Three methods: constant, increase, and decrease. - These changes have been tested in fsrb and by running training on them. - Only need to pass in workspace_width_method as a flag (constant by default). - Uses starting workspace_width value provided. - Changed fractal_replay_buffer prototype. Needs workspace_width_method explicitly. --- .../async_sac_state_sim/.vscode/launch.json | 4 + .../async_sac_state_sim.py | 13 +- examples/async_sac_state_sim/run_actor.sh | 37 ++-- examples/async_sac_state_sim/run_learner.sh | 35 ++-- .../serl_launcher/data/data_store.py | 3 +- .../data/fractal_symmetry_replay_buffer.py | 166 +++++++++++++++--- serl_launcher/serl_launcher/data/fsrb_test.py | 103 +++++++++-- serl_launcher/serl_launcher/utils/launcher.py | 3 +- 8 files changed, 300 insertions(+), 64 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 60496b7a..16a95b29 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -20,6 +20,8 @@ "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity + "--start_num", "81", // Fractal Contraction starting with 81 branches + "--workspace_width_method", "decrease",// Workspace Width Method to increase density ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries @@ -52,6 +54,8 @@ "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity + "--start_num", "81", // Fractal Contraction starting with 81 branches + "--workspace_width_method", "decrease",// Workspace Width Method to increase density ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 2358122a..7b12f020 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -69,14 +69,20 @@ flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch and fractal_contraction -flags.DEFINE_integer("start_num",None, "Initila number of branch on the first depth") # For Fractal Cntraction flags.DEFINE_integer("alpha",None,"alpha value") +# Contraction +flags.DEFINE_integer("start_num",81, "Initial number of branch on the first depth") # For Fractal Cntraction +# Density Workspace width +flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') + +# Debug flags.DEFINE_boolean( "debug", False, "Debug mode." ) # debug mode will disable wandb logging +# Logging flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") @@ -314,9 +320,12 @@ def main(_): x_obs_idx=np.array([0,4]), y_obs_idx=np.array([1,5]), preload_rlds_path=FLAGS.preload_rlds_path, - start_num = FLAGS.start_num, max_depth = FLAGS.max_depth, alpha = FLAGS.alpha, + # Contraction + start_num = FLAGS.start_num, + # Workspace Width Density + workspace_width_method=FLAGS.workspace_width_method ) replay_iterator = replay_buffer.get_iterator( sample_args={ diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 1635f2dc..4cb8a598 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -2,29 +2,42 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_3D_con-81-3-batch_2048_replay_8M_utd_32 \ - --seed 0 \ + --exp_name PandaReachCube-v0_state_sim_3D_con-81-9-batch_256_replay_1M_utd_8_ww_const \ --replay_buffer_type fractal_symmetry_replay_buffer \ + --seed 0 \ --max_steps 300_000 \ --training_starts 1000 \ - --critic_actor_ratio 32 \ - --batch_size 2048 \ - --replay_buffer_capacity 8_000_000 \ + --critic_actor_ratio 8 \ + --batch_size 256 \ + --replay_buffer_capacity 1_000_000 \ --save_model True \ --branch_method contraction \ --split_method time \ --max_traj_length 100 \ - --max_depth 4 \ - --start_num 81 \ --alpha 1 \ - --max_depth 4 \ + --max_depth 2 \ --branching_factor 3 \ --workspace_width 0.5 \ - --render - # --checkpoint_period 10000 \ - # --checkpoint_path "$CHECKPOINT_DIR" \ + --start_num 81 \ + --workspace_width_method constant \ + #--debug \ + #--render + #--checkpoint_period 10000 \ + #--checkpoint_path "$CHECKPOINT_DIR" \ #--render \ - #--debug # wandb is disabled when debug diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 74b470fe..9ba22679 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -2,26 +2,41 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_3D_con-81-3-batch_2048_replay_8M_utd_32 \ + --exp_name PandaReachCube-v0_state_sim_3D_con-81-9-batch_256_replay_1M_utd_8_ww_const \ --replay_buffer_type fractal_symmetry_replay_buffer \ + --seed 0 \ --max_steps 30_000 \ --training_starts 1000 \ - --critic_actor_ratio 32 \ - --batch_size 2048 \ - --replay_buffer_capacity 8_000_000 \ + --critic_actor_ratio 8 \ + --batch_size 256 \ + --replay_buffer_capacity 1_000_000 \ --save_model True \ --branch_method contraction \ --split_method time \ - --starting_branch_count 1 \ --max_traj_length 100 \ - --start_num 81 \ --alpha 1 \ - --max_depth 4 \ + --max_depth 2 \ --branching_factor 3 \ --workspace_width 0.5 \ - # --checkpoint_period 10000 \ - # --checkpoint_path "$CHECKPOINT_DIR" \ - #--debug # wandb is disabled when debug \ No newline at end of file + --start_num 81 \ + --workspace_width_method constant \ + #--debug \ + #--render \ + #--checkpoint_period 10000 \ + #--checkpoint_path "$CHECKPOINT_DIR" \ diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index 1e1ff8e6..c144434c 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -153,10 +153,11 @@ def __init__( branch_method: str, split_method: str, workspace_width: int, + workspace_width_method: str, rlds_logger: Optional[RLDSLogger] = None, **kwargs: dict, ): - FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width, **kwargs) + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width, workspace_width_method,**kwargs) DataStoreBase.__init__(self, capacity) self._lock = Lock() self._logger = None diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index d98346d0..458b3144 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -14,11 +14,17 @@ def __init__( branch_method: str, split_method: str, workspace_width: int, + workspace_width_method: str, kwargs: dict ): + + # Initialize values + self.debug_time = False self.current_branch_count=1 self.update_max_traj_length = False + self.workspace_width = workspace_width + # Set correct split function in self.split pointer. method_check = "split_method" match split_method: case "time": @@ -41,6 +47,7 @@ def __init__( case _: raise ValueError("incorrect value passed to split_method") + # Set correct branch function in self.branch pointer. method_check = "branch_method" match branch_method: case "fractal": @@ -54,6 +61,7 @@ def __init__( self.start_num=kwargs["start_num"] self.branch = self.fractal_branch + case "contraction": assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") self.start_num=kwargs["start_num"] @@ -63,42 +71,69 @@ def __init__( self.branching_factor = kwargs["branching_factor"] self.branch = self.fractal_contraction + case "linear": raise NotImplementedError("linear branch method is not yet implemented") - # TODO: Implement constant branch method - # assert "branch_count_rate_of_change" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branch_count_rate_of_change") - # self.branch_count_rate_of_change=kwargs["branch_count_rate_of_change"] - # del kwargs["branch_count_rate_of_change"] - - # assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "starting_branch_count") - # self.current_branch_count = kwargs["starting_branch_count"] - # del kwargs["starting_branch_count"] - # self.branch = self.linear_branch + case "constant": assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") self.current_branch_count = kwargs["starting_branch_count"] del kwargs["starting_branch_count"] self.branch = self.constant_branch + case "test": self.branch = self.test_branch + case _: raise ValueError("incorrect value passed to branch_method") - + + # Set correct workspace_width function in self.workspace_width_method + match workspace_width_method: + + case "constant": + if workspace_width is None: + raise ValueError("workspace_width must be defined for constant workspace width method") + self.get_workspace_width = self.ww_constant + + case "decrease": + if workspace_width is None: + raise ValueError("workspace_width must be defined for constant workspace width method") + self.get_workspace_width = self.ww_decrease + + case "increase": + if workspace_width is None: + raise ValueError("workspace_width must be defined for constant workspace width method") + self.get_workspace_width = self.ww_increase + + case _: + raise ValueError("incorrect value passed to workspace_width_method") + + #--------------------------------------------------------------------------------------------------------------- + # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations self.x_obs_idx = kwargs["x_obs_idx"] self.y_obs_idx = kwargs["y_obs_idx"] - self.workspace_width = workspace_width + + # Set initial fractal config values self.timestep = 0 self.current_depth = 0 - self.branch_index = [workspace_width/2] - self.start_num = kwargs["start_num"] + self.branch_index = [workspace_width/2] # TODO: is this a correct initialization? + self.start_num = kwargs["start_num"] # Used for fractal contractions + ## TODO: this is done in insert. Are we repeating work? + # Create branch index to control how to set (x,y) offsets of different branches self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) + + ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) + # Set a constant value useful in the computation a standard offset between branches constant = self.workspace_width/(2 * self.current_branch_count) + + # Compute the translation offsets from left edge according to branch index: for i in range(0, self.current_branch_count): self.branch_index[i] = (2 * i + 1) * constant + ## Warn about unused kwargs for k in kwargs.keys(): print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") @@ -110,31 +145,67 @@ def __init__( # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform + # Init replay buffer class super().__init__( observation_space=observation_space, action_space=action_space, capacity=capacity, ) + # Missing arguments def _handle_bad_args_(self, type: str, method: str, arg: str) : return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" + # Update (x,y) observations with translational transformation def transform(self, data_dict: DatasetDict, transform: np.array): # return data_dict with positional arguments += translation data_dict["observations"] += transform data_dict["next_observations"] += transform + #--- BRANCH METHODS --- # FOR TESTING ONLY def test_branch(self): self.current_depth += 1 return 1 def fractal_branch(self): + ''' + Computes the number of branches for the current depth using an exponential growth rule. + + This method implements a "fractal branching" strategy, where the number of branches + increases exponentially with depth. At each depth `d`, the number of branches is calculated as: + + num_branches = branching_factor ** current_depth + + where: + - branching_factor: The base number of branches at each split. + - current_depth: The current depth in the fractal tree (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' # return a new number of branches = branching_factor ^ depth return self.branching_factor ** self.current_depth def fractal_contraction(self): + ''' + Computes the number of branches for the current depth using a contraction rule. + + This method implements a "fractal contraction" branching strategy, where the number + of branches decreases exponentially with depth. At each depth `d`, the number of branches + is calculated as: + + num_branches = start_num / (branching_factor ** (d - 1)) + + where: + - start_num: The initial number of branches at depth 1. + - branching_factor: The factor by which the number of branches contracts at each depth. + - d: The current depth (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' # return b = self.branching_factor d = self.current_depth @@ -142,16 +213,21 @@ def fractal_contraction(self): return int( top_b/( b**(d-1) )) - # REQUIRES TESTING - # def constant_branch(self): - # # return current number of branches - # return self.current_branch_count + def constant_branch(self): + ''' + Used to create pure translations with no further branching. + self.current_branch_count used to set the total number of transformations. + ''' + # return current number of branches + return self.current_branch_count # REQUIRES TESTING # def linear_branch(self): # # return a new number of branches = branches_count + n # return self.current_branch_count + self.branch_count_rate_of_change + + #---- SPLIT METHODS --- # FOR TESTING ONLY def test_split(self, data_dict: DatasetDict): return True @@ -165,39 +241,78 @@ def time_split(self, data_dict: DatasetDict): def constant_split(self, data_dict: DatasetDict): return True + + # Workspace Width Methods + def ww_constant(self): + return self.workspace_width + def ww_increase(self): + ''' + Increase workspace width with depth by 5cm. Lower density. + ''' + return self.workspace_width + self.current_depth*0.05 + + def ww_decrease(self): + ''' + Decrease workspace width with depth by 5cm. Higher density. + ''' + return self.workspace_width - self.current_depth*0.05 + + #--------------------------------------------- + # Now perform fractal transformations. + # Currently iterating through x,y loop (slow). + # TODO: multigpu processing. + def insert(self, data_dict_not: DatasetDict): - sector_1 = dt.now() + # Time + sector_1 = dt.now() if self.debug_time else None data_dict = copy.deepcopy(data_dict_not) # Update number of branches if needed if self.split(data_dict): + + # Get your current branch count temp = self.current_branch_count + + # Get new current branch count based on depth self.current_branch_count = self.branch() + + # When we enter into a new depth and new branch count update the transformations if temp != self.current_branch_count: + + ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) - constant = self.workspace_width/(2 * self.current_branch_count) + constant = self.get_workspace_width()/(2 * self.current_branch_count) + for i in range(0, self.current_branch_count): self.branch_index[i] = (2 * i + 1) * constant - sector_2 = dt.now() # Initialize to extreme x and y - x = -self.workspace_width/2 + sector_2 = dt.now() if self.debug_time else None + + x = -self.get_workspace_width()/2 transform = np.zeros_like(data_dict["observations"]) transform[self.x_obs_idx] = x transform[self.y_obs_idx] = x self.transform(data_dict, transform) - sector_3 = dt.now() + # Transform and insert transitions (multiprocessing in the future) + sector_3 = dt.now() if self.debug_time else None + for x in range(0, self.current_branch_count): + + # Set offset in x-direction transform[self.x_obs_idx] = self.branch_index[x] + for y in range(0, self.current_branch_count): new_data_dict = copy.deepcopy(data_dict) transform[self.y_obs_idx] = self.branch_index[y] self.transform(new_data_dict, transform) super().insert(new_data_dict) - sector_4 = dt.now() + + # Reset current_depth, timestep, and max_traj_length + sector_4 = dt.now() if self.debug_time else None self.timestep += 1 if data_dict["dones"]: self.current_depth = 0 @@ -205,5 +320,6 @@ def insert(self, data_dict_not: DatasetDict): self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) self.timestep = 0 - # finish = dt.now() - # print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") + if self.debug_time: + finish = dt.now() + print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 58518566..9429a3b2 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -17,8 +17,11 @@ flags.DEFINE_integer("max_steps",20000,"Maximum steps") flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only -flags.DEFINE_integer("start_num",81, "Initila number of branch on the first depth") # For Fractal Cntraction flags.DEFINE_integer("alpha",1,"alpha value") +# Contraction +flags.DEFINE_integer("start_num",81, "Initila number of branch on the first depth") # For Fractal Cntraction +# Density Workspace width +flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') def main(_): @@ -39,10 +42,11 @@ def main(_): x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, max_depth=FLAGS.max_depth, - max_traj_length = 150, + max_traj_length = 100, branching_factor=FLAGS.branching_factor, start_num = FLAGS.start_num, - alpha = FLAGS.alpha + alpha = FLAGS.alpha, + workspace_width_method=FLAGS.workspace_width_method ) observation, info = env.reset() @@ -72,8 +76,9 @@ def main(_): # branch() tests - ## fractal -Associative Expansions - + #------------------------------------------------------------------- + # Fractal Associative Expansions + #------------------------------------------------------------------- replay_buffer.branching_factor = 3 replay_buffer.current_depth = 1 @@ -102,8 +107,9 @@ def main(_): print("\033[32mTEST PASSED \033[0m fractal_branch() tests passed") - - # fractal contraction + #------------------------------------------------------------------- + # Fractal Associative Contractions + #------------------------------------------------------------------- replay_buffer.branching_factor = 3 replay_buffer.current_depth = 1 @@ -132,13 +138,10 @@ def main(_): print("\033[32mTEST PASSED \033[0m fractal_contraction() tests passed") - - - - - - + #------------------------------------------------------------------- # split() tests + #------------------------------------------------------------------- + ## time replay_buffer.max_steps = 100 replay_buffer.max_depth = 4 @@ -188,6 +191,80 @@ def main(_): print("\033[32mTEST PASSED \033[0m insert() tests passed") + #------------------------------------------------------------------- + # Fractal Expansions with workspace_width_modification + #------------------------------------------------------------------- + print('\nWorkspace width tests....') + + replay_buffer.branching_factor = 3 + replay_buffer.current_depth = 1 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + + replay_buffer.current_depth = 2 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + + replay_buffer.current_depth = 3 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + replay_buffer.current_depth = 0 + + del result, expected + + print("\n\033[32mTEST PASSED \033[0m workspace_width_method() test passed") + print("\nfinished!\n") diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 5eaab48b..4fe2654d 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -210,7 +210,7 @@ def make_replay_buffer( branch_method: str = None, # used only type=="fractal_symmetry_replay_buffer" split_method : str = None, # used only type=="fractal_symmetry_replay_buffer" workspace_width : float = None, # used only type=="fractal_symmetry_replay_buffer" - + workspace_width_method : str = None, # used only type=="fractal_symmetry_replay_buffer" **kwargs: dict # used only type=="fractal_symmetry_replay_buffer" ): """ @@ -268,6 +268,7 @@ def make_replay_buffer( branch_method=branch_method, split_method=split_method, workspace_width=workspace_width, + workspace_width_method=workspace_width_method, rlds_logger=rlds_logger, kwargs=kwargs, ) From d760188876af7fd4f3e233bbc29fec102b0743cb Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 1 Aug 2025 10:49:35 -0500 Subject: [PATCH 088/141] Updated readme + comments --- README.md | 6 ++++++ demos/demos/demoHandling.py | 13 +++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 9803ab6c..1e399f51 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,12 @@ We fixed a major issue in the intervention action frame. See release [v0.1.1](ht pip install -e . ``` +6. **Install the demos** + ```bash + cd demos + pip install -e . + ``` + ## Overview and Code Structure SERL provides a set of common libraries for users to train RL policies for robotic manipulation tasks. The main structure of running the RL experiments involves having an actor node and a learner node, both of which interact with the robot gym environment. Both nodes run asynchronously, with data being sent from the actor to the learner node via the network using [agentlace](https://github.com/youliangtan/agentlace). The learner will periodically synchronize the policy with the actor. This design provides flexibility for parallel training and inference. diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py index d3f3ece0..04a5e974 100644 --- a/demos/demos/demoHandling.py +++ b/demos/demos/demoHandling.py @@ -67,7 +67,10 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): The .npz file must contain arrays named 'obs', 'acs', 'rewards', 'terminateds', 'truncateds', 'info', and optionally 'dones'. Each episode is processed, and transitions are inserted into the data_store. - Inserted transitions in data store will remain in the data_store as pointers + Inserted transitions in data store will remain in the data_store as pointers. + + ***Note*** + Need to insert obs and acs in the same way as async_sac_state via jax Parameters ---------- @@ -107,11 +110,12 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): T = len(ep_acts) for t in range(T): - obs_t = ep_obs[t] - next_obs_t = ep_obs[t+1] - a_t = ep_acts[t] + obs_t = np.asarray(ep_obs[t], dtype=np.float32) + next_obs_t = np.asarray(ep_obs[t+1], dtype=np.float32) + a_t = np.asarray(ep_acts[t], dtype=np.float32) r_t = float(ep_rews[t]) done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + #info_t = ep_info[t] # masks will be created right before insert below if self.debug: @@ -123,6 +127,7 @@ def insert_data_to_buffer(self,data_store: QueuedDataStore): f"Reward: {r_t:.2f} \n " f"Done: {done_t}") + # Insert using SERLs data_store/ReplayBuffer insert mechanism directly. data_store.insert( dict( observations =obs_t, From 8c09c2ccb8cc9b573aec8747055d7db716b9f344 Mon Sep 17 00:00:00 2001 From: crosen77 Date: Fri, 1 Aug 2025 16:05:04 -0500 Subject: [PATCH 089/141] first attempt at disassociated --- .../async_sac_state_sim/.vscode/launch.json | 16 +++-- .../async_sac_state_sim.py | 10 ++- .../data/fractal_symmetry_replay_buffer.py | 63 +++++++++++++++++++ 3 files changed, 82 insertions(+), 7 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 16a95b29..6c8f450a 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -13,15 +13,17 @@ "args": [ "--learner", // REQUIRED: Indicates this is a learner instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=SERL-ReachTask", // Experiment name for wandb logging + "--exp_name", "disassociated_testing", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--max_steps", "100000", // Maximum training steps "--training_starts", "1000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size - "--replay_buffer_capacity", "100000", // Replay buffer capacity - "--start_num", "81", // Fractal Contraction starting with 81 branches - "--workspace_width_method", "decrease",// Workspace Width Method to increase density + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + "--disassociated_type", "octahedron", // Type of disassociated test to perform + "--min_branch_count", "3", // Minimum branch count for disassociated testing + "--max_branch_count", "9", // Maximum branch count for disassociated testing + "--num_depth_sectors", "13" // Desired number of sectors to divide rollouts into for branch count splitting ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries @@ -54,8 +56,10 @@ "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity - "--start_num", "81", // Fractal Contraction starting with 81 branches - "--workspace_width_method", "decrease",// Workspace Width Method to increase density + "--disassociated_type", "octahedron", // Type of disassociated test to perform + "--min_branch_count", "3", // Minimum branch count for disassociated testing + "--max_branch_count", "9", // Maximum branch count for disassociated testing + "--num_depth_sectors", "13" // Desired number of sectors to divide rollouts into for branch count splitting ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 7b12f020..b10c534a 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -74,8 +74,16 @@ # Contraction flags.DEFINE_integer("start_num",81, "Initial number of branch on the first depth") # For Fractal Cntraction +# Disassociated +flags.DEFINE_enum("disassociated_type", "octahedron", ["octahedron", "hourglass"], + "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + + " Hourglass: Contract from max to min then expand to max") +flags.DEFINE_integer("min_branch_count", 1, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", 1, "Maximum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("num_depth_sectors", 1, "Desired number of sectors to divide rollout into for branch count splitting") + # Density Workspace width -flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') +flags.DEFINE_string("workspace_width_method", 'constant', 'Controls workspace width dimensions configurations') # Debug flags.DEFINE_boolean( diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 458b3144..8c68925c 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -42,6 +42,24 @@ def __init__( case "constant": self.split = self.constant_split + case "disassociated": + assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") + self.max_traj_length=kwargs["max_traj_length"] + update_max_traj_length = True + + assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") + self.alpha=kwargs["alpha"] + + assert "num_depth_sectors" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "num_depth_sectors") + self.num_depth_sectors=kwargs["num_depth_sectors"] + + self.steps_per_depth = int(np.floor(self.max_traj_length/self.num_depth_sectors)) # does this need to be np.float32? + + self.transition_counter = 0 + + self.split = self.disassociated_split + + case "test": self.split = self.test_split case _: @@ -75,6 +93,26 @@ def __init__( case "linear": raise NotImplementedError("linear branch method is not yet implemented") # self.branch = self.linear_branch + + case "disassociated": + assert "min_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "min_branch_count)") + self.min_branch_count = kwargs["min_branch_count"] + + assert "max_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_branch_count)") + self.max_branch_count = kwargs["max_branch_count"] + + if self.min_branch_count > self.max_branch_count: + raise ValueError(f"min_branch_count: {self.min_branch_count} is larger than max_branch_count: {self.max_branch_count}. Max should be larger than min.") + + assert "disassociated_type" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "disassociated_type") + if kwargs["disassociated_type"] == "hourglass": + self.disassociated_type = "hourglass" + self.current_branch_count = self.max_branch_count + elif kwargs["disassociated_type"] == "octahedron": + self.current_branch_count = "octahedron" + self.starting_branch_count = self.min_branch_count + + self.branch = self.disassociated_branch case "constant": assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") @@ -221,6 +259,21 @@ def constant_branch(self): # return current number of branches return self.current_branch_count + def disassociated_branch(self): + ''' + Used to create branches for disassociated fractal methods. + self.min_branch_count specifies the mininum branch count desired during the fractal rollout + self.max_branch_count specifies the maximum branch count desired during the fractal rollout + self.disassociated_type specifies whether to expand and then contract or to contract and then expand + self.steps_per_depth specifies the number of timesteps to take before splitting + (calculated indirectly via self.max_traj_length / self.num_depth_sectors) + self.num_depth_sectors specifies the number of sectors the rollout should be divided into for even splitting + ''' + if self.disassociated_type == "hourglass": + return int((self.max_branch_count - self.min_branch_count)/(self.num_depth_sectors/2) * np.abs(self.current_depth - (self.num_depth_sectors/2)) + self.min_branch_count) + elif self.disassociated_type == "octahedron": + return int((self.min_branch_count - self.max_branch_count)/(self.num_depth_sectors/2) * np.abs(self.current_depth - (self.num_depth_sectors/2)) + self.max_branch_count) + # REQUIRES TESTING # def linear_branch(self): # # return a new number of branches = branches_count + n @@ -241,6 +294,15 @@ def time_split(self, data_dict: DatasetDict): def constant_split(self, data_dict: DatasetDict): return True + + def disassociated_split(self, data_dict: DatasetDict): + if self.transition_counter >= self.steps_per_depth: + self.current_depth += 1 + self.transition_counter = 0 + return True + else: + self.transition_counter += 1 + return False # Workspace Width Methods def ww_constant(self): @@ -318,6 +380,7 @@ def insert(self, data_dict_not: DatasetDict): self.current_depth = 0 if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + self.steps_per_depth = int(np.floor(self.max_traj_length/self.num_depth_sectors)) # does this need to be np.float32? self.timestep = 0 if self.debug_time: From b3dfcc3c954575f2aa38b453c92e6e7b2669511f Mon Sep 17 00:00:00 2001 From: crosen77 Date: Fri, 1 Aug 2025 17:39:48 -0500 Subject: [PATCH 090/141] made some debugging changes. Now it will run and learn. still have to check that it expands and contracts correctly for octohedron, haven't tested hourglass yet. --- .../async_sac_state_sim/.vscode/launch.json | 20 +++++++++++++++---- .../async_sac_state_sim.py | 7 ++++++- .../data/fractal_symmetry_replay_buffer.py | 9 ++++----- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 6c8f450a..1f62625e 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -2,6 +2,9 @@ // VSCode debug configuration version "version": "0.2.0", "configurations": [ + + + // LEARNER { // Configuration for the RL agent learner component @@ -13,17 +16,22 @@ "args": [ "--learner", // REQUIRED: Indicates this is a learner instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name", "disassociated_testing", // Experiment name for wandb logging + "--exp_name=SERL-ReachTask", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility + "--random_steps", "1000", // Number of random steps at beginning "--max_steps", "100000", // Maximum training steps + "--replay_buffer_type","fractal_symmetry_replay_buffer", // reply buffer type "--training_starts", "1000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size - "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + "--replay_buffer_capacity", "100000", // Replay buffer capacity + "--branch_method", "disassociated", // branch method type + "--split_method", "disassociated", // split method type "--disassociated_type", "octahedron", // Type of disassociated test to perform "--min_branch_count", "3", // Minimum branch count for disassociated testing "--max_branch_count", "9", // Maximum branch count for disassociated testing - "--num_depth_sectors", "13" // Desired number of sectors to divide rollouts into for branch count splitting + "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + "--alpha", "1" // alpha ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries @@ -52,14 +60,18 @@ "--seed", "0", // Random seed for reproducibility "--random_steps", "1000", // Number of random steps at beginning "--max_steps", "100000", // Maximum training steps + "--replay_buffer_type","fractal_symmetry_replay_buffer", // reply buffer type "--training_starts", "1000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "100000", // Replay buffer capacity + "--branch_method", "disassociated", // branch method type + "--split_method", "disassociated", // split method type "--disassociated_type", "octahedron", // Type of disassociated test to perform "--min_branch_count", "3", // Minimum branch count for disassociated testing "--max_branch_count", "9", // Maximum branch count for disassociated testing - "--num_depth_sectors", "13" // Desired number of sectors to divide rollouts into for branch count splitting + "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + "--alpha", "1" // alpha ], "console": "integratedTerminal", // Use integrated terminal to see output "justMyCode": false, // Allow stepping into libraries diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index b10c534a..478f0d22 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -333,7 +333,12 @@ def main(_): # Contraction start_num = FLAGS.start_num, # Workspace Width Density - workspace_width_method=FLAGS.workspace_width_method + workspace_width_method=FLAGS.workspace_width_method, + # Disassociated + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, + num_depth_sectors=FLAGS.num_depth_sectors ) replay_iterator = replay_buffer.get_iterator( sample_args={ diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 8c68925c..c825bad6 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -45,7 +45,7 @@ def __init__( case "disassociated": assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") self.max_traj_length=kwargs["max_traj_length"] - update_max_traj_length = True + self.update_max_traj_length = True assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") self.alpha=kwargs["alpha"] @@ -106,12 +106,11 @@ def __init__( assert "disassociated_type" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "disassociated_type") if kwargs["disassociated_type"] == "hourglass": - self.disassociated_type = "hourglass" self.current_branch_count = self.max_branch_count elif kwargs["disassociated_type"] == "octahedron": - self.current_branch_count = "octahedron" - self.starting_branch_count = self.min_branch_count - + self.current_branch_count = self.min_branch_count + self.disassociated_type = kwargs["disassociated_type"] + self.branch = self.disassociated_branch case "constant": From 776c7688d2be8991665076a74e4dbc98d9a5e813 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 1 Aug 2025 18:07:34 -0500 Subject: [PATCH 091/141] First incomplete draft of parallel fractal replay buffer. --- ...fractal_symmetry_replay_buffer_parallel.py | 426 ++++++++++++++++++ 1 file changed, 426 insertions(+) create mode 100644 serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py new file mode 100644 index 00000000..496e691a --- /dev/null +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py @@ -0,0 +1,426 @@ +# Generic support +import copy +from datetime import datetime as dt + +# RL Envs +import gym + +# Data +from serl_launcher.data.dataset import DatasetDict +from serl_launcher.data.replay_buffer import ReplayBuffer + +# Optimization of matrix operations +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax + + +class FractalSymmetryReplayBuffer(ReplayBuffer): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + branch_method: str, + split_method: str, + workspace_width: int, + workspace_width_method: str, + kwargs: dict + ): + + # Initialize values + self.debug_time = False + self.current_branch_count=1 + self.update_max_traj_length = False + self.workspace_width = workspace_width + self.current_branch_count = 0 + self.num_transforms = 0 + + # Set correct split function in self.split pointer. + method_check = "split_method" + match split_method: + case "time": + assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_depth") + self.max_depth=kwargs["max_depth"] + + assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") + self.max_traj_length=kwargs["max_traj_length"] + + assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") + self.alpha=kwargs["alpha"] + update_max_traj_length = True + self.split = self.time_split + + case "constant": + self.split = self.constant_split + + case "test": + self.split = self.test_split + case _: + raise ValueError("incorrect value passed to split_method") + + # Set correct branch function in self.branch pointer. + method_check = "branch_method" + match branch_method: + case "fractal": + assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_depth") + self.max_depth=kwargs["max_depth"] + + assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") + self.branching_factor=kwargs["branching_factor"] + + assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") + self.start_num=kwargs["start_num"] + + self.branch = self.fractal_branch + + case "contraction": + assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") + self.start_num=kwargs["start_num"] + + # Add branching_factor for contraction method + assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") + self.branching_factor = kwargs["branching_factor"] + + self.branch = self.fractal_contraction + + case "linear": + raise NotImplementedError("linear branch method is not yet implemented") + # self.branch = self.linear_branch + + case "constant": + assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") + self.current_branch_count = kwargs["starting_branch_count"] + del kwargs["starting_branch_count"] + + self.branch = self.constant_branch + + case "test": + self.branch = self.test_branch + + case _: + raise ValueError("incorrect value passed to branch_method") + + # Set correct workspace_width function in self.workspace_width_method + match workspace_width_method: + + case "constant": + if workspace_width is None: + raise ValueError("workspace_width must be defined for constant workspace width method") + self.get_workspace_width = self.ww_constant + + case "decrease": + if workspace_width is None: + raise ValueError("workspace_width must be defined for constant workspace width method") + self.get_workspace_width = self.ww_decrease + + case "increase": + if workspace_width is None: + raise ValueError("workspace_width must be defined for constant workspace width method") + self.get_workspace_width = self.ww_increase + + case _: + raise ValueError("incorrect value passed to workspace_width_method") + + #--------------------------------------------------------------------------------------------------------------- + # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations + self.x_obs_idx = kwargs["x_obs_idx"] + self.y_obs_idx = kwargs["y_obs_idx"] + + # Set initial fractal config values + self.timestep = 0 + self.current_depth = 0 + self.branch_index = [workspace_width/2] # TODO: is this a correct initialization? + self.start_num = kwargs["start_num"] # Used for fractal contractions + + ## TODO: this is done in insert. Are we repeating work? + # Create branch index to control how to set (x,y) offsets of different branches + self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) + + ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) + # Set a constant value useful in the computation a standard offset between branches + constant = self.workspace_width/(2 * self.current_branch_count) + + # Compute the translation offsets from left edge according to branch index: + for i in range(0, self.current_branch_count): + self.branch_index[i] = (2 * i + 1) * constant + + ## Warn about unused kwargs + for k in kwargs.keys(): + print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") + + # TODO + # Create more methods with tests + # + # Future considerations: + # when dealing with x and y for transforms, consider two versions: + # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms + # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform + + # Init replay buffer class + super().__init__( + observation_space=observation_space, + action_space=action_space, + capacity=capacity, + ) + + # Missing arguments + def _handle_bad_args_(self, type: str, method: str, arg: str) : + return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" + + + # Update (x,y) observations with translational transformation + def transform(self, data_dict: DatasetDict, transform: np.array): + # return data_dict with positional arguments += translation + data_dict["observations"] += transform + data_dict["next_observations"] += transform + + #--- BRANCH METHODS --- + # FOR TESTING ONLY + def test_branch(self): + self.current_depth += 1 + return 1 + + def fractal_branch(self): + ''' + Computes the number of branches for the current depth using an exponential growth rule. + + This method implements a "fractal branching" strategy, where the number of branches + increases exponentially with depth. At each depth `d`, the number of branches is calculated as: + + num_branches = branching_factor ** current_depth + + where: + - branching_factor: The base number of branches at each split. + - current_depth: The current depth in the fractal tree (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' + # return a new number of branches = branching_factor ^ depth + return self.branching_factor ** self.current_depth + + def fractal_contraction(self): + ''' + Computes the number of branches for the current depth using a contraction rule. + + This method implements a "fractal contraction" branching strategy, where the number + of branches decreases exponentially with depth. At each depth `d`, the number of branches + is calculated as: + + num_branches = start_num / (branching_factor ** (d - 1)) + + where: + - start_num: The initial number of branches at depth 1. + - branching_factor: The factor by which the number of branches contracts at each depth. + - d: The current depth (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' + # return + b = self.branching_factor + d = self.current_depth + top_b = self.start_num + + return int( top_b/( b**(d-1) )) + + def constant_branch(self): + ''' + Used to create pure translations with no further branching. + self.current_branch_count used to set the total number of transformations. + ''' + # return current number of branches + return self.current_branch_count + + # REQUIRES TESTING + # def linear_branch(self): + # # return a new number of branches = branches_count + n + # return self.current_branch_count + self.branch_count_rate_of_change + + + #---- SPLIT METHODS --- + # FOR TESTING ONLY + def test_split(self, data_dict: DatasetDict): + return True + + # REQUIRES TESTING + def time_split(self, data_dict: DatasetDict): + if self.timestep % (self.max_traj_length//self.max_depth) or self.current_depth >= self.max_depth: + return False + self.current_depth += 1 + return True + + def constant_split(self, data_dict: DatasetDict): + return True + + # Workspace Width Methods + def ww_constant(self): + return self.workspace_width + + def ww_increase(self): + ''' + Increase workspace width with depth by 5cm. Lower density. + ''' + return self.workspace_width + self.current_depth*0.05 + + def ww_decrease(self): + ''' + Decrease workspace width with depth by 5cm. Higher density. + ''' + return self.workspace_width - self.current_depth*0.05 + + #--------------------------------------------- + + def compute_transformation(self, data_dict, num_transforms): + ''' + + ''' + + # Generate observation matrices + o_m, no_m = self.generate_obs_mats(data_dict, num_transforms) + + # Transform o_m and no_m using fractal equation (use lax.scan inside for incremental change) + del_o_m = lax.scan( self.compute_delta, o_m ) + del_no_m = lax.scan( self.compute_deltas, no_m) + + return del_o_m,del_no_m + + + def compute_deltas(self,mat,type='grid'): # add type as jnp array + ''' + Given a type of fractal framework perform the operation + ''' + + if type == 'grid': + + # Iterate through matrix cols + idx = jnp.arange(self.num_transforms) + + # Convert idx into arrays (i,j) that tell us the x,y order for deltas + i,j = jnp.divmod( idx, self.current_branch_count ) + + # 1. Get the boundary of the workspace environment + edge = self.get_workspace_width()/2. + + # 2. Get array of translation deltas: ( (2i+1)/2b ) * ww + constant = self.get_workspace_width() / (2. * self.current_branch_count) + x_offset = ( (2 * i) + 1) * constant + y_offset = ( (2 * j) + 1) * constant + + # 3. Enter for block and robot x,y positions for all transformations + # from jax import lax + + # m_updated = lax.scatter(m, + # indices=r[:, None], + # updates=new_rows, + # dimension_numbers=lax.ScatterDimensionNumbers( + # update_window_dims=(1,), + # inserted_window_dims=(0,), + # scatter_dims_to_operand_dims=(0,) + # ), + # indices_are_sorted=False, + # unique_indices=True +) + + + else: + raise TypeError('transform_matrix: no correct fractal framework type was given...') + + def generate_obs_mats(self, data_dict: DatasetDict, num_transforms: int, type='grid'): + ''' + Generate observations matrices, whose cols repeat according to the correct number of transformations + + Args: + - data_dict (DatasetDict): + - num_transforms (int): computes the total number of transformations for the grid stategy + - type (str): grid or radial fractal framework + + Return: + - o_m: observation matrix + - no_m: next observation matrix + ''' + + # Convert observations to a vector + o_col = jnp.array(data_dict["observations"], type=jnp.float32) + no_col = jnp.array(data_dict["next_observations"], type=jnp.float32) + + # Expand observations column vector to a matrix using tile + o_m = jnp.tile(o_col, (1,num_transforms)) + no_m = jnp.tile(no_col, (1,num_transforms)) + + return o_m, no_m + + def insert_mat_replay_buffer(self, del_o, del_no, new_data_dict: DatasetDict): + # Convert back to dict by replacing data_dict's observations with new ones by the columns and insert into replay buffer one-by-one. + # TODO: Is there a batch_insert? + + # Is this the fast way to create these dicts from matrix cols? + for col in range( del_o.shape[1] ): + new_data_dict['observations'] = np.array( del_o[:, col],type=np.float32) + new_data_dict['next_observations'] = np.array( del_no[:, col],type=np.float32) + + # Insert each col's new info + super().insert(new_data_dict) + + def insert(self, data_dict: DatasetDict): + ''' + + Note: + If observations for all time-steps are available, then a batch_transform + ''' + + # Time + sector_1 = dt.now() if self.debug_time else None + + # Update number of branches if needed + if self.split(data_dict): + + # Get the current branch count + temp = self.current_branch_count + + # Get new current branch count based on depth + self.current_branch_count = self.branch() + + # When we enter into a new depth and new branch count update the transformations + if temp != self.current_branch_count: + + # Get the number of transformations + self.num_transforms = self.current_branch_count**2 + + # Compute transformation deltas for observations + self.delta_o_m, self.delta_no_m = self.compute_transformation(data_dict, self.num_transforms) + + # Add observations with deltas + transf_o_m = o_m + self.delta_o_m + transf_no_m = no_m + self.delta_no_m + + insert_mat_replay_buffer(transf_o_m, transf_no_m) + + #else?? + + # Generate observation matrices with same + o_m, no_m = self.generate_obs_mats(data_dict, self.num_transforms) + + # Add observations with deltas + transf_o_m = o_m + self.delta_o_m + transf_no_m = no_m + self.delta_no_m + + # Insert to replay buffer + insert_mat_replay_buffer(transf_o_m, transf_no_m) + + + # Reset current_depth, timestep, and max_traj_length + sector_4 = dt.now() if self.debug_time else None + + self.timestep += 1 + if data_dict["dones"]: + self.current_depth = 0 + if self.update_max_traj_length: + self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + self.timestep = 0 + + if self.debug_time: + finish = dt.now() + print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file From d947657fe108e49655afaa3598908e8b804e07a8 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Sun, 3 Aug 2025 22:42:31 -0500 Subject: [PATCH 092/141] First draft of optimized/parallelized fractal_symmetry_replay_buffer_parallel class. Using jax/jit and setting all transforms as matrices and element sums. --- ...fractal_symmetry_replay_buffer_parallel.py | 72 ++++++++++--------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py index 496e691a..466ac1b4 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py @@ -12,9 +12,8 @@ # Optimization of matrix operations import numpy as np import jax +from jax import jit import jax.numpy as jnp -from jax import lax - class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( @@ -169,7 +168,6 @@ def __init__( def _handle_bad_args_(self, type: str, method: str, arg: str) : return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" - # Update (x,y) observations with translational transformation def transform(self, data_dict: DatasetDict, transform: np.array): # return data_dict with positional arguments += translation @@ -272,20 +270,20 @@ def ww_decrease(self): return self.workspace_width - self.current_depth*0.05 #--------------------------------------------- - + @jit def compute_transformation(self, data_dict, num_transforms): ''' - + Add fractal transformations (translations to matrix of observations) ''' # Generate observation matrices o_m, no_m = self.generate_obs_mats(data_dict, num_transforms) # Transform o_m and no_m using fractal equation (use lax.scan inside for incremental change) - del_o_m = lax.scan( self.compute_delta, o_m ) - del_no_m = lax.scan( self.compute_deltas, no_m) + trans_o_m = self.compute_deltas( o_m ) + trans_no_m = self.compute_deltas( no_m) - return del_o_m,del_no_m + return trans_o_m,trans_no_m def compute_deltas(self,mat,type='grid'): # add type as jnp array @@ -304,30 +302,23 @@ def compute_deltas(self,mat,type='grid'): # add type as jnp array # 1. Get the boundary of the workspace environment edge = self.get_workspace_width()/2. - # 2. Get array of translation deltas: ( (2i+1)/2b ) * ww + # 2. Get array of translation deltas for x,y offsets: ( (2i+1)/2b ) * ww constant = self.get_workspace_width() / (2. * self.current_branch_count) x_offset = ( (2 * i) + 1) * constant y_offset = ( (2 * j) + 1) * constant - # 3. Enter for block and robot x,y positions for all transformations - # from jax import lax - - # m_updated = lax.scatter(m, - # indices=r[:, None], - # updates=new_rows, - # dimension_numbers=lax.ScatterDimensionNumbers( - # update_window_dims=(1,), - # inserted_window_dims=(0,), - # scatter_dims_to_operand_dims=(0,) - # ), - # indices_are_sorted=False, - # unique_indices=True -) - - + # 3. Generate the matrix grid of transformations by element-wise sum of deltas on the matrix of env. observations. Use self.x_obs_idx & y_obs_idx. + # Perform separate updates + trans_mat = ( + mat + .at[self.x_obs_idx].add(x_offset) # add x_offset to each of the two x-rows + .at[self.y_obs_idx].add(y_offset) # add y_offset to each of the two y-rows + ) else: raise TypeError('transform_matrix: no correct fractal framework type was given...') + return trans_mat + def generate_obs_mats(self, data_dict: DatasetDict, num_transforms: int, type='grid'): ''' Generate observations matrices, whose cols repeat according to the correct number of transformations @@ -390,13 +381,9 @@ def insert(self, data_dict: DatasetDict): self.num_transforms = self.current_branch_count**2 # Compute transformation deltas for observations - self.delta_o_m, self.delta_no_m = self.compute_transformation(data_dict, self.num_transforms) - - # Add observations with deltas - transf_o_m = o_m + self.delta_o_m - transf_no_m = no_m + self.delta_no_m + trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) - insert_mat_replay_buffer(transf_o_m, transf_no_m) + insert_mat_replay_buffer(trans_o_m, trans_no_m) #else?? @@ -408,7 +395,7 @@ def insert(self, data_dict: DatasetDict): transf_no_m = no_m + self.delta_no_m # Insert to replay buffer - insert_mat_replay_buffer(transf_o_m, transf_no_m) + batch_insert(transf_o_m, transf_no_m) # Reset current_depth, timestep, and max_traj_length @@ -423,4 +410,23 @@ def insert(self, data_dict: DatasetDict): if self.debug_time: finish = dt.now() - print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") \ No newline at end of file + print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") + + def batch_insert(self, batch: DatasetDict): + """ + batch is a dict of NumPy arrays, each of shape (B, *field_shape). + We insert B new entries in one go. + """ + B = batch["observations"].shape[0] + # wrap‐around indices + idxs = (np.arange(self._insert_index, + self._insert_index + B) % self._capacity) + + # for each key, write the entire batch into self.dataset_dict + for k, array in batch.items(): + # e.g. self.dataset_dict["observations"][idxs] = batch["observations"] + self.dataset_dict[k][idxs] = array + + # advance pointers + self._insert_index = idxs[-1] + 1 + self._size = min(self._size + B, self._capacity) From a6eae2f4dfc1d6a2dbfcad70ff48cd6d81d62830 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 10:48:44 -0500 Subject: [PATCH 093/141] draft 2. improved batch class. still no testing. --- ...fractal_symmetry_replay_buffer_parallel.py | 114 +++++++++++++----- 1 file changed, 83 insertions(+), 31 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py index 466ac1b4..c15efce2 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py @@ -273,7 +273,21 @@ def ww_decrease(self): @jit def compute_transformation(self, data_dict, num_transforms): ''' - Add fractal transformations (translations to matrix of observations) + Add fractal transformations (translations) to matrix of observations in jitted form. + - data_dict of transitions comes in, a matrix of size (num_transforms,num_transforms) comes out. + - compute_deltas will create a matrix of (x,y) transformations and add them to the expanded observations. + - results for observations and next_observations saved in trans_o_m and trans_no_m. + + Only this parent function needs to be jitted. Helper functions below do not need since they are never called independently. + + Params: + - data_dict (DatasetDict) - dictionary of translations + - num_transforms (int) - number of transformations + + Returns: + - trans_o_m (jnp.ndarray) - this is a (num_transform,num_transform) matrix + - trans_no_m (jnp.ndarray) - this is a (num_transform,num_transform) matrix + ''' # Generate observation matrices @@ -285,7 +299,6 @@ def compute_transformation(self, data_dict, num_transforms): return trans_o_m,trans_no_m - def compute_deltas(self,mat,type='grid'): # add type as jnp array ''' Given a type of fractal framework perform the operation @@ -374,7 +387,7 @@ def insert(self, data_dict: DatasetDict): # Get new current branch count based on depth self.current_branch_count = self.branch() - # When we enter into a new depth and new branch count update the transformations + # New depth -> new branch count. Hence update the transformations: if temp != self.current_branch_count: # Get the number of transformations @@ -383,20 +396,15 @@ def insert(self, data_dict: DatasetDict): # Compute transformation deltas for observations trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) - insert_mat_replay_buffer(trans_o_m, trans_no_m) + # Insert to replay buffer + batch_insert(trans_o_m, trans_no_m) #else?? - - # Generate observation matrices with same - o_m, no_m = self.generate_obs_mats(data_dict, self.num_transforms) - - # Add observations with deltas - transf_o_m = o_m + self.delta_o_m - transf_no_m = no_m + self.delta_no_m - - # Insert to replay buffer - batch_insert(transf_o_m, transf_no_m) - + # Compute transformation deltas for observations + trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) + + # Insert as batch into replay buffer + batch_insert(trans_o_m, trans_no_m, data_dict) # Reset current_depth, timestep, and max_traj_length sector_4 = dt.now() if self.debug_time else None @@ -412,21 +420,65 @@ def insert(self, data_dict: DatasetDict): finish = dt.now() print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") - def batch_insert(self, batch: DatasetDict): + def batch_insert(self, trans_o_m: jnp.ndarray, trans_no_m: jnp.ndarray, data_dict: DatasetDict) -> None: """ - batch is a dict of NumPy arrays, each of shape (B, *field_shape). - We insert B new entries in one go. + Inserts N data-augmented transitions into self.dataset_dict all at once. + + Args: + data_dict: original single transition dict with keys + - "observations" shape (obs_dim,) + - "next_observations" shape (obs_dim,) + - "actions" shape (act_dim,) + - "rewards" scalar or shape () + - "dones" scalar bool + - "masks" scalar or shape () + trans_o_m: jnp.ndarray (obs_dim, N) — each column is a new obs + trans_no_m: jnp.ndarray (obs_dim, N) — each column is new next_obs + + Returns: + None. Mutates self.dataset_dict and advances the insert pointer. """ - B = batch["observations"].shape[0] - # wrap‐around indices - idxs = (np.arange(self._insert_index, - self._insert_index + B) % self._capacity) - - # for each key, write the entire batch into self.dataset_dict - for k, array in batch.items(): - # e.g. self.dataset_dict["observations"][idxs] = batch["observations"] - self.dataset_dict[k][idxs] = array - - # advance pointers - self._insert_index = idxs[-1] + 1 - self._size = min(self._size + B, self._capacity) + + # Extract number of transformations directly from the data + N = trans_o_m.shape[1] + + # Expand observations and next observations + # Transpose matrix of observations and next_observations as a numpy array: trans_o_m: (obs_dim, N) → we want (N, obs_dim) in NumPy + obs_batch = np.array(trans_o_m.T) # shape (N, obs_dim) + next_obs_batch= np.array(trans_no_m.T) # shape (N, obs_dim) + + # Expand actions: + actions = np.array(data_dict["actions"], dtype=np.float32) + + # If shape (act_dim,), broadcast to (N, act_dim) + actions_batch = np.broadcast_to(actions, (N,) + actions.shape) + + # Rewards/dones/masks batches + rewards = float(data_dict["rewards"]) + dones = bool(data_dict["dones"]) + masks = bool(data_dict["masks"]) # check type + + rewards_batch = np.full((N,), rewards, dtype=np.float32) + dones_batch = np.full((N,), dones, dtype=bool) + masks_batch = np.full((N,), masks, dtype=np.float32) + + batch = { + "observations": obs_batch, + "next_observations": next_obs_batch, + "actions": actions_batch, + "rewards": rewards_batch, + "dones": dones_batch, + "masks": masks_batch, + } + + # Create appropriate indeces for circular buffer. Consider overflow: + idxs = (self._insert_index + np.arange(N)) % self._capacity + + # One‐shot write into each field + for key, arr in batch.items(): # arr has shape (N, *field_shape) + self.dataset_dict[key][idxs] = arr + + # Update pointers + self._insert_index = int(idxs[-1]) + 1 + self._size = min(self._size + N, self._capacity) + \ No newline at end of file From 6b1bad94fa13fd87008f170a9b1da7a76a4d16db Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 11:21:20 -0500 Subject: [PATCH 094/141] Modified FractalSymmetryReplayBufferParallel so that it subclasses FractalSymmetryReplayBuffer. Also added necessary files in the data_store.py for FractalSymmetryReplayBufferParallel. Still needs testing. --- .../serl_launcher/data/data_store.py | 67 ++++- ...fractal_symmetry_replay_buffer_parallel.py | 248 ++---------------- 2 files changed, 94 insertions(+), 221 deletions(-) diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index c144434c..f43d126b 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -11,6 +11,10 @@ FractalSymmetryReplayBuffer ) +from serl_launcher.data.fractal_symmetry_replay_buffer_parallel import ( + FractalSymmetryReplayBufferParallel +) + from agentlace.data.data_store import DataStoreBase from typing import List, Optional, TypeVar @@ -206,6 +210,68 @@ def latest_data_id(self): def get_latest_data(self, from_id: int): raise NotImplementedError # TODO +class FractalSymmetryReplayBufferParallelDataStore(FractalSymmetryReplayBufferParallel, DataStoreBase): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + branch_method: str, + split_method: str, + workspace_width: int, + workspace_width_method: str, + rlds_logger: Optional[RLDSLogger] = None, + **kwargs: dict, + ): + FractalSymmetryReplayBufferParallel.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width, workspace_width_method,**kwargs) + DataStoreBase.__init__(self, capacity) + self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger + + # ensure thread safety + def insert(self, data): + with self._lock: + super(FractalSymmetryReplayBufferParallelDataStore, self).insert(data) + + # TODO: Data logging currently does NOT WORK as shown if we want to log our transformed transitions + # add data to the rlds logger + if self._logger: + if self.step_type in { + RLDSStepType.TERMINATION, + RLDSStepType.TRUNCATION, + }: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["dones"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: check if this is correct + reward=data["rewards"], + step_type=self.step_type, + ) + + # ensure thread safety + def sample(self, *args, **kwargs): + with self._lock: + return super(FractalSymmetryReplayBufferParallelDataStore, self).sample(*args, **kwargs) + + # NOTE: method for DataStoreBase + def latest_data_id(self): + return self._insert_index + + # NOTE: method for DataStoreBase + def get_latest_data(self, from_id: int): + raise NotImplementedError # TODO + def populate_data_store( data_store: DataStoreBase, demos_path: str, @@ -224,7 +290,6 @@ def populate_data_store( print(f"Loaded {len(data_store)} transitions.") return data_store - def populate_data_store_with_z_axis_only( data_store: DataStoreBase, demos_path: str, diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py index c15efce2..4e03ac3e 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py @@ -7,7 +7,7 @@ # Data from serl_launcher.data.dataset import DatasetDict -from serl_launcher.data.replay_buffer import ReplayBuffer +from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer # Optimization of matrix operations import numpy as np @@ -15,7 +15,10 @@ from jax import jit import jax.numpy as jnp -class FractalSymmetryReplayBuffer(ReplayBuffer): +class FractalSymmetryReplayBufferParallel(FractalSymmetryReplayBuffer): + ''' + Override .insert(), add batch_insert(), override __init__ if needed, etc. + ''' def __init__( self, observation_space: gym.Space, @@ -27,123 +30,29 @@ def __init__( workspace_width_method: str, kwargs: dict ): + + # parallel-specific setup... + super().__init__( + observation_space=observation_space, + action_space=action_space, + capacity=capacity, + branch_method=branch_method, + split_method=split_method, + workspace_width=workspace_width, + workspace_width_method=workspace_width_method, + kwargs=kwargs, + ) - # Initialize values - self.debug_time = False - self.current_branch_count=1 - self.update_max_traj_length = False - self.workspace_width = workspace_width - self.current_branch_count = 0 - self.num_transforms = 0 - - # Set correct split function in self.split pointer. - method_check = "split_method" - match split_method: - case "time": - assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_depth") - self.max_depth=kwargs["max_depth"] - - assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") - self.max_traj_length=kwargs["max_traj_length"] - - assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") - self.alpha=kwargs["alpha"] - update_max_traj_length = True - self.split = self.time_split - - case "constant": - self.split = self.constant_split - - case "test": - self.split = self.test_split - case _: - raise ValueError("incorrect value passed to split_method") - - # Set correct branch function in self.branch pointer. - method_check = "branch_method" - match branch_method: - case "fractal": - assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_depth") - self.max_depth=kwargs["max_depth"] - - assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") - self.branching_factor=kwargs["branching_factor"] - - assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") - self.start_num=kwargs["start_num"] - - self.branch = self.fractal_branch - - case "contraction": - assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") - self.start_num=kwargs["start_num"] - - # Add branching_factor for contraction method - assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") - self.branching_factor = kwargs["branching_factor"] - - self.branch = self.fractal_contraction - - case "linear": - raise NotImplementedError("linear branch method is not yet implemented") - # self.branch = self.linear_branch - - case "constant": - assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") - self.current_branch_count = kwargs["starting_branch_count"] - del kwargs["starting_branch_count"] - - self.branch = self.constant_branch - - case "test": - self.branch = self.test_branch - - case _: - raise ValueError("incorrect value passed to branch_method") - - # Set correct workspace_width function in self.workspace_width_method - match workspace_width_method: - - case "constant": - if workspace_width is None: - raise ValueError("workspace_width must be defined for constant workspace width method") - self.get_workspace_width = self.ww_constant - - case "decrease": - if workspace_width is None: - raise ValueError("workspace_width must be defined for constant workspace width method") - self.get_workspace_width = self.ww_decrease - - case "increase": - if workspace_width is None: - raise ValueError("workspace_width must be defined for constant workspace width method") - self.get_workspace_width = self.ww_increase - - case _: - raise ValueError("incorrect value passed to workspace_width_method") - - #--------------------------------------------------------------------------------------------------------------- - # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations - self.x_obs_idx = kwargs["x_obs_idx"] - self.y_obs_idx = kwargs["y_obs_idx"] - - # Set initial fractal config values - self.timestep = 0 - self.current_depth = 0 - self.branch_index = [workspace_width/2] # TODO: is this a correct initialization? - self.start_num = kwargs["start_num"] # Used for fractal contractions - - ## TODO: this is done in insert. Are we repeating work? # Create branch index to control how to set (x,y) offsets of different branches - self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) + self.branch_index_parallel = np.empty(self.current_branch_count, dtype=np.float32) ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) # Set a constant value useful in the computation a standard offset between branches - constant = self.workspace_width/(2 * self.current_branch_count) + constant_parallel = self.workspace_width/(2 * self.current_branch_count) # Compute the translation offsets from left edge according to branch index: for i in range(0, self.current_branch_count): - self.branch_index[i] = (2 * i + 1) * constant + self.branch_index_parallel[i] = (2 * i + 1) * constant_parallel ## Warn about unused kwargs for k in kwargs.keys(): @@ -164,111 +73,6 @@ def __init__( capacity=capacity, ) - # Missing arguments - def _handle_bad_args_(self, type: str, method: str, arg: str) : - return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" - - # Update (x,y) observations with translational transformation - def transform(self, data_dict: DatasetDict, transform: np.array): - # return data_dict with positional arguments += translation - data_dict["observations"] += transform - data_dict["next_observations"] += transform - - #--- BRANCH METHODS --- - # FOR TESTING ONLY - def test_branch(self): - self.current_depth += 1 - return 1 - - def fractal_branch(self): - ''' - Computes the number of branches for the current depth using an exponential growth rule. - - This method implements a "fractal branching" strategy, where the number of branches - increases exponentially with depth. At each depth `d`, the number of branches is calculated as: - - num_branches = branching_factor ** current_depth - - where: - - branching_factor: The base number of branches at each split. - - current_depth: The current depth in the fractal tree (self.current_depth). - - Returns: - int: The computed number of branches for the current depth. - ''' - # return a new number of branches = branching_factor ^ depth - return self.branching_factor ** self.current_depth - - def fractal_contraction(self): - ''' - Computes the number of branches for the current depth using a contraction rule. - - This method implements a "fractal contraction" branching strategy, where the number - of branches decreases exponentially with depth. At each depth `d`, the number of branches - is calculated as: - - num_branches = start_num / (branching_factor ** (d - 1)) - - where: - - start_num: The initial number of branches at depth 1. - - branching_factor: The factor by which the number of branches contracts at each depth. - - d: The current depth (self.current_depth). - - Returns: - int: The computed number of branches for the current depth. - ''' - # return - b = self.branching_factor - d = self.current_depth - top_b = self.start_num - - return int( top_b/( b**(d-1) )) - - def constant_branch(self): - ''' - Used to create pure translations with no further branching. - self.current_branch_count used to set the total number of transformations. - ''' - # return current number of branches - return self.current_branch_count - - # REQUIRES TESTING - # def linear_branch(self): - # # return a new number of branches = branches_count + n - # return self.current_branch_count + self.branch_count_rate_of_change - - - #---- SPLIT METHODS --- - # FOR TESTING ONLY - def test_split(self, data_dict: DatasetDict): - return True - - # REQUIRES TESTING - def time_split(self, data_dict: DatasetDict): - if self.timestep % (self.max_traj_length//self.max_depth) or self.current_depth >= self.max_depth: - return False - self.current_depth += 1 - return True - - def constant_split(self, data_dict: DatasetDict): - return True - - # Workspace Width Methods - def ww_constant(self): - return self.workspace_width - - def ww_increase(self): - ''' - Increase workspace width with depth by 5cm. Lower density. - ''' - return self.workspace_width + self.current_depth*0.05 - - def ww_decrease(self): - ''' - Decrease workspace width with depth by 5cm. Higher density. - ''' - return self.workspace_width - self.current_depth*0.05 - #--------------------------------------------- @jit def compute_transformation(self, data_dict, num_transforms): @@ -316,9 +120,9 @@ def compute_deltas(self,mat,type='grid'): # add type as jnp array edge = self.get_workspace_width()/2. # 2. Get array of translation deltas for x,y offsets: ( (2i+1)/2b ) * ww - constant = self.get_workspace_width() / (2. * self.current_branch_count) - x_offset = ( (2 * i) + 1) * constant - y_offset = ( (2 * j) + 1) * constant + constant_parallel = self.get_workspace_width() / (2. * self.current_branch_count) + x_offset = ( (2 * i) + 1) * constant_parallel + y_offset = ( (2 * j) + 1) * constant_parallel # 3. Generate the matrix grid of transformations by element-wise sum of deltas on the matrix of env. observations. Use self.x_obs_idx & y_obs_idx. # Perform separate updates @@ -394,16 +198,20 @@ def insert(self, data_dict: DatasetDict): self.num_transforms = self.current_branch_count**2 # Compute transformation deltas for observations + sector_2 = dt.now() if self.debug_time else None trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) # Insert to replay buffer - batch_insert(trans_o_m, trans_no_m) + sector_3 = dt.now() if self.debug_time else None + batch_insert(trans_o_m, trans_no_m, data_dict) #else?? # Compute transformation deltas for observations + sector_2 = dt.now() if self.debug_time else None trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) # Insert as batch into replay buffer + sector_3 = dt.now() if self.debug_time else None batch_insert(trans_o_m, trans_no_m, data_dict) # Reset current_depth, timestep, and max_traj_length From 23aa69f68792b1590572dd1ca9b9eab1c774f66a Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 11:33:36 -0500 Subject: [PATCH 095/141] Updated ancillary files to called the function appropriately. - async_sac_state_sim - launch.json - launcher.py Still needs testing. --- .../async_sac_state_sim/.vscode/launch.json | 2 +- .../async_sac_state_sim.py | 36 +++++++++---------- serl_launcher/serl_launcher/utils/launcher.py | 14 ++++++++ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index cce68381..45071ac9 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -27,7 +27,7 @@ "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity // Fractal - "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer + "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer "--branch_method", "constant", "--split_method", "constant", "--starting_branch_count", "27", // Start with 27 branches diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 971f46f3..cb76d9c6 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -358,29 +358,29 @@ def main(_): sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) replay_buffer = make_replay_buffer( env, - capacity=FLAGS.replay_buffer_capacity, - rlds_logger_path=FLAGS.log_rlds_path, - type=FLAGS.replay_buffer_type, - branch_method=FLAGS.branch_method, - split_method=FLAGS.split_method, - branching_factor=FLAGS.branching_factor, + capacity =FLAGS.replay_buffer_capacity, + rlds_logger_path =FLAGS.log_rlds_path, + type =FLAGS.replay_buffer_type, + branch_method =FLAGS.branch_method, + split_method =FLAGS.split_method, + branching_factor =FLAGS.branching_factor, starting_branch_count=FLAGS.starting_branch_count, - workspace_width=FLAGS.workspace_width, - max_traj_length=FLAGS.max_traj_length, - x_obs_idx=np.array([0,4]), - y_obs_idx=np.array([1,5]), - preload_rlds_path=FLAGS.preload_rlds_path, - max_depth = FLAGS.max_depth, - alpha = FLAGS.alpha, + workspace_width =FLAGS.workspace_width, + max_traj_length =FLAGS.max_traj_length, + x_obs_idx =np.array([0,4]), + y_obs_idx =np.array([1,5]), + preload_rlds_path =FLAGS.preload_rlds_path, + max_depth =FLAGS.max_depth, + alpha =FLAGS.alpha, # Contraction - start_num = FLAGS.start_num, + start_num =FLAGS.start_num, # Workspace Width Density workspace_width_method=FLAGS.workspace_width_method, # Disassociated - disassociated_type=FLAGS.disassociated_type, - min_branch_count=FLAGS.min_branch_count, - max_branch_count=FLAGS.max_branch_count, - num_depth_sectors=FLAGS.num_depth_sectors + disassociated_type =FLAGS.disassociated_type, + min_branch_count =FLAGS.min_branch_count, + max_branch_count =FLAGS.max_branch_count, + num_depth_sectors =FLAGS.num_depth_sectors ) replay_iterator = replay_buffer.get_iterator( sample_args={ diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 4fe2654d..03227c66 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -19,6 +19,7 @@ MemoryEfficientReplayBufferDataStore, ReplayBufferDataStore, FractalSymmetryReplayBufferDataStore, + FractalSymmetryReplayBufferParallelDataStore, ) ############################################################################## @@ -272,6 +273,19 @@ def make_replay_buffer( rlds_logger=rlds_logger, kwargs=kwargs, ) + elif type == "fractal_symmetry_replay_buffer_parallel": + replay_buffer = FractalSymmetryReplayBufferParallelDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + branch_method=branch_method, + split_method=split_method, + workspace_width=workspace_width, + workspace_width_method=workspace_width_method, + rlds_logger=rlds_logger, + kwargs=kwargs, + ) + else: raise ValueError(f"Unsupported replay_buffer_type: {type}") From 5585112aed7098b17c9055430219a45b0bc57640 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 11:35:28 -0500 Subject: [PATCH 096/141] updated shell files. --- examples/async_sac_state_sim/run_actor.sh | 2 +- examples/async_sac_state_sim/run_learner.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index cc2d90a7..c2a72ea3 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -19,7 +19,7 @@ python async_sac_state_sim.py \ --env PandaReachCube-v0 \ --exp_name PandaReachCube-v0_state_sim_demos_5_const_27^1_batch_256_replay_1M_utd_8 \ --seed 0 \ - --replay_buffer_type fractal_symmetry_replay_buffer \ + --replay_buffer_type fractal_symmetry_replay_buffer_parallel \ --seed 0 \ --max_steps 300_000 \ --training_starts 1000 \ diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 86421fda..c90659be 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -18,7 +18,7 @@ python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ --exp_name PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8 \ - --replay_buffer_type fractal_symmetry_replay_buffer \ + --replay_buffer_type fractal_symmetry_replay_buffer_parallel \ --max_steps 50_000 \ --training_starts 1000 \ --random_steps 1000 \ From b04e871f470f7fed67bb0e8b6f7ee59126176cc6 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 11:39:00 -0500 Subject: [PATCH 097/141] small change to test 3 translations with parallel --- .../#async_sac_state_sim.py# | 401 ++++++++++++++++++ examples/async_sac_state_sim/run_actor.sh | 16 +- examples/async_sac_state_sim/run_learner.sh | 20 +- 3 files changed, 419 insertions(+), 18 deletions(-) create mode 100644 examples/async_sac_state_sim/#async_sac_state_sim.py# diff --git a/examples/async_sac_state_sim/#async_sac_state_sim.py# b/examples/async_sac_state_sim/#async_sac_state_sim.py# new file mode 100644 index 00000000..50ff52f0 --- /dev/null +++ b/examples/async_sac_state_sim/#async_sac_state_sim.py# @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 + +import time +from functools import partial + +import gym +import jax +import jax.numpy as jnp +import numpy as np +import tqdm +from absl import app, flags +from flax.training import checkpoints + +from agentlace.data.data_store import QueuedDataStore +from agentlace.trainer import TrainerClient, TrainerServer +from serl_launcher.utils.launcher import ( + make_sac_agent, + make_trainer_config, + make_wandb_logger, + make_replay_buffer, +) + +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics +from serl_launcher.agents.continuous.sac import SACAgent +from serl_launcher.common.evaluation import evaluate +from serl_launcher.utils.timer_utils import Timer + +import franka_sim + +from demos.demoHandling import DemoHandling + +FLAGS = flags.FLAGS + +flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") +flags.DEFINE_string("agent", "sac", "Name of agent.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") +flags.DEFINE_integer("seed", 42, "Random seed.") +flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") +flags.DEFINE_integer("critic_actor_ratio", 8, "critic to actor update ratio.") + +flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") +flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") + +flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.") +flags.DEFINE_integer("training_starts", 300, "Training starts after this step.") +flags.DEFINE_integer("steps_per_update", 30, "Number of steps per update the server.") + +flags.DEFINE_integer("log_period", 10, "Logging period.") +flags.DEFINE_integer("eval_period", 2000, "Evaluation period.") +flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.") + +# flag to indicate if this is a learner or a actor +flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") +flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") +flags.DEFINE_boolean("render", False, "Render the environment.") +flags.DEFINE_string("ip", "localhost", "IP address of the learner.") +flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") +flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") + +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which type of replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") # Remember to default None +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in centimeters") # Remember to default None +flags.DEFINE_integer("max_depth",None,"max_depth") +flags.DEFINE_integer("depth", None, "Total layers of depth") +flags.DEFINE_integer("dendrites", None, "Dendrites for fractal branching") +flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") +flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_integer("start_num",None, "Initila number of branch on the first depth") # For Fractal Cntraction +flags.DEFINE_integer("alpha",None,"alpha value") + + +flags.DEFINE_boolean( + "debug", False, "Debug mode." +) # debug mode will disable wandb logging + +flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") + + +# Load demonstation data +flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") +flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") +flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") + +def print_green(x): + return print("\033[92m {}\033[00m".format(x)) + + +############################################################################## + + +def actor(agent: SACAgent, data_store, env, sampling_rng, demos_handler=None): + """ + This is the actor loop, which runs when "--actor" is set to True. + """ + client = TrainerClient( + "actor_env", + FLAGS.ip, + make_trainer_config(), + data_store, + wait_for_server=True, + ) + + # Function to update the agent with new params + def update_params(params): + nonlocal agent + agent = agent.replace(state=agent.state.replace(params=params)) + + client.recv_network_callback(update_params) + + eval_env = gym.make(FLAGS.env) + #if FLAGS.env == "PandaPickCube-v0": + eval_env = gym.wrappers.FlattenObservation(eval_env) ## Note!! + eval_env = RecordEpisodeStatistics(eval_env) + + obs, _ = env.reset() + done = False + + # training loop + timer = Timer() + running_return = 0.0 + + # Load demos: handler.run will insert all transition demo data into the data store. + if FLAGS.load_demos: + with timer.context("sample and step into env with loaded demos"): + + # Insert complete demonstration into the data store via demo Handling class + print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") + demos_handler.insert_data_to_buffer(data_store) + FLAGS.random_steps = 0 # Set random steps to 0 since we have demo data + # For subsequent steps, sample actions from the agent + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): + timer.tick("total") + + with timer.context("sample_actions"): + if step < FLAGS.random_steps: + actions = env.action_space.sample() + else: + # When interacting with agent, put obs into jax device and convert output to np.ndarrar + sampling_rng, key = jax.random.split(sampling_rng) + actions = agent.sample_actions( + observations=jax.device_put(obs), + seed=key, + deterministic=False, + ) + actions = np.asarray(jax.device_get(actions)) + + # Step environment + with timer.context("step_env"): + + next_obs, reward, done, truncated, info = env.step(actions) + next_obs = np.asarray(next_obs, dtype=np.float32) + reward = np.asarray(reward, dtype=np.float32) + + running_return += reward + + data_store.insert( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done or truncated, + ) + ) + + obs = next_obs + if done or truncated: + running_return = 0.0 + obs, _ = env.reset() + + if FLAGS.render: + env.render() + + if step % FLAGS.steps_per_update == 0: + client.update() + + if step % FLAGS.eval_period == 0: + with timer.context("eval"): + evaluate_info = evaluate( + policy_fn=partial(agent.sample_actions, argmax=True), + env=eval_env, + num_episodes=FLAGS.eval_n_trajs, + ) + stats = {"eval": evaluate_info} + client.request("send-stats", stats) + + timer.tock("total") + + if step % FLAGS.log_period == 0: + stats = {"timer": timer.get_average_times()} + client.request("send-stats", stats) + + +############################################################################## + + +def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): + """ + The learner loop, which runs when "--learner" is set to True. + """ + # set up wandb and logging + wandb_logger = make_wandb_logger( + project=FLAGS.exp_name, + description=FLAGS.exp_name or FLAGS.env, + debug=FLAGS.debug, + ) + + # To track the step in the training loop + update_steps = 0 + def stats_callback(type: str, payload: dict) -> dict: + """Callback for when server receives stats request.""" + assert type == "send-stats", f"Invalid request type: {type}" + if wandb_logger is not None: + wandb_logger.log(payload, step=update_steps) + return {} # not expecting a response + + # Create server + server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server.register_data_store("actor_env", replay_buffer) + server.start(threaded=True) + + # Loop to wait until replay_buffer is filled + pbar = tqdm.tqdm( + total=FLAGS.training_starts, + initial=len(replay_buffer), + desc="Filling up replay buffer", + position=0, + leave=True, + ) + while len(replay_buffer) < FLAGS.training_starts: + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + time.sleep(1) + pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + pbar.close() + + # send the initial network to the actor + server.publish_network(agent.state.params) + print_green("sent initial network to actor") + + # wait till the replay buffer is filled with enough data + timer = Timer() + + # show replay buffer progress bar during training + pbar = tqdm.tqdm( + total=FLAGS.replay_buffer_capacity, + initial=len(replay_buffer), + desc="replay buffer", + ) + + for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"): + # Train the networks + with timer.context("sample_replay_buffer"): + batch = next(replay_iterator) + + with timer.context("train"): + agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.critic_actor_ratio) + agent = jax.block_until_ready(agent) + + # publish the updated network + server.publish_network(agent.state.params) + + if update_steps % FLAGS.log_period == 0 and wandb_logger: + wandb_logger.log(update_info, step=update_steps) + wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps) + + if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0: + assert FLAGS.checkpoint_path is not None + checkpoints.save_checkpoint( + FLAGS.checkpoint_path, agent.state, step=update_steps, keep=20 + ) + + pbar.update(len(replay_buffer) - pbar.n) # update replay buffer bar + update_steps += 1 + + +############################################################################## + + +def main(_): + devices = jax.local_devices() + num_devices = len(devices) + sharding = jax.sharding.PositionalSharding(devices) + assert FLAGS.batch_size % num_devices == 0 + + # seed + rng = jax.random.PRNGKey(FLAGS.seed) + + # create env and load dataset + if FLAGS.render: + env = gym.make(FLAGS.env, render_mode="human") + else: + env = gym.make(FLAGS.env) + + #if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + + rng, sampling_rng = jax.random.split(rng) + agent: SACAgent = make_sac_agent( + seed=FLAGS.seed, + sample_obs=env.observation_space.sample(), + sample_action=env.action_space.sample(), + ) + + # replicate agent across devices + # need the jnp.array to avoid a bug where device_put doesn't recognize primitives + agent: SACAgent = jax.device_put( + jax.tree.map(jnp.array, agent), sharding.replicate() + ) + + # Demo Data + if FLAGS.load_demos: + print_green("Setting demo parameters") + # Create a handler for the demo data + demos_handler = DemoHandling( + demo_dir=FLAGS.demo_dir, + file_name=FLAGS.file_name, + ) + + # 1. Modify actor data_store size + # Extract number of demo transitions + demo_transitions = demos_handler.get_num_transitions() + + if demo_transitions > 2000: + qds_size = demo_transitions + 1000 # Increment the queue size on the actor + else: + qds_size = 2000 # the original queue size on the actor + + # 2. Modify training starts (since we have good data) + FLAGS.training_starts = 1 + + else: + demos_handler = None + qds_size = 2000 # the original queue size on the actor + + + if FLAGS.learner: + sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) + replay_buffer = make_replay_buffer( + env, + capacity=FLAGS.replay_buffer_capacity, + rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=np.array([0,4]), + y_obs_idx=np.array([1,5]), + preload_rlds_path=FLAGS.preload_rlds_path, + start_num = FLAGS.start_num, + max_depth = FLAGS.max_depth, + alpha = FLAGS.alpha, + ) + replay_iterator = replay_buffer.get_iterator( + sample_args={ + "batch_size": FLAGS.batch_size * FLAGS.critic_actor_ratio, + }, + device=sharding.replicate(), + ) + # learner loop + print_green("starting learner loop") + learner( + sampling_rng, + agent, + replay_buffer, + replay_iterator=replay_iterator, + ) + + elif FLAGS.actor: + sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) + + if FLAGS.load_demos: + print_green("loading demo data") + + # Create a data store for the actor + data_store = QueuedDataStore(qds_size) # the queue size on the actor + else: + print_green("no demo data, using empty data store") + # Create a data store for the actor + data_store = QueuedDataStore(2000) # the queue size on the actor + + # actor loop + print_green("starting actor loop") + actor(agent, data_store, env, sampling_rng, demos_handler) + + else: + raise NotImplementedError("Must be either a learner or an actor") + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index c2a72ea3..bc1ead2e 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -17,7 +17,7 @@ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_demos_5_const_27^1_batch_256_replay_1M_utd_8 \ + --exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8 \ --seed 0 \ --replay_buffer_type fractal_symmetry_replay_buffer_parallel \ --seed 0 \ @@ -29,17 +29,17 @@ python async_sac_state_sim.py \ --save_model True \ --branch_method constant \ --split_method constant \ - --starting_branch_count 27 \ + --starting_branch_count 3 \ --workspace_width 0.5 \ - --load_demos \ - --demo_dir /data/data/serl/demos \ - --file_name data_franka_reach_random_5_2.npz \ - # --max_traj_length 100 \ + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ + # # --max_traj_length 100 \ # --max_depth 4 \ # --start_num 81 \ - # --alpha 1 \ + --alpha 1 \ # --branching_factor 3 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ #--render \ - #--debug # wandb is disabled when debug \ No newline at end of file + --debug # wandb is disabled when debug \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index c90659be..47fabcf9 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -17,28 +17,28 @@ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8 \ + --exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8 \ --replay_buffer_type fractal_symmetry_replay_buffer_parallel \ --max_steps 50_000 \ --training_starts 1000 \ --random_steps 1000 \ - --critic_actor_ratio 32 \ - --batch_size 2048 \ - --replay_buffer_capacity 8_000_000 \ + --critic_actor_ratio 8 \ + --batch_size 256 \ + --replay_buffer_capacity 1_000_000 \ --save_model True \ --branch_method constant \ --split_method constant \ - --starting_branch_count 27 \ + --starting_branch_count 3 \ --workspace_width 0.5 \ - --load_demos \ - --demo_dir /data/data/serl/demos \ - --file_name data_franka_reach_random_5_2.npz \ + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ # --max_traj_length 100 \ # --max_depth 4 \ # --start_num 81 \ - # --alpha 1 \ + --alpha 1 \ # --branching_factor 3 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ - #--debug # wandb is disabled when debug + --debug # wandb is disabled when debug #--render \ No newline at end of file From 84b19c40d0e742827700bf0a91d3975d9ea9ad77 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 11:41:30 -0500 Subject: [PATCH 098/141] forgot to update launch.json with parallel friendly parameters --- .../async_sac_state_sim/.vscode/launch.json | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index 45071ac9..a40a96dc 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -17,7 +17,7 @@ "--learner", // REQUIRED: Indicates this is a learner instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1_000", // Number of random steps at beginning "--max_steps", "50_000 ", // Maximum training steps @@ -30,21 +30,21 @@ "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer "--branch_method", "constant", "--split_method", "constant", - "--starting_branch_count", "27", // Start with 27 branches + "--starting_branch_count", "3", // Start with 27 branches "--workspace_width", "0.5", // Demonstration data loading options - "--load_demos", // Load demo dataset - "--demo_dir", "/data/data/serl/demos", - "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load // Dissasociated Fractals - "--branch_method", "disassociated", // branch method type - "--split_method", "disassociated", // split method type - "--disassociated_type", "octahedron", // Type of disassociated test to perform - "--min_branch_count", "3", // Minimum branch count for disassociated testing - "--max_branch_count", "9", // Maximum branch count for disassociated testing - "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting "--alpha", "1" // alpha ], @@ -71,34 +71,34 @@ "args": [ "--actor", // REQUIRED: Indicates this is an actor instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name=PandaReachCube-v0_state_sim_3d_demos_5_const_27^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1_000", // Number of random steps at beginning - "--max_steps", "500_000", // Maximum training steps + "--max_steps", "500_000 ", // Maximum training steps "--training_starts", "1_000", // Start training after buffer has this many samples "--critic_actor_ratio", "8", // Critic-to-actor update ratio "--batch_size", "256", // Training batch size "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity // Fractal - "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer + "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer "--branch_method", "constant", "--split_method", "constant", - "--starting_branch_count", "27", // Start with 27 branches - "--workspace_width", "0.5", - + "--starting_branch_count", "3", // Start with 27 branches + "--workspace_width", "0.5", + // Demonstration data loading options - "--load_demos", // Load demo dataset - "--demo_dir", "/data/data/serl/demos", - "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load - // Disassociated Fractals - "--branch_method", "disassociated", // branch method type - "--split_method", "disassociated", // split method type - "--disassociated_type", "octahedron", // Type of disassociated test to perform - "--min_branch_count", "3", // Minimum branch count for disassociated testing - "--max_branch_count", "9", // Maximum branch count for disassociated testing - "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting "--alpha", "1" // alpha ], From 654e8f56ccd2c27c4e3d0565e43ef6930e277cca Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 4 Aug 2025 12:32:18 -0500 Subject: [PATCH 099/141] Bug fixes - Some bug fixes related to jit declaration of compute_transformation with static arguments self and num_transforms. --- .../async_sac_state_sim/.vscode/launch.json | 4 +- ...fractal_symmetry_replay_buffer_parallel.py | 38 +++---------------- 2 files changed, 8 insertions(+), 34 deletions(-) diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json index a40a96dc..a1768ccb 100644 --- a/examples/async_sac_state_sim/.vscode/launch.json +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -17,7 +17,7 @@ "--learner", // REQUIRED: Indicates this is a learner instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--exp_name", "PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1_000", // Number of random steps at beginning "--max_steps", "50_000 ", // Maximum training steps @@ -71,7 +71,7 @@ "args": [ "--actor", // REQUIRED: Indicates this is an actor instance "--env", "PandaReachCube-v0", // Environment to use - "--exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--exp_name", "PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging "--seed", "0", // Random seed for reproducibility "--random_steps", "1_000", // Number of random steps at beginning "--max_steps", "500_000 ", // Maximum training steps diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py index 4e03ac3e..6ecca136 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py @@ -14,6 +14,7 @@ import jax from jax import jit import jax.numpy as jnp +from functools import partial class FractalSymmetryReplayBufferParallel(FractalSymmetryReplayBuffer): ''' @@ -43,38 +44,11 @@ def __init__( kwargs=kwargs, ) - # Create branch index to control how to set (x,y) offsets of different branches - self.branch_index_parallel = np.empty(self.current_branch_count, dtype=np.float32) - - ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) - # Set a constant value useful in the computation a standard offset between branches - constant_parallel = self.workspace_width/(2 * self.current_branch_count) - - # Compute the translation offsets from left edge according to branch index: - for i in range(0, self.current_branch_count): - self.branch_index_parallel[i] = (2 * i + 1) * constant_parallel - - ## Warn about unused kwargs - for k in kwargs.keys(): - print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") - - # TODO - # Create more methods with tests - # - # Future considerations: - # when dealing with x and y for transforms, consider two versions: - # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms - # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform - - # Init replay buffer class - super().__init__( - observation_space=observation_space, - action_space=action_space, - capacity=capacity, - ) + # Set the total number of transformations to the number of transformations squared for the grid method. May need to change for radial method later. + self.num_transforms = self.current_branch_count ** 2 #--------------------------------------------- - @jit + @partial(jax.jit, static_argnums=(0, 2)) # states that self and num_transforms are not array like and do not need to be traced and need to be marked as static. def compute_transformation(self, data_dict, num_transforms): ''' Add fractal transformations (translations) to matrix of observations in jitted form. @@ -151,8 +125,8 @@ def generate_obs_mats(self, data_dict: DatasetDict, num_transforms: int, type='g ''' # Convert observations to a vector - o_col = jnp.array(data_dict["observations"], type=jnp.float32) - no_col = jnp.array(data_dict["next_observations"], type=jnp.float32) + o_col = jnp.array(data_dict["observations"], dtype=jnp.float32) + no_col = jnp.array(data_dict["next_observations"], dtype=jnp.float32) # Expand observations column vector to a matrix using tile o_m = jnp.tile(o_col, (1,num_transforms)) From dafd77f64e1e9d0169e04ed807d78238da870796 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Mon, 4 Aug 2025 22:57:48 -0500 Subject: [PATCH 100/141] Parallelized insert function including batch insert. Created _handle_methods_ and _handle_method_arg_ --- .../serl_launcher/data/data_store.py | 72 +--- .../data/fractal_symmetry_replay_buffer.py | 319 ++++++++---------- ...fractal_symmetry_replay_buffer_parallel.py | 266 --------------- serl_launcher/serl_launcher/data/fsrb_test.py | 27 +- .../serl_launcher/data/replay_buffer.py | 25 +- serl_launcher/serl_launcher/utils/launcher.py | 17 +- 6 files changed, 175 insertions(+), 551 deletions(-) delete mode 100644 serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index f43d126b..b22de984 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -11,10 +11,6 @@ FractalSymmetryReplayBuffer ) -from serl_launcher.data.fractal_symmetry_replay_buffer_parallel import ( - FractalSymmetryReplayBufferParallel -) - from agentlace.data.data_store import DataStoreBase from typing import List, Optional, TypeVar @@ -154,14 +150,16 @@ def __init__( observation_space: gym.Space, action_space: gym.Space, capacity: int, + workspace_width: int, + x_obs_idx, + y_obs_idx, branch_method: str, split_method: str, - workspace_width: int, workspace_width_method: str, rlds_logger: Optional[RLDSLogger] = None, **kwargs: dict, ): - FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width, workspace_width_method,**kwargs) + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, workspace_width, x_obs_idx, y_obs_idx, branch_method, split_method, workspace_width_method,**kwargs) DataStoreBase.__init__(self, capacity) self._lock = Lock() self._logger = None @@ -210,68 +208,6 @@ def latest_data_id(self): def get_latest_data(self, from_id: int): raise NotImplementedError # TODO -class FractalSymmetryReplayBufferParallelDataStore(FractalSymmetryReplayBufferParallel, DataStoreBase): - def __init__( - self, - observation_space: gym.Space, - action_space: gym.Space, - capacity: int, - branch_method: str, - split_method: str, - workspace_width: int, - workspace_width_method: str, - rlds_logger: Optional[RLDSLogger] = None, - **kwargs: dict, - ): - FractalSymmetryReplayBufferParallel.__init__(self, observation_space, action_space, capacity, branch_method, split_method, workspace_width, workspace_width_method,**kwargs) - DataStoreBase.__init__(self, capacity) - self._lock = Lock() - self._logger = None - - if rlds_logger: - self.step_type = RLDSStepType.TERMINATION # to init the state for restart - self._logger = rlds_logger - - # ensure thread safety - def insert(self, data): - with self._lock: - super(FractalSymmetryReplayBufferParallelDataStore, self).insert(data) - - # TODO: Data logging currently does NOT WORK as shown if we want to log our transformed transitions - # add data to the rlds logger - if self._logger: - if self.step_type in { - RLDSStepType.TERMINATION, - RLDSStepType.TRUNCATION, - }: - self.step_type = RLDSStepType.RESTART - elif not data["masks"]: # 0 is done, 1 is not done - self.step_type = RLDSStepType.TERMINATION - elif data["dones"]: - self.step_type = RLDSStepType.TRUNCATION - else: - self.step_type = RLDSStepType.TRANSITION - - self._logger( - action=data["actions"], - obs=data["next_observations"], # TODO: check if this is correct - reward=data["rewards"], - step_type=self.step_type, - ) - - # ensure thread safety - def sample(self, *args, **kwargs): - with self._lock: - return super(FractalSymmetryReplayBufferParallelDataStore, self).sample(*args, **kwargs) - - # NOTE: method for DataStoreBase - def latest_data_id(self): - return self._insert_index - - # NOTE: method for DataStoreBase - def get_latest_data(self, from_id: int): - raise NotImplementedError # TODO - def populate_data_store( data_store: DataStoreBase, demos_path: str, diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index dbdf247e..0ce437ea 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -5,88 +5,99 @@ import copy from datetime import datetime as dt + class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( self, observation_space: gym.Space, action_space: gym.Space, capacity: int, + workspace_width: int, + x_obs_idx : np.ndarray, + y_obs_idx : np.ndarray, branch_method: str, split_method: str, - workspace_width: int, workspace_width_method: str, kwargs: dict ): # Initialize values - self.debug_time = False - self.current_branch_count=1 + self.debug_time = True + self.current_branch_count = 1 self.update_max_traj_length = False self.workspace_width = workspace_width + + # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations + self.x_obs_idx = x_obs_idx + self.y_obs_idx = y_obs_idx - # Set correct split function in self.split pointer. - method_check = "split_method" - match split_method: - case "time": - assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_depth") - self.max_depth=kwargs["max_depth"] - - assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") - self.max_traj_length=kwargs["max_traj_length"] + # Set initial fractal config values + self.timestep = 0 + self.current_depth = 0 - assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") - self.alpha=kwargs["alpha"] - update_max_traj_length = True - self.split = self.time_split + self.split_method = split_method + self.branch_method = branch_method + self.workspace_width_method = workspace_width_method - case "constant": - self.split = self.constant_split + self._handle_methods_(kwargs) - case "disassociated": - assert "max_traj_length" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "max_traj_length") - self.max_traj_length=kwargs["max_traj_length"] - self.update_max_traj_length = True + ## Warn about unused kwargs + for k in kwargs.keys(): + print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") + + # TODO + # Create more methods with tests + # + # Future considerations: + # when dealing with x and y for transforms, consider two versions: + # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms + # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform - assert "alpha" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "alpha") - self.alpha=kwargs["alpha"] + # Init replay buffer class + super().__init__( + observation_space=observation_space, + action_space=action_space, + capacity=capacity, + ) - assert "num_depth_sectors" in kwargs.keys(), self._handle_bad_args_(method_check, split_method, "num_depth_sectors") - self.num_depth_sectors=kwargs["num_depth_sectors"] + self.generate_transform_deltas() - self.steps_per_depth = int(np.floor(self.max_traj_length/self.num_depth_sectors)) # does this need to be np.float32? + def _handle_method_arg_(self, value, method_type, method, kwargs): + assert value in kwargs.keys(), f"\033[31mERROR: \033[0m{value} must be defined for {method_type} \"{method}\"" + setattr(FractalSymmetryReplayBuffer, value, kwargs[value]) + del kwargs[value] - self.transition_counter = 0 + def _handle_methods_(self, kwargs): - self.split = self.disassociated_split + # Initialize split_method + match self.split_method: + case "time": + self._handle_method_arg_("max_depth", "split_method", self.split_method, kwargs) + self._handle_method_arg_("max_traj_length", "split_method", self.split_method, kwargs) + self._handle_method_arg_("alpha", "split_method", self.split_method, kwargs) + self.update_max_traj_length = True + self.split = self.time_split + case "constant": + self.split = self.constant_split + case "test": self.split = self.test_split case _: raise ValueError("incorrect value passed to split_method") - # Set correct branch function in self.branch pointer. - method_check = "branch_method" - match branch_method: + # Initialize branch_method + match self.branch_method: case "fractal": - assert "max_depth" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_depth") - self.max_depth=kwargs["max_depth"] - - assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") - self.branching_factor=kwargs["branching_factor"] - - assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") - self.start_num=kwargs["start_num"] + self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) self.branch = self.fractal_branch case "contraction": - assert "start_num" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "start_num") - self.start_num=kwargs["start_num"] - - # Add branching_factor for contraction method - assert "branching_factor" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "branching_factor") - self.branching_factor = kwargs["branching_factor"] + self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) self.branch = self.fractal_contraction @@ -95,28 +106,26 @@ def __init__( # self.branch = self.linear_branch case "disassociated": - assert "min_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "min_branch_count)") - self.min_branch_count = kwargs["min_branch_count"] - - assert "max_branch_count" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "max_branch_count)") - self.max_branch_count = kwargs["max_branch_count"] + self._handle_method_arg_("min_branch_count", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("max_branch_count", "branch_method", self.branch_method, kwargs) if self.min_branch_count > self.max_branch_count: - raise ValueError(f"min_branch_count: {self.min_branch_count} is larger than max_branch_count: {self.max_branch_count}. Max should be larger than min.") - - assert "disassociated_type" in kwargs.keys(), self._handle_bad_args_(method_check, branch_method, "disassociated_type") - if kwargs["disassociated_type"] == "hourglass": - self.current_branch_count = self.max_branch_count - elif kwargs["disassociated_type"] == "octahedron": - self.current_branch_count = self.min_branch_count - self.disassociated_type = kwargs["disassociated_type"] + raise ValueError(f"min_branch_count ({self.min_branch_count}) is larger than max_branch_count ({self.max_branch_count})") + + match kwargs["disassociated_type"]: + case "hourglass": + self.starting_branch_count = self.max_branch_count + case "octahedron": + self.starting_branch_count = self.min_branch_count + case _: + raise ValueError(f"incorrect value passed to disassociated_type") + self.disassociated_type = kwargs["disassociated_type"] + del kwargs["disassociated_type"] self.branch = self.disassociated_branch case "constant": - assert "starting_branch_count" in kwargs.keys(), self._handle_bad_args_("branch_method", branch_method, "starting_branch_count") - self.current_branch_count = kwargs["starting_branch_count"] - del kwargs["starting_branch_count"] + self._handle_method_arg_("starting_branch_count", "branch_method", self.branch_method, kwargs) self.branch = self.constant_branch @@ -126,80 +135,51 @@ def __init__( case _: raise ValueError("incorrect value passed to branch_method") - # Set correct workspace_width function in self.workspace_width_method - match workspace_width_method: - - case "constant": - if workspace_width is None: - raise ValueError("workspace_width must be defined for constant workspace width method") - self.get_workspace_width = self.ww_constant - - case "decrease": - if workspace_width is None: - raise ValueError("workspace_width must be defined for constant workspace width method") - self.get_workspace_width = self.ww_decrease - - case "increase": - if workspace_width is None: - raise ValueError("workspace_width must be defined for constant workspace width method") - self.get_workspace_width = self.ww_increase + # Initialize workspace_width_method + # match self.workspace_width_method: - case _: - raise ValueError("incorrect value passed to workspace_width_method") + # case "constant": + # if self.workspace_width is None: + # raise ValueError("workspace_width must be defined for constant workspace width method") + # self.get_workspace_width = self.ww_constant + + # case "decrease": + # if self.workspace_width is None: + # raise ValueError("workspace_width must be defined for constant workspace width method") + # self.get_workspace_width = self.ww_decrease + + # case "increase": + # if self.workspace_width is None: + # raise ValueError("workspace_width must be defined for constant workspace width method") + # self.get_workspace_width = self.ww_increase - #--------------------------------------------------------------------------------------------------------------- - # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations - self.x_obs_idx = kwargs["x_obs_idx"] - self.y_obs_idx = kwargs["y_obs_idx"] - - # Set initial fractal config values - self.timestep = 0 - self.current_depth = 0 - self.branch_index = [workspace_width/2] # TODO: is this a correct initialization? - self.start_num = kwargs["start_num"] # Used for fractal contractions + # case _: + # raise ValueError("incorrect value passed to workspace_width_method") + + if self.starting_branch_count: + self.current_branch_count = self.starting_branch_count + delattr(FractalSymmetryReplayBuffer, "starting_branch_count") + + def generate_transform_deltas(self): + + obs_size = self.dataset_dict["observations"].shape[1] + total_branches = self.current_branch_count ** 2 - ## TODO: this is done in insert. Are we repeating work? - # Create branch index to control how to set (x,y) offsets of different branches - self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) + self.transform_deltas = np.zeros(shape=(total_branches, obs_size), dtype=np.float32) - ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) - # Set a constant value useful in the computation a standard offset between branches - constant = self.workspace_width/(2 * self.current_branch_count) + idx = np.arange(total_branches) + x_deltas, y_deltas = np.divmod(idx, self.current_branch_count) - # Compute the translation offsets from left edge according to branch index: - for i in range(0, self.current_branch_count): - self.branch_index[i] = (2 * i + 1) * constant + x_deltas = (2 * x_deltas + 1) * self.workspace_width / (2 * self.current_branch_count) + y_deltas = (2 * y_deltas + 1) * self.workspace_width / (2 * self.current_branch_count) + x_deltas = np.repeat(x_deltas, self.x_obs_idx.size) + y_deltas = np.repeat(y_deltas, self.y_obs_idx.size) + x_deltas = np.reshape(x_deltas, (total_branches, self.x_obs_idx.size)) + y_deltas = np.reshape(y_deltas, (total_branches, self.y_obs_idx.size)) - ## Warn about unused kwargs - for k in kwargs.keys(): - print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") - - # TODO - # Create more methods with tests - # - # Future considerations: - # when dealing with x and y for transforms, consider two versions: - # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms - # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform - - # Init replay buffer class - super().__init__( - observation_space=observation_space, - action_space=action_space, - capacity=capacity, - ) - - # Missing arguments - def _handle_bad_args_(self, type: str, method: str, arg: str) : - return f"\033[31mERROR: \033[0m{arg} must be defined for {type} \"{method}\"" + self.transform_deltas[:, self.x_obs_idx] = x_deltas + self.transform_deltas[:, self.y_obs_idx] = y_deltas - - # Update (x,y) observations with translational transformation - def transform(self, data_dict: DatasetDict, transform: np.array): - # return data_dict with positional arguments += translation - data_dict["observations"] += transform - data_dict["next_observations"] += transform - #--- BRANCH METHODS --- # FOR TESTING ONLY def test_branch(self): @@ -243,12 +223,8 @@ def fractal_contraction(self): Returns: int: The computed number of branches for the current depth. ''' - # return - b = self.branching_factor - d = self.current_depth - top_b = self.start_num - return int( top_b/( b**(d-1) )) + return self.branching_factor ** (self.max_depth - self.current_depth + 1) def constant_branch(self): ''' @@ -269,14 +245,14 @@ def disassociated_branch(self): self.num_depth_sectors specifies the number of sectors the rollout should be divided into for even splitting ''' if self.disassociated_type == "hourglass": - return int((self.max_branch_count - self.min_branch_count)/(self.num_depth_sectors/2) * np.abs(self.current_depth - (self.num_depth_sectors/2)) + self.min_branch_count) + return int((self.max_branch_count - self.min_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.min_branch_count) elif self.disassociated_type == "octahedron": - return int((self.min_branch_count - self.max_branch_count)/(self.num_depth_sectors/2) * np.abs(self.current_depth - (self.num_depth_sectors/2)) + self.max_branch_count) + return int((self.min_branch_count - self.max_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.max_branch_count) # REQUIRES TESTING - # def linear_branch(self): - # # return a new number of branches = branches_count + n - # return self.current_branch_count + self.branch_count_rate_of_change + def linear_branch(self): + # return a new number of branches = branches_count + n + return self.current_branch_count + self.branch_count_rate_of_change #---- SPLIT METHODS --- @@ -293,15 +269,6 @@ def time_split(self, data_dict: DatasetDict): def constant_split(self, data_dict: DatasetDict): return True - - def disassociated_split(self, data_dict: DatasetDict): - if self.transition_counter >= self.steps_per_depth: - self.current_depth += 1 - self.transition_counter = 0 - return True - else: - self.transition_counter += 1 - return False # Workspace Width Methods def ww_constant(self): @@ -311,13 +278,13 @@ def ww_increase(self): ''' Increase workspace width with depth by 5cm. Lower density. ''' - return self.workspace_width + self.current_depth*0.05 + return self.workspace_width + 0.05 def ww_decrease(self): ''' Decrease workspace width with depth by 5cm. Higher density. ''' - return self.workspace_width - self.current_depth*0.05 + return self.workspace_width - 0.05 #--------------------------------------------- # Now perform fractal transformations. @@ -332,54 +299,48 @@ def insert(self, data_dict_not: DatasetDict): # Update number of branches if needed if self.split(data_dict): - - # Get your current branch count temp = self.current_branch_count - - # Get new current branch count based on depth self.current_branch_count = self.branch() - - # When we enter into a new depth and new branch count update the transformations + # self.workspace_width = self.get_workspace_width() + # Update transform_deltas if needed if temp != self.current_branch_count: - - ## Set spacing for transformations (see more: https://www.notion.so/Fractal-Symmetry-Characterization-225cd3402f1a80948509ec86f0b6ee5e?source=copy_link#22bcd3402f1a8054b97ff0ab387923f7) - self.branch_index = np.empty(self.current_branch_count, dtype=np.float32) - constant = self.get_workspace_width()/(2 * self.current_branch_count) - - for i in range(0, self.current_branch_count): - self.branch_index[i] = (2 * i + 1) * constant + self.generate_transform_deltas() # Initialize to extreme x and y sector_2 = dt.now() if self.debug_time else None - x = -self.get_workspace_width()/2 - transform = np.zeros_like(data_dict["observations"]) - transform[self.x_obs_idx] = x - transform[self.y_obs_idx] = x - self.transform(data_dict, transform) + base_diff = -self.workspace_width/2 + data_dict["observations"][self.x_obs_idx] += base_diff + data_dict["observations"][self.y_obs_idx] += base_diff + data_dict["next_observations"][self.x_obs_idx] += base_diff + data_dict["next_observations"][self.y_obs_idx] += base_diff # Transform and insert transitions (multiprocessing in the future) sector_3 = dt.now() if self.debug_time else None - for x in range(0, self.current_branch_count): + num_transforms = self.current_branch_count ** 2 + obs_batch = np.tile(data_dict["observations"], (num_transforms, 1)) + next_obs_batch = np.tile(data_dict["next_observations"], (num_transforms, 1)) - # Set offset in x-direction - transform[self.x_obs_idx] = self.branch_index[x] + obs_batch += self.transform_deltas + next_obs_batch += self.transform_deltas + + data_dict["observations"] = obs_batch + data_dict["next_observations"] = next_obs_batch + data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) + data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) + data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) + data_dict["dones"] = np.tile(data_dict["dones"], num_transforms) + + super().insert(data_dict, batch_size=num_transforms) - for y in range(0, self.current_branch_count): - new_data_dict = copy.deepcopy(data_dict) - transform[self.y_obs_idx] = self.branch_index[y] - self.transform(new_data_dict, transform) - super().insert(new_data_dict) - # Reset current_depth, timestep, and max_traj_length sector_4 = dt.now() if self.debug_time else None self.timestep += 1 - if data_dict["dones"]: + if data_dict["dones"][0]: self.current_depth = 0 if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) - self.steps_per_depth = int(np.floor(self.max_traj_length/self.num_depth_sectors)) # does this need to be np.float32? self.timestep = 0 diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py deleted file mode 100644 index 6ecca136..00000000 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer_parallel.py +++ /dev/null @@ -1,266 +0,0 @@ -# Generic support -import copy -from datetime import datetime as dt - -# RL Envs -import gym - -# Data -from serl_launcher.data.dataset import DatasetDict -from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer - -# Optimization of matrix operations -import numpy as np -import jax -from jax import jit -import jax.numpy as jnp -from functools import partial - -class FractalSymmetryReplayBufferParallel(FractalSymmetryReplayBuffer): - ''' - Override .insert(), add batch_insert(), override __init__ if needed, etc. - ''' - def __init__( - self, - observation_space: gym.Space, - action_space: gym.Space, - capacity: int, - branch_method: str, - split_method: str, - workspace_width: int, - workspace_width_method: str, - kwargs: dict - ): - - # parallel-specific setup... - super().__init__( - observation_space=observation_space, - action_space=action_space, - capacity=capacity, - branch_method=branch_method, - split_method=split_method, - workspace_width=workspace_width, - workspace_width_method=workspace_width_method, - kwargs=kwargs, - ) - - # Set the total number of transformations to the number of transformations squared for the grid method. May need to change for radial method later. - self.num_transforms = self.current_branch_count ** 2 - - #--------------------------------------------- - @partial(jax.jit, static_argnums=(0, 2)) # states that self and num_transforms are not array like and do not need to be traced and need to be marked as static. - def compute_transformation(self, data_dict, num_transforms): - ''' - Add fractal transformations (translations) to matrix of observations in jitted form. - - data_dict of transitions comes in, a matrix of size (num_transforms,num_transforms) comes out. - - compute_deltas will create a matrix of (x,y) transformations and add them to the expanded observations. - - results for observations and next_observations saved in trans_o_m and trans_no_m. - - Only this parent function needs to be jitted. Helper functions below do not need since they are never called independently. - - Params: - - data_dict (DatasetDict) - dictionary of translations - - num_transforms (int) - number of transformations - - Returns: - - trans_o_m (jnp.ndarray) - this is a (num_transform,num_transform) matrix - - trans_no_m (jnp.ndarray) - this is a (num_transform,num_transform) matrix - - ''' - - # Generate observation matrices - o_m, no_m = self.generate_obs_mats(data_dict, num_transforms) - - # Transform o_m and no_m using fractal equation (use lax.scan inside for incremental change) - trans_o_m = self.compute_deltas( o_m ) - trans_no_m = self.compute_deltas( no_m) - - return trans_o_m,trans_no_m - - def compute_deltas(self,mat,type='grid'): # add type as jnp array - ''' - Given a type of fractal framework perform the operation - ''' - - if type == 'grid': - - # Iterate through matrix cols - idx = jnp.arange(self.num_transforms) - - # Convert idx into arrays (i,j) that tell us the x,y order for deltas - i,j = jnp.divmod( idx, self.current_branch_count ) - - # 1. Get the boundary of the workspace environment - edge = self.get_workspace_width()/2. - - # 2. Get array of translation deltas for x,y offsets: ( (2i+1)/2b ) * ww - constant_parallel = self.get_workspace_width() / (2. * self.current_branch_count) - x_offset = ( (2 * i) + 1) * constant_parallel - y_offset = ( (2 * j) + 1) * constant_parallel - - # 3. Generate the matrix grid of transformations by element-wise sum of deltas on the matrix of env. observations. Use self.x_obs_idx & y_obs_idx. - # Perform separate updates - trans_mat = ( - mat - .at[self.x_obs_idx].add(x_offset) # add x_offset to each of the two x-rows - .at[self.y_obs_idx].add(y_offset) # add y_offset to each of the two y-rows - ) - else: - raise TypeError('transform_matrix: no correct fractal framework type was given...') - - return trans_mat - - def generate_obs_mats(self, data_dict: DatasetDict, num_transforms: int, type='grid'): - ''' - Generate observations matrices, whose cols repeat according to the correct number of transformations - - Args: - - data_dict (DatasetDict): - - num_transforms (int): computes the total number of transformations for the grid stategy - - type (str): grid or radial fractal framework - - Return: - - o_m: observation matrix - - no_m: next observation matrix - ''' - - # Convert observations to a vector - o_col = jnp.array(data_dict["observations"], dtype=jnp.float32) - no_col = jnp.array(data_dict["next_observations"], dtype=jnp.float32) - - # Expand observations column vector to a matrix using tile - o_m = jnp.tile(o_col, (1,num_transforms)) - no_m = jnp.tile(no_col, (1,num_transforms)) - - return o_m, no_m - - def insert_mat_replay_buffer(self, del_o, del_no, new_data_dict: DatasetDict): - # Convert back to dict by replacing data_dict's observations with new ones by the columns and insert into replay buffer one-by-one. - # TODO: Is there a batch_insert? - - # Is this the fast way to create these dicts from matrix cols? - for col in range( del_o.shape[1] ): - new_data_dict['observations'] = np.array( del_o[:, col],type=np.float32) - new_data_dict['next_observations'] = np.array( del_no[:, col],type=np.float32) - - # Insert each col's new info - super().insert(new_data_dict) - - def insert(self, data_dict: DatasetDict): - ''' - - Note: - If observations for all time-steps are available, then a batch_transform - ''' - - # Time - sector_1 = dt.now() if self.debug_time else None - - # Update number of branches if needed - if self.split(data_dict): - - # Get the current branch count - temp = self.current_branch_count - - # Get new current branch count based on depth - self.current_branch_count = self.branch() - - # New depth -> new branch count. Hence update the transformations: - if temp != self.current_branch_count: - - # Get the number of transformations - self.num_transforms = self.current_branch_count**2 - - # Compute transformation deltas for observations - sector_2 = dt.now() if self.debug_time else None - trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) - - # Insert to replay buffer - sector_3 = dt.now() if self.debug_time else None - batch_insert(trans_o_m, trans_no_m, data_dict) - - #else?? - # Compute transformation deltas for observations - sector_2 = dt.now() if self.debug_time else None - trans_o_m, trans_no_m = self.compute_transformation(data_dict, self.num_transforms) - - # Insert as batch into replay buffer - sector_3 = dt.now() if self.debug_time else None - batch_insert(trans_o_m, trans_no_m, data_dict) - - # Reset current_depth, timestep, and max_traj_length - sector_4 = dt.now() if self.debug_time else None - - self.timestep += 1 - if data_dict["dones"]: - self.current_depth = 0 - if self.update_max_traj_length: - self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) - self.timestep = 0 - - if self.debug_time: - finish = dt.now() - print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") - - def batch_insert(self, trans_o_m: jnp.ndarray, trans_no_m: jnp.ndarray, data_dict: DatasetDict) -> None: - """ - Inserts N data-augmented transitions into self.dataset_dict all at once. - - Args: - data_dict: original single transition dict with keys - - "observations" shape (obs_dim,) - - "next_observations" shape (obs_dim,) - - "actions" shape (act_dim,) - - "rewards" scalar or shape () - - "dones" scalar bool - - "masks" scalar or shape () - trans_o_m: jnp.ndarray (obs_dim, N) — each column is a new obs - trans_no_m: jnp.ndarray (obs_dim, N) — each column is new next_obs - - Returns: - None. Mutates self.dataset_dict and advances the insert pointer. - """ - - # Extract number of transformations directly from the data - N = trans_o_m.shape[1] - - # Expand observations and next observations - # Transpose matrix of observations and next_observations as a numpy array: trans_o_m: (obs_dim, N) → we want (N, obs_dim) in NumPy - obs_batch = np.array(trans_o_m.T) # shape (N, obs_dim) - next_obs_batch= np.array(trans_no_m.T) # shape (N, obs_dim) - - # Expand actions: - actions = np.array(data_dict["actions"], dtype=np.float32) - - # If shape (act_dim,), broadcast to (N, act_dim) - actions_batch = np.broadcast_to(actions, (N,) + actions.shape) - - # Rewards/dones/masks batches - rewards = float(data_dict["rewards"]) - dones = bool(data_dict["dones"]) - masks = bool(data_dict["masks"]) # check type - - rewards_batch = np.full((N,), rewards, dtype=np.float32) - dones_batch = np.full((N,), dones, dtype=bool) - masks_batch = np.full((N,), masks, dtype=np.float32) - - batch = { - "observations": obs_batch, - "next_observations": next_obs_batch, - "actions": actions_batch, - "rewards": rewards_batch, - "dones": dones_batch, - "masks": masks_batch, - } - - # Create appropriate indeces for circular buffer. Consider overflow: - idxs = (self._insert_index + np.arange(N)) % self._capacity - - # One‐shot write into each field - for key, arr in batch.items(): # arr has shape (N, *field_shape) - self.dataset_dict[key][idxs] = arr - - # Update pointers - self._insert_index = int(idxs[-1]) + 1 - self._size = min(self._size + N, self._capacity) - \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 9429a3b2..e51a5c0f 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -9,17 +9,15 @@ FLAGS = flags.FLAGS -flags.DEFINE_integer("capacity", 10000000, "Replay buffer capacity.") -flags.DEFINE_string("branch_method", "contraction", "Method for determining the number of transforms per dimension (x,y)") -flags.DEFINE_string("split_method", "time", "Method for determining whether to change the number of transforms per dimension (x,y)") +flags.DEFINE_integer("capacity", 1000000, "Replay buffer capacity.") +flags.DEFINE_string("branch_method", "constant", "Method for determining the number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", "constant", "Method for determining whether to change the number of transforms per dimension (x,y)") flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only -flags.DEFINE_integer("max_steps",20000,"Maximum steps") +flags.DEFINE_integer("max_steps",100,"Maximum steps") flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only -flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("starting_branch_count", 1000, "Initial number of transforms per dimension (x,y)") # For constant_branch only flags.DEFINE_integer("alpha",1,"alpha value") -# Contraction -flags.DEFINE_integer("start_num",81, "Initila number of branch on the first depth") # For Fractal Cntraction # Density Workspace width flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') @@ -44,14 +42,16 @@ def main(_): max_depth=FLAGS.max_depth, max_traj_length = 100, branching_factor=FLAGS.branching_factor, - start_num = FLAGS.start_num, alpha = FLAGS.alpha, + starting_branch_count = FLAGS.starting_branch_count, workspace_width_method=FLAGS.workspace_width_method ) observation, info = env.reset() + observation = np.zeros_like(observation) action = env.action_space.sample() next_observation, reward, terminated, truncated, info = env.step(action) + next_observation = np.ones_like(observation) data_dict = dict( observations=observation, next_observations=next_observation, @@ -63,16 +63,7 @@ def main(_): del env, observation, next_observation, action, reward, truncated, terminated, info, y_obs_idx, x_obs_idx, _ - # transform() test - - expected = np.copy(data_dict["observations"]) * 2 - replay_buffer.transform(data_dict, np.copy(data_dict["observations"])) - result = data_dict["observations"] - assert np.array_equal(result, expected), f"\033[31mTEST FAILED\033[0m transform() test failed (expected {expected} but got {result})" - - del result, expected - - print("\n\033[32mTEST PASSED \033[0m transform() test passed") + replay_buffer.insert(data_dict) # branch() tests diff --git a/serl_launcher/serl_launcher/data/replay_buffer.py b/serl_launcher/serl_launcher/data/replay_buffer.py index 6dd29c73..0a1ab414 100644 --- a/serl_launcher/serl_launcher/data/replay_buffer.py +++ b/serl_launcher/serl_launcher/data/replay_buffer.py @@ -22,17 +22,24 @@ def _init_replay_dict( def _insert_recursively( - dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int + dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int, capacity: int, batch_size: int = None, ): if isinstance(dataset_dict, np.ndarray): - dataset_dict[insert_index] = data_dict + if batch_size: + if insert_index + batch_size > capacity: + dataset_dict[insert_index:capacity] = data_dict[0:(capacity - insert_index)] + dataset_dict[0:(insert_index + batch_size - capacity)] = data_dict[(capacity - insert_index):batch_size] + else: + dataset_dict[insert_index:(insert_index + batch_size)] = data_dict + else: + dataset_dict[insert_index] = data_dict elif isinstance(dataset_dict, dict): assert dataset_dict.keys() == data_dict.keys(), ( dataset_dict.keys(), data_dict.keys(), ) for k in dataset_dict.keys(): - _insert_recursively(dataset_dict[k], data_dict[k], insert_index) + _insert_recursively(dataset_dict[k], data_dict[k], insert_index, capacity, batch_size) else: raise TypeError() @@ -68,11 +75,15 @@ def __init__( def __len__(self) -> int: return self._size - def insert(self, data_dict: DatasetDict): - _insert_recursively(self.dataset_dict, data_dict, self._insert_index) + def insert(self, data_dict: DatasetDict, batch_size : int = None): + _insert_recursively(self.dataset_dict, data_dict, self._insert_index, self._capacity, batch_size) - self._insert_index = (self._insert_index + 1) % self._capacity - self._size = min(self._size + 1, self._capacity) + if batch_size: + self._insert_index = (self._insert_index + batch_size) % self._capacity + self._size = min(self._size + batch_size, self._capacity) + else: + self._insert_index = (self._insert_index + 1) % self._capacity + self._size = min(self._size + 1, self._capacity) def get_iterator(self, queue_size: int = 2, sample_args: dict = {}, device=None): # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 03227c66..ce2d4539 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -19,7 +19,6 @@ MemoryEfficientReplayBufferDataStore, ReplayBufferDataStore, FractalSymmetryReplayBufferDataStore, - FractalSymmetryReplayBufferParallelDataStore, ) ############################################################################## @@ -212,6 +211,8 @@ def make_replay_buffer( split_method : str = None, # used only type=="fractal_symmetry_replay_buffer" workspace_width : float = None, # used only type=="fractal_symmetry_replay_buffer" workspace_width_method : str = None, # used only type=="fractal_symmetry_replay_buffer" + x_obs_idx = None, + y_obs_idx = None, **kwargs: dict # used only type=="fractal_symmetry_replay_buffer" ): """ @@ -270,18 +271,8 @@ def make_replay_buffer( split_method=split_method, workspace_width=workspace_width, workspace_width_method=workspace_width_method, - rlds_logger=rlds_logger, - kwargs=kwargs, - ) - elif type == "fractal_symmetry_replay_buffer_parallel": - replay_buffer = FractalSymmetryReplayBufferParallelDataStore( - env.observation_space, - env.action_space, - capacity=capacity, - branch_method=branch_method, - split_method=split_method, - workspace_width=workspace_width, - workspace_width_method=workspace_width_method, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, rlds_logger=rlds_logger, kwargs=kwargs, ) From ded80ef868b837a9b9229630ffc71fb4ec67f3b5 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 5 Aug 2025 18:54:05 -0500 Subject: [PATCH 101/141] Overhaul of automated testing. Removed multiple unneeded files and flags/features for FractalSymmetryReplayBuffer --- .../#async_sac_state_sim.py# | 401 ------------------ .../async_sac_state_sim.py | 71 ++-- .../async_sac_state_sim/automate_funcs.py | 176 -------- .../async_sac_state_sim/automate_learner.sh | 6 - .../async_sac_state_sim/automated_tests.sh | 139 ++++++ ...ate_actor.sh => automated_tests_helper.sh} | 2 +- examples/async_sac_state_sim/automation.sh | 8 - examples/async_sac_state_sim/run_actor.sh | 21 +- examples/async_sac_state_sim/run_learner.sh | 13 +- examples/async_sac_state_sim/run_tests.py | 85 ---- examples/async_sac_state_sim/tmux_launch.sh | 2 +- .../async_sac_state_sim/tmux_launch_tests.sh | 18 + serl_launcher/serl_launcher/common/wandb.py | 1 + .../data/fractal_symmetry_replay_buffer.py | 126 ++---- serl_launcher/serl_launcher/utils/launcher.py | 2 + 15 files changed, 239 insertions(+), 832 deletions(-) delete mode 100644 examples/async_sac_state_sim/#async_sac_state_sim.py# delete mode 100644 examples/async_sac_state_sim/automate_funcs.py delete mode 100644 examples/async_sac_state_sim/automate_learner.sh create mode 100644 examples/async_sac_state_sim/automated_tests.sh rename examples/async_sac_state_sim/{automate_actor.sh => automated_tests_helper.sh} (72%) delete mode 100644 examples/async_sac_state_sim/automation.sh delete mode 100644 examples/async_sac_state_sim/run_tests.py create mode 100644 examples/async_sac_state_sim/tmux_launch_tests.sh diff --git a/examples/async_sac_state_sim/#async_sac_state_sim.py# b/examples/async_sac_state_sim/#async_sac_state_sim.py# deleted file mode 100644 index 50ff52f0..00000000 --- a/examples/async_sac_state_sim/#async_sac_state_sim.py# +++ /dev/null @@ -1,401 +0,0 @@ -#!/usr/bin/env python3 - -import time -from functools import partial - -import gym -import jax -import jax.numpy as jnp -import numpy as np -import tqdm -from absl import app, flags -from flax.training import checkpoints - -from agentlace.data.data_store import QueuedDataStore -from agentlace.trainer import TrainerClient, TrainerServer -from serl_launcher.utils.launcher import ( - make_sac_agent, - make_trainer_config, - make_wandb_logger, - make_replay_buffer, -) - -from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics -from serl_launcher.agents.continuous.sac import SACAgent -from serl_launcher.common.evaluation import evaluate -from serl_launcher.utils.timer_utils import Timer - -import franka_sim - -from demos.demoHandling import DemoHandling - -FLAGS = flags.FLAGS - -flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") -flags.DEFINE_string("agent", "sac", "Name of agent.") -flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") -flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") -flags.DEFINE_integer("seed", 42, "Random seed.") -flags.DEFINE_bool("save_model", False, "Whether to save model.") -flags.DEFINE_integer("batch_size", 256, "Batch size.") -flags.DEFINE_integer("critic_actor_ratio", 8, "critic to actor update ratio.") - -flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") -flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") - -flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.") -flags.DEFINE_integer("training_starts", 300, "Training starts after this step.") -flags.DEFINE_integer("steps_per_update", 30, "Number of steps per update the server.") - -flags.DEFINE_integer("log_period", 10, "Logging period.") -flags.DEFINE_integer("eval_period", 2000, "Evaluation period.") -flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.") - -# flag to indicate if this is a learner or a actor -flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") -flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") -flags.DEFINE_boolean("render", False, "Render the environment.") -flags.DEFINE_string("ip", "localhost", "IP address of the learner.") -flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") -flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") - -# flags for replay buffer -flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which type of replay buffer to use") -flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") -flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") # Remember to default None -flags.DEFINE_float("workspace_width", 0.5, "Workspace width in centimeters") # Remember to default None -flags.DEFINE_integer("max_depth",None,"max_depth") -flags.DEFINE_integer("depth", None, "Total layers of depth") -flags.DEFINE_integer("dendrites", None, "Dendrites for fractal branching") -flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") -flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") -flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") -flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch and fractal_contraction -flags.DEFINE_integer("start_num",None, "Initila number of branch on the first depth") # For Fractal Cntraction -flags.DEFINE_integer("alpha",None,"alpha value") - - -flags.DEFINE_boolean( - "debug", False, "Debug mode." -) # debug mode will disable wandb logging - -flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") -flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") - - -# Load demonstation data -flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") -flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") -flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") - -def print_green(x): - return print("\033[92m {}\033[00m".format(x)) - - -############################################################################## - - -def actor(agent: SACAgent, data_store, env, sampling_rng, demos_handler=None): - """ - This is the actor loop, which runs when "--actor" is set to True. - """ - client = TrainerClient( - "actor_env", - FLAGS.ip, - make_trainer_config(), - data_store, - wait_for_server=True, - ) - - # Function to update the agent with new params - def update_params(params): - nonlocal agent - agent = agent.replace(state=agent.state.replace(params=params)) - - client.recv_network_callback(update_params) - - eval_env = gym.make(FLAGS.env) - #if FLAGS.env == "PandaPickCube-v0": - eval_env = gym.wrappers.FlattenObservation(eval_env) ## Note!! - eval_env = RecordEpisodeStatistics(eval_env) - - obs, _ = env.reset() - done = False - - # training loop - timer = Timer() - running_return = 0.0 - - # Load demos: handler.run will insert all transition demo data into the data store. - if FLAGS.load_demos: - with timer.context("sample and step into env with loaded demos"): - - # Insert complete demonstration into the data store via demo Handling class - print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") - demos_handler.insert_data_to_buffer(data_store) - FLAGS.random_steps = 0 # Set random steps to 0 since we have demo data - # For subsequent steps, sample actions from the agent - for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): - timer.tick("total") - - with timer.context("sample_actions"): - if step < FLAGS.random_steps: - actions = env.action_space.sample() - else: - # When interacting with agent, put obs into jax device and convert output to np.ndarrar - sampling_rng, key = jax.random.split(sampling_rng) - actions = agent.sample_actions( - observations=jax.device_put(obs), - seed=key, - deterministic=False, - ) - actions = np.asarray(jax.device_get(actions)) - - # Step environment - with timer.context("step_env"): - - next_obs, reward, done, truncated, info = env.step(actions) - next_obs = np.asarray(next_obs, dtype=np.float32) - reward = np.asarray(reward, dtype=np.float32) - - running_return += reward - - data_store.insert( - dict( - observations=obs, - actions=actions, - next_observations=next_obs, - rewards=reward, - masks=1.0 - done, - dones=done or truncated, - ) - ) - - obs = next_obs - if done or truncated: - running_return = 0.0 - obs, _ = env.reset() - - if FLAGS.render: - env.render() - - if step % FLAGS.steps_per_update == 0: - client.update() - - if step % FLAGS.eval_period == 0: - with timer.context("eval"): - evaluate_info = evaluate( - policy_fn=partial(agent.sample_actions, argmax=True), - env=eval_env, - num_episodes=FLAGS.eval_n_trajs, - ) - stats = {"eval": evaluate_info} - client.request("send-stats", stats) - - timer.tock("total") - - if step % FLAGS.log_period == 0: - stats = {"timer": timer.get_average_times()} - client.request("send-stats", stats) - - -############################################################################## - - -def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): - """ - The learner loop, which runs when "--learner" is set to True. - """ - # set up wandb and logging - wandb_logger = make_wandb_logger( - project=FLAGS.exp_name, - description=FLAGS.exp_name or FLAGS.env, - debug=FLAGS.debug, - ) - - # To track the step in the training loop - update_steps = 0 - def stats_callback(type: str, payload: dict) -> dict: - """Callback for when server receives stats request.""" - assert type == "send-stats", f"Invalid request type: {type}" - if wandb_logger is not None: - wandb_logger.log(payload, step=update_steps) - return {} # not expecting a response - - # Create server - server = TrainerServer(make_trainer_config(), request_callback=stats_callback) - server.register_data_store("actor_env", replay_buffer) - server.start(threaded=True) - - # Loop to wait until replay_buffer is filled - pbar = tqdm.tqdm( - total=FLAGS.training_starts, - initial=len(replay_buffer), - desc="Filling up replay buffer", - position=0, - leave=True, - ) - while len(replay_buffer) < FLAGS.training_starts: - pbar.update(len(replay_buffer) - pbar.n) # Update progress bar - time.sleep(1) - pbar.update(len(replay_buffer) - pbar.n) # Update progress bar - pbar.close() - - # send the initial network to the actor - server.publish_network(agent.state.params) - print_green("sent initial network to actor") - - # wait till the replay buffer is filled with enough data - timer = Timer() - - # show replay buffer progress bar during training - pbar = tqdm.tqdm( - total=FLAGS.replay_buffer_capacity, - initial=len(replay_buffer), - desc="replay buffer", - ) - - for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"): - # Train the networks - with timer.context("sample_replay_buffer"): - batch = next(replay_iterator) - - with timer.context("train"): - agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.critic_actor_ratio) - agent = jax.block_until_ready(agent) - - # publish the updated network - server.publish_network(agent.state.params) - - if update_steps % FLAGS.log_period == 0 and wandb_logger: - wandb_logger.log(update_info, step=update_steps) - wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps) - - if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0: - assert FLAGS.checkpoint_path is not None - checkpoints.save_checkpoint( - FLAGS.checkpoint_path, agent.state, step=update_steps, keep=20 - ) - - pbar.update(len(replay_buffer) - pbar.n) # update replay buffer bar - update_steps += 1 - - -############################################################################## - - -def main(_): - devices = jax.local_devices() - num_devices = len(devices) - sharding = jax.sharding.PositionalSharding(devices) - assert FLAGS.batch_size % num_devices == 0 - - # seed - rng = jax.random.PRNGKey(FLAGS.seed) - - # create env and load dataset - if FLAGS.render: - env = gym.make(FLAGS.env, render_mode="human") - else: - env = gym.make(FLAGS.env) - - #if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) - - rng, sampling_rng = jax.random.split(rng) - agent: SACAgent = make_sac_agent( - seed=FLAGS.seed, - sample_obs=env.observation_space.sample(), - sample_action=env.action_space.sample(), - ) - - # replicate agent across devices - # need the jnp.array to avoid a bug where device_put doesn't recognize primitives - agent: SACAgent = jax.device_put( - jax.tree.map(jnp.array, agent), sharding.replicate() - ) - - # Demo Data - if FLAGS.load_demos: - print_green("Setting demo parameters") - # Create a handler for the demo data - demos_handler = DemoHandling( - demo_dir=FLAGS.demo_dir, - file_name=FLAGS.file_name, - ) - - # 1. Modify actor data_store size - # Extract number of demo transitions - demo_transitions = demos_handler.get_num_transitions() - - if demo_transitions > 2000: - qds_size = demo_transitions + 1000 # Increment the queue size on the actor - else: - qds_size = 2000 # the original queue size on the actor - - # 2. Modify training starts (since we have good data) - FLAGS.training_starts = 1 - - else: - demos_handler = None - qds_size = 2000 # the original queue size on the actor - - - if FLAGS.learner: - sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) - replay_buffer = make_replay_buffer( - env, - capacity=FLAGS.replay_buffer_capacity, - rlds_logger_path=FLAGS.log_rlds_path, - type=FLAGS.replay_buffer_type, - branch_method=FLAGS.branch_method, - split_method=FLAGS.split_method, - branching_factor=FLAGS.branching_factor, - starting_branch_count=FLAGS.starting_branch_count, - workspace_width=FLAGS.workspace_width, - max_traj_length=FLAGS.max_traj_length, - x_obs_idx=np.array([0,4]), - y_obs_idx=np.array([1,5]), - preload_rlds_path=FLAGS.preload_rlds_path, - start_num = FLAGS.start_num, - max_depth = FLAGS.max_depth, - alpha = FLAGS.alpha, - ) - replay_iterator = replay_buffer.get_iterator( - sample_args={ - "batch_size": FLAGS.batch_size * FLAGS.critic_actor_ratio, - }, - device=sharding.replicate(), - ) - # learner loop - print_green("starting learner loop") - learner( - sampling_rng, - agent, - replay_buffer, - replay_iterator=replay_iterator, - ) - - elif FLAGS.actor: - sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) - - if FLAGS.load_demos: - print_green("loading demo data") - - # Create a data store for the actor - data_store = QueuedDataStore(qds_size) # the queue size on the actor - else: - print_green("no demo data, using empty data store") - # Create a data store for the actor - data_store = QueuedDataStore(2000) # the queue size on the actor - - # actor loop - print_green("starting actor loop") - actor(agent, data_store, env, sampling_rng, demos_handler) - - else: - raise NotImplementedError("Must be either a learner or an actor") - - -if __name__ == "__main__": - app.run(main) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index cb76d9c6..116a3e1c 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -34,6 +34,7 @@ flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") flags.DEFINE_string("agent", "sac", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_string("run_name", None, "Name of run for wandb logging") flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") @@ -60,32 +61,19 @@ flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") # flags for replay buffer -flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which type of replay buffer to use") +flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which replay buffer to use") flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") -flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") # Remember to default None -flags.DEFINE_float("workspace_width", 0.5, "Workspace width in centimeters") # Remember to default None -flags.DEFINE_integer("max_depth",None,"max_depth") -flags.DEFINE_integer("depth", None, "Total layers of depth") -flags.DEFINE_integer("dendrites", None, "Dendrites for fractal branching") -flags.DEFINE_integer("timesplit_freq", None, "Frequency of splits according to time") -flags.DEFINE_integer("branch_count_rate_of_change", None, "Rate of change for linear branching") -flags.DEFINE_integer("starting_branch_count", 1, "Initial number of branches") -flags.DEFINE_integer("branching_factor", None, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction flags.DEFINE_integer("alpha",None,"alpha value") - -# Contraction -flags.DEFINE_integer("start_num",81, "Initial number of branch on the first depth") # For Fractal Cntraction - -# Disassociated -flags.DEFINE_enum("disassociated_type", "octahedron", ["octahedron", "hourglass"], +flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + " Hourglass: Contract from max to min then expand to max") -flags.DEFINE_integer("min_branch_count", 1, "Minimum number of branches for disassociated fractal rollout") -flags.DEFINE_integer("max_branch_count", 1, "Maximum number of branches for disassociated fractal rollout") -flags.DEFINE_integer("num_depth_sectors", 1, "Desired number of sectors to divide rollout into for branch count splitting") - -# Density Workspace width -flags.DEFINE_string("workspace_width_method", 'constant', 'Controls workspace width dimensions configurations') +flags.DEFINE_integer("min_branch_count", None, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") # Debug flags.DEFINE_boolean( @@ -222,6 +210,7 @@ def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): # set up wandb and logging wandb_logger = make_wandb_logger( project=FLAGS.exp_name, + name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, debug=FLAGS.debug, ) @@ -358,29 +347,23 @@ def main(_): sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) replay_buffer = make_replay_buffer( env, - capacity =FLAGS.replay_buffer_capacity, - rlds_logger_path =FLAGS.log_rlds_path, - type =FLAGS.replay_buffer_type, - branch_method =FLAGS.branch_method, - split_method =FLAGS.split_method, - branching_factor =FLAGS.branching_factor, + capacity=FLAGS.replay_buffer_capacity, + rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, starting_branch_count=FLAGS.starting_branch_count, - workspace_width =FLAGS.workspace_width, - max_traj_length =FLAGS.max_traj_length, - x_obs_idx =np.array([0,4]), - y_obs_idx =np.array([1,5]), - preload_rlds_path =FLAGS.preload_rlds_path, - max_depth =FLAGS.max_depth, - alpha =FLAGS.alpha, - # Contraction - start_num =FLAGS.start_num, - # Workspace Width Density - workspace_width_method=FLAGS.workspace_width_method, - # Disassociated - disassociated_type =FLAGS.disassociated_type, - min_branch_count =FLAGS.min_branch_count, - max_branch_count =FLAGS.max_branch_count, - num_depth_sectors =FLAGS.num_depth_sectors + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=np.array([0,4]), + y_obs_idx=np.array([1,5]), + preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, ) replay_iterator = replay_buffer.get_iterator( sample_args={ diff --git a/examples/async_sac_state_sim/automate_funcs.py b/examples/async_sac_state_sim/automate_funcs.py deleted file mode 100644 index 12a46058..00000000 --- a/examples/async_sac_state_sim/automate_funcs.py +++ /dev/null @@ -1,176 +0,0 @@ -import subprocess -import os -import psutil -from absl import app -import copy - -LOG_DIR = "serl_logs" -SESSION = "serl_session" - -os.makedirs(LOG_DIR, exist_ok=True) - -def run_tmux_command(cmd: str): - return subprocess.run(cmd, shell=True, check=True) - -def set_conda_env(session_name: str, - env: str = "serl", - workdir: str = "./"): - """ - Set the conda environment for the current shell in a tmux session. - Call as `set_conda_env(f"{SESSION}:0.0", workdir="./", env="serl")` - """ - - cmd = ( - f"conda activate {env} && " - f"cd {workdir}" - ) - - subprocess.run( - f'tmux send-keys -t {session_name} "{cmd}" C-m', - shell=True - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - -def send_tmux_command(target: str, command: str): - """ - Send a command to a specific tmux target session as: - (f"{SESSION}:0.0", f"bash ./automate_actor.sh") - Args: - - target: str - The tmux target session (e.g., "serl_session:0.0") - - command: str - The command to run in the tmux session - - Returns: - - None - - """ - full_command = ( - f"{command}" - ) - - subprocess.run( - f'tmux send-keys -t {target} "{full_command}" C-m', - shell=True, - #check=True - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - -def ensure_tmux_session(session_name: str): - result = subprocess.run( - ["tmux", "has-session", "-t", session_name] - # stdout=subprocess.DEVNULL, - # stderr=subprocess.DEVNULL - ) - if result.returncode != 0: - # Session does not exist, create it and split the window - subprocess.run(["tmux", "new-session", "-d", "-s", session_name, "-n", "main"], check=True) - subprocess.run(["tmux", "split-window", "-v", "-t", f"{session_name}:0"], check=True) - - # Set the conda environment in both panes once at the beginning - set_conda_env(f"{SESSION}:0.0", workdir="./", env="serl") - set_conda_env(f"{SESSION}:0.1", workdir="./", env="serl") - -def getProcesses(name): - p1_pid = None - p2_pid = None - first = True - while True: - for proc in psutil.process_iter(["name"]): - if proc.info["name"] == name: - if first: - p1_pid = proc.pid - first = False - else: - p2_pid = proc.pid - break - if p1_pid and p2_pid and p1_pid != p2_pid: - p1 = psutil.Process(p1_pid) - p2 = psutil.Process(p2_pid) - return p1, p2 - -def run_test(args:dict): - - max_steps = args["max_steps"] - for seed in range(2): - - args["seed"] = seed - actor_log = os.path.join(LOG_DIR, f"actor_seed_{seed}.log") - learner_log = os.path.join(LOG_DIR, f"learner_seed_{seed}.log") - - args["max_steps"] = max_steps * 10000 - send_tmux_command(f"{SESSION}:0.0", f"bash ./automate_actor.sh {' '.join(f'--{k} {v}' for k, v in args.items())}") #conda/jax/tmux mem conflict? , f"{actor_cmd(args)}")# > {actor_log} 2>&1") # Uncomment to send to log files - args["max_steps"] = max_steps - send_tmux_command(f"{SESSION}:0.1", f"bash ./automate_learner.sh {' '.join(f'--{k} {v}' for k, v in args.items())}")#mem conflict?, f"{learner_cmd(args)}")# > {learner_log} 2>&1") - - p1, p2 = getProcesses("async_sac_state") - - while (True): - if p1.is_running(): - if p2.is_running(): - continue - else: - p1.terminate() - p1.wait() - break - else: - if p2.is_running(): - p2.terminate() - p2.wait() - break - break - - - -def main(_): - ensure_tmux_session(SESSION) - set_conda_env - - env = "PandaReachCube-v0" - - base_args = { - "env": "PandaReachCube-v0", - "random_steps": 1000, - "training_starts": 1000, - "critic_actor_ratio": 8, - "batch_size": 256, - "replay_buffer_capacity": 1000000, - "save_model": True, - "max_steps": 50000, - "workspace_width": 0.5 - } - - # --- Baseline --- - args = copy.deepcopy(base_args) - args["replay_buffer_type"] = "replay_buffer" - args["exp_name"] = f"{args['env']}-baseline" - run_test(args) - - - # --- 1x1, 3x3, 9x9, 27x27 --- - args = copy.deepcopy(base_args) - args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - args["branch_method"] = "constant" - args["split_method"] = "constant" - - for b in (1, 3, 9, 27): - args["starting_branch_count"] = b - args["exp_name"] = f"{args['env']}-{b}x{b}" - run_test(args) - - # --- Fractal Expansion 3^4, 9^2 --- - args = copy.deepcopy(base_args) - args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - args["branch_method"] = "fractal" - args["split_method"] = "time" - args["alpha"] = 1 - - for b, d in ((3, 2), (3, 3)): - for a in (0.25, 0.5, 0.75, 1): - args["branching_factor"] = b - args["max_depth"] = d - args["exp_name"] = f"{args['env']}-{b}^{d}-alpha-{a}" - run_test(args) - -if __name__ == "__main__": - app.run(main) diff --git a/examples/async_sac_state_sim/automate_learner.sh b/examples/async_sac_state_sim/automate_learner.sh deleted file mode 100644 index 0e0b4be4..00000000 --- a/examples/async_sac_state_sim/automate_learner.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ - -python async_sac_state_sim.py --learner "$@" \ No newline at end of file diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh new file mode 100644 index 00000000..ea245be6 --- /dev/null +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -0,0 +1,139 @@ +#!/bin/bash + +tmux set-option -t serl_session:0.1 remain-on-exit on +tmux set-option -t serl_session:0.2 remain-on-exit on + +TEST="async_sac_state_sim.py" +CONDA_ENV="serl" +ENV="PandaReachCube-v0" +MAX_STEPS=25000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +CRITIC_ACTOR_RATIO=8 +EXP_NAME="MASS-TESTING-$ENV-max-steps-$MAX_STEPS" + +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO" + +echo "$BASE_ARGS" +for seed in 1 2 3 +do + for batch_size in 256 + do + for replay_buffer_capacity in 1000000 + do + # Baseline + ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed" + + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" + done + echo "I'm DONE waiting >:(" + + REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" + for workspace_width in 0.5 + do + for starting_branch_count in 1 3 9 27 81 + do + # Constant + ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" + done + echo "I'm DONE waiting >:(" + done + + for alpha in 0.9 + do + for branching_factor in 3 9 + do + for max_depth in 2 4 + do + # Fractal Expansion + ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" + done + echo "I'm DONE waiting >:(" + + # Fractal Contraction + ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" + done + echo "I'm DONE waiting >:(" + done + done + for min_branch_count in 1 3 9 + do + for max_branch_count in 3 9 27 + do + if [ $min_branch_count -ge $max_branch_count ]; then + continue + fi + # Disassociative (Hourglass) + ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" + + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" + done + echo "I'm DONE waiting >:(" + + # Disassociative (Octahedron) + ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" + + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" + done + echo "I'm DONE waiting >:(" + done + done + done + done + done + done +done + +tmux kill-session -t serl_session diff --git a/examples/async_sac_state_sim/automate_actor.sh b/examples/async_sac_state_sim/automated_tests_helper.sh similarity index 72% rename from examples/async_sac_state_sim/automate_actor.sh rename to examples/async_sac_state_sim/automated_tests_helper.sh index d0ade5fc..90702a2f 100644 --- a/examples/async_sac_state_sim/automate_actor.sh +++ b/examples/async_sac_state_sim/automated_tests_helper.sh @@ -3,4 +3,4 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -python async_sac_state_sim.py --actor "$@" \ No newline at end of file +python async_sac_state_sim.py "$@" \ No newline at end of file diff --git a/examples/async_sac_state_sim/automation.sh b/examples/async_sac_state_sim/automation.sh deleted file mode 100644 index 39164dc0..00000000 --- a/examples/async_sac_state_sim/automation.sh +++ /dev/null @@ -1,8 +0,0 @@ - -tmux new-session -d -s serl_session - -tmux split-window -h -t serl_session - -tmux split-window -v -t serl_session:0.1 - -tmux send-keys -t serl_session:0.0 "conda activate serl && python automate_funcs.py" C-m diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index bc1ead2e..35df30e6 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -17,12 +17,12 @@ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ python async_sac_state_sim.py \ --actor \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8 \ - --seed 0 \ - --replay_buffer_type fractal_symmetry_replay_buffer_parallel \ - --seed 0 \ - --max_steps 300_000 \ + --exp_name this_is_a_fake_test_experiment \ + --run_name this_is_a_custom_run_name \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 50_000 \ --training_starts 1000 \ + --random_steps 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ --replay_buffer_capacity 1_000_000 \ @@ -31,15 +31,14 @@ python async_sac_state_sim.py \ --split_method constant \ --starting_branch_count 3 \ --workspace_width 0.5 \ + --alpha 1 \ + # --debug # wandb is disabled when debug # --load_demos \ # --demo_dir /data/data/serl/demos \ # --file_name data_franka_reach_random_5_2.npz \ - # # --max_traj_length 100 \ - # --max_depth 4 \ - # --start_num 81 \ - --alpha 1 \ + # --max_traj_length 100 \ + # --max_depth 4 \ # --branching_factor 3 \ # --checkpoint_period 10000 \ # --checkpoint_path "$CHECKPOINT_DIR" \ - #--render \ - --debug # wandb is disabled when debug \ No newline at end of file + #--render \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 47fabcf9..7415620f 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -17,8 +17,9 @@ export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ python async_sac_state_sim.py "$@"\ --learner \ --env PandaReachCube-v0 \ - --exp_name PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8 \ - --replay_buffer_type fractal_symmetry_replay_buffer_parallel \ + --exp_name this_is_a_fake_test_experiment \ + --run_name this_is_a_custom_run_name \ + --replay_buffer_type fractal_symmetry_replay_buffer \ --max_steps 50_000 \ --training_starts 1000 \ --random_steps 1000 \ @@ -30,15 +31,13 @@ python async_sac_state_sim.py "$@"\ --split_method constant \ --starting_branch_count 3 \ --workspace_width 0.5 \ + --alpha 1 \ + # --debug # wandb is disabled when debug # --load_demos \ # --demo_dir /data/data/serl/demos \ # --file_name data_franka_reach_random_5_2.npz \ # --max_traj_length 100 \ # --max_depth 4 \ - # --start_num 81 \ - --alpha 1 \ # --branching_factor 3 \ # --checkpoint_period 10000 \ - # --checkpoint_path "$CHECKPOINT_DIR" \ - --debug # wandb is disabled when debug - #--render \ No newline at end of file + # --checkpoint_path "$CHECKPOINT_DIR" \ \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_tests.py b/examples/async_sac_state_sim/run_tests.py deleted file mode 100644 index 5acbaba3..00000000 --- a/examples/async_sac_state_sim/run_tests.py +++ /dev/null @@ -1,85 +0,0 @@ - -import subprocess -from absl import app -import time - -def actor(kwargs:dict): - script = "python3 /home/ryan_vanderstelt/serl/examples/async_sac_state_sim/async_sac_state_sim.py --actor" - for k, v in kwargs.items(): - script += f" --{k}" - script += f" {v}" - return script - -def learner(kwargs:dict): - script = "python3 /home/ryan_vanderstelt/serl/examples/async_sac_state_sim/async_sac_state_sim.py --learner" - for k, v in kwargs.items(): - script += f" --{k}" - script += f" {v}" - return script - -def main(_): - env = "PandaReachCube-v0" - random_steps = 1000 - max_steps_learner = 5000 - max_steps_actor = max_steps_learner * 2 - training_starts = 1000 - critic_actor_ratio = 8 - batch_size = 256 - replay_buffer_capacity = 1000000 - save_model = True - - rb_args = { - "env": env, - "random_steps": random_steps, - "max_steps": max_steps_actor, - "training_starts": training_starts, - "critic_actor_ratio": critic_actor_ratio, - "batch_size": batch_size, - "replay_buffer_capacity": replay_buffer_capacity, - "save_model": save_model, - } - - subprocess.run(['tmux', 'new-session', '-d', '-s', 'serl_session']) - # time.sleep(10) - subprocess.run("tmux split-window -v", shell=True) - - # baseline - rb_args["replay_buffer_type"] = "replay_buffer" - rb_args["exp_name"] = f"{env}-baseline" - for seed in range(0, 5): - rb_args["seed"] = seed - rb_args["max_steps"] = max_steps_actor - actor_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.0 \"conda run -vvv -n serl {actor(kwargs=rb_args)}\" C-m", shell=True) - - - rb_args["max_steps"] = max_steps_learner - learner_process = subprocess.Popen(f"tmux send-keys -t serl_session:0.1 \"conda run -vvv -n serl '{learner(kwargs=rb_args)}'\" C-m", shell=True) - subprocess.run("tmux attach-session -t serl_session", shell=True) - - learner_process.wait() - actor_process.terminate() - actor_process.wait() - - # constant - # rb_args["replay_buffer_type"] = "fractal_symmetry_replay_buffer" - # rb_args["branch_method"] = "constant" - # rb_args["split_method"] = "constant" - # for b in [1, 3, 9, 27]: - # rb_args["starting_branch_count"] = b - # rb_args["exp_name"] = f"{env}-{b}x{b}" - # for seed in range(0, 5): - # rb_args["seed"] = seed - # rb_args["max_steps"] = max_steps_actor - # actor_process = subprocess.Popen(["tmux", "send-keys", "-t", "serl_session:0.0", actor(kwargs=rb_args)]) - - # rb_args["max_steps"] = max_steps_learner - # learner_process = subprocess.Popen(learner(kwargs=rb_args), shell=True) - - # learner_process.wait() - # actor_process.terminate() - # actor_process.wait() - - -if __name__ == "__main__": - app.run(main) - \ No newline at end of file diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 728d6e52..dff31d54 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -3,7 +3,7 @@ # EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} CONDA_ENV=${CONDA_ENV:-"serl"} -cd $EXAMPLE_DIR +# cd $EXAMPLE_DIR echo "Running from $(pwd)" # Create a new tmux session diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh new file mode 100644 index 00000000..55a9dc31 --- /dev/null +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -0,0 +1,18 @@ +#!/bin/bash + + +# Create a new tmux session +tmux new-session -d -s serl_session + +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh" C-m + +# Attach to the tmux session +tmux attach-session -t serl_session + +# kill the tmux session by running the following command +# tmux kill-session -t serl_session diff --git a/serl_launcher/serl_launcher/common/wandb.py b/serl_launcher/serl_launcher/common/wandb.py index 5bc0f432..d01ebf05 100644 --- a/serl_launcher/serl_launcher/common/wandb.py +++ b/serl_launcher/serl_launcher/common/wandb.py @@ -70,6 +70,7 @@ def __init__( self.run = wandb.init( config=self._variant, project=self.config.project, + name=self.config.name, entity=self.config.entity, group=self.config.group, tags=self.config.tag, diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 0ce437ea..a071a58e 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -68,42 +68,30 @@ def _handle_method_arg_(self, value, method_type, method, kwargs): del kwargs[value] def _handle_methods_(self, kwargs): - - # Initialize split_method - match self.split_method: - case "time": - self._handle_method_arg_("max_depth", "split_method", self.split_method, kwargs) - self._handle_method_arg_("max_traj_length", "split_method", self.split_method, kwargs) - self._handle_method_arg_("alpha", "split_method", self.split_method, kwargs) - - self.update_max_traj_length = True - self.split = self.time_split - - case "constant": - self.split = self.constant_split - - case "test": - self.split = self.test_split - case _: - raise ValueError("incorrect value passed to split_method") # Initialize branch_method match self.branch_method: + case "fractal": self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) self.branch = self.fractal_branch + if not self.split_method: + self.split_method = "time" case "contraction": self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) self.branch = self.fractal_contraction + if not self.split_method: + self.split_method = "time" case "linear": raise NotImplementedError("linear branch method is not yet implemented") # self.branch = self.linear_branch + case "disassociated": self._handle_method_arg_("min_branch_count", "branch_method", self.branch_method, kwargs) @@ -123,38 +111,36 @@ def _handle_methods_(self, kwargs): self.disassociated_type = kwargs["disassociated_type"] del kwargs["disassociated_type"] self.branch = self.disassociated_branch + if not self.split_method: + self.split_method = "time" case "constant": self._handle_method_arg_("starting_branch_count", "branch_method", self.branch_method, kwargs) self.branch = self.constant_branch - - case "test": - self.branch = self.test_branch + if not self.split_method: + self.split_method = "never" case _: raise ValueError("incorrect value passed to branch_method") - # Initialize workspace_width_method - # match self.workspace_width_method: - - # case "constant": - # if self.workspace_width is None: - # raise ValueError("workspace_width must be defined for constant workspace width method") - # self.get_workspace_width = self.ww_constant - - # case "decrease": - # if self.workspace_width is None: - # raise ValueError("workspace_width must be defined for constant workspace width method") - # self.get_workspace_width = self.ww_decrease - - # case "increase": - # if self.workspace_width is None: - # raise ValueError("workspace_width must be defined for constant workspace width method") - # self.get_workspace_width = self.ww_increase + match self.split_method: + case "time": + self._handle_method_arg_("max_depth", "split_method", self.split_method, kwargs) + self._handle_method_arg_("max_traj_length", "split_method", self.split_method, kwargs) + self._handle_method_arg_("alpha", "split_method", self.split_method, kwargs) + + self.update_max_traj_length = True + self.split = self.time_split + + case "constant": + self.split = self.constant_split - # case _: - # raise ValueError("incorrect value passed to workspace_width_method") + case "never": + self.split = self.never_split + + case _: + raise ValueError("incorrect value passed to split_method") if self.starting_branch_count: self.current_branch_count = self.starting_branch_count @@ -180,12 +166,6 @@ def generate_transform_deltas(self): self.transform_deltas[:, self.x_obs_idx] = x_deltas self.transform_deltas[:, self.y_obs_idx] = y_deltas - #--- BRANCH METHODS --- - # FOR TESTING ONLY - def test_branch(self): - self.current_depth += 1 - return 1 - def fractal_branch(self): ''' Computes the number of branches for the current depth using an exponential growth rule. @@ -249,75 +229,43 @@ def disassociated_branch(self): elif self.disassociated_type == "octahedron": return int((self.min_branch_count - self.max_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.max_branch_count) - # REQUIRES TESTING def linear_branch(self): # return a new number of branches = branches_count + n - return self.current_branch_count + self.branch_count_rate_of_change - - - #---- SPLIT METHODS --- - # FOR TESTING ONLY - def test_split(self, data_dict: DatasetDict): - return True + return self.current_branch_count + self.branching_factor - # REQUIRES TESTING def time_split(self, data_dict: DatasetDict): if self.timestep % (self.max_traj_length//self.max_depth) or self.current_depth >= self.max_depth: return False self.current_depth += 1 - return True - + return True + def constant_split(self, data_dict: DatasetDict): + self.current_depth += 1 return True - - # Workspace Width Methods - def ww_constant(self): - return self.workspace_width - def ww_increase(self): - ''' - Increase workspace width with depth by 5cm. Lower density. - ''' - return self.workspace_width + 0.05 + def never_split(self, data_dict: DatasetDict): + return False - def ww_decrease(self): - ''' - Decrease workspace width with depth by 5cm. Higher density. - ''' - return self.workspace_width - 0.05 - - #--------------------------------------------- - # Now perform fractal transformations. - # Currently iterating through x,y loop (slow). - # TODO: multigpu processing. - def insert(self, data_dict_not: DatasetDict): - # Time - sector_1 = dt.now() if self.debug_time else None data_dict = copy.deepcopy(data_dict_not) # Update number of branches if needed if self.split(data_dict): temp = self.current_branch_count self.current_branch_count = self.branch() - # self.workspace_width = self.get_workspace_width() # Update transform_deltas if needed if temp != self.current_branch_count: self.generate_transform_deltas() # Initialize to extreme x and y - sector_2 = dt.now() if self.debug_time else None - base_diff = -self.workspace_width/2 data_dict["observations"][self.x_obs_idx] += base_diff data_dict["observations"][self.y_obs_idx] += base_diff data_dict["next_observations"][self.x_obs_idx] += base_diff data_dict["next_observations"][self.y_obs_idx] += base_diff - # Transform and insert transitions (multiprocessing in the future) - sector_3 = dt.now() if self.debug_time else None - + # Transform and insert transitions num_transforms = self.current_branch_count ** 2 obs_batch = np.tile(data_dict["observations"], (num_transforms, 1)) next_obs_batch = np.tile(data_dict["next_observations"], (num_transforms, 1)) @@ -335,15 +283,9 @@ def insert(self, data_dict_not: DatasetDict): super().insert(data_dict, batch_size=num_transforms) # Reset current_depth, timestep, and max_traj_length - sector_4 = dt.now() if self.debug_time else None self.timestep += 1 if data_dict["dones"][0]: self.current_depth = 0 if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) - self.timestep = 0 - - - if self.debug_time: - finish = dt.now() - print(f"Splits: {(sector_2 - sector_1).total_seconds():.5f} : {(sector_3 - sector_2).total_seconds():.5f} : {(sector_4 - sector_3).total_seconds():.5f} : {(finish - sector_4).total_seconds():.5f}\nLaptime: {(finish - sector_1).total_seconds():.5f}") + self.timestep = 0 \ No newline at end of file diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index ce2d4539..0a25490c 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -180,6 +180,7 @@ def make_trainer_config(port_number: int = 5488, broadcast_port: int = 5489): def make_wandb_logger( project: str = "agentlace", + name: str = "placeholder_run_name", description: str = "serl_launcher", debug: bool = False, ): @@ -187,6 +188,7 @@ def make_wandb_logger( wandb_config.update( { "project": project, + "name": name, "exp_descriptor": description, "tag": description, } From 2f30c1851cfd17b4ce99fcdc41edf7bbef580324 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 5 Aug 2025 19:26:09 -0500 Subject: [PATCH 102/141] Fixed alpha type --- examples/async_sac_state_sim/async_sac_state_sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 116a3e1c..bf9f159e 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -68,7 +68,7 @@ flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction -flags.DEFINE_integer("alpha",None,"alpha value") +flags.DEFINE_float("alpha",None,"alpha value") flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + " Hourglass: Contract from max to min then expand to max") From 227480debffed8dbb281b7639503cebf319805ea Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 7 Aug 2025 16:10:10 -0500 Subject: [PATCH 103/141] Added automated testing with parallel run support --- .../async_sac_state_sim.py | 14 +- .../async_sac_state_sim/automated_tests.sh | 215 +++++++++--------- .../async_sac_state_sim/tmux_launch_tests.sh | 21 +- serl_launcher/serl_launcher/common/wandb.py | 11 +- .../data/fractal_symmetry_replay_buffer.py | 5 +- serl_launcher/serl_launcher/utils/launcher.py | 4 + 6 files changed, 150 insertions(+), 120 deletions(-) diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index bf9f159e..b5309f56 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -40,6 +40,10 @@ flags.DEFINE_bool("save_model", False, "Whether to save model.") flags.DEFINE_integer("batch_size", 256, "Batch size.") flags.DEFINE_integer("critic_actor_ratio", 8, "critic to actor update ratio.") +flags.DEFINE_integer("port_number", 5488, "Port for server") +flags.DEFINE_integer("broadcast_port", 5489, "Port for server") +flags.DEFINE_boolean("wandb_offline", False, "Save locally to be synced with 'wandb sync ") +flags.DEFINE_string("wandb_output_dir", None, "Where to save local wandb files") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") @@ -76,9 +80,7 @@ flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") # Debug -flags.DEFINE_boolean( - "debug", False, "Debug mode." -) # debug mode will disable wandb logging +flags.DEFINE_boolean("debug", False, "Debug mode.") # debug mode will disable wandb logging # Logging flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") @@ -104,7 +106,7 @@ def actor(agent: SACAgent, data_store, env, sampling_rng, demos_handler=None): client = TrainerClient( "actor_env", FLAGS.ip, - make_trainer_config(), + make_trainer_config(port_number=FLAGS.port_number, broadcast_port=FLAGS.broadcast_port), data_store, wait_for_server=True, ) @@ -212,7 +214,9 @@ def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): project=FLAGS.exp_name, name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, + wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, + offline=FLAGS.wandb_offline, ) # To track the step in the training loop @@ -225,7 +229,7 @@ def stats_callback(type: str, payload: dict) -> dict: return {} # not expecting a response # Create server - server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server = TrainerServer(make_trainer_config(port_number=FLAGS.port_number, broadcast_port=FLAGS.broadcast_port), request_callback=stats_callback) server.register_data_store("actor_env", replay_buffer) server.start(threaded=True) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index ea245be6..af888550 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -1,8 +1,9 @@ #!/bin/bash -tmux set-option -t serl_session:0.1 remain-on-exit on -tmux set-option -t serl_session:0.2 remain-on-exit on - +SEED=$1 +PORT_NUMBER=$2 +BROADCAST_PORT=$3 +WANDB_OUTPUT_DIR=~/wandb_logs TEST="async_sac_state_sim.py" CONDA_ENV="serl" ENV="PandaReachCube-v0" @@ -10,125 +11,129 @@ MAX_STEPS=25000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="MASS-TESTING-$ENV-max-steps-$MAX_STEPS" +EXP_NAME="TESTING-PARALLEL-$ENV" -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT --seed $SEED" echo "$BASE_ARGS" -for seed in 1 2 3 + +for batch_size in 256 do - for batch_size in 256 + for replay_buffer_capacity in 1000000 do - for replay_buffer_capacity in 1000000 - do - # Baseline - ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed" - - tmux respawn-pane -k -t serl_session:0.1 - tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + # Baseline + ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" - while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + # echo "I'm waiting :)" + # done + + echo "I'm DONE waiting >:(" + + REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" + for workspace_width in 0.5 + do + for starting_branch_count in $(seq 3 2 81) do - sleep 1 - echo "I'm waiting :)" - done - echo "I'm DONE waiting >:(" - - REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" - for workspace_width in 0.5 - do - for starting_branch_count in 1 3 9 27 81 - do - # Constant - ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + # Constant + ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" - tmux respawn-pane -k -t serl_session:0.1 - tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - echo "I'm waiting :)" - done - echo "I'm DONE waiting >:(" + tmux respawn-pane -k -t serl_session:$SEED.1 + tmux respawn-pane -k -t serl_session:$SEED.2 + tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + echo "I'm waiting :)" done + + echo "I'm DONE waiting >:(" + done - for alpha in 0.9 + for alpha in 0.9 + do + for branching_factor in 3 9 do - for branching_factor in 3 9 + for max_depth in 2 4 do - for max_depth in 2 4 - do - # Fractal Expansion - ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" - - tmux respawn-pane -k -t serl_session:0.1 - tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - echo "I'm waiting :)" - done - echo "I'm DONE waiting >:(" + if [[ $(( max_depth * branching_factor )) -eq 36 ]]; then + continue + fi + # Fractal Expansion + ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + # echo "I'm waiting :)" + # done + # echo "I'm DONE waiting >:(" - # Fractal Contraction - ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + # Fractal Contraction + ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" - tmux respawn-pane -k -t serl_session:0.1 - tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - echo "I'm waiting :)" - done - echo "I'm DONE waiting >:(" - done + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + # echo "I'm waiting :)" + # done + # echo "I'm DONE waiting >:(" done - for min_branch_count in 1 3 9 + done + for min_branch_count in 1 3 9 + do + for max_branch_count in 3 9 27 do - for max_branch_count in 3 9 27 - do - if [ $min_branch_count -ge $max_branch_count ]; then - continue - fi - # Disassociative (Hourglass) - ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" - - tmux respawn-pane -k -t serl_session:0.1 - tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - echo "I'm waiting :)" - done - echo "I'm DONE waiting >:(" + if [ $min_branch_count -ge $max_branch_count ]; then + continue + fi + # Disassociative (Hourglass) + ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" + + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + # echo "I'm waiting :)" + # done + # echo "I'm DONE waiting >:(" - # Disassociative (Octahedron) - ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --seed $seed --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" + # Disassociative (Octahedron) + ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" - tmux respawn-pane -k -t serl_session:0.1 - tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - echo "I'm waiting :)" - done - echo "I'm DONE waiting >:(" - done + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + # echo "I'm waiting :)" + # done + # echo "I'm DONE waiting >:(" done done done @@ -136,4 +141,4 @@ do done done -tmux kill-session -t serl_session +tmux kill-window -t serl_session:$SEED diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh index 55a9dc31..7ea5db17 100644 --- a/examples/async_sac_state_sim/tmux_launch_tests.sh +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -3,14 +3,25 @@ # Create a new tmux session tmux new-session -d -s serl_session +tmux setw -g remain-on-exit on -# Split the window horizontally -tmux split-window -v -tmux split-pane -h -t serl_session:0.1 +for SEED in {1..5} +do + # New window + tmux new-window -t serl_session -n $SEED + # Split the window horizontally + tmux split-window -v + tmux split-pane -h -t serl_session:$SEED.1 -# Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "bash automated_tests.sh" C-m + # Find open ports for TrainerServer + OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + PORTS=( $OPEN_PORTS ) + # Navigate to the activate the conda environment in the first pane + tmux send-keys -t serl_session:$SEED.0 "bash automated_tests.sh $SEED ${PORTS[0]} ${PORTS[1]}" C-m + echo ${PORTS[0]} + echo ${PORTS[1]} +done # Attach to the tmux session tmux attach-session -t serl_session diff --git a/serl_launcher/serl_launcher/common/wandb.py b/serl_launcher/serl_launcher/common/wandb.py index d01ebf05..4b27cb9e 100644 --- a/serl_launcher/serl_launcher/common/wandb.py +++ b/serl_launcher/serl_launcher/common/wandb.py @@ -6,6 +6,7 @@ import absl.flags as flags import ml_collections import wandb +import uuid def _recursive_flatten_dict(d: dict): @@ -41,12 +42,13 @@ def __init__( variant, wandb_output_dir=None, debug=False, + offline=False, ): self.config = wandb_config if self.config.unique_identifier == "": self.config.unique_identifier = datetime.datetime.now().strftime( - "%m-%d-%Y-%H-%M-%S" - ) + "%m-%d-%Y" + ) + str(uuid.uuid1()) self.config.experiment_id = ( self.experiment_id @@ -65,7 +67,10 @@ def __init__( if debug: mode = "disabled" else: - mode = "online" + if offline: + mode = "offline" + else: + mode = "online" self.run = wandb.init( config=self._variant, diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index a071a58e..e718165d 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -63,6 +63,8 @@ def __init__( self.generate_transform_deltas() def _handle_method_arg_(self, value, method_type, method, kwargs): + if hasattr(self, value): + return assert value in kwargs.keys(), f"\033[31mERROR: \033[0m{value} must be defined for {method_type} \"{method}\"" setattr(FractalSymmetryReplayBuffer, value, kwargs[value]) del kwargs[value] @@ -142,9 +144,8 @@ def _handle_methods_(self, kwargs): case _: raise ValueError("incorrect value passed to split_method") - if self.starting_branch_count: + if hasattr(self, "starting_branch_count"): self.current_branch_count = self.starting_branch_count - delattr(FractalSymmetryReplayBuffer, "starting_branch_count") def generate_transform_deltas(self): diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 0a25490c..fbed0c0a 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -182,7 +182,9 @@ def make_wandb_logger( project: str = "agentlace", name: str = "placeholder_run_name", description: str = "serl_launcher", + wandb_output_dir: str = None, debug: bool = False, + offline: bool = False, ): wandb_config = WandBLogger.get_default_config() wandb_config.update( @@ -196,7 +198,9 @@ def make_wandb_logger( wandb_logger = WandBLogger( wandb_config=wandb_config, variant={}, + wandb_output_dir=wandb_output_dir, debug=debug, + offline=offline ) return wandb_logger From 6f7a7a4d799284aeb38984e7f0b82f77d0da7a22 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 8 Aug 2025 15:11:59 -0500 Subject: [PATCH 104/141] testing --- .../async_sac_state_sim/automated_tests.sh | 86 ++++++++++--------- .../async_sac_state_sim/tmux_launch_tests.sh | 1 + 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index af888550..895b2d02 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -11,7 +11,7 @@ MAX_STEPS=25000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="TESTING-PARALLEL-$ENV" +EXP_NAME="PARALLEL-CONSTANT-$ENV" BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT --seed $SEED" @@ -19,32 +19,34 @@ echo "$BASE_ARGS" for batch_size in 256 do - for replay_buffer_capacity in 1000000 + for replay_buffer_capacity in 1000000 4000000 do # Baseline ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + echo "Running baseline with args: $ARGS" + tmux respawn-pane -k -t serl_session:$SEED.1 + tmux respawn-pane -k -t serl_session:$SEED.2 + tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - # echo "I'm waiting :)" - # done + while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + + done - echo "I'm DONE waiting >:(" + echo "Finished!" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" - for workspace_width in 0.5 + for workspace_width in 0.5 1 do - for starting_branch_count in $(seq 3 2 81) + for starting_branch_count in 11 41 81 do # Constant ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + echo "Running constant with args: $ARGS" tmux respawn-pane -k -t serl_session:$SEED.1 tmux respawn-pane -k -t serl_session:$SEED.2 tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m @@ -53,10 +55,10 @@ do while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; do sleep 1 - echo "I'm waiting :)" + done - echo "I'm DONE waiting >:(" + echo "Finished!" done for alpha in 0.9 @@ -71,6 +73,7 @@ do # Fractal Expansion ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + # echo "Running fractal expansion with args: $ARGS" # tmux respawn-pane -k -t serl_session:$SEED.1 # tmux respawn-pane -k -t serl_session:$SEED.2 # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m @@ -78,13 +81,14 @@ do # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; # do # sleep 1 - # echo "I'm waiting :)" + # # done - # echo "I'm DONE waiting >:(" + # echo "Finished!" # Fractal Contraction ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + # echo "Running fractal contraction with args: $ARGS" # tmux respawn-pane -k -t serl_session:$SEED.1 # tmux respawn-pane -k -t serl_session:$SEED.2 # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m @@ -93,9 +97,9 @@ do # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; # do # sleep 1 - # echo "I'm waiting :)" + # # done - # echo "I'm DONE waiting >:(" + # echo "Finished!" done done for min_branch_count in 1 3 9 @@ -108,32 +112,34 @@ do # Disassociative (Hourglass) ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + # echo "Running hourglass with args: $ARGS" + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - # echo "I'm waiting :)" - # done - # echo "I'm DONE waiting >:(" + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + + # done + # echo "Finished!" # Disassociative (Octahedron) ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + # echo "Running octahedron with args: $ARGS" + # tmux respawn-pane -k -t serl_session:$SEED.1 + # tmux respawn-pane -k -t serl_session:$SEED.2 + # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - # echo "I'm waiting :)" - # done - # echo "I'm DONE waiting >:(" + # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; + # do + # sleep 1 + + # done + # echo "Finished!" done done done diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh index 7ea5db17..9db6815f 100644 --- a/examples/async_sac_state_sim/tmux_launch_tests.sh +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -23,6 +23,7 @@ do done # Attach to the tmux session +tmux kill-window -t serl_sessoin:0 tmux attach-session -t serl_session # kill the tmux session by running the following command From ee086184738c864380a11b81a791961eab49dd36 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 12 Aug 2025 09:47:14 -0500 Subject: [PATCH 105/141] improved test files --- .../async_sac_state_sim/automated_tests.sh | 246 ++++++++---------- .../async_sac_state_sim/tmux_launch_tests.sh | 8 +- .../data/fractal_symmetry_replay_buffer.py | 5 +- 3 files changed, 112 insertions(+), 147 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index 895b2d02..577b13a5 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -1,8 +1,6 @@ #!/bin/bash SEED=$1 -PORT_NUMBER=$2 -BROADCAST_PORT=$3 WANDB_OUTPUT_DIR=~/wandb_logs TEST="async_sac_state_sim.py" CONDA_ENV="serl" @@ -11,140 +9,114 @@ MAX_STEPS=25000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="PARALLEL-CONSTANT-$ENV" - -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT --seed $SEED" - -echo "$BASE_ARGS" - -for batch_size in 256 -do - for replay_buffer_capacity in 1000000 4000000 - do - # Baseline - ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" - - echo "Running baseline with args: $ARGS" - tmux respawn-pane -k -t serl_session:$SEED.1 - tmux respawn-pane -k -t serl_session:$SEED.2 - tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - - done - - echo "Finished!" - - REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" - for workspace_width in 0.5 1 - do - for starting_branch_count in 11 41 81 - do - # Constant - ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" - - echo "Running constant with args: $ARGS" - tmux respawn-pane -k -t serl_session:$SEED.1 - tmux respawn-pane -k -t serl_session:$SEED.2 - tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - do - sleep 1 - - done - - echo "Finished!" - done - - for alpha in 0.9 - do - for branching_factor in 3 9 - do - for max_depth in 2 4 - do - if [[ $(( max_depth * branching_factor )) -eq 36 ]]; then - continue - fi - # Fractal Expansion - ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" - - # echo "Running fractal expansion with args: $ARGS" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - # - # done - # echo "Finished!" - - # Fractal Contraction - ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" - - # echo "Running fractal contraction with args: $ARGS" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - # - # done - # echo "Finished!" - done - done - for min_branch_count in 1 3 9 - do - for max_branch_count in 3 9 27 - do - if [ $min_branch_count -ge $max_branch_count ]; then - continue - fi - # Disassociative (Hourglass) - ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" - - # echo "Running hourglass with args: $ARGS" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - - # done - # echo "Finished!" - - # Disassociative (Octahedron) - ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" - - # echo "Running octahedron with args: $ARGS" - # tmux respawn-pane -k -t serl_session:$SEED.1 - # tmux respawn-pane -k -t serl_session:$SEED.2 - # tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m - # tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - # while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null; - # do - # sleep 1 - - # done - # echo "Finished!" - done - done - done - done +EXP_NAME="SCALING-REPLAY-BUFFER-CAPACITY-WITH-BRANCHES-$ENV" +REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" + +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO --seed $SEED" +ARGS="" + +function run_test { + + OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + PORTS=( $OPEN_PORTS ) + PORT_NUMBER=${PORTS[0]} + BROADCAST_PORT=${PORTS[1]} + + ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:$SEED.1 + tmux respawn-pane -k -t serl_session:$SEED.2 + tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for actor or learner to finish + while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null || ! tmux capture-pane -t serl_session:$SEED.1 -p | grep "Pane is dead" > /dev/null ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "logout" > /dev/null || ! tmux capture-pane -t serl_session:$SEED.1 -p | grep "logout" > /dev/null; + do + sleep 1 done -done + echo "Finished!" + +} + +# # BASELINE TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000 2000 1000000 +# do +# ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" +# run_test +# done +# done + +# # CONSTANT TESTING +# for starting_branch_count in 1 2 +# do +# for batch_size in 256 +# do +# for workspace_width in 0.5 +# do +# for replay_buffer_capacity in $(( starting_branch_count * starting_branch_count * TRAINING_STARTS )) $(( starting_branch_count * starting_branch_count * TRAINING_STARTS * 2 )) $(( starting_branch_count * starting_branch_count * TRAINING_STARTS * 1000 )) +# do +# ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" +# run_test +# done +# done +# done +# done + +# # FRACTAL TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000000 +# do +# for workspace_width in 0.5 +# do +# for alpha in 0.9 +# do +# for branching_factor in 3 9 +# do +# for max_depth in 2 4 +# do +# # Fractal Expansion +# ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" +# run_test + +# # Fractal Contraction +# ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" +# run_test +# done +# done +# done +# done +# done +# done + +# # DISASSOCIATIVE TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000000 +# do +# for workspace_width in 0.5 +# do +# for alpha in 0.9 +# do +# for min_branch_count in 1 3 9 +# do +# for max_branch_count in 3 9 27 +# do +# # Disassociative (Hourglass) +# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" +# run_test + +# # Disassociative (Octahedron) +# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" + +# done +# done +# done +# done +# done +# done -tmux kill-window -t serl_session:$SEED +tmux kill-window -t serl_session:$SEED \ No newline at end of file diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh index 9db6815f..dbd2fc31 100644 --- a/examples/async_sac_state_sim/tmux_launch_tests.sh +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -13,17 +13,11 @@ do tmux split-window -v tmux split-pane -h -t serl_session:$SEED.1 - # Find open ports for TrainerServer - OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) - PORTS=( $OPEN_PORTS ) # Navigate to the activate the conda environment in the first pane - tmux send-keys -t serl_session:$SEED.0 "bash automated_tests.sh $SEED ${PORTS[0]} ${PORTS[1]}" C-m - echo ${PORTS[0]} - echo ${PORTS[1]} + tmux send-keys -t serl_session:$SEED.0 "bash automated_tests.sh $SEED" C-m done # Attach to the tmux session -tmux kill-window -t serl_sessoin:0 tmux attach-session -t serl_session # kill the tmux session by running the following command diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index e718165d..770713aa 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -73,7 +73,6 @@ def _handle_methods_(self, kwargs): # Initialize branch_method match self.branch_method: - case "fractal": self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) @@ -247,9 +246,9 @@ def constant_split(self, data_dict: DatasetDict): def never_split(self, data_dict: DatasetDict): return False - def insert(self, data_dict_not: DatasetDict): + def insert(self, data: DatasetDict): - data_dict = copy.deepcopy(data_dict_not) + data_dict = copy.deepcopy(data) # Update number of branches if needed if self.split(data_dict): From d88b11f0355227cf7f8322c87e890bb9919f9941 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 13 Aug 2025 13:25:11 -0500 Subject: [PATCH 106/141] Created sparse env for reach --- franka_sim/franka_sim/__init__.py | 5 + franka_sim/franka_sim/envs/__init__.py | 1 + .../franka_sim/envs/panda_pick_gym_env.py | 4 +- .../envs/panda_reach_sparse_gym_env.py | 281 ++++++++++++++++++ 4 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index dab351ab..ef353f2a 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -23,6 +23,11 @@ entry_point="franka_sim.envs:PandaReachCubeGymEnv", max_episode_steps=100, ) +register( + id="PandaReachSparseCube-v0", + entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", + max_episode_steps=100, +) # register( # id="PandaReachCubeVision-v0", # entry_point="franka_sim.envs:PandaReachCubeGymEnv", diff --git a/franka_sim/franka_sim/envs/__init__.py b/franka_sim/franka_sim/envs/__init__.py index e3c90e39..a5e8c7c5 100644 --- a/franka_sim/franka_sim/envs/__init__.py +++ b/franka_sim/franka_sim/envs/__init__.py @@ -4,4 +4,5 @@ __all__ = [ "PandaPickCubeGymEnv", "PandaReachCubeGymEnv", + "PandaReachSparseCubeGymEnv" ] diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index c3d3324e..dcf4783f 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,8 +144,8 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, - width=960, - height=960, + width=128, + height=128, camera_id=0 ) self._viewer.render(self.render_mode) diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py new file mode 100644 index 00000000..560d8bd7 --- /dev/null +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -0,0 +1,281 @@ +from pathlib import Path +from typing import Any, Literal, Tuple, Dict + +import gym +import mujoco +import numpy as np +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None + +from franka_sim.controllers import opspace +from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv +from gym.envs.registration import register + +_HERE = Path(__file__).parent +_XML_PATH = _HERE / "xmls" / "arena.xml" +_PANDA_HOME = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)) +_CARTESIAN_BOUNDS = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]) +_SAMPLING_BOUNDS = np.asarray([[0.25, -0.25], [0.55, 0.25]]) + + +class PandaReachSparseCubeGymEnv(MujocoGymEnv): + metadata = {"render_modes": ["rgb_array", "human"]} + + def __init__( + self, + action_scale: np.ndarray = np.asarray([0.1, 1]), + seed: int = 0, + control_dt: float = 0.02, + physics_dt: float = 0.002, + time_limit: float = 10.0, + render_spec: GymRenderingSpec = GymRenderingSpec(), + render_mode: Literal["rgb_array", "human"] = "rgb_array", + image_obs: bool = False, + ): + self._action_scale = action_scale + + super().__init__( + xml_path=_XML_PATH, + seed=seed, + control_dt=control_dt, + physics_dt=physics_dt, + time_limit=time_limit, + render_spec=render_spec, + ) + + self.metadata = { + "render_modes": [ + "human", + "rgb_array", + ], + "render_fps": int(np.round(1.0 / self.control_dt)), + } + + self.render_mode = render_mode + self.camera_id = (0, 1) + self.image_obs = image_obs + + # Caching. + self._panda_dof_ids = np.asarray( + [self._model.joint(f"joint{i}").id for i in range(1, 8)] + ) + self._panda_ctrl_ids = np.asarray( + [self._model.actuator(f"actuator{i}").id for i in range(1, 8)] + ) + self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id + self._pinch_site_id = self._model.site("pinch").id + self._block_z = self._model.geom("block").size[2] + + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + "block_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + } + ), + } + ) + + if self.image_obs: + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + } + ), + "images": gym.spaces.Dict( + { + "front": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + "wrist": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + } + ), + } + ) + + self.action_space = gym.spaces.Box( + low=np.asarray([-1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0]), + dtype=np.float32, + ) + + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO + from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + + self._viewer = MujocoRenderer( + self.model, + self.data, + width=960, + height=960, + camera_id=0 + ) + self._viewer.render(self.render_mode) + + def reset( + self, seed=None, **kwargs + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" + mujoco.mj_resetData(self._model, self._data) + + # Reset arm to home position. + self._data.qpos[self._panda_dof_ids] = _PANDA_HOME + mujoco.mj_forward(self._model, self._data) + + # Reset mocap body to home position. + tcp_pos = self._data.sensor("2f85/pinch_pos").data + self._data.mocap_pos[0] = tcp_pos + + # Sample a new block position. + block_xy = np.random.uniform(*_SAMPLING_BOUNDS) + self._data.jnt("block").qpos[:3] = (*block_xy, self._block_z) + mujoco.mj_forward(self._model, self._data) + + # Cache the initial block height. + # self._z_init = self._data.sensor("block_pos").data[2] + # self._z_success = self._z_init + 0.2 + + obs = self._compute_observation() + return obs, {} + + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ + x, y, z = action + + # Set the mocap position. + pos = self._data.mocap_pos[0].copy() + dpos = np.asarray([x, y, z]) * self._action_scale[0] + npos = np.clip(pos + dpos, *_CARTESIAN_BOUNDS) + self._data.mocap_pos[0] = npos + + # Set gripper grasp. + self._data.ctrl[self._gripper_ctrl_id] = 0 # Fully open position + + for _ in range(self._n_substeps): + tau = opspace( + model=self._model, + data=self._data, + site_id=self._pinch_site_id, + dof_ids=self._panda_dof_ids, + pos=self._data.mocap_pos[0], + ori=self._data.mocap_quat[0], + joint=_PANDA_HOME, + gravity_comp=True, + ) + self._data.ctrl[self._panda_ctrl_ids] = tau + mujoco.mj_step(self._model, self._data) + + obs = self._compute_observation() + rew = self._compute_reward() + terminated = self.time_limit_exceeded() + + return obs, rew, terminated, False, {} + + def render(self): + rendered_frames = [] + for cam_id in self.camera_id: + rendered_frames.append( + self._viewer.render(render_mode="rgb_array") + ) + return rendered_frames + + # Helper methods. + + def _compute_observation(self) -> dict: + obs = {} + obs["state"] = {} + + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + + if self.image_obs: + obs["images"] = {} + obs["images"]["front"], obs["images"]["wrist"] = self.render() + else: + block_pos = self._data.sensor("block_pos").data.astype(np.float32) + obs["state"]["block_pos"] = block_pos + + if self.render_mode == "human": + self._viewer.render(self.render_mode) + + return obs + + def _compute_reward(self) -> float: + # Get positions + block_pos = self._data.sensor("block_pos").data + tcp_pos = self._data.sensor("2f85/pinch_pos").data + + # Calculate distance + dist = np.linalg.norm(block_pos - tcp_pos) + + # Distance-based reward + r_close = np.exp(-20 * dist) + r_close = np.clip(r_close, 0.0, 1.0) + + return (r_close > 0.95) + + +if __name__ == "__main__": + # Create wrapped environment + env = PandaReachSparseCubeGymEnv(render_mode="human") + env.reset() + for i in range(5000): + env.step(np.random.uniform(-1, 1, 3)) + env.render() + env.close() From a04447f6921c96eb225fab2c0ca3c20f974ff336 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 13 Aug 2025 19:14:36 -0500 Subject: [PATCH 107/141] Added optional images to fractal_symmetry_replay_buffer --- .../async_sac_state_sim/automated_tests.sh | 46 ++++---- .../franka_sim/envs/panda_pick_gym_env.py | 4 +- .../serl_launcher/data/data_store.py | 4 +- .../data/fractal_symmetry_replay_buffer.py | 106 +++++++++++++----- serl_launcher/serl_launcher/data/fsrb_test.py | 19 ++-- serl_launcher/serl_launcher/utils/launcher.py | 2 +- 6 files changed, 123 insertions(+), 58 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index 577b13a5..eb22ffb8 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -12,7 +12,7 @@ CRITIC_ACTOR_RATIO=8 EXP_NAME="SCALING-REPLAY-BUFFER-CAPACITY-WITH-BRANCHES-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO --seed $SEED" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" ARGS="" function run_test { @@ -39,32 +39,38 @@ function run_test { } -# # BASELINE TESTING -# for batch_size in 256 -# do -# for replay_buffer_capacity in 1000 2000 1000000 -# do -# ARGS="--run_name baseline-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" -# run_test -# done -# done - -# # CONSTANT TESTING -# for starting_branch_count in 1 2 +# BASELINE TESTING +# for CRITIC_ACTOR_RATIO in 8 32 # do -# for batch_size in 256 +# for batch_size in 256 2048 # do -# for workspace_width in 0.5 +# for replay_buffer_capacity in 1000 1000000 # do -# for replay_buffer_capacity in $(( starting_branch_count * starting_branch_count * TRAINING_STARTS )) $(( starting_branch_count * starting_branch_count * TRAINING_STARTS * 2 )) $(( starting_branch_count * starting_branch_count * TRAINING_STARTS * 1000 )) -# do -# ARGS="--run_name constant-$starting_branch_count^1-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" -# run_test -# done +# ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" +# run_test # done # done # done +# CONSTANT TESTING +for CRITIC_ACTOR_RATIO in 8 32 +do + for starting_branch_count in 3 9 27 81 243 + do + for batch_size in 256 2048 + do + for workspace_width in 0.5 + do + for replay_buffer_capacity in $(( starting_branch_count * starting_branch_count * TRAINING_STARTS )) $(( starting_branch_count * starting_branch_count * TRAINING_STARTS * 1000 )) + do + ARGS="--run_name constant-$starting_branch_count^1 --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done + done + done + done +done + # # FRACTAL TESTING # for batch_size in 256 # do diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index c3d3324e..dcf4783f 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,8 +144,8 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, - width=960, - height=960, + width=128, + height=128, camera_id=0 ) self._viewer.render(self.render_mode) diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index b22de984..a6f074ee 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -155,11 +155,11 @@ def __init__( y_obs_idx, branch_method: str, split_method: str, - workspace_width_method: str, rlds_logger: Optional[RLDSLogger] = None, + image_keys: Iterable[str] = None, **kwargs: dict, ): - FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, workspace_width, x_obs_idx, y_obs_idx, branch_method, split_method, workspace_width_method,**kwargs) + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, workspace_width, x_obs_idx, y_obs_idx, branch_method, split_method, image_keys, **kwargs) DataStoreBase.__init__(self, capacity) self._lock = Lock() self._logger = None diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 770713aa..4c990298 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -17,8 +17,8 @@ def __init__( y_obs_idx : np.ndarray, branch_method: str, split_method: str, - workspace_width_method: str, - kwargs: dict + img_keys: list, + kwargs: dict, ): # Initialize values @@ -26,6 +26,8 @@ def __init__( self.current_branch_count = 1 self.update_max_traj_length = False self.workspace_width = workspace_width + self.img_keys = img_keys + self._img_insert_index_ = 0 # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations self.x_obs_idx = x_obs_idx @@ -37,21 +39,22 @@ def __init__( self.split_method = split_method self.branch_method = branch_method - self.workspace_width_method = workspace_width_method self._handle_methods_(kwargs) - ## Warn about unused kwargs + # Warn about unused kwargs for k in kwargs.keys(): print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") - # TODO - # Create more methods with tests - # - # Future considerations: - # when dealing with x and y for transforms, consider two versions: - # radial trees at different angles can have strong fractal symmetry built in but will be more difficult to evenly space lowest level transforms - # grid-based will inherently evenly space lowest transforms, but fractal symmetry will be restricted to 9 branching_factor (original plus 8) in order to conform + # Account for images + img_buffer_size = (capacity + self.expected_branches - 1) // self.expected_branches + 1 + self.img_buffer = {} + for k in img_keys: + temp = np.empty(shape=observation_space.spaces[k].shape, dtype=observation_space.spaces[k].dtype) + temp = temp[np.newaxis, :] + self.img_buffer[k] = np.repeat(temp, img_buffer_size, axis=0) + + observation_space.spaces[k] = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(1,), dtype=np.int32) # Init replay buffer class super().__init__( @@ -60,6 +63,10 @@ def __init__( capacity=capacity, ) + for k in img_keys: + self.dataset_dict["observations"][k] = self.dataset_dict["observations"][k].flatten() + self.dataset_dict["next_observations"][k] = self.dataset_dict["next_observations"][k].flatten() + self.generate_transform_deltas() def _handle_method_arg_(self, value, method_type, method, kwargs): @@ -80,6 +87,7 @@ def _handle_methods_(self, kwargs): self.branch = self.fractal_branch if not self.split_method: self.split_method = "time" + self.expected_branches = self.branching_factor ** self.max_depth case "contraction": self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) @@ -88,6 +96,7 @@ def _handle_methods_(self, kwargs): self.branch = self.fractal_contraction if not self.split_method: self.split_method = "time" + self.expected_branches = self.branching_factor ** self.max_depth case "linear": raise NotImplementedError("linear branch method is not yet implemented") @@ -114,6 +123,7 @@ def _handle_methods_(self, kwargs): self.branch = self.disassociated_branch if not self.split_method: self.split_method = "time" + self.expected_branches = self.max_branch_count case "constant": self._handle_method_arg_("starting_branch_count", "branch_method", self.branch_method, kwargs) @@ -121,6 +131,7 @@ def _handle_methods_(self, kwargs): self.branch = self.constant_branch if not self.split_method: self.split_method = "never" + self.expected_branches = self.starting_branch_count case _: raise ValueError("incorrect value passed to branch_method") @@ -148,7 +159,11 @@ def _handle_methods_(self, kwargs): def generate_transform_deltas(self): - obs_size = self.dataset_dict["observations"].shape[1] + obs_state = self.dataset_dict["observations"] + if self.img_keys: + obs_state = self.dataset_dict["observations"]["state"] + + obs_size = len(obs_state[0]) total_branches = self.current_branch_count ** 2 self.transform_deltas = np.zeros(shape=(total_branches, obs_size), dtype=np.float32) @@ -245,11 +260,23 @@ def constant_split(self, data_dict: DatasetDict): def never_split(self, data_dict: DatasetDict): return False - + + def insert_images(self, observation: dict): + for k in self.img_keys: + self.img_buffer[k][self._img_insert_index_] = observation[k] + self._img_insert_index_ += 1 + def insert(self, data: DatasetDict): data_dict = copy.deepcopy(data) + obs_state = data_dict["observations"] + action = data_dict["actions"] + n_obs_state = data_dict["observations"] + reward = data_dict["rewards"] + mask = data_dict["masks"] + done = data_dict["dones"] + # Update number of branches if needed if self.split(data_dict): temp = self.current_branch_count @@ -258,33 +285,60 @@ def insert(self, data: DatasetDict): if temp != self.current_branch_count: self.generate_transform_deltas() + # Insert images if needed + if self.img_keys: + if self.timestep == 0: + self.insert_images(data_dict["observations"]) + self.insert_images(data_dict["next_observations"]) + obs_state = data_dict["observations"]["state"] + n_obs_state = data_dict["next_observations"]["state"] + for k in self.img_keys: + data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) + data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) + # Initialize to extreme x and y base_diff = -self.workspace_width/2 - data_dict["observations"][self.x_obs_idx] += base_diff - data_dict["observations"][self.y_obs_idx] += base_diff - data_dict["next_observations"][self.x_obs_idx] += base_diff - data_dict["next_observations"][self.y_obs_idx] += base_diff + obs_state[self.x_obs_idx] += base_diff + obs_state[self.y_obs_idx] += base_diff + n_obs_state[self.x_obs_idx] += base_diff + n_obs_state[self.y_obs_idx] += base_diff # Transform and insert transitions num_transforms = self.current_branch_count ** 2 - obs_batch = np.tile(data_dict["observations"], (num_transforms, 1)) - next_obs_batch = np.tile(data_dict["next_observations"], (num_transforms, 1)) + obs_batch = np.tile(obs_state, (num_transforms, 1)) + next_obs_batch = np.tile(n_obs_state, (num_transforms, 1)) obs_batch += self.transform_deltas next_obs_batch += self.transform_deltas - data_dict["observations"] = obs_batch - data_dict["next_observations"] = next_obs_batch - data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) - data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) - data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) - data_dict["dones"] = np.tile(data_dict["dones"], num_transforms) + obs_state = obs_batch + n_obs_state = next_obs_batch + action = np.tile(action, (num_transforms, 1)) + reward = np.tile(reward, num_transforms) + mask = np.tile(mask, num_transforms) + done = np.tile(done, num_transforms) + + for k in self.img_keys: + data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) + + if self.img_keys: + data_dict["observations"]["state"] = obs_state + data_dict["next_observations"]["state"] = n_obs_state + else: + data_dict["observations"] = obs_state + data_dict["next_observations"] = n_obs_state + + data_dict["actions"] = action + data_dict["rewards"] = reward + data_dict["masks"] = mask + data_dict["dones"] = done super().insert(data_dict, batch_size=num_transforms) # Reset current_depth, timestep, and max_traj_length self.timestep += 1 - if data_dict["dones"][0]: + if done[0]: self.current_depth = 0 if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index e51a5c0f..46245bc9 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -3,20 +3,22 @@ import gym from serl_launcher.utils.launcher import make_replay_buffer from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper from absl import app, flags import franka_sim # import pandas as pd FLAGS = flags.FLAGS -flags.DEFINE_integer("capacity", 1000000, "Replay buffer capacity.") +flags.DEFINE_integer("capacity", 10000, "Replay buffer capacity.") flags.DEFINE_string("branch_method", "constant", "Method for determining the number of transforms per dimension (x,y)") flags.DEFINE_string("split_method", "constant", "Method for determining whether to change the number of transforms per dimension (x,y)") flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only flags.DEFINE_integer("max_steps",100,"Maximum steps") flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only -flags.DEFINE_integer("starting_branch_count", 1000, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("starting_branch_count", 10, "Initial number of transforms per dimension (x,y)") # For constant_branch only flags.DEFINE_integer("alpha",1,"alpha value") # Density Workspace width flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') @@ -27,8 +29,11 @@ def main(_): y_obs_idx = np.array([1, 5]) # Initialize replay buffer - env = gym.make("PandaReachCube-v0") - env = gym.wrappers.FlattenObservation(env) + env = gym.make("PandaPickCubeVision-v0") + env = SERLObsWrapper(env) + # env = gym.wrappers.FlattenObservation(env) + + image_keys = [key for key in env.observation_space.keys() if key != "state"] replay_buffer = make_replay_buffer( env, @@ -39,19 +44,19 @@ def main(_): workspace_width=FLAGS.workspace_width, x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, + image_keys=image_keys, max_depth=FLAGS.max_depth, max_traj_length = 100, branching_factor=FLAGS.branching_factor, alpha = FLAGS.alpha, starting_branch_count = FLAGS.starting_branch_count, - workspace_width_method=FLAGS.workspace_width_method ) observation, info = env.reset() - observation = np.zeros_like(observation) + action = env.action_space.sample() next_observation, reward, terminated, truncated, info = env.step(action) - next_observation = np.ones_like(observation) + data_dict = dict( observations=observation, next_observations=next_observation, diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index fbed0c0a..699a0968 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -276,10 +276,10 @@ def make_replay_buffer( branch_method=branch_method, split_method=split_method, workspace_width=workspace_width, - workspace_width_method=workspace_width_method, x_obs_idx=x_obs_idx, y_obs_idx=y_obs_idx, rlds_logger=rlds_logger, + image_keys=image_keys, kwargs=kwargs, ) From efb19e3053d7eb8f8fd07ec1a6f6df8ce57a9217 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 13 Aug 2025 19:21:54 -0500 Subject: [PATCH 108/141] Improved tests --- examples/async_sac_state_sim/tmux_launch_tests.sh | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh index dbd2fc31..bc70df96 100644 --- a/examples/async_sac_state_sim/tmux_launch_tests.sh +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -2,10 +2,14 @@ # Create a new tmux session +if tmux has-session -t serl_session; then + tmux kill-server -t serl_session +fi + tmux new-session -d -s serl_session tmux setw -g remain-on-exit on -for SEED in {1..5} +for SEED in 1 do # New window tmux new-window -t serl_session -n $SEED @@ -18,7 +22,7 @@ do done # Attach to the tmux session -tmux attach-session -t serl_session +# tmux attach-session -t serl_session # kill the tmux session by running the following command # tmux kill-session -t serl_session From c6986a5ead4f312f8e37a654de883e027152ba99 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 13 Aug 2025 19:30:35 -0500 Subject: [PATCH 109/141] added offline to testing options --- .../async_sac_state_sim/automated_tests.sh | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index eb22ffb8..489e7405 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -11,8 +11,9 @@ RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 EXP_NAME="SCALING-REPLAY-BUFFER-CAPACITY-WITH-BRANCHES-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" +OFFLINE=False -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline $OFFLINE --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" ARGS="" function run_test { @@ -39,18 +40,18 @@ function run_test { } -# BASELINE TESTING -# for CRITIC_ACTOR_RATIO in 8 32 -# do -# for batch_size in 256 2048 -# do -# for replay_buffer_capacity in 1000 1000000 -# do -# ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" -# run_test -# done -# done -# done +BASELINE TESTING +for CRITIC_ACTOR_RATIO in 8 32 +do + for batch_size in 256 2048 + do + for replay_buffer_capacity in 1000 1000000 + do + ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" + run_test + done + done +done # CONSTANT TESTING for CRITIC_ACTOR_RATIO in 8 32 From 1531be9204f43305d1f89d0a7a5e09e1d5561243 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Wed, 13 Aug 2025 20:02:16 -0500 Subject: [PATCH 110/141] test changes --- examples/async_sac_state_sim/automated_tests.sh | 2 +- examples/async_sac_state_sim/tmux_launch_tests.sh | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index 489e7405..a05bcf7f 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -9,7 +9,7 @@ MAX_STEPS=25000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="SCALING-REPLAY-BUFFER-CAPACITY-WITH-BRANCHES-$ENV" +EXP_NAME="BATCH_SIZE-CRITIC_ACTOR_RATIO-TEST-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" OFFLINE=False diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh index bc70df96..2b06c1ee 100644 --- a/examples/async_sac_state_sim/tmux_launch_tests.sh +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -2,10 +2,6 @@ # Create a new tmux session -if tmux has-session -t serl_session; then - tmux kill-server -t serl_session -fi - tmux new-session -d -s serl_session tmux setw -g remain-on-exit on @@ -22,7 +18,7 @@ do done # Attach to the tmux session -# tmux attach-session -t serl_session +tmux attach-session -t serl_session # kill the tmux session by running the following command # tmux kill-session -t serl_session From c9a0941679d4471fc9a14d85e8becbbb0edbe125 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Thu, 14 Aug 2025 21:20:36 -0500 Subject: [PATCH 111/141] begun functionality for images for fractal_symmetry_replay_buffer --- .../async_sac_state_sim/automated_tests.sh | 11 +- .../data/fractal_symmetry_replay_buffer.py | 138 +++++++++++++----- serl_launcher/serl_launcher/data/fsrb_test.py | 14 +- 3 files changed, 112 insertions(+), 51 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index a05bcf7f..3224edc3 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -11,9 +11,8 @@ RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 EXP_NAME="BATCH_SIZE-CRITIC_ACTOR_RATIO-TEST-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" -OFFLINE=False -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline $OFFLINE --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" ARGS="" function run_test { @@ -28,11 +27,11 @@ function run_test { echo "Running constant with args: $ARGS" tmux respawn-pane -k -t serl_session:$SEED.1 tmux respawn-pane -k -t serl_session:$SEED.2 - tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps $((MAX_STEPS * 2)) $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 $BASE_ARGS $ARGS" C-m tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m # Wait for actor or learner to finish - while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "Pane is dead" > /dev/null || ! tmux capture-pane -t serl_session:$SEED.1 -p | grep "Pane is dead" > /dev/null ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "logout" > /dev/null || ! tmux capture-pane -t serl_session:$SEED.1 -p | grep "logout" > /dev/null; + while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "logout" > /dev/null; do sleep 1 done @@ -40,10 +39,10 @@ function run_test { } -BASELINE TESTING +# BASELINE TESTING for CRITIC_ACTOR_RATIO in 8 32 do - for batch_size in 256 2048 + for batch_size in 2048 do for replay_buffer_capacity in 1000 1000000 do diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index 4c990298..fde11e8c 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -270,13 +270,6 @@ def insert(self, data: DatasetDict): data_dict = copy.deepcopy(data) - obs_state = data_dict["observations"] - action = data_dict["actions"] - n_obs_state = data_dict["observations"] - reward = data_dict["rewards"] - mask = data_dict["masks"] - done = data_dict["dones"] - # Update number of branches if needed if self.split(data_dict): temp = self.current_branch_count @@ -284,62 +277,131 @@ def insert(self, data: DatasetDict): # Update transform_deltas if needed if temp != self.current_branch_count: self.generate_transform_deltas() - + # Insert images if needed if self.img_keys: if self.timestep == 0: self.insert_images(data_dict["observations"]) self.insert_images(data_dict["next_observations"]) - obs_state = data_dict["observations"]["state"] - n_obs_state = data_dict["next_observations"]["state"] for k in self.img_keys: data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) - + # Initialize to extreme x and y base_diff = -self.workspace_width/2 - obs_state[self.x_obs_idx] += base_diff - obs_state[self.y_obs_idx] += base_diff - n_obs_state[self.x_obs_idx] += base_diff - n_obs_state[self.y_obs_idx] += base_diff + data_dict["observations"][self.x_obs_idx] += base_diff + data_dict["observations"][self.y_obs_idx] += base_diff + data_dict["next_observations"][self.x_obs_idx] += base_diff + data_dict["next_observations"][self.y_obs_idx] += base_diff # Transform and insert transitions num_transforms = self.current_branch_count ** 2 - obs_batch = np.tile(obs_state, (num_transforms, 1)) - next_obs_batch = np.tile(n_obs_state, (num_transforms, 1)) + obs_batch = np.tile(data_dict["observations"], (num_transforms, 1)) + next_obs_batch = np.tile(data_dict["next_observations"], (num_transforms, 1)) obs_batch += self.transform_deltas next_obs_batch += self.transform_deltas - obs_state = obs_batch - n_obs_state = next_obs_batch - action = np.tile(action, (num_transforms, 1)) - reward = np.tile(reward, num_transforms) - mask = np.tile(mask, num_transforms) - done = np.tile(done, num_transforms) - - for k in self.img_keys: - data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) - data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) - if self.img_keys: - data_dict["observations"]["state"] = obs_state - data_dict["next_observations"]["state"] = n_obs_state + data_dict["observations"]["state"] = obs_batch + data_dict["next_observations"]["state"] = next_obs_batch + for k in self.img_keys: + data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) + else: - data_dict["observations"] = obs_state - data_dict["next_observations"] = n_obs_state + data_dict["observations"] = obs_batch + data_dict["next_observations"] = next_obs_batch - data_dict["actions"] = action - data_dict["rewards"] = reward - data_dict["masks"] = mask - data_dict["dones"] = done + data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) + data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) + data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) + data_dict["dones"] = np.tile(data_dict["dones"], num_transforms) super().insert(data_dict, batch_size=num_transforms) # Reset current_depth, timestep, and max_traj_length self.timestep += 1 - if done[0]: + if data_dict["dones"][0]: self.current_depth = 0 if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) - self.timestep = 0 \ No newline at end of file + self.timestep = 0 + + # def insert(self, data: DatasetDict): + + # data_dict = copy.deepcopy(data) + + # obs_state = data_dict["observations"] + # action = data_dict["actions"] + # n_obs_state = data_dict["observations"] + # reward = data_dict["rewards"] + # mask = data_dict["masks"] + # done = data_dict["dones"] + + # # Update number of branches if needed + # if self.split(data_dict): + # temp = self.current_branch_count + # self.current_branch_count = self.branch() + # # Update transform_deltas if needed + # if temp != self.current_branch_count: + # self.generate_transform_deltas() + + # # Insert images if needed + # if self.img_keys: + # if self.timestep == 0: + # self.insert_images(data_dict["observations"]) + # self.insert_images(data_dict["next_observations"]) + # obs_state = data_dict["observations"]["state"] + # n_obs_state = data_dict["next_observations"]["state"] + # for k in self.img_keys: + # data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) + # data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) + + # # Initialize to extreme x and y + # base_diff = -self.workspace_width/2 + # obs_state[self.x_obs_idx] += base_diff + # obs_state[self.y_obs_idx] += base_diff + # n_obs_state[self.x_obs_idx] += base_diff + # n_obs_state[self.y_obs_idx] += base_diff + + # # Transform and insert transitions + # num_transforms = self.current_branch_count ** 2 + # obs_batch = np.tile(obs_state, (num_transforms, 1)) + # next_obs_batch = np.tile(n_obs_state, (num_transforms, 1)) + + # obs_batch += self.transform_deltas + # next_obs_batch += self.transform_deltas + + # obs_state = obs_batch + # n_obs_state = next_obs_batch + # action = np.tile(action, (num_transforms, 1)) + # reward = np.tile(reward, num_transforms) + # mask = np.tile(mask, num_transforms) + # done = np.tile(done, num_transforms) + + # for k in self.img_keys: + # data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + # data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) + + # if self.img_keys: + # data_dict["observations"]["state"] = obs_state + # data_dict["next_observations"]["state"] = n_obs_state + # else: + # data_dict["observations"] = obs_state + # data_dict["next_observations"] = n_obs_state + + # data_dict["actions"] = action + # data_dict["rewards"] = reward + # data_dict["masks"] = mask + # data_dict["dones"] = done + + # super().insert(data_dict, batch_size=num_transforms) + + # # Reset current_depth, timestep, and max_traj_length + # self.timestep += 1 + # if done[0]: + # self.current_depth = 0 + # if self.update_max_traj_length: + # self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + # self.timestep = 0 \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 46245bc9..5772f799 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -13,12 +13,12 @@ flags.DEFINE_integer("capacity", 10000, "Replay buffer capacity.") flags.DEFINE_string("branch_method", "constant", "Method for determining the number of transforms per dimension (x,y)") -flags.DEFINE_string("split_method", "constant", "Method for determining whether to change the number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", "never", "Method for determining whether to change the number of transforms per dimension (x,y)") flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only flags.DEFINE_integer("max_steps",100,"Maximum steps") flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only -flags.DEFINE_integer("starting_branch_count", 10, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("starting_branch_count", 3, "Initial number of transforms per dimension (x,y)") # For constant_branch only flags.DEFINE_integer("alpha",1,"alpha value") # Density Workspace width flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') @@ -29,11 +29,11 @@ def main(_): y_obs_idx = np.array([1, 5]) # Initialize replay buffer - env = gym.make("PandaPickCubeVision-v0") - env = SERLObsWrapper(env) - # env = gym.wrappers.FlattenObservation(env) + env = gym.make("PandaReachCube-v0") + # env = SERLObsWrapper(env) + env = gym.wrappers.FlattenObservation(env) - image_keys = [key for key in env.observation_space.keys() if key != "state"] + # image_keys = [key for key in env.observation_space.keys() if key != "state"] replay_buffer = make_replay_buffer( env, @@ -44,7 +44,7 @@ def main(_): workspace_width=FLAGS.workspace_width, x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, - image_keys=image_keys, + # image_keys=image_keys, max_depth=FLAGS.max_depth, max_traj_length = 100, branching_factor=FLAGS.branching_factor, From 84bbca44f62c001f5268dad7ff1d26106fb087fe Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 15 Aug 2025 12:05:03 -0500 Subject: [PATCH 112/141] testing changes --- .../async_sac_state_sim/automated_tests.sh | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index 3224edc3..7fccc41b 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -9,7 +9,7 @@ MAX_STEPS=25000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="BATCH_SIZE-CRITIC_ACTOR_RATIO-TEST-$ENV" +EXP_NAME="BATCH_SIZE-VS-BRANCHES-WALL-TIME-TEST-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" @@ -30,7 +30,7 @@ function run_test { tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 $BASE_ARGS $ARGS" C-m tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - # Wait for actor or learner to finish + # Wait for learner to finish while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "logout" > /dev/null; do sleep 1 @@ -40,11 +40,11 @@ function run_test { } # BASELINE TESTING -for CRITIC_ACTOR_RATIO in 8 32 +for CRITIC_ACTOR_RATIO in 8 16 32 do - for batch_size in 2048 + for batch_size in 256 512 1024 2048 do - for replay_buffer_capacity in 1000 1000000 + for replay_buffer_capacity in 1000000 do ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" run_test @@ -53,15 +53,15 @@ do done # CONSTANT TESTING -for CRITIC_ACTOR_RATIO in 8 32 +for CRITIC_ACTOR_RATIO in 8 16 32 do - for starting_branch_count in 3 9 27 81 243 + for starting_branch_count in 2 8 32 do - for batch_size in 256 2048 + for batch_size in 256 512 1024 2048 do for workspace_width in 0.5 do - for replay_buffer_capacity in $(( starting_branch_count * starting_branch_count * TRAINING_STARTS )) $(( starting_branch_count * starting_branch_count * TRAINING_STARTS * 1000 )) + for replay_buffer_capacity in $(( starting_branch_count * starting_branch_count * 50000 )) do ARGS="--run_name constant-$starting_branch_count^1 --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" run_test @@ -125,4 +125,4 @@ done # done # done -tmux kill-window -t serl_session:$SEED \ No newline at end of file +tmux kill-window -t serl_session:$SEED From 1dbb0443f04470a698d7f27475ff4fc3e5a21ed6 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 15 Aug 2025 13:34:03 -0500 Subject: [PATCH 113/141] fractal_symmetry_replay_buffer with images finished, awaiting demos for testing --- .../data/fractal_symmetry_replay_buffer.py | 151 +++++++----------- serl_launcher/serl_launcher/data/fsrb_test.py | 10 +- 2 files changed, 63 insertions(+), 98 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index fde11e8c..ebbfbbc8 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -47,8 +47,10 @@ def __init__( print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") # Account for images - img_buffer_size = (capacity + self.expected_branches - 1) // self.expected_branches + 1 - self.img_buffer = {} + if self.img_keys: + img_buffer_size = (capacity + self.expected_branches - 1) // self.expected_branches + 1 + self.img_buffer = {} + self.insert = self.insert_with_images for k in img_keys: temp = np.empty(shape=observation_space.spaces[k].shape, dtype=observation_space.spaces[k].dtype) temp = temp[np.newaxis, :] @@ -278,15 +280,6 @@ def insert(self, data: DatasetDict): if temp != self.current_branch_count: self.generate_transform_deltas() - # Insert images if needed - if self.img_keys: - if self.timestep == 0: - self.insert_images(data_dict["observations"]) - self.insert_images(data_dict["next_observations"]) - for k in self.img_keys: - data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) - data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) - # Initialize to extreme x and y base_diff = -self.workspace_width/2 data_dict["observations"][self.x_obs_idx] += base_diff @@ -302,17 +295,9 @@ def insert(self, data: DatasetDict): obs_batch += self.transform_deltas next_obs_batch += self.transform_deltas - if self.img_keys: - data_dict["observations"]["state"] = obs_batch - data_dict["next_observations"]["state"] = next_obs_batch - for k in self.img_keys: - data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) - data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) - - else: - data_dict["observations"] = obs_batch - data_dict["next_observations"] = next_obs_batch + data_dict["observations"] = obs_batch + data_dict["next_observations"] = next_obs_batch data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) @@ -328,80 +313,60 @@ def insert(self, data: DatasetDict): self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) self.timestep = 0 - # def insert(self, data: DatasetDict): - - # data_dict = copy.deepcopy(data) + def insert_with_images(self, data: DatasetDict): - # obs_state = data_dict["observations"] - # action = data_dict["actions"] - # n_obs_state = data_dict["observations"] - # reward = data_dict["rewards"] - # mask = data_dict["masks"] - # done = data_dict["dones"] + data_dict = copy.deepcopy(data) - # # Update number of branches if needed - # if self.split(data_dict): - # temp = self.current_branch_count - # self.current_branch_count = self.branch() - # # Update transform_deltas if needed - # if temp != self.current_branch_count: - # self.generate_transform_deltas() + # Update number of branches if needed + if self.split(data_dict): + temp = self.current_branch_count + self.current_branch_count = self.branch() + # Update transform_deltas if needed + if temp != self.current_branch_count: + self.generate_transform_deltas() - # # Insert images if needed - # if self.img_keys: - # if self.timestep == 0: - # self.insert_images(data_dict["observations"]) - # self.insert_images(data_dict["next_observations"]) - # obs_state = data_dict["observations"]["state"] - # n_obs_state = data_dict["next_observations"]["state"] - # for k in self.img_keys: - # data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) - # data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) + # Insert images + if self.timestep == 0: + self.insert_images(data_dict["observations"]) + self.insert_images(data_dict["next_observations"]) + + for k in self.img_keys: + data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) + data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) - # # Initialize to extreme x and y - # base_diff = -self.workspace_width/2 - # obs_state[self.x_obs_idx] += base_diff - # obs_state[self.y_obs_idx] += base_diff - # n_obs_state[self.x_obs_idx] += base_diff - # n_obs_state[self.y_obs_idx] += base_diff - - # # Transform and insert transitions - # num_transforms = self.current_branch_count ** 2 - # obs_batch = np.tile(obs_state, (num_transforms, 1)) - # next_obs_batch = np.tile(n_obs_state, (num_transforms, 1)) - - # obs_batch += self.transform_deltas - # next_obs_batch += self.transform_deltas - - # obs_state = obs_batch - # n_obs_state = next_obs_batch - # action = np.tile(action, (num_transforms, 1)) - # reward = np.tile(reward, num_transforms) - # mask = np.tile(mask, num_transforms) - # done = np.tile(done, num_transforms) - - # for k in self.img_keys: - # data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) - # data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) - - # if self.img_keys: - # data_dict["observations"]["state"] = obs_state - # data_dict["next_observations"]["state"] = n_obs_state - # else: - # data_dict["observations"] = obs_state - # data_dict["next_observations"] = n_obs_state + # Initialize to extreme x and y + base_diff = -self.workspace_width/2 + data_dict["observations"]["state"][self.x_obs_idx] += base_diff + data_dict["observations"]["state"][self.y_obs_idx] += base_diff + data_dict["next_observations"]["state"][self.x_obs_idx] += base_diff + data_dict["next_observations"]["state"][self.y_obs_idx] += base_diff - # data_dict["actions"] = action - # data_dict["rewards"] = reward - # data_dict["masks"] = mask - # data_dict["dones"] = done - - # super().insert(data_dict, batch_size=num_transforms) - - # # Reset current_depth, timestep, and max_traj_length - # self.timestep += 1 - # if done[0]: - # self.current_depth = 0 - # if self.update_max_traj_length: - # self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) - # self.timestep = 0 \ No newline at end of file + # Transform and insert transitions + num_transforms = self.current_branch_count ** 2 + obs_batch = np.tile(data_dict["observations"]["state"], (num_transforms, 1)) + next_obs_batch = np.tile(data_dict["next_observations"]["state"], (num_transforms, 1)) + + obs_batch += self.transform_deltas + next_obs_batch += self.transform_deltas + + + data_dict["observations"]["state"] = obs_batch + data_dict["next_observations"]["state"] = next_obs_batch + data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) + data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) + data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) + data_dict["dones"] = np.tile(data_dict["dones"], num_transforms) + + for k in self.img_keys: + data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) + + super().insert(data_dict, batch_size=num_transforms) + + # Reset current_depth, timestep, and max_traj_length + self.timestep += 1 + if data_dict["dones"][0]: + self.current_depth = 0 + if self.update_max_traj_length: + self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + self.timestep = 0 \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 5772f799..4dc2bcea 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -29,11 +29,11 @@ def main(_): y_obs_idx = np.array([1, 5]) # Initialize replay buffer - env = gym.make("PandaReachCube-v0") - # env = SERLObsWrapper(env) - env = gym.wrappers.FlattenObservation(env) + env = gym.make("PandaPickCubeVision-v0") + env = SERLObsWrapper(env) + # env = gym.wrappers.FlattenObservation(env) - # image_keys = [key for key in env.observation_space.keys() if key != "state"] + image_keys = [key for key in env.observation_space.keys() if key != "state"] replay_buffer = make_replay_buffer( env, @@ -44,7 +44,7 @@ def main(_): workspace_width=FLAGS.workspace_width, x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, - # image_keys=image_keys, + image_keys=image_keys, max_depth=FLAGS.max_depth, max_traj_length = 100, branching_factor=FLAGS.branching_factor, From a828d5018ae247fd96b707276dbe998ee9f39ef1 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 15 Aug 2025 15:12:10 -0500 Subject: [PATCH 114/141] compiling code pieces to set things up --- demos/demos/franka_reach_drq_demo_script.py | 72 +++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 demos/demos/franka_reach_drq_demo_script.py diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py new file mode 100644 index 00000000..0ac266d9 --- /dev/null +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +import time +import numpy as np +from absl import app, flags, logging + +import gym +from oxe_envlogger.envlogger import AutoOXEEnvLogger + +FLAGS = flags.FLAGS + +flags.DEFINE_integer('num_episodes', 10, 'Number of episodes to log.') +flags.DEFINE_string('output_dir', 'datasets/', + 'Path in a filesystem to record trajectories.') +flags.DEFINE_string('env_name', 'CartPole-v1', 'Name of the environment.') +flags.DEFINE_boolean('enable_envlogger', False, 'Enable envlogger.') + + +############################################################################## + + +def main(unused_argv): + logging.info(f'Creating gym environment...') + + env = gym.make(FLAGS.env_name) + logging.info(f'Done creating {FLAGS.env_name} environment.') + + if FLAGS.enable_envlogger: + env = AutoOXEEnvLogger( + env=env, + dataset_name=FLAGS.env_name, + directory=FLAGS.output_dir, + ) + + logging.info('Training an agent for %r episodes...', FLAGS.num_episodes) + + for i in range(FLAGS.num_episodes): + + # example to log custom metadata during new episode + if FLAGS.enable_envlogger: + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) + env.reset() + terminated = False + truncated = False + + while not (terminated or truncated): + action = env.action_space.sample() + + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": time.time()}) + + return_step = env.step(action) + + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False + + logging.info( + 'Done training a random agent for %r episodes.', FLAGS.num_episodes) + + +if __name__ == '__main__': + app.run(main) \ No newline at end of file From 4173234365964b225680f9a387f10bad97ccf697 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 15 Aug 2025 15:16:13 -0500 Subject: [PATCH 115/141] adding oxe_logger dependency --- demos/requirements.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 demos/requirements.txt diff --git a/demos/requirements.txt b/demos/requirements.txt new file mode 100644 index 00000000..4a733bc8 --- /dev/null +++ b/demos/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/rail-berkeley/oxe_envlogger.git@main From 18e6ea3d8e25a0c7687e3a275fde922354d8cf47 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 15 Aug 2025 15:26:11 -0500 Subject: [PATCH 116/141] commented out some lines that i think are not needed --- demos/demos/franka_reach_drq_demo_script.py | 247 +++++++++++++++++++- 1 file changed, 246 insertions(+), 1 deletion(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 0ac266d9..0661dbce 100644 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -1,13 +1,51 @@ #!/usr/bin/env python3 import time +from time import sleep, perf_counter +from datetime import datetime import numpy as np from absl import app, flags, logging import gym from oxe_envlogger.envlogger import AutoOXEEnvLogger -FLAGS = flags.FLAGS +import numpy as np +from absl import app, flags + +import os + +from typing import Any, Dict, Optional +import pickle as pkl +import gym +from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics + +from serl_launcher.agents.continuous.drq import DrQAgent +from serl_launcher.common.evaluation import evaluate +from serl_launcher.utils.timer_utils import Timer +from serl_launcher.utils.train_utils import concat_batches + +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper + +import franka_sim + +flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("agent", "drq", "Name of agent.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") +flags.DEFINE_integer("seed", 42, "Random seed.") + +# flag to indicate if this is a leaner or a actor +flags.DEFINE_string("ip", "localhost", "IP address of the learner.") +# "small" is a 4 layer convnet, "resnet" and "mobilenet" are frozen with pretrained weights +flags.DEFINE_string("encoder_type", "resnet-pretrained", "Encoder type.") +flags.DEFINE_string("demo_path", None, "Path to the demo data.") + +flags.DEFINE_boolean( + "debug", False, "Debug mode." +) # debug mode will disable wandb logging + +flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") +flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") flags.DEFINE_integer('num_episodes', 10, 'Number of episodes to log.') flags.DEFINE_string('output_dir', 'datasets/', @@ -15,16 +53,223 @@ flags.DEFINE_string('env_name', 'CartPole-v1', 'Name of the environment.') flags.DEFINE_boolean('enable_envlogger', False, 'Enable envlogger.') +FLAGS = flags.FLAGS + +#------------------------------------------------------------------------------------------- +## Key Config Variables +#------------------------------------------------------------------------------------------- +# Proportional and derivative control gain for action scaling -- empirically tuned +Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 +Kv = 10.0 + +ACTION_MAX = 10 # Maximum action value for clipping actions +ERROR_THRESHOLD = 0.008 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. + +# Number of demonstration episodes to generate +NUM_DEMOS = 20 + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' +task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' + +# Debug mode for rendering and visualization +DEBUG = False + +if DEBUG: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' +else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + +############################################################################## +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 3.0 # Camera distance + viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer ############################################################################## +def compute_error(object_pos, current_pos, prev_error, dt): + """Compute the error and its derivative between the object position and the current end-effector position. + Args: + object_pos (np.ndarray): The position of the object in the environment. + current_pos (np.ndarray): The current position of the end-effector. + prev_error (np.ndarray): The previous error value for derivative calculation. + dt (float): Time step for derivative calculation. + + Returns: + error (np.ndarray): The current error vector between the object and end-effector positions. + derror (np.ndarray): The derivative of the error vector. + """ + error = object_pos - current_pos # Calculate the error vector + derror = (error - prev_error) / dt # Calculate the derivative of the error vector + + prev_error = error.copy() # Update previous error for next iteration + return error, derror + +def demo(env, lastObs): + """ + Executes a scripted reach sequence using a hierarchical approach. + + Implements 1-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Gripper: + - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. + + Args: + env: Gymnasium environment instance + lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel + - observations: + object_pos[0:3], + gripper_pos[3] + panda/tcp_pos[4:7], + panda/tcp_vel[7:10] + + Returns: + """ + + ## Init goal, current_pos, and object position from last observation + object_pos = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + gripper_pos = np.zeros(1, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + + + # Initialize (single) episode data collection + # episodeObs = [] # Observations for this episode + # episodeAcs = [] # Actions for this episode + # episodeRews = [] # Rewards for this episode + # episodeInfo = [] # Info for this episode + # episodeTerminated = [] # Terminated flags for this episode + # episodeTruncated = [] # Truncated flags for this episode + # episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Dictionary to store episode data + # episode_data = { + # "observations": episodeObs, + # "actions": episodeAcs, + # "rewards": episodeRews, + # "infos": episodeInfo, + # "terminateds": episodeTerminated, + # "truncateds": episodeTruncated, + # "dones": episodeDones + # } + + # close gripper + fgr_pos = 0 + + # Error thresholds + error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + object_pos = lastObs[opi[0]:opi[1]] # block pos + current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos + + # Relative position between end-effector and object + dt = env.unwrapped.model.opt.timestep # Mujoco time step + prev_error = np.zeros_like(object_pos) + error, derror = compute_error(object_pos, current_pos, prev_error, dt) + + time_step = 0 # Track total time_steps in episode + episodeObs.append(lastObs) # Store initial observation + + # Initialize previous time for dt calculation + prev_time = perf_counter() # Start time for dt calculation + + # Phase 1: Reach + # Terminate when distance to above-object position < error_threshold + print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: + env.render() # Visual feedback + + # Record current time and compute dt + curr_time = perf_counter() + dt = curr_time - prev_time + prev_time = curr_time + + # Initialize action vector [x, y, z] + action = np.array([0., 0., 0.]) + + # Proportional control with gain of 6 + action[:3] = error * Kp + derror * Kv + prev_error = error.copy() # Update previous error for next iteration + + # Clip action to prevent excessive movements + action = np.clip(action/ACTION_MAX, -0.1, 0.1) # + + # Keep gripper closed -- no need. only 3 dimensions of control + #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. + + # Unpack new Gymnasium step API + # new_obs, reward, terminated, truncated, info = env.step(action) + # done = terminated or truncated + + # Store episode data + #store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) + + # Update and print state information + object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) + + # Update error for next iteration + # error, derror = compute_error(object_pos, cur_pos, prev_error, dt) + + # Update time step + time_step += 1 + + # Sleep + #if DEBUG: + sleep(0.25) # Activated when DEBUG is True for better visualization. + + # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. + # if weld_flag: + # deactivate_weld(env, constraint_name="grasp_weld") + + # Break out of the loop to start a new episode + return action + + # # If we reach here, the episode was not successful + # if weld_flag: + # print("Failed to transport object to goal position. Deactivating weld.") + # deactivate_weld(env, constraint_name="grasp_weld") + +############################################################################## def main(unused_argv): logging.info(f'Creating gym environment...') env = gym.make(FLAGS.env_name) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + if FLAGS.env == "PandaPickCubeVision-v0": + env = SERLObsWrapper(env) + #env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + logging.info(f'Done creating {FLAGS.env_name} environment.') + if FLAGS.enable_envlogger: env = AutoOXEEnvLogger( env=env, From 8178bdfa66f2041332f282575ee32ea7de9b4f2e Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 18 Aug 2025 17:00:41 -0500 Subject: [PATCH 117/141] set demo call to simply return action, but block_pos is not available. may need hack in env.reset() to a fixed position --- demos/demos/franka_reach_drq_demo_script.py | 85 ++++++++++----------- 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 0661dbce..d4ed6449 100644 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -25,10 +25,12 @@ from serl_launcher.utils.train_utils import concat_batches from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper import franka_sim flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") + flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") @@ -50,7 +52,6 @@ flags.DEFINE_integer('num_episodes', 10, 'Number of episodes to log.') flags.DEFINE_string('output_dir', 'datasets/', 'Path in a filesystem to record trajectories.') -flags.DEFINE_string('env_name', 'CartPole-v1', 'Name of the environment.') flags.DEFINE_boolean('enable_envlogger', False, 'Enable envlogger.') FLAGS = flags.FLAGS @@ -73,13 +74,22 @@ task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' # Debug mode for rendering and visualization -DEBUG = False +DEBUG = True if DEBUG: _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' else: _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI +# Indices for franka_sim reach environment observations +if robot == 'franka' and task == 'reach': + + opi = np.array([0, 3]) # Indices for object position in observation + gpi = np.array([3]) # Indices for gripper position in observation + rpi = np.array([4, 7]) # Indices for robot position in observation + rvi = np.array([7, 10]) # Indices for robot velocity in observation + + ############################################################################## def set_front_cam_view(env): """ @@ -124,7 +134,7 @@ def compute_error(object_pos, current_pos, prev_error, dt): return error, derror -def demo(env, lastObs): +def demo(env, lastObs, prev_error, prev_time): """ Executes a scripted reach sequence using a hierarchical approach. @@ -152,28 +162,7 @@ def demo(env, lastObs): object_pos = np.zeros(3, dtype=np.float32) current_pos = np.zeros(3, dtype=np.float32) gripper_pos = np.zeros(1, dtype=np.float32) - object_rel_pos = np.zeros(3, dtype=np.float32) - - - # Initialize (single) episode data collection - # episodeObs = [] # Observations for this episode - # episodeAcs = [] # Actions for this episode - # episodeRews = [] # Rewards for this episode - # episodeInfo = [] # Info for this episode - # episodeTerminated = [] # Terminated flags for this episode - # episodeTruncated = [] # Truncated flags for this episode - # episodeDones = [] # Done flags (terminated or truncated) for this episode - - # Dictionary to store episode data - # episode_data = { - # "observations": episodeObs, - # "actions": episodeAcs, - # "rewards": episodeRews, - # "infos": episodeInfo, - # "terminateds": episodeTerminated, - # "truncateds": episodeTruncated, - # "dones": episodeDones - # } + object_rel_pos = np.zeros(3, dtype=np.float32) # close gripper fgr_pos = 0 @@ -193,29 +182,30 @@ def demo(env, lastObs): prev_error = np.zeros_like(object_pos) error, derror = compute_error(object_pos, current_pos, prev_error, dt) - time_step = 0 # Track total time_steps in episode - episodeObs.append(lastObs) # Store initial observation - - # Initialize previous time for dt calculation - prev_time = perf_counter() # Start time for dt calculation - + # time_step = 0 # Track total time_steps in episode + # episodeObs.append(lastObs) # Store initial observation + # Phase 1: Reach # Terminate when distance to above-object position < error_threshold - print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") - while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: - env.render() # Visual feedback + # print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold: + # env.render() # Visual feedback # Record current time and compute dt curr_time = perf_counter() dt = curr_time - prev_time - prev_time = curr_time # Initialize action vector [x, y, z] action = np.array([0., 0., 0.]) + + # Update error for next iteration + error, derror = compute_error(object_pos, current_pos, prev_error, dt) + # Proportional control with gain of 6 action[:3] = error * Kp + derror * Kv prev_error = error.copy() # Update previous error for next iteration + # Clip action to prevent excessive movements action = np.clip(action/ACTION_MAX, -0.1, 0.1) # @@ -231,13 +221,11 @@ def demo(env, lastObs): #store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) # Update and print state information - object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) + # object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) - # Update error for next iteration - # error, derror = compute_error(object_pos, cur_pos, prev_error, dt) # Update time step - time_step += 1 + # time_step += 1 # Sleep #if DEBUG: @@ -248,7 +236,7 @@ def demo(env, lastObs): # deactivate_weld(env, constraint_name="grasp_weld") # Break out of the loop to start a new episode - return action + return action, prev_error, curr_time # # If we reach here, the episode was not successful # if weld_flag: @@ -259,21 +247,21 @@ def demo(env, lastObs): def main(unused_argv): logging.info(f'Creating gym environment...') - env = gym.make(FLAGS.env_name) + env = gym.make(FLAGS.env, render_mode=_render_mode) if FLAGS.env == "PandaPickCube-v0": env = gym.wrappers.FlattenObservation(env) if FLAGS.env == "PandaPickCubeVision-v0": env = SERLObsWrapper(env) - #env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) - logging.info(f'Done creating {FLAGS.env_name} environment.') + logging.info(f'Done creating {FLAGS.env} environment.') if FLAGS.enable_envlogger: env = AutoOXEEnvLogger( env=env, - dataset_name=FLAGS.env_name, + dataset_name=FLAGS.env, directory=FLAGS.output_dir, ) @@ -289,12 +277,17 @@ def main(unused_argv): env.set_step_metadata({"timestamp": time.time()}) logging.info('episode %r', i) - env.reset() + + # Start a new episode + obs,_ = env.reset() terminated = False truncated = False + curr_time = perf_counter() while not (terminated or truncated): - action = env.action_space.sample() + + # Get action from the demo function + action,prev_error,curr_time = demo(env,obs,prev_error, curr_time) # example to log custom step metadata if FLAGS.enable_envlogger: From 4db34a4d4b880979be6d2bc56dc4ac4a81121618 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 19 Aug 2025 18:15:37 -0500 Subject: [PATCH 118/141] testing --- .../async_sac_state_sim/automated_tests.sh | 89 ++++++++++--------- .../async_sac_state_sim/tmux_launch_tests.sh | 18 ++-- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index 7fccc41b..b439120b 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -1,70 +1,75 @@ #!/bin/bash -SEED=$1 +SEEDS=$1 WANDB_OUTPUT_DIR=~/wandb_logs TEST="async_sac_state_sim.py" CONDA_ENV="serl" ENV="PandaReachCube-v0" -MAX_STEPS=25000 +MAX_STEPS=15000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="BATCH_SIZE-VS-BRANCHES-WALL-TIME-TEST-$ENV" +EXP_NAME="27^1-OLD-VERSION-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --wandb_offline --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --seed $SEED" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS" ARGS="" function run_test { - OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) - PORTS=( $OPEN_PORTS ) - PORT_NUMBER=${PORTS[0]} - BROADCAST_PORT=${PORTS[1]} - - ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" - - echo "Running constant with args: $ARGS" - tmux respawn-pane -k -t serl_session:$SEED.1 - tmux respawn-pane -k -t serl_session:$SEED.2 - tmux send-keys -t serl_session:$SEED.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:$SEED.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS $BASE_ARGS $ARGS" C-m "exit" C-m - - # Wait for learner to finish - while ! tmux capture-pane -t serl_session:$SEED.2 -p | grep "logout" > /dev/null; - do - sleep 1 - done - echo "Finished!" - -} - -# BASELINE TESTING -for CRITIC_ACTOR_RATIO in 8 16 32 -do - for batch_size in 256 512 1024 2048 + for seed in {1..$SEEDS} do - for replay_buffer_capacity in 1000000 - do - ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" - run_test + OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + PORTS=( $OPEN_PORTS ) + PORT_NUMBER=${PORTS[0]} + BROADCAST_PORT=${PORTS[1]} + + ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for learner to finish + while ! tmux capture-pane -t serl_session:0.2 -p | grep "logout" > /dev/null; + do + sleep 1 done + echo "Finished!" done -done +} + +# # BASELINE TESTING +# for CRITIC_ACTOR_RATIO in 8 16 32 +# do +# for batch_size in 256 512 1024 2048 +# do +# for replay_buffer_capacity in 1000000 +# do +# ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" +# run_test +# done +# done +# done # CONSTANT TESTING -for CRITIC_ACTOR_RATIO in 8 16 32 +for steps_per_update in 1 30 100 do - for starting_branch_count in 2 8 32 + for CRITIC_ACTOR_RATIO in 8 do - for batch_size in 256 512 1024 2048 + for starting_branch_count in 27 do - for workspace_width in 0.5 + for batch_size in 256 do - for replay_buffer_capacity in $(( starting_branch_count * starting_branch_count * 50000 )) + for workspace_width in 0.5 do - ARGS="--run_name constant-$starting_branch_count^1 --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" - run_test + for replay_buffer_capacity in 1000000 + do + ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done done done done diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh index 2b06c1ee..80d47f46 100644 --- a/examples/async_sac_state_sim/tmux_launch_tests.sh +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -1,22 +1,18 @@ #!/bin/bash - +SEEDS=5 # Create a new tmux session tmux new-session -d -s serl_session tmux setw -g remain-on-exit on -for SEED in 1 -do - # New window - tmux new-window -t serl_session -n $SEED - # Split the window horizontally - tmux split-window -v - tmux split-pane -h -t serl_session:$SEED.1 +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh $SEEDS" C-m - # Navigate to the activate the conda environment in the first pane - tmux send-keys -t serl_session:$SEED.0 "bash automated_tests.sh $SEED" C-m -done # Attach to the tmux session tmux attach-session -t serl_session From 53d4d8e8c731c2e285752b07d455f8ad036e1771 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 19 Aug 2025 19:01:49 -0500 Subject: [PATCH 119/141] testing --- .../async_sac_state_sim/automated_tests.sh | 57 +++++++++---------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh index b439120b..30766bef 100644 --- a/examples/async_sac_state_sim/automated_tests.sh +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -5,11 +5,11 @@ WANDB_OUTPUT_DIR=~/wandb_logs TEST="async_sac_state_sim.py" CONDA_ENV="serl" ENV="PandaReachCube-v0" -MAX_STEPS=15000 +MAX_STEPS=25000 TRAINING_STARTS=1000 RANDOM_STEPS=1000 CRITIC_ACTOR_RATIO=8 -EXP_NAME="27^1-OLD-VERSION-$ENV" +EXP_NAME="GENERAL-RETESTING-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS" @@ -17,14 +17,14 @@ ARGS="" function run_test { - for seed in {1..$SEEDS} + for seed in $(seq 1 1 $SEEDS) do - OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) - PORTS=( $OPEN_PORTS ) - PORT_NUMBER=${PORTS[0]} - BROADCAST_PORT=${PORTS[1]} + # OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + # PORTS=( $OPEN_PORTS ) + # PORT_NUMBER=${PORTS[0]} + # BROADCAST_PORT=${PORTS[1]} - ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + # ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" echo "Running constant with args: $ARGS" tmux respawn-pane -k -t serl_session:0.1 @@ -41,35 +41,32 @@ function run_test { done } -# # BASELINE TESTING -# for CRITIC_ACTOR_RATIO in 8 16 32 -# do -# for batch_size in 256 512 1024 2048 -# do -# for replay_buffer_capacity in 1000000 -# do -# ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" -# run_test -# done -# done -# done +# BASELINE TESTING +for CRITIC_ACTOR_RATIO in 8 +do + for batch_size in 256 2048 + do + for replay_buffer_capacity in 1000 1000000 + do + ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" + run_test + done + done +done # CONSTANT TESTING -for steps_per_update in 1 30 100 +for CRITIC_ACTOR_RATIO in 8 do - for CRITIC_ACTOR_RATIO in 8 + for starting_branch_count in 2 8 64 do - for starting_branch_count in 27 + for batch_size in 256 do - for batch_size in 256 + for workspace_width in 10 1 .1 do - for workspace_width in 0.5 + for replay_buffer_capacity in $((1000 * $starting_branch_count * $starting_branch_count)) $((1000000 * $starting_branch_count * $starting_branch_count)) do - for replay_buffer_capacity in 1000000 - do - ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" - run_test - done + ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test done done done From 1d3138cc0b8242db5ca03cc04fc1c9fb27fb0d7c Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 21 Aug 2025 12:26:32 -0500 Subject: [PATCH 120/141] v0.2.0 - Complete Redo. - No block_pos info was available, missed it. Cannot attempt to do PID as before. - Created teleop with keyboard. - Control is okay, but having overshooting. TODO: clip. - Hard to see. Will need to have controls for mujoco angle as well. --- demos/demos/franka_reach_drq_demo_script.py | 300 ++++++++------------ 1 file changed, 126 insertions(+), 174 deletions(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index d4ed6449..0df49a25 100644 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 import time -from time import sleep, perf_counter -from datetime import datetime +# from time import perf_counter + import numpy as np from absl import app, flags, logging @@ -14,82 +14,63 @@ import os -from typing import Any, Dict, Optional import pickle as pkl import gym -from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics - -from serl_launcher.agents.continuous.drq import DrQAgent -from serl_launcher.common.evaluation import evaluate -from serl_launcher.utils.timer_utils import Timer -from serl_launcher.utils.train_utils import concat_batches from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper from serl_launcher.wrappers.chunking import ChunkingWrapper import franka_sim +# Teleoperation imports +import sys, select, termios, tty + +#------------------------------------------------------------------------------------------- +# Flags +#------------------------------------------------------------------------------------------- flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") -flags.DEFINE_string("agent", "drq", "Name of agent.") +#flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") -flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") +flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") # flag to indicate if this is a leaner or a actor -flags.DEFINE_string("ip", "localhost", "IP address of the learner.") +#flags.DEFINE_string("ip", "localhost", "IP address of the learner.") # "small" is a 4 layer convnet, "resnet" and "mobilenet" are frozen with pretrained weights flags.DEFINE_string("encoder_type", "resnet-pretrained", "Encoder type.") -flags.DEFINE_string("demo_path", None, "Path to the demo data.") +#flags.DEFINE_string("demo_path", None, "Path to the demo data.") -flags.DEFINE_boolean( - "debug", False, "Debug mode." -) # debug mode will disable wandb logging +flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging -flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") -flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("log_rlds_path", "/data/data/serl/demos/franka_reach_drq_demo_script", "Path to save RLDS logs.") +#flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", + "Directory to save the output data. This is where the RLDS logs will be saved.") + +flags.DEFINE_integer("num_demos", 10, "Number of episodes to log.") -flags.DEFINE_integer('num_episodes', 10, 'Number of episodes to log.') -flags.DEFINE_string('output_dir', 'datasets/', - 'Path in a filesystem to record trajectories.') -flags.DEFINE_boolean('enable_envlogger', False, 'Enable envlogger.') +flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") + +# Keyboard teleoperation +flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") FLAGS = flags.FLAGS #------------------------------------------------------------------------------------------- -## Key Config Variables +## Telop Config Variables #------------------------------------------------------------------------------------------- -# Proportional and derivative control gain for action scaling -- empirically tuned -Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 -Kv = 10.0 - -ACTION_MAX = 10 # Maximum action value for clipping actions -ERROR_THRESHOLD = 0.008 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. - -# Number of demonstration episodes to generate -NUM_DEMOS = 20 - -# Robot configuration -robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' -task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' - -# Debug mode for rendering and visualization -DEBUG = True - -if DEBUG: - _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' -else: - _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI - -# Indices for franka_sim reach environment observations -if robot == 'franka' and task == 'reach': - - opi = np.array([0, 3]) # Indices for object position in observation - gpi = np.array([3]) # Indices for gripper position in observation - rpi = np.array([4, 7]) # Indices for robot position in observation - rvi = np.array([7, 10]) # Indices for robot velocity in observation - - +# Bind, xyz, gripper vals to keys +moveBindings = { + 'i':(1,0,0,0), + ',':(-1,0,0,0), + 'j':(0,1,0,0), + 'l':(0,-1,0,0), + 'u':(0,0,0,1), + 'o':(0,0,0,-1), + 'm':(0,0,1,0), + '.':(0,0,-1,0), + } ############################################################################## def set_front_cam_view(env): """ @@ -106,7 +87,7 @@ def set_front_cam_view(env): if hasattr(viewer, 'cam'): viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) viewer.cam.distance = 3.0 # Camera distance - viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left viewer.cam.elevation = -30 # Negative = above, positive = below # Hide menu @@ -114,150 +95,120 @@ def set_front_cam_view(env): return viewer -############################################################################## -def compute_error(object_pos, current_pos, prev_error, dt): - """Compute the error and its derivative between the object position and the current end-effector position. - Args: - object_pos (np.ndarray): The position of the object in the environment. - current_pos (np.ndarray): The current position of the end-effector. - prev_error (np.ndarray): The previous error value for derivative calculation. - dt (float): Time step for derivative calculation. - - Returns: - error (np.ndarray): The current error vector between the object and end-effector positions. - derror (np.ndarray): The derivative of the error vector. +def getKey(settings): """ - error = object_pos - current_pos # Calculate the error vector - derror = (error - prev_error) / dt # Calculate the derivative of the error vector - - prev_error = error.copy() # Update previous error for next iteration - return error, derror + Waits briefly for a keypress and returns the pressed key. + Parameters + ---------- + settings : list + Original terminal settings so they can be restored after reading. -def demo(env, lastObs, prev_error, prev_time): + Returns + ------- + key : str + The key pressed by the user, or '' (empty string) if no key was pressed. """ - Executes a scripted reach sequence using a hierarchical approach. - - Implements 1-phase control strategy: - 1. Approach: Move gripper above object (3cm offset) + # Put terminal into raw mode so keypress is captured instantly + tty.setraw(sys.stdin.fileno()) - Store observations, actions, and info in global lists for later replay buffer inclusion. + # Wait for human to input action, we will not provide a timeout so it is blocking. + rlist, _, _ = select.select([sys.stdin], [], []) # select.select(rlist, wlist, xlist[, timeout]) - Gripper: - - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. - - Args: - env: Gymnasium environment instance - lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel - - observations: - object_pos[0:3], - gripper_pos[3] - panda/tcp_pos[4:7], - panda/tcp_vel[7:10] + if rlist: + # Read exactly one character if a key is pressed + key = sys.stdin.read(1) + else: + key = '' - Returns: + # Restore terminal to original settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + return key + +def get_kb_demo_action(speed=0.1): """ + Reads keyboard input and maps it to a 3D action vector for robot control. + + The function uses non-blocking keyboard input to allow interactive + teleoperation. Keys are mapped to directions in Cartesian space: + - 'i' : +x (forward) + - ',' : -x (backward) + - 'j' : +y (left) + - 'l' : -y (right) + - 'm' : +z (up) + - '.' : -z (down) + - 'k' : stop (zero vector) + + Parameters + ---------- + speed : float, optional + Step size for each key press (default 0.2). + + Returns + ------- + np.ndarray + Action vector of shape (3,), where each entry corresponds to + [x, y, z] translation command. Example: [0.2, 0.0, 0.0]. + """ + # Save current terminal settings so we can restore later + settings = termios.tcgetattr(sys.stdin) - ## Init goal, current_pos, and object position from last observation - object_pos = np.zeros(3, dtype=np.float32) - current_pos = np.zeros(3, dtype=np.float32) - gripper_pos = np.zeros(1, dtype=np.float32) - object_rel_pos = np.zeros(3, dtype=np.float32) - - # close gripper - fgr_pos = 0 - - # Error thresholds - error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) - - finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. - finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger - - ## Extract data - object_pos = lastObs[opi[0]:opi[1]] # block pos - current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos - - # Relative position between end-effector and object - dt = env.unwrapped.model.opt.timestep # Mujoco time step - prev_error = np.zeros_like(object_pos) - error, derror = compute_error(object_pos, current_pos, prev_error, dt) - - # time_step = 0 # Track total time_steps in episode - # episodeObs.append(lastObs) # Store initial observation - - # Phase 1: Reach - # Terminate when distance to above-object position < error_threshold - # print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") - while np.linalg.norm(error) >= error_threshold: - # env.render() # Visual feedback - - # Record current time and compute dt - curr_time = perf_counter() - dt = curr_time - prev_time + # Initialize action as a zero vector (no movement) + action = np.zeros(4, dtype=float) - # Initialize action vector [x, y, z] - action = np.array([0., 0., 0.]) + try: + # Capture the pressed key + key = getKey(settings) - # Update error for next iteration - error, derror = compute_error(object_pos, current_pos, prev_error, dt) + if key in moveBindings: + # Lookup (x, y, z) direction and scale by speed + dx, dy, dz, g= moveBindings[key] + action = np.array([dx, dy, dz, g], dtype=float) * speed - - # Proportional control with gain of 6 - action[:3] = error * Kp + derror * Kv - prev_error = error.copy() # Update previous error for next iteration - - - # Clip action to prevent excessive movements - action = np.clip(action/ACTION_MAX, -0.1, 0.1) # + elif key == 'k': + # 'k' means stop → zero vector + action = np.zeros(3, dtype=float) - # Keep gripper closed -- no need. only 3 dimensions of control - #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. - - # Unpack new Gymnasium step API - # new_obs, reward, terminated, truncated, info = env.step(action) - # done = terminated or truncated - - # Store episode data - #store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) - - # Update and print state information - # object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) - - - # Update time step - # time_step += 1 + elif key == '\x03': # CTRL-C + raise KeyboardInterrupt - # Sleep - #if DEBUG: - sleep(0.25) # Activated when DEBUG is True for better visualization. + finally: + # Restore terminal even if something goes wrong + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) - # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. - # if weld_flag: - # deactivate_weld(env, constraint_name="grasp_weld") + return action - # Break out of the loop to start a new episode - return action, prev_error, curr_time - - # # If we reach here, the episode was not successful - # if weld_flag: - # print("Failed to transport object to goal position. Deactivating weld.") - # deactivate_weld(env, constraint_name="grasp_weld") - ############################################################################## def main(unused_argv): logging.info(f'Creating gym environment...') + # Render mode configuration based on debug flag + if FLAGS.debug: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' + else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + # Create the environment with the specified render mode and wrappers env = gym.make(FLAGS.env, render_mode=_render_mode) if FLAGS.env == "PandaPickCube-v0": env = gym.wrappers.FlattenObservation(env) + if FLAGS.env == "PandaPickCubeVision-v0": env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) logging.info(f'Done creating {FLAGS.env} environment.') + # Set the camera view to a front-facing perspective + if hasattr(env.unwrapped, '_viewer'): + viewer = set_front_cam_view(env) + if viewer: + logging.info('Camera view set to front-facing perspective.') + else: + logging.warning('Failed to set camera view. Viewer not available.') + # If envlogger is enabled, wrap the environment with AutoOXEEnvLogger to log episodes if FLAGS.enable_envlogger: env = AutoOXEEnvLogger( env=env, @@ -265,11 +216,13 @@ def main(unused_argv): directory=FLAGS.output_dir, ) - logging.info('Training an agent for %r episodes...', FLAGS.num_episodes) + logging.info('Training an agent for %r episodes...', FLAGS.num_demos) - for i in range(FLAGS.num_episodes): + #--- LOOP DEMOS/EPISODES --- + # Loop through the number of demos specified by the user to record demonstrations + for i in range(FLAGS.num_demos): - # example to log custom metadata during new episode + # Log custom metadata during new episode: language embeddings randomly. if FLAGS.enable_envlogger: env.set_episode_metadata({ "language_embedding": np.random.random((5,)).astype(np.float32) @@ -279,15 +232,14 @@ def main(unused_argv): logging.info('episode %r', i) # Start a new episode - obs,_ = env.reset() + obs = env.reset()[0] terminated = False truncated = False - curr_time = perf_counter() while not (terminated or truncated): # Get action from the demo function - action,prev_error,curr_time = demo(env,obs,prev_error, curr_time) + action = get_kb_demo_action() # example to log custom step metadata if FLAGS.enable_envlogger: From 0e52c8e9fc128a7532336e6c42548673e538a508 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 21 Aug 2025 16:36:41 -0500 Subject: [PATCH 121/141] - Do not incude env.close() --- demos/demos/franka_reach_drq_demo_script.py | 26 +++++++++++++++------ demos/oxe_envlogger | 1 + 2 files changed, 20 insertions(+), 7 deletions(-) create mode 160000 demos/oxe_envlogger diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 0df49a25..b2e7dcac 100644 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -32,7 +32,7 @@ #flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") -flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") # flag to indicate if this is a leaner or a actor @@ -48,7 +48,7 @@ flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", "Directory to save the output data. This is where the RLDS logs will be saved.") -flags.DEFINE_integer("num_demos", 10, "Number of episodes to log.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") @@ -60,6 +60,8 @@ #------------------------------------------------------------------------------------------- ## Telop Config Variables #------------------------------------------------------------------------------------------- +ACTION_MAX = 10 # Maximum action value for clipping actions + # Bind, xyz, gripper vals to keys moveBindings = { 'i':(1,0,0,0), @@ -167,15 +169,18 @@ def get_kb_demo_action(speed=0.1): elif key == 'k': # 'k' means stop → zero vector - action = np.zeros(3, dtype=float) + action = np.zeros(4, dtype=float) elif key == '\x03': # CTRL-C raise KeyboardInterrupt - + finally: # Restore terminal even if something goes wrong termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + # Clip action values to prevent excessive commands + action = np.clip(action, -ACTION_MAX, ACTION_MAX) + return action ############################################################################## @@ -232,10 +237,13 @@ def main(unused_argv): logging.info('episode %r', i) # Start a new episode - obs = env.reset()[0] + env.reset() terminated = False truncated = False + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached while not (terminated or truncated): # Get action from the demo function @@ -254,9 +262,13 @@ def main(unused_argv): obs, reward, terminated, info = return_step truncated = False - logging.info( - 'Done training a random agent for %r episodes.', FLAGS.num_episodes) + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 + logging.info( + 'Done training a random agent for %r episodes.', FLAGS.num_demos) + + #env.close() # it seems async_drq_sim does not use close() dm_env method error. if __name__ == '__main__': app.run(main) \ No newline at end of file diff --git a/demos/oxe_envlogger b/demos/oxe_envlogger new file mode 160000 index 00000000..de3c48bc --- /dev/null +++ b/demos/oxe_envlogger @@ -0,0 +1 @@ +Subproject commit de3c48bcf094ebba350ce0ba183efca4478a501a From 134a211e2677eb328a0e8e38b4c0544acd278d08 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 21 Aug 2025 22:26:22 -0500 Subject: [PATCH 122/141] v.0.3.0 - Set original image resolution to (960,960,3) for easy viewing. - Need to resize these images in the SERLObsWrapper such that image specs match those expected by DRQ as well as ensure they are of type uint8. - resize function added that can work with cv2 or PIL. - Call to observation returns state and resized images. Lots of checks here. - Can record demos for any number of episodes. - Keyboard bindings use the following keys: moveBindings = { 'i':(1,0,0,0), ',':(-1,0,0,0), 'j':(0,1,0,0), 'l':(0,-1,0,0), 'u':(0,0,0,1), 'o':(0,0,0,-1), 'm':(0,0,1,0), '.':(0,0,-1,0), } Note that after commanding actions the robot has some momentum that current controller does not cancel well. You will need practice. - Data gets stored to data customized folders that are checked for existance or created. --- demos/demos/franka_reach_drq_demo_script.py | 124 ++++++++------ franka_sim/franka_sim/envs/xmls/arena.xml | 2 +- .../wrappers/serl_obs_wrappers.py | 158 ++++++++++++++++-- 3 files changed, 222 insertions(+), 62 deletions(-) mode change 100644 => 100755 demos/demos/franka_reach_drq_demo_script.py diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py old mode 100644 new mode 100755 index b2e7dcac..06f730a1 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -1,25 +1,21 @@ #!/usr/bin/env python3 - +import os import time -# from time import perf_counter +from datetime import datetime import numpy as np -from absl import app, flags, logging -import gym +# Logging +from absl import app, flags, logging from oxe_envlogger.envlogger import AutoOXEEnvLogger -import numpy as np -from absl import app, flags - -import os - -import pickle as pkl +# DRL import gym - +import mujoco from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper from serl_launcher.wrappers.chunking import ChunkingWrapper +# Needed to create the franka environment import franka_sim # Teleoperation imports @@ -29,30 +25,14 @@ # Flags #------------------------------------------------------------------------------------------- flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") - -#flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") -flags.DEFINE_integer("seed", 42, "Random seed.") - -# flag to indicate if this is a leaner or a actor -#flags.DEFINE_string("ip", "localhost", "IP address of the learner.") -# "small" is a 4 layer convnet, "resnet" and "mobilenet" are frozen with pretrained weights -flags.DEFINE_string("encoder_type", "resnet-pretrained", "Encoder type.") -#flags.DEFINE_string("demo_path", None, "Path to the demo data.") - flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging - -flags.DEFINE_string("log_rlds_path", "/data/data/serl/demos/franka_reach_drq_demo_script", "Path to save RLDS logs.") #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", - "Directory to save the output data. This is where the RLDS logs will be saved.") - + "Directory to save the output data. This is where the RLDS logs will be saved.") flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") - flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") - -# Keyboard teleoperation flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") FLAGS = flags.FLAGS @@ -74,28 +54,43 @@ '.':(0,0,-1,0), } ############################################################################## -def set_front_cam_view(env): +def ensure_dir_exists(): """ - Set the camera view to a front-facing perspective for better visualization. - - Args: - env: The environment instance containing the viewer. - - Returns: - viewer: The viewer with updated camera settings. + Ensure that the output directory exists. If it does not exist, create it. + + Returns + ------- + out_path : str + The path to the output directory. """ - viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + # Create a timestamped directory for saving outputs + robot = "franka" + task = "reach" + initStateSpace = "random" # "fixed" or "random" + + # Customize the path + script_dir = FLAGS.output_dir + + # Create output filename with configuration details + fileName = "data_" + robot + "_" + task + fileName += "_" + initStateSpace + fileName += "_" + str(FLAGS.num_demos) - if hasattr(viewer, 'cam'): - viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) - viewer.cam.distance = 3.0 # Camera distance - viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left - viewer.cam.elevation = -30 # Negative = above, positive = below + # Add timestamp to filename for uniqueness + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + fileName += "_" + timestamp - # Hide menu - viewer._hide_overlay = True - - return viewer + # Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Ensure the directory exists + if not os.path.exists(out_path): + os.makedirs(script_dir, exist_ok=True) + logging.info(f"Created output directory at: {out_path}") + else: + logging.info(f"Output directory already exists at: {out_path}") + + return out_path def getKey(settings): """ @@ -183,6 +178,29 @@ def get_kb_demo_action(speed=0.1): return action +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 2.0 # Camera distance + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + ############################################################################## def main(unused_argv): logging.info(f'Creating gym environment...') @@ -200,12 +218,17 @@ def main(unused_argv): env = gym.wrappers.FlattenObservation(env) if FLAGS.env == "PandaPickCubeVision-v0": - env = SERLObsWrapper(env) + # env = SERLObsWrapper(env) + env = SERLObsWrapper( + env, + target_hw=(128, 128), + img_dtype=np.uint8, # or np.float32 + normalize=False, # True if using float32 in [0,1] + ) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) logging.info(f'Done creating {FLAGS.env} environment.') - # Set the camera view to a front-facing perspective if hasattr(env.unwrapped, '_viewer'): viewer = set_front_cam_view(env) if viewer: @@ -215,10 +238,13 @@ def main(unused_argv): # If envlogger is enabled, wrap the environment with AutoOXEEnvLogger to log episodes if FLAGS.enable_envlogger: + + out_path = ensure_dir_exists() + env = AutoOXEEnvLogger( env=env, dataset_name=FLAGS.env, - directory=FLAGS.output_dir, + directory=out_path, ) logging.info('Training an agent for %r episodes...', FLAGS.num_demos) diff --git a/franka_sim/franka_sim/envs/xmls/arena.xml b/franka_sim/franka_sim/envs/xmls/arena.xml index e8b69cd9..5c766831 100644 --- a/franka_sim/franka_sim/envs/xmls/arena.xml +++ b/franka_sim/franka_sim/envs/xmls/arena.xml @@ -19,7 +19,7 @@ - + diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index 41c169f9..c70b9527 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -2,24 +2,158 @@ from gym.spaces import flatten_space, flatten +# Optional resizers +import numpy as np +try: + import cv2 + _HAS_CV2 = True +except Exception: + _HAS_CV2 = False + +try: + from PIL import Image + _HAS_PIL = True +except Exception: + _HAS_PIL = False + +def _resize_hwc(img: np.ndarray, hw: tuple[int, int]) -> np.ndarray: + """Resize HxWxC image to (H', W', C) without changing dtype/range.""" + H, W = hw + if _HAS_CV2: + # cv2 wants (W, H) + return cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + if _HAS_PIL: + pil = Image.fromarray(img if img.dtype == np.uint8 else np.clip(img, 0, 255).astype(np.uint8)) + pil = pil.resize((W, H), resample=Image.Resampling.BILINEAR) + out = np.asarray(pil) + if img.dtype != np.uint8: + # If original was float, map back to float [0,1] + out = out.astype(np.float32) / 255.0 + return out + # Pure NumPy (simple nearest neighbor) + y_idx = (np.linspace(0, img.shape[0] - 1, H)).astype(np.int32) + x_idx = (np.linspace(0, img.shape[1] - 1, W)).astype(np.int32) + return img[y_idx][:, x_idx] + class SERLObsWrapper(gym.ObservationWrapper): """ This observation wrapper treat the observation space as a dictionary of a flattened state space and the images. """ - def __init__(self, env): + def __init__( + self, + env, + target_hw=(128, 128), # (H, W) for resized images + img_dtype=np.uint8, # np.uint8 for [0..255], or np.float32 for [0..1] + normalize=False, # if True and img_dtype=float32, scale to [0,1] + image_parent_key="images", # where images live in the original obs dict + ): super().__init__(env) - self.observation_space = gym.spaces.Dict( - { - "state": flatten_space(self.env.observation_space["state"]), - **(self.env.observation_space["images"]), - } - ) + assert isinstance(self.env.observation_space, gym.spaces.Dict), \ + "Expected Dict observation_space with keys {'state', 'images'}" + + # ---- Build new observation_space ---- + base_space = self.env.observation_space + assert "state" in base_space.spaces, "Missing 'state' in observation_space" + assert image_parent_key in base_space.spaces, f"Missing '{image_parent_key}' in observation_space" + img_space_dict = base_space.spaces[image_parent_key] + assert isinstance(img_space_dict, gym.spaces.Dict), \ + f"'{image_parent_key}' must be a Dict of image spaces" + + # Flattened state space + state_space = flatten_space(base_space.spaces["state"]) + + + # Image spaces (resized) + H, W = target_hw + image_spaces = {} + for k, sp in img_space_dict.spaces.items(): + # Assume HWC input; preserve channel count + if hasattr(sp, "shape") and sp.shape is not None: + if len(sp.shape) != 3: + raise ValueError(f"Image space '{k}' must be HxWxC; got shape {sp.shape}") + C = sp.shape[-1] + else: + raise ValueError(f"Image space '{k}' missing shape") + + if img_dtype == np.uint8: + low, high = 0, 255 + elif img_dtype == np.float32: + low, high = 0.0, 1.0 if normalize else float(getattr(sp, "high", 1.0)) + else: + raise ValueError("img_dtype must be np.uint8 or np.float32") + + image_spaces[k] = gym.spaces.Box( + low=low, + high=high, + shape=(H, W, C), + dtype=img_dtype, + ) + + # Final Dict space: {'state': ..., 'front': Box(...), 'wrist': Box(...), ...} + self.observation_space = gym.spaces.Dict({ + "state": state_space, + **image_spaces + }) + # self.observation_space = gym.spaces.Dict( + # { + # "state": flatten_space(self.env.observation_space["state"]), + # **(self.env.observation_space["images"]), + # } + # ) + + # Store config + self._target_hw = target_hw + self._img_dtype = img_dtype + self._normalize = normalize + self._image_parent_key = image_parent_key + + # def observation(self, obs): + # obs = { + # "state": flatten(self.env.observation_space["state"], obs["state"]), + # **(obs["images"]), + # } + # return obs def observation(self, obs): - obs = { - "state": flatten(self.env.observation_space["state"], obs["state"]), - **(obs["images"]), - } - return obs + # Flatten state using original (pre-flatten) state space definition + flat_state = flatten(self.env.observation_space.spaces["state"], obs["state"]) + + # Pull original images dict + imgs = obs[self._image_parent_key] + + # Resize & cast each image to match observation_space spec + out = {"state": flat_state} + for k, sp in self.observation_space.spaces.items(): + if k == "state": + continue + img = imgs[k] + # Ensure HWC + if img.ndim != 3: + raise ValueError(f"Image '{k}' must be HxWxC; got shape {img.shape}") + + # If float32 images in [0,1] but we want uint8, scale up before resize for best quality + want_uint8 = (self._img_dtype == np.uint8) + if want_uint8: + if img.dtype != np.uint8: + # Assume 0..1 range; if 0..255 float, clip and cast + img = np.clip(img, 0.0, 1.0) if img.max() <= 1.0 else np.clip(img/255.0, 0.0, 1.0) + img = (img * 255.0 + 0.5).astype(np.uint8) + resized = _resize_hwc(img, self._target_hw).astype(np.uint8) + else: + # float32 output + if img.dtype == np.uint8: + if self._normalize: + img = img.astype(np.float32) / 255.0 + else: + img = img.astype(np.float32) # keep 0..255 range if you really want that + else: + img = img.astype(np.float32) + if self._normalize and img.max() > 1.0: + img = img / 255.0 + resized = _resize_hwc(img, self._target_hw).astype(np.float32) + + out[k] = resized + + return out \ No newline at end of file From de3d3a1b1a778086384ef6cf5613a1a90551b97b Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 21 Aug 2025 22:32:06 -0500 Subject: [PATCH 123/141] improved docstrings --- .../wrappers/serl_obs_wrappers.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index c70b9527..0c789762 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -17,11 +17,23 @@ _HAS_PIL = False def _resize_hwc(img: np.ndarray, hw: tuple[int, int]) -> np.ndarray: - """Resize HxWxC image to (H', W', C) without changing dtype/range.""" + """ + Resize HxWxC image to (H', W', C) without changing dtype/range. + Supports cv2, PIL, or pure NumPy resizing. + If img is float32, it will be scaled to [0,1] if not already in that range. + If img is uint8, it will be resized as-is. + + Args: + img (np.ndarray): Input image in HxWxC format. + hw (tuple[int, int]): Target height and width (H', W'). + Returns: + np.ndarray: Resized image in H'xW'xC format. + """ H, W = hw if _HAS_CV2: # cv2 wants (W, H) return cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + if _HAS_PIL: pil = Image.fromarray(img if img.dtype == np.uint8 else np.clip(img, 0, 255).astype(np.uint8)) pil = pil.resize((W, H), resample=Image.Resampling.BILINEAR) @@ -30,6 +42,7 @@ def _resize_hwc(img: np.ndarray, hw: tuple[int, int]) -> np.ndarray: # If original was float, map back to float [0,1] out = out.astype(np.float32) / 255.0 return out + # Pure NumPy (simple nearest neighbor) y_idx = (np.linspace(0, img.shape[0] - 1, H)).astype(np.int32) x_idx = (np.linspace(0, img.shape[1] - 1, W)).astype(np.int32) @@ -37,8 +50,21 @@ def _resize_hwc(img: np.ndarray, hw: tuple[int, int]) -> np.ndarray: class SERLObsWrapper(gym.ObservationWrapper): """ - This observation wrapper treat the observation space as a dictionary - of a flattened state space and the images. + Observation wrapper for SERL environments. + Flattens the 'state' space and resizes images to a target height and width. + Supports both uint8 and float32 images, with optional normalization. + The observation space is a Dict with 'state' and resized image spaces. + + Args: + env (gym.Env): The environment to wrap. + target_hw (tuple[int, int]): Target height and width for resized images. + img_dtype (np.dtype): Data type for images, either np.uint8 or np.float32. + normalize (bool): If True, scales float32 images to [0,1]. + image_parent_key (str): Key in the observation dict where images are stored. + + Defaults to "images". + Returns: + gym.spaces.Dict: The new observation space with flattened state and resized images. """ def __init__( From e7d677cbc6a10fd94cb8f7423fafdf77f5d7a040 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Mon, 1 Sep 2025 17:07:40 -0500 Subject: [PATCH 124/141] finished images for FSRB --- .../data/fractal_symmetry_replay_buffer.py | 250 +++++++++++------- serl_launcher/serl_launcher/data/fsrb_test.py | 50 +++- 2 files changed, 202 insertions(+), 98 deletions(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index ebbfbbc8..a55ef3a8 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -1,10 +1,11 @@ +import copy +from typing import Iterable, Optional + import gym import numpy as np -from serl_launcher.data.dataset import DatasetDict +from serl_launcher.data.dataset import DatasetDict, _sample from serl_launcher.data.replay_buffer import ReplayBuffer -import copy - -from datetime import datetime as dt +from flax.core import frozen_dict class FractalSymmetryReplayBuffer(ReplayBuffer): def __init__( @@ -47,28 +48,33 @@ def __init__( print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") # Account for images + self._num_stack = None + next_observation_space = None if self.img_keys: - img_buffer_size = (capacity + self.expected_branches - 1) // self.expected_branches + 1 self.img_buffer = {} - self.insert = self.insert_with_images - for k in img_keys: - temp = np.empty(shape=observation_space.spaces[k].shape, dtype=observation_space.spaces[k].dtype) - temp = temp[np.newaxis, :] - self.img_buffer[k] = np.repeat(temp, img_buffer_size, axis=0) + next_observation_space_dict = copy.deepcopy(observation_space.spaces) + for k in img_keys: + img_obs_space = observation_space.spaces[k] + if self._num_stack is None: + self._num_stack = img_obs_space.shape[0] + img_buffer_size = ((self.expected_branches + capacity - 1) // self.expected_branches) * (self._num_stack + 1) + buffer_shape = list(img_obs_space.shape[1:]) + buffer_shape.insert(0, img_buffer_size) + self.img_buffer[k] = np.empty(buffer_shape, img_obs_space.dtype) + + observation_space.spaces[k] = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(), dtype=np.int32) + next_observation_space_dict.pop(k) + next_observation_space = gym.spaces.Dict(next_observation_space_dict) - observation_space.spaces[k] = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(1,), dtype=np.int32) # Init replay buffer class super().__init__( observation_space=observation_space, + next_observation_space=next_observation_space, action_space=action_space, capacity=capacity, ) - for k in img_keys: - self.dataset_dict["observations"][k] = self.dataset_dict["observations"][k].flatten() - self.dataset_dict["next_observations"][k] = self.dataset_dict["next_observations"][k].flatten() - self.generate_transform_deltas() def _handle_method_arg_(self, value, method_type, method, kwargs): @@ -89,7 +95,7 @@ def _handle_methods_(self, kwargs): self.branch = self.fractal_branch if not self.split_method: self.split_method = "time" - self.expected_branches = self.branching_factor ** self.max_depth + self.expected_branches = (self.branching_factor ** self.max_depth) ** 2 case "contraction": self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) @@ -98,7 +104,7 @@ def _handle_methods_(self, kwargs): self.branch = self.fractal_contraction if not self.split_method: self.split_method = "time" - self.expected_branches = self.branching_factor ** self.max_depth + self.expected_branches = (self.branching_factor ** self.max_depth) ** 2 case "linear": raise NotImplementedError("linear branch method is not yet implemented") @@ -125,7 +131,7 @@ def _handle_methods_(self, kwargs): self.branch = self.disassociated_branch if not self.split_method: self.split_method = "time" - self.expected_branches = self.max_branch_count + self.expected_branches = self.max_branch_count ** 2 case "constant": self._handle_method_arg_("starting_branch_count", "branch_method", self.branch_method, kwargs) @@ -133,7 +139,7 @@ def _handle_methods_(self, kwargs): self.branch = self.constant_branch if not self.split_method: self.split_method = "never" - self.expected_branches = self.starting_branch_count + self.expected_branches = self.starting_branch_count ** 2 case _: raise ValueError("incorrect value passed to branch_method") @@ -165,7 +171,7 @@ def generate_transform_deltas(self): if self.img_keys: obs_state = self.dataset_dict["observations"]["state"] - obs_size = len(obs_state[0]) + obs_size = obs_state.shape[-1] total_branches = self.current_branch_count ** 2 self.transform_deltas = np.zeros(shape=(total_branches, obs_size), dtype=np.float32) @@ -180,8 +186,12 @@ def generate_transform_deltas(self): x_deltas = np.reshape(x_deltas, (total_branches, self.x_obs_idx.size)) y_deltas = np.reshape(y_deltas, (total_branches, self.y_obs_idx.size)) - self.transform_deltas[:, self.x_obs_idx] = x_deltas - self.transform_deltas[:, self.y_obs_idx] = y_deltas + self.transform_deltas[..., self.x_obs_idx] = x_deltas + self.transform_deltas[..., self.y_obs_idx] = y_deltas + + if self._num_stack: + self.transform_deltas = np.expand_dims(self.transform_deltas, axis=1) + self.transform_deltas = np.repeat(self.transform_deltas, self._num_stack, axis=1) def fractal_branch(self): ''' @@ -265,13 +275,28 @@ def never_split(self, data_dict: DatasetDict): def insert_images(self, observation: dict): for k in self.img_keys: - self.img_buffer[k][self._img_insert_index_] = observation[k] + if self._num_stack: + self.img_buffer[k][self._img_insert_index_] = observation[k][0, ...] + else: + self.img_buffer[k][self._img_insert_index_] = observation[k] self._img_insert_index_ += 1 def insert(self, data: DatasetDict): data_dict = copy.deepcopy(data) + if self.img_keys: + obs = data_dict["observations"]["state"] + n_obs = data_dict["next_observations"]["state"] + else: + obs = data_dict["observations"] + n_obs = data_dict["next_observations"] + + actions = data_dict["actions"] + rewards = data_dict["rewards"] + masks = data_dict["masks"] + dones = data_dict["dones"] + # Update number of branches if needed if self.split(data_dict): temp = self.current_branch_count @@ -282,26 +307,50 @@ def insert(self, data: DatasetDict): # Initialize to extreme x and y base_diff = -self.workspace_width/2 - data_dict["observations"][self.x_obs_idx] += base_diff - data_dict["observations"][self.y_obs_idx] += base_diff - data_dict["next_observations"][self.x_obs_idx] += base_diff - data_dict["next_observations"][self.y_obs_idx] += base_diff + obs[..., self.x_obs_idx] += base_diff + obs[..., self.y_obs_idx] += base_diff + n_obs[..., self.x_obs_idx] += base_diff + n_obs[..., self.y_obs_idx] += base_diff - # Transform and insert transitions + # Transform transitions num_transforms = self.current_branch_count ** 2 - obs_batch = np.tile(data_dict["observations"], (num_transforms, 1)) - next_obs_batch = np.tile(data_dict["next_observations"], (num_transforms, 1)) - obs_batch += self.transform_deltas - next_obs_batch += self.transform_deltas + obs_shape = np.ones(len(obs.shape) + 1, dtype=int) + obs_shape[0] = num_transforms + obs = np.tile(obs, obs_shape) + n_obs = np.tile(n_obs, obs_shape) + actions = np.tile(actions, (num_transforms, 1)) + rewards = np.tile(rewards, num_transforms) + masks = np.tile(masks, num_transforms) + dones = np.tile(dones, num_transforms) - - data_dict["observations"] = obs_batch - data_dict["next_observations"] = next_obs_batch - data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) - data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) - data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) - data_dict["dones"] = np.tile(data_dict["dones"], num_transforms) + obs += self.transform_deltas + n_obs += self.transform_deltas + + # Insert images + if self.img_keys: + if self.timestep == 0: + for i in range(self._num_stack): + self.insert_images(data_dict["observations"]) + self.insert_images(data_dict["next_observations"]) + + for k in self.img_keys: + data_dict["observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) + data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + data_dict["next_observations"].pop(k) + + # Pack back into dictionary and insert + if self.img_keys: + data_dict["observations"]["state"] = obs + data_dict["next_observations"]["state"] = n_obs + else: + data_dict["observations"] = obs + data_dict["next_observations"] = n_obs + + data_dict["actions"] = actions + data_dict["rewards"] = rewards + data_dict["masks"] = masks + data_dict["dones"] = dones super().insert(data_dict, batch_size=num_transforms) @@ -312,61 +361,80 @@ def insert(self, data: DatasetDict): if self.update_max_traj_length: self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) self.timestep = 0 + + def sample( + self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None, pack_obs_and_next_obs: bool = False, + ) -> frozen_dict.FrozenDict: + """Samples from the replay buffer. + + Args: + batch_size: Minibatch size. + keys: Keys to sample. + indx: Take indices instead of sampling. + pack_obs_and_next_obs: whether to pack img and next_img into one image. + It's useful when they have overlapping frames. - def insert_with_images(self, data: DatasetDict): - - data_dict = copy.deepcopy(data) - - # Update number of branches if needed - if self.split(data_dict): - temp = self.current_branch_count - self.current_branch_count = self.branch() - # Update transform_deltas if needed - if temp != self.current_branch_count: - self.generate_transform_deltas() + Returns: + A frozen dictionary. + """ + # If no images, sample normally + if not self.img_keys: + return super().sample(batch_size, keys, indx) - # Insert images - if self.timestep == 0: - self.insert_images(data_dict["observations"]) - self.insert_images(data_dict["next_observations"]) - + # Generate random indexes for sampling + if indx is None: + if hasattr(self.np_random, "integers"): + indx = self.np_random.integers(len(self), size=batch_size) + else: + indx = self.np_random.randint(len(self), size=batch_size) + + for i in range(batch_size): + while not self._is_correct_index[indx[i]]: + if hasattr(self.np_random, "integers"): + indx[i] = self.np_random.integers(len(self)) + else: + indx[i] = self.np_random.randint(len(self)) + else: + raise NotImplementedError() + + # Sample w/o images + if keys is None: + keys = self.dataset_dict.keys() + else: + assert "observations" in keys + + keys = list(keys) + keys.remove("observations") + + batch = super().sample(batch_size, keys, indx) + batch = batch.unfreeze() + + obs_keys = self.dataset_dict["observations"].keys() + obs_keys = list(obs_keys) for k in self.img_keys: - data_dict["observations"][k] = (self._img_insert_index_ - 2) % len(self.img_buffer[k]) - data_dict["next_observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) - - # Initialize to extreme x and y - base_diff = -self.workspace_width/2 - data_dict["observations"]["state"][self.x_obs_idx] += base_diff - data_dict["observations"]["state"][self.y_obs_idx] += base_diff - data_dict["next_observations"]["state"][self.x_obs_idx] += base_diff - data_dict["next_observations"]["state"][self.y_obs_idx] += base_diff - - # Transform and insert transitions - num_transforms = self.current_branch_count ** 2 - obs_batch = np.tile(data_dict["observations"]["state"], (num_transforms, 1)) - next_obs_batch = np.tile(data_dict["next_observations"]["state"], (num_transforms, 1)) + obs_keys.remove(k) - obs_batch += self.transform_deltas - next_obs_batch += self.transform_deltas - - - data_dict["observations"]["state"] = obs_batch - data_dict["next_observations"]["state"] = next_obs_batch - data_dict["actions"] = np.tile(data_dict["actions"], (num_transforms, 1)) - data_dict["rewards"] = np.tile(data_dict["rewards"], num_transforms) - data_dict["masks"] = np.tile(data_dict["masks"], num_transforms) - data_dict["dones"] = np.tile(data_dict["dones"], num_transforms) + batch["observations"] = {} + for k in obs_keys: + batch["observations"][k] = _sample( + self.dataset_dict["observations"][k], indx + ) + # Sample images for k in self.img_keys: - data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) - data_dict["next_observations"][k] = np.tile(data_dict["next_observations"][k], num_transforms) - - super().insert(data_dict, batch_size=num_transforms) - - # Reset current_depth, timestep, and max_traj_length - self.timestep += 1 - if data_dict["dones"][0]: - self.current_depth = 0 - if self.update_max_traj_length: - self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) - self.timestep = 0 \ No newline at end of file + obs_imgs = self.img_buffer[k] + obs_imgs = np.lib.stride_tricks.sliding_window_view( + obs_imgs, self._num_stack + 1, axis=0 + ) + obs_imgs = obs_imgs[self.dataset_dict["observations"][k][indx] - self._num_stack] + # transpose from (B, H, W, C, T) to (B, T, H, W, C) to follow jaxrl_m convention + obs_imgs = obs_imgs.transpose((0, 4, 1, 2, 3)) + + if pack_obs_and_next_obs: + batch["observations"][k] = obs_imgs + else: + batch["observations"][k] = obs_imgs[:, :-1, ...] + if "next_observations" in keys: + batch["next_observations"][k] = obs_imgs[:, 1:, ...] + + return frozen_dict.freeze(batch) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py index 4dc2bcea..f6b9e256 100644 --- a/serl_launcher/serl_launcher/data/fsrb_test.py +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -11,14 +11,14 @@ FLAGS = flags.FLAGS -flags.DEFINE_integer("capacity", 10000, "Replay buffer capacity.") +flags.DEFINE_integer("capacity", 10, "Replay buffer capacity.") flags.DEFINE_string("branch_method", "constant", "Method for determining the number of transforms per dimension (x,y)") flags.DEFINE_string("split_method", "never", "Method for determining whether to change the number of transforms per dimension (x,y)") flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only flags.DEFINE_integer("max_steps",100,"Maximum steps") flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only -flags.DEFINE_integer("starting_branch_count", 3, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only flags.DEFINE_integer("alpha",1,"alpha value") # Density Workspace width flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') @@ -31,10 +31,21 @@ def main(_): # Initialize replay buffer env = gym.make("PandaPickCubeVision-v0") env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=3, act_exec_horizon=None) + + # env = gym.make("PandaReachCube-v0") # env = gym.wrappers.FlattenObservation(env) image_keys = [key for key in env.observation_space.keys() if key != "state"] + their_buffer = make_replay_buffer( + env, + type="memory_efficient_replay_buffer", + capacity=FLAGS.capacity, + image_keys=image_keys, + + ) + replay_buffer = make_replay_buffer( env, type="fractal_symmetry_replay_buffer", @@ -45,18 +56,26 @@ def main(_): x_obs_idx=x_obs_idx, y_obs_idx= y_obs_idx, image_keys=image_keys, - max_depth=FLAGS.max_depth, + # max_depth=FLAGS.max_depth, max_traj_length = 100, - branching_factor=FLAGS.branching_factor, - alpha = FLAGS.alpha, + # branching_factor=FLAGS.branching_factor, + # alpha = FLAGS.alpha, starting_branch_count = FLAGS.starting_branch_count, ) observation, info = env.reset() - action = env.action_space.sample() next_observation, reward, terminated, truncated, info = env.step(action) + # observation = np.zeros_like(observation) + for k in observation.keys(): + observation[k] = np.zeros_like(observation[k]) + next_observation[k] = np.ones_like(next_observation[k]) + + action = np.ones_like(action) + # next_observation = np.ones_like(next_observation) + reward = 1 + data_dict = dict( observations=observation, next_observations=next_observation, @@ -68,7 +87,24 @@ def main(_): del env, observation, next_observation, action, reward, truncated, terminated, info, y_obs_idx, x_obs_idx, _ - replay_buffer.insert(data_dict) + for i in range(6): + + replay_buffer.insert(data_dict) + their_buffer.insert(data_dict) + assert(replay_buffer.dataset_dict["observations"]["state"][i % 10].all() == their_buffer.dataset_dict["observations"]["state"][(i + 3) % 10].all()) + assert(replay_buffer.dataset_dict["next_observations"]["state"][i % 10].all() == their_buffer.dataset_dict["next_observations"]["state"][(i + 3) % 10].all()) + assert(replay_buffer.img_buffer["front"][replay_buffer.dataset_dict["observations"]["front"][i % 10]].all() == their_buffer.dataset_dict["observations"]["front"][(i + 3) % 10].all()) + + data_dict["observations"]["state"] += 1 + data_dict["next_observations"]["state"] += 1 + for k in image_keys: + data_dict["observations"][k] += 1 + data_dict["next_observations"][k] += 1 + + replay_buffer.sample(batch_size=3, indx=np.array([2,3,4])) + their_buffer.sample(batch_size=3, indx=np.array([5,6,7])) + + # branch() tests From 0819935a96505d7433571144742dae8cab6707fc Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Mon, 1 Sep 2025 17:44:42 -0500 Subject: [PATCH 125/141] Updated franka_rearch_drq_demo_script to be able to move the camera as well using: a,d,w,s,q,e to change azimuth, elevation, and distance. - code checks first for cam if so, then checks for action. if no cam, then just checks for action. --- demos/demos/franka_reach_drq_demo_script.py | 107 ++++++++++++++------ 1 file changed, 74 insertions(+), 33 deletions(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 06f730a1..41506845 100755 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -53,44 +53,74 @@ 'm':(0,0,1,0), '.':(0,0,-1,0), } + +# Extend bindings to include camera controls +camBindings = { + 'a': ("azimuth", -5), # rotate left + 'd': ("azimuth", 5), # rotate right + 'w': ("elevation", 2), # tilt up + 's': ("elevation", -2), # tilt down + 'q': ("distance", -0.1),# zoom in + 'e': ("distance", 0.1), # zoom out +} + + +def update_camera(viewer,key): + """ + Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. + Controls: + 'a' : rotate left + 'd' : rotate right + 'w' : tilt up + 's' : tilt down + 'q' : zoom in + 'e' : zoom out + """ + if hasattr(viewer, 'cam'): + # Get current camera parameters + attr, delta = camBindings[key] + val = getattr(viewer.cam, attr) + + setattr(viewer.cam, attr, val + delta) + + + ############################################################################## def ensure_dir_exists(): """ - Ensure that the output directory exists. If it does not exist, create it. + For oxe_envlogger + RLDS compatibility, data must be written in the following format: + /data/data/serl/demos/franka_reach_drq_demo_script/ + └── session_20250821_222412/ + └── PandaPickCubeVision-v0/ + └── 0.1.0/ + dataset_info.json + features.json + PandaPickCubeVision-v0-train.tfrecord-00000-of-00001 + + We can have a base path with customized sessions inside. + Inside each session we have: env-version-files + Returns ------- out_path : str The path to the output directory. """ - # Create a timestamped directory for saving outputs - robot = "franka" - task = "reach" - initStateSpace = "random" # "fixed" or "random" - # Customize the path - script_dir = FLAGS.output_dir + root = FLAGS.output_dir + session = datetime.now().strftime("session_%Y%m%d_%H%M%S") + session_root = os.path.join(root, session, '_num_demos_', str(FLAGS.num_demos)) + + # Dataset details + dataset_name = FLAGS.env + version = "0.1.0" # PArt of RLDS format. needed. # Create output filename with configuration details - fileName = "data_" + robot + "_" + task - fileName += "_" + initStateSpace - fileName += "_" + str(FLAGS.num_demos) + dataset_dir = os.path.join(session_root, dataset_name, version) + os.makedirs(dataset_dir, exist_ok=True) + logging.info(f"TFDS builder dir: {dataset_dir}") - # Add timestamp to filename for uniqueness - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - fileName += "_" + timestamp - - # Build a filename in that same directory - out_path = os.path.join(script_dir, fileName) - - # Ensure the directory exists - if not os.path.exists(out_path): - os.makedirs(script_dir, exist_ok=True) - logging.info(f"Created output directory at: {out_path}") - else: - logging.info(f"Output directory already exists at: {out_path}") - - return out_path + return dataset_dir def getKey(settings): """ @@ -122,9 +152,11 @@ def getKey(settings): termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) return key -def get_kb_demo_action(speed=0.1): +def get_kb_demo_action(env,speed=0.1): """ - Reads keyboard input and maps it to a 3D action vector for robot control. + Reads keyboard input and maps it to a 3D action vector for robot control or camera action. + TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. + Otherwise, none actions are still considered steps in the loop. The function uses non-blocking keyboard input to allow interactive teleoperation. Keys are mapped to directions in Cartesian space: @@ -157,7 +189,13 @@ def get_kb_demo_action(speed=0.1): # Capture the pressed key key = getKey(settings) - if key in moveBindings: + # Check keys for camera first, if so, update camera and get another key for action + if key in camBindings: + if hasattr(env.unwrapped, "_viewer"): + update_camera(env.unwrapped._viewer.viewer,key) + key = getKey(settings) # get another key for action + + elif key in moveBindings: # Lookup (x, y, z) direction and scale by speed dx, dy, dz, g= moveBindings[key] action = np.array([dx, dy, dz, g], dtype=float) * speed @@ -239,12 +277,13 @@ def main(unused_argv): # If envlogger is enabled, wrap the environment with AutoOXEEnvLogger to log episodes if FLAGS.enable_envlogger: - out_path = ensure_dir_exists() + dataset_dir = ensure_dir_exists() env = AutoOXEEnvLogger( env=env, dataset_name=FLAGS.env, - directory=out_path, + directory=dataset_dir, + #split_name="train", # "train", "test", or "validation" ) logging.info('Training an agent for %r episodes...', FLAGS.num_demos) @@ -255,6 +294,8 @@ def main(unused_argv): # Log custom metadata during new episode: language embeddings randomly. if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. env.set_episode_metadata({ "language_embedding": np.random.random((5,)).astype(np.float32) }) @@ -270,14 +311,14 @@ def main(unused_argv): step = 0 # Termination occurs when the hand reaches the target or the maximum trajectory length is reached - while not (terminated or truncated): + while not (terminated or truncated): # Get action from the demo function - action = get_kb_demo_action() + action = get_kb_demo_action(env) # example to log custom step metadata if FLAGS.enable_envlogger: - env.set_step_metadata({"timestamp": time.time()}) + env.set_step_metadata({"timestamp": np.float32(time.time())}) return_step = env.step(action) From ee1107d2fc5e013f918d6859e0cfff3ba7c58466 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 2 Sep 2025 10:49:25 -0500 Subject: [PATCH 126/141] v.0.2.0 - Working version in debug mode - correctly updates OXEEnvLogger generated files to be read by TFDS tfds.builder_from_directory in launcher.py - When running outside debugger still get an exception that I need to work out, but can be used to move forward. --- demos/demos/franka_reach_drq_demo_script.py | 242 +++++++++++++++++++- demos/requirements.txt | 5 +- 2 files changed, 235 insertions(+), 12 deletions(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 41506845..f9049fb0 100755 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -21,12 +21,17 @@ # Teleoperation imports import sys, select, termios, tty +# RLDS/TFDS +import json, glob, inspect +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets import folder_dataset #------------------------------------------------------------------------------------------- # Flags #------------------------------------------------------------------------------------------- flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") -flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") +flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", @@ -64,7 +69,6 @@ 'e': ("distance", 0.1), # zoom out } - def update_camera(viewer,key): """ Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. @@ -83,9 +87,211 @@ def update_camera(viewer,key): setattr(viewer.cam, attr, val + delta) - +def close_logger_and_env(env): + """ + Best-effort shutdown: + 1) close embedded envloggers/writers if present + 2) close dm_env (if you have a DeepMind-style env) + 3) close Gym/Gymnasium env + 4) close the Mujoco viewer + Also walks through common wrapper attributes (env, _env, unwrapped). + """ + import logging + + visited = set() + + def _safe_close(obj, label=""): + if obj is None or id(obj) in visited: + return + visited.add(id(obj)) + + # 1) Close common logger/writer attributes first (flush TFRecord) + for name in ("_envlogger", "envlogger", "logger", "_logger", "writer"): + try: + logger_obj = getattr(obj, name, None) + if logger_obj is not None and hasattr(logger_obj, "close"): + logger_obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.{name}.close() raised {e!r}") + + # 2) Close dm_env if present + try: + dm = getattr(obj, "dm_env", None) + if dm is not None and hasattr(dm, "close"): + dm.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.dm_env.close() raised {e!r}") + + # 3) Close Gym/Gymnasium env + try: + if hasattr(obj, "close"): + obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.close() raised {e!r}") + + # 4) Close Mujoco viewer if accessible + try: + # Some stacks keep viewer at env._viewer.viewer; others just env._viewer + viewer = None + vwrap = getattr(obj, "_viewer", None) + if vwrap is not None: + viewer = getattr(vwrap, "viewer", vwrap) + if viewer is not None and hasattr(viewer, "close"): + viewer.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label} viewer close raised {e!r}") + + # Recurse into common wrapper links + for child_name in ("env", "_env", "unwrapped", "environment", "base_env"): + child = getattr(obj, child_name, None) + if child is not None and child is not obj: + _safe_close(child, f"{label}.{child_name}" if label else child_name) + + _safe_close(env, "env") ############################################################################## +# ----------------------------------------------------------------------------- +# Finalize TFDS metadata so TFDS can load the split without manual edits. +# ----------------------------------------------------------------------------- +# def finalize_tfds_metadata(builder_dir: str): +# """ +# Make the TFDS builder dir loadable by writing numShards/shardLengths. +# Works with new/old TFDS; prints helpful diagnostics if filenames don't match. +# """ +# import os, json, glob, inspect +# import tensorflow as tf +# import tensorflow_datasets as tfds +# try: +# from tensorflow_datasets import folder_dataset +# except Exception: +# from tensorflow_datasets.core import folder_dataset + +# info_path = os.path.join(builder_dir, "dataset_info.json") +# if not os.path.exists(info_path): +# raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + +# with open(info_path) as f: +# info = json.load(f) + +# dataset_name = info["name"] # e.g., PandaPickCubeVision-v0 +# file_format = info.get("fileFormat", "tfrecord") # "tfrecord" +# # Use the exact template written by the logger +# template_str = info["splits"][0]["filepathTemplate"] + +# # Construct a ShardedFileTemplate object TFDS uses to discover shard files. This object encodes the directory, dataset name, suffix, and the template expression. +# template = tfds.core.ShardedFileTemplate( +# data_dir=builder_dir, +# dataset_name=dataset_name, +# filetype_suffix=file_format, +# template=template_str, +# ) + +# # Quick diagnostics so mismatches are obvious +# print(f"[finalize] builder_dir: {builder_dir}") +# print(f"[finalize] template: {template_str} (dataset={dataset_name}, suffix={file_format})") +# print("[finalize] files found: ", sorted(glob.glob(os.path.join(builder_dir, f"*.{file_format}*")))) + +# # Attempt the canonical TFDS way: compute per-split shard stats by scanning files that match the template. Different TFDS versions call the directory argument out_dir or data_dir, so it inspects the signature and passes the right name. +# try: +# sig = inspect.signature(folder_dataset.compute_split_info).parameters +# if "out_dir" in sig: +# # Compute the split info (num shards, num examples,...). See more at https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/compute_split_info +# split_infos = folder_dataset.compute_split_info(out_dir=builder_dir, filename_template=template) +# else: +# split_infos = folder_dataset.compute_split_info(data_dir=builder_dir, filename_template=template) + +# sig_w = inspect.signature(folder_dataset.write_metadata).parameters +# kwargs = dict(split_infos=split_infos, filename_template=template) +# if "out_dir" in sig_w: +# kwargs["out_dir"] = builder_dir +# else: +# kwargs["data_dir"] = builder_dir +# if "features" in sig_w: +# kwargs["features"] = None +# folder_dataset.write_metadata(**kwargs) + +# except ValueError as e: +# # Common cause: filenames don’t match the template, or shards not flushed yet. +# print(f"[finalize] TFDS compute_split_info failed: {e}") +# shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{dataset_name}-*.{file_format}-*"))) +# if not shard_paths: +# # last resort: any .tfrecord* under the dir +# shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_format}*"))) +# if not shard_paths: +# raise ValueError( +# f"No {file_format} shards found in {builder_dir}. " +# f"Expected filenames like '{dataset_name}-train.{file_format}-00000' " +# f"per template '{template_str}'." +# ) from e + +# # Count records per shard and patch dataset_info.json automatically +# shard_lengths = [] +# for p in shard_paths: +# n = sum(1 for _ in tf.data.TFRecordDataset(p)) +# shard_lengths.append(n) + +# # Write lengths for each split that uses the template (here, 'train') +# for s in info["splits"]: +# # Only add lengths for splits that use this file template +# if s.get("filepathTemplate") == template_str: +# s["numShards"] = len(shard_paths) +# s["shardLengths"] = shard_lengths + +# with open(info_path, "w") as f: +# json.dump(info, f, indent=2) +# print(f"[finalize] wrote shardLengths={shard_lengths} into {info_path}") + +# # Sanity +# b = tfds.builder_from_directory(builder_dir) +# print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + +def finalize_tfds_metadata_beamless(builder_dir: str): + """ + Beam-free finalize: count TFRecord examples per shard and write + numShards/shardLengths into dataset_info.json so TFDS will load. + """ + import os, json, glob + import tensorflow as tf + + info_path = os.path.join(builder_dir, "dataset_info.json") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + + with open(info_path) as f: + info = json.load(f) + + ds_name = info["name"] # e.g. "PandaPickCubeVision-v0" + file_fmt = info.get("fileFormat", "tfrecord") + tmpl_str = info["splits"][0]["filepathTemplate"] + + # Prefer strict pattern "-.-" + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{ds_name}-*.{file_fmt}-*"))) + if not shard_paths: + # Fallback to any tfrecord-like file + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_fmt}*"))) + if not shard_paths: + raise FileNotFoundError( + f"No {file_fmt} shards found in {builder_dir}. " + f"Expected like '{ds_name}-train.{file_fmt}-00000' per template '{tmpl_str}'." + ) + + # Count episodes (1 Example = 1 episode with envlogger/RLDS) + shard_lengths = [sum(1 for _ in tf.data.TFRecordDataset(p)) for p in shard_paths] + + # Write lengths for each split using this template + for s in info["splits"]: + if s.get("filepathTemplate") == tmpl_str: + s["numShards"] = len(shard_paths) + s["shardLengths"] = shard_lengths + # Re-write dataset_info.json with updated shard info + with open(info_path, "w") as f: + json.dump(info, f, indent=2) + + # Sanity log + import tensorflow_datasets as tfds + b = tfds.builder_from_directory(builder_dir) + print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + def ensure_dir_exists(): """ For oxe_envlogger + RLDS compatibility, data must be written in the following format: @@ -109,7 +315,7 @@ def ensure_dir_exists(): # Customize the path root = FLAGS.output_dir session = datetime.now().strftime("session_%Y%m%d_%H%M%S") - session_root = os.path.join(root, session, '_num_demos_', str(FLAGS.num_demos)) + session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") # Dataset details dataset_name = FLAGS.env @@ -274,7 +480,9 @@ def main(unused_argv): else: logging.warning('Failed to set camera view. Viewer not available.') - # If envlogger is enabled, wrap the environment with AutoOXEEnvLogger to log episodes + # Wrap with oxe_envlogger to record demos + dataset_dir = None + session_root = None if FLAGS.enable_envlogger: dataset_dir = ensure_dir_exists() @@ -285,8 +493,7 @@ def main(unused_argv): directory=dataset_dir, #split_name="train", # "train", "test", or "validation" ) - - logging.info('Training an agent for %r episodes...', FLAGS.num_demos) + logging.info('Recording %r demos...', FLAGS.num_demos) #--- LOOP DEMOS/EPISODES --- # Loop through the number of demos specified by the user to record demonstrations @@ -332,10 +539,23 @@ def main(unused_argv): print(f" step: {step}", f"reward: {reward:.3f}\n") step += 1 - logging.info( - 'Done training a random agent for %r episodes.', FLAGS.num_demos) - - #env.close() # it seems async_drq_sim does not use close() dm_env method error. + logging.info('Done recording %r demos.', FLAGS.num_demos) + + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: + + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") + + # closes logger(s) + dm_env + gym + viewer + close_logger_and_env(env) + + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) + + # Note: async_drq_sim will read from the dataset_dir you printed above. if __name__ == '__main__': app.run(main) \ No newline at end of file diff --git a/demos/requirements.txt b/demos/requirements.txt index 4a733bc8..b20b6e1d 100644 --- a/demos/requirements.txt +++ b/demos/requirements.txt @@ -1 +1,4 @@ -git+https://github.com/rail-berkeley/oxe_envlogger.git@main +tensorflow-metadata==1.17.2 +apache-beam==2.67.0 +protobuf>=4.21.6,<4.22 + From 178b800619d12c3f7ad61bed51691ad5eb56f331 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Mon, 8 Sep 2025 11:43:33 -0500 Subject: [PATCH 127/141] finished sparse reach --- demos/demos/franka_reach_drq_demo_script.py | 6 +- examples/async_drq_sim/async_drq_sim.py | 74 ++++++++-- examples/async_drq_sim/automated_tests.sh | 130 ++++++++++++++++++ .../async_drq_sim/automated_tests_helper.sh | 6 + examples/async_drq_sim/run_actor.sh | 3 +- examples/async_drq_sim/run_learner.sh | 6 +- examples/async_drq_sim/tmux_launch_tests.sh | 20 +++ .../async_sac_state_sim.py | 17 ++- franka_sim/franka_sim/__init__.py | 2 +- franka_sim/franka_sim/envs/__init__.py | 1 + .../envs/panda_reach_sparse_gym_env.py | 19 ++- .../networks/reward_classifier.py | 2 +- .../serl_launcher/utils/train_utils.py | 2 +- 13 files changed, 259 insertions(+), 29 deletions(-) create mode 100644 examples/async_drq_sim/automated_tests.sh create mode 100644 examples/async_drq_sim/automated_tests_helper.sh create mode 100644 examples/async_drq_sim/tmux_launch_tests.sh diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index f9049fb0..0fb0f261 100755 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -29,12 +29,12 @@ #------------------------------------------------------------------------------------------- # Flags #------------------------------------------------------------------------------------------- -flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") -flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", +flags.DEFINE_string("output_dir", "~/serl/demos/franka_reach_drq_demo_script", "Directory to save the output data. This is where the RLDS logs will be saved.") flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") @@ -465,7 +465,7 @@ def main(unused_argv): # env = SERLObsWrapper(env) env = SERLObsWrapper( env, - target_hw=(128, 128), + target_hw=(960, 960), img_dtype=np.uint8, # or np.float32 normalize=False, # True if using float32 in [0,1] ) diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index 15a197d0..5dc119be 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -41,6 +41,7 @@ flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_string("run_name", None, "Name of run for wandb logging") flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") @@ -69,6 +70,21 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "memory_efficient_replay_buffer", "Which replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_float("alpha",None,"alpha value") +flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], + "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + + " Hourglass: Contract from max to min then expand to max") +flags.DEFINE_integer("min_branch_count", None, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") + flags.DEFINE_boolean( "debug", False, "Debug mode." ) # debug mode will disable wandb logging @@ -76,6 +92,11 @@ flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +# Load demonstation data +flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") +flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") +flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") + devices = jax.local_devices() num_devices = len(devices) sharding = jax.sharding.PositionalSharding(devices) @@ -108,7 +129,7 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCubeVision-v0": + if "Vision" in FLAGS.env: eval_env = SERLObsWrapper(eval_env) eval_env = ChunkingWrapper(eval_env, obs_horizon=1, act_exec_horizon=None) eval_env = RecordEpisodeStatistics(eval_env) @@ -191,9 +212,12 @@ def learner( """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project=FLAGS.exp_name, + name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, + # wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, + # offline=FLAGS.wandb_offline, ) # To track the step in the training loop @@ -325,11 +349,17 @@ def main(_): else: env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env in {"PandaPickCube-v0", "PandaReachCube-v0", "PandaPickSparseCube-v0", "PandaReachSparseCube-v0", "PandaPickCubeVision-v0", "PandaReachCubeVision-v0", "PandaPickSparseCubeVision-v0", "PandaReachSparseCubeVision-v0"}: + x_obs_idx=np.array([0,4]) + y_obs_idx=np.array([1,5]) + else: + raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") + + if "Vision" in FLAGS.env: env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + else: + env = gym.wrappers.FlattenObservation(env) image_keys = [key for key in env.observation_space.keys() if key != "state"] @@ -354,7 +384,21 @@ def main(_): env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="memory_efficient_replay_buffer", + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, image_keys=image_keys, ) @@ -375,9 +419,23 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]: demo_buffer = make_replay_buffer( env, capacity=FLAGS.replay_buffer_capacity, - type="memory_efficient_replay_buffer", - image_keys=image_keys, + rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, + image_keys=image_keys, preload_data_transform=preload_data_transform, ) diff --git a/examples/async_drq_sim/automated_tests.sh b/examples/async_drq_sim/automated_tests.sh new file mode 100644 index 00000000..42a53225 --- /dev/null +++ b/examples/async_drq_sim/automated_tests.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +SEEDS=$1 +# WANDB_OUTPUT_DIR=~/wandb_logs +TEST="async_sac_state_sim.py" +CONDA_ENV="serl" +ENV="PandaPickCubeVision-v0" +MAX_STEPS=1000000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +CRITIC_ACTOR_RATIO=8 +EXP_NAME="GENERAL-RETESTING-$ENV" +REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" + +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS" +ARGS="" + +function run_test { + + for seed in $(seq 1 1 $SEEDS) + do + # OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + # PORTS=( $OPEN_PORTS ) + # PORT_NUMBER=${PORTS[0]} + # BROADCAST_PORT=${PORTS[1]} + + # ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for learner to finish + while ! tmux capture-pane -t serl_session:0.2 -p | grep "logout" > /dev/null; + do + sleep 1 + done + echo "Finished!" + done +} + +# BASELINE TESTING +for CRITIC_ACTOR_RATIO in 8 +do + for batch_size in 256 + do + for replay_buffer_capacity in 1000000 + do + ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" + run_test + done + done +done + +# CONSTANT TESTING +for CRITIC_ACTOR_RATIO in 8 +do + for starting_branch_count in 27 + do + for batch_size in 256 + do + for workspace_width in 0.5 + do + for replay_buffer_capacity in 1000000 + do + ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done + done + done + done +done + +# # FRACTAL TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000000 +# do +# for workspace_width in 0.5 +# do +# for alpha in 0.9 +# do +# for branching_factor in 3 9 +# do +# for max_depth in 2 4 +# do +# # Fractal Expansion +# ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" +# run_test + +# # Fractal Contraction +# ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" +# run_test +# done +# done +# done +# done +# done +# done + +# # DISASSOCIATIVE TESTING +# for batch_size in 256 +# do +# for replay_buffer_capacity in 1000000 +# do +# for workspace_width in 0.5 +# do +# for alpha in 0.9 +# do +# for min_branch_count in 1 3 9 +# do +# for max_branch_count in 3 9 27 +# do +# # Disassociative (Hourglass) +# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" +# run_test + +# # Disassociative (Octahedron) +# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" + +# done +# done +# done +# done +# done +# done + +tmux kill-window -t serl_session:$SEED diff --git a/examples/async_drq_sim/automated_tests_helper.sh b/examples/async_drq_sim/automated_tests_helper.sh new file mode 100644 index 00000000..4f6df155 --- /dev/null +++ b/examples/async_drq_sim/automated_tests_helper.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ + +python async_drq_sim.py "$@" \ No newline at end of file diff --git a/examples/async_drq_sim/run_actor.sh b/examples/async_drq_sim/run_actor.sh index 52fcfc41..ebdb0514 100644 --- a/examples/async_drq_sim/run_actor.sh +++ b/examples/async_drq_sim/run_actor.sh @@ -1,8 +1,7 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ python async_drq_sim.py "$@" \ --actor \ - --render \ --exp_name=serl_dev_drq_sim_test_resnet \ --seed 0 \ --random_steps 1000 \ diff --git a/examples/async_drq_sim/run_learner.sh b/examples/async_drq_sim/run_learner.sh index d36331c5..49589428 100644 --- a/examples/async_drq_sim/run_learner.sh +++ b/examples/async_drq_sim/run_learner.sh @@ -1,5 +1,5 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ python async_drq_sim.py "$@" \ --learner \ --exp_name=serl_dev_drq_sim_test_resnet \ @@ -7,5 +7,5 @@ python async_drq_sim.py "$@" \ --training_starts 1000 \ --critic_actor_ratio 4 \ --encoder_type resnet-pretrained \ - --demo_path franka_lift_cube_image_20_trajs.pkl \ - --debug # wandb is disabled when debug +# --demo_path franka_lift_cube_image_20_trajs.pkl \ +# --debug # wandb is disabled when debug diff --git a/examples/async_drq_sim/tmux_launch_tests.sh b/examples/async_drq_sim/tmux_launch_tests.sh new file mode 100644 index 00000000..aa7a6a57 --- /dev/null +++ b/examples/async_drq_sim/tmux_launch_tests.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +SEEDS=1 +# Create a new tmux session +tmux new-session -d -s serl_session +tmux setw -g remain-on-exit on + +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh $SEEDS" C-m + + +# Attach to the tmux session +tmux attach-session -t serl_session + +# kill the tmux session by running the following command +# tmux kill-session -t serl_session diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index b5309f56..ad55f04a 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -214,9 +214,9 @@ def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): project=FLAGS.exp_name, name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, - wandb_output_dir=FLAGS.wandb_output_dir, + # wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, - offline=FLAGS.wandb_offline, + # offline=FLAGS.wandb_offline, ) # To track the step in the training loop @@ -304,8 +304,13 @@ def main(_): env = gym.make(FLAGS.env, render_mode="human") else: env = gym.make(FLAGS.env) - - #if FLAGS.env == "PandaPickCube-v0": + + if FLAGS.env in {"PandaPickCube-v0", "PandaReachCube-v0", "PandaPickSparseCube-v0", "PandaReachSparseCube-v0"}: + x_obs_idx=np.array([0,4]) + y_obs_idx=np.array([1,5]) + else: + raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") + env = gym.wrappers.FlattenObservation(env) rng, sampling_rng = jax.random.split(rng) @@ -360,8 +365,8 @@ def main(_): starting_branch_count=FLAGS.starting_branch_count, workspace_width=FLAGS.workspace_width, max_traj_length=FLAGS.max_traj_length, - x_obs_idx=np.array([0,4]), - y_obs_idx=np.array([1,5]), + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, preload_rlds_path=FLAGS.preload_rlds_path, max_depth=FLAGS.max_depth, alpha=FLAGS.alpha, diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 93f6a2f3..3a3cb9dc 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -34,5 +34,5 @@ # entry_point="franka_sim.envs:PandaReachCubeGymEnv", # max_episode_steps=100, # kwargs={"image_obs": True}, -# ) +# ) diff --git a/franka_sim/franka_sim/envs/__init__.py b/franka_sim/franka_sim/envs/__init__.py index a5e8c7c5..8315e983 100644 --- a/franka_sim/franka_sim/envs/__init__.py +++ b/franka_sim/franka_sim/envs/__init__.py @@ -1,5 +1,6 @@ from franka_sim.envs.panda_pick_gym_env import PandaPickCubeGymEnv from franka_sim.envs.panda_reach_gym_env import PandaReachCubeGymEnv +from franka_sim.envs.panda_reach_sparse_gym_env import PandaReachSparseCubeGymEnv __all__ = [ "PandaPickCubeGymEnv", diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py index 560d8bd7..832b89d5 100644 --- a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -37,6 +37,7 @@ def __init__( render_spec: GymRenderingSpec = GymRenderingSpec(), render_mode: Literal["rgb_array", "human"] = "rgb_array", image_obs: bool = False, + demo: str = "None", ): self._action_scale = action_scale @@ -188,7 +189,7 @@ def step( truncated: bool, info: dict[str, Any] """ - x, y, z = action + x, y, z, _ = action # Set the mocap position. pos = self._data.mocap_pos[0].copy() @@ -215,7 +216,17 @@ def step( obs = self._compute_observation() rew = self._compute_reward() - terminated = self.time_limit_exceeded() + + # # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. + # # IF ACCIDENTALLY MERGED, IT WILL REDUCE PERFORMANCE OF THE AGENT. + # if self.demo == "franka_reach_demo": + # if rew >= 0.85: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 + # terminated = True + # else: + # Check if the time limit is exceeded. + terminated = (rew >= 0.85) + if self._time_limit is not None: + terminated = terminated or self.time_limit_exceeded() return obs, rew, terminated, False, {} @@ -264,11 +275,11 @@ def _compute_reward(self) -> float: # Calculate distance dist = np.linalg.norm(block_pos - tcp_pos) - # Distance-based reward + # Distance-based reward. Note at norm of 0.015, reward will be 0.5 r_close = np.exp(-20 * dist) r_close = np.clip(r_close, 0.0, 1.0) - return (r_close > 0.95) + return (r_close > 0.85) if __name__ == "__main__": diff --git a/serl_launcher/serl_launcher/networks/reward_classifier.py b/serl_launcher/serl_launcher/networks/reward_classifier.py index b05e196c..6b60f17c 100644 --- a/serl_launcher/serl_launcher/networks/reward_classifier.py +++ b/serl_launcher/serl_launcher/networks/reward_classifier.py @@ -67,7 +67,7 @@ def create_classifier( with open(pretrained_encoder_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) diff --git a/serl_launcher/serl_launcher/utils/train_utils.py b/serl_launcher/serl_launcher/utils/train_utils.py index 31037317..dd8816a7 100644 --- a/serl_launcher/serl_launcher/utils/train_utils.py +++ b/serl_launcher/serl_launcher/utils/train_utils.py @@ -108,7 +108,7 @@ def load_resnet10_params(agent, image_keys=("image",), public=True): with open(file_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) From 808b7e75211eff86acc76ab8d97e8ca8e62c84d8 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 9 Sep 2025 09:56:30 -0500 Subject: [PATCH 128/141] - Cleaning up code. - Changing name of env to sparse reach version. - In demos branch, fixing jax.tree.leaves. --- demos/demos/franka_reach_drq_demo_script.py | 195 +++++------------- examples/async_drq_sim/.vscode/launch.json | 169 +++++++++++++++ examples/async_drq_sim/async_drq_sim.py | 19 +- franka_sim/franka_sim/__init__.py | 5 + .../networks/reward_classifier.py | 2 +- serl_launcher/serl_launcher/utils/launcher.py | 7 + .../serl_launcher/utils/train_utils.py | 2 +- 7 files changed, 249 insertions(+), 150 deletions(-) create mode 100644 examples/async_drq_sim/.vscode/launch.json diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index f9049fb0..a1708e6e 100755 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -36,7 +36,7 @@ #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", "Directory to save the output data. This is where the RLDS logs will be saved.") -flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_integer("num_demos", 1, "Number of episodes to log.") flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") @@ -149,102 +149,6 @@ def _safe_close(obj, label=""): _safe_close(env, "env") -############################################################################## -# ----------------------------------------------------------------------------- -# Finalize TFDS metadata so TFDS can load the split without manual edits. -# ----------------------------------------------------------------------------- -# def finalize_tfds_metadata(builder_dir: str): -# """ -# Make the TFDS builder dir loadable by writing numShards/shardLengths. -# Works with new/old TFDS; prints helpful diagnostics if filenames don't match. -# """ -# import os, json, glob, inspect -# import tensorflow as tf -# import tensorflow_datasets as tfds -# try: -# from tensorflow_datasets import folder_dataset -# except Exception: -# from tensorflow_datasets.core import folder_dataset - -# info_path = os.path.join(builder_dir, "dataset_info.json") -# if not os.path.exists(info_path): -# raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") - -# with open(info_path) as f: -# info = json.load(f) - -# dataset_name = info["name"] # e.g., PandaPickCubeVision-v0 -# file_format = info.get("fileFormat", "tfrecord") # "tfrecord" -# # Use the exact template written by the logger -# template_str = info["splits"][0]["filepathTemplate"] - -# # Construct a ShardedFileTemplate object TFDS uses to discover shard files. This object encodes the directory, dataset name, suffix, and the template expression. -# template = tfds.core.ShardedFileTemplate( -# data_dir=builder_dir, -# dataset_name=dataset_name, -# filetype_suffix=file_format, -# template=template_str, -# ) - -# # Quick diagnostics so mismatches are obvious -# print(f"[finalize] builder_dir: {builder_dir}") -# print(f"[finalize] template: {template_str} (dataset={dataset_name}, suffix={file_format})") -# print("[finalize] files found: ", sorted(glob.glob(os.path.join(builder_dir, f"*.{file_format}*")))) - -# # Attempt the canonical TFDS way: compute per-split shard stats by scanning files that match the template. Different TFDS versions call the directory argument out_dir or data_dir, so it inspects the signature and passes the right name. -# try: -# sig = inspect.signature(folder_dataset.compute_split_info).parameters -# if "out_dir" in sig: -# # Compute the split info (num shards, num examples,...). See more at https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/compute_split_info -# split_infos = folder_dataset.compute_split_info(out_dir=builder_dir, filename_template=template) -# else: -# split_infos = folder_dataset.compute_split_info(data_dir=builder_dir, filename_template=template) - -# sig_w = inspect.signature(folder_dataset.write_metadata).parameters -# kwargs = dict(split_infos=split_infos, filename_template=template) -# if "out_dir" in sig_w: -# kwargs["out_dir"] = builder_dir -# else: -# kwargs["data_dir"] = builder_dir -# if "features" in sig_w: -# kwargs["features"] = None -# folder_dataset.write_metadata(**kwargs) - -# except ValueError as e: -# # Common cause: filenames don’t match the template, or shards not flushed yet. -# print(f"[finalize] TFDS compute_split_info failed: {e}") -# shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{dataset_name}-*.{file_format}-*"))) -# if not shard_paths: -# # last resort: any .tfrecord* under the dir -# shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_format}*"))) -# if not shard_paths: -# raise ValueError( -# f"No {file_format} shards found in {builder_dir}. " -# f"Expected filenames like '{dataset_name}-train.{file_format}-00000' " -# f"per template '{template_str}'." -# ) from e - -# # Count records per shard and patch dataset_info.json automatically -# shard_lengths = [] -# for p in shard_paths: -# n = sum(1 for _ in tf.data.TFRecordDataset(p)) -# shard_lengths.append(n) - -# # Write lengths for each split that uses the template (here, 'train') -# for s in info["splits"]: -# # Only add lengths for splits that use this file template -# if s.get("filepathTemplate") == template_str: -# s["numShards"] = len(shard_paths) -# s["shardLengths"] = shard_lengths - -# with open(info_path, "w") as f: -# json.dump(info, f, indent=2) -# print(f"[finalize] wrote shardLengths={shard_lengths} into {info_path}") - -# # Sanity -# b = tfds.builder_from_directory(builder_dir) -# print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) - def finalize_tfds_metadata_beamless(builder_dir: str): """ Beam-free finalize: count TFRecord examples per shard and write @@ -462,7 +366,6 @@ def main(unused_argv): env = gym.wrappers.FlattenObservation(env) if FLAGS.env == "PandaPickCubeVision-v0": - # env = SERLObsWrapper(env) env = SERLObsWrapper( env, target_hw=(128, 128), @@ -473,6 +376,7 @@ def main(unused_argv): logging.info(f'Done creating {FLAGS.env} environment.') + # Set camera to front view if viewer is available if hasattr(env.unwrapped, '_viewer'): viewer = set_front_cam_view(env) if viewer: @@ -487,6 +391,7 @@ def main(unused_argv): dataset_dir = ensure_dir_exists() + # Will save as many episodes as possible into files of 200MB each by default. env = AutoOXEEnvLogger( env=env, dataset_name=FLAGS.env, @@ -497,63 +402,65 @@ def main(unused_argv): #--- LOOP DEMOS/EPISODES --- # Loop through the number of demos specified by the user to record demonstrations - for i in range(FLAGS.num_demos): - - # Log custom metadata during new episode: language embeddings randomly. - if FLAGS.enable_envlogger: - # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. - # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. - env.set_episode_metadata({ - "language_embedding": np.random.random((5,)).astype(np.float32) - }) - env.set_step_metadata({"timestamp": time.time()}) - - logging.info('episode %r', i) - - # Start a new episode - env.reset() - terminated = False - truncated = False - - step = 0 + try: + for i in range(FLAGS.num_demos): - # Termination occurs when the hand reaches the target or the maximum trajectory length is reached - while not (terminated or truncated): + # Log custom metadata during new episode: language embeddings randomly. + if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) - # Get action from the demo function - action = get_kb_demo_action(env) + # Start a new episode + env.reset() + terminated = False + truncated = False - # example to log custom step metadata - if FLAGS.enable_envlogger: - env.set_step_metadata({"timestamp": np.float32(time.time())}) + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached + while not (terminated or truncated): + + # Get action from the demo function + action = get_kb_demo_action(env) - return_step = env.step(action) + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": np.float32(time.time())}) - # NOTE: to handle gym.Env.step() return value change in gym 0.26 - if len(return_step) == 5: - obs, reward, terminated, truncated, info = return_step - else: - obs, reward, terminated, info = return_step - truncated = False + return_step = env.step(action) - print(f" step: {step}", f"reward: {reward:.3f}\n") - step += 1 + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False - logging.info('Done recording %r demos.', FLAGS.num_demos) + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 - # Finalize TFDS metadata so SERL/TFDS can load the split - if FLAGS.enable_envlogger and dataset_dir is not None: + logging.info('Done recording %r demos.', FLAGS.num_demos) + + finally: + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: - # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() - # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. - # If you skip this step, some data may not be written and the dataset may be incomplete. - logging.info("Closing environment to flush data to disk...") + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") - # closes logger(s) + dm_env + gym + viewer - close_logger_and_env(env) + # closes logger(s) + dm_env + gym + viewer + close_logger_and_env(env) - # Scan files and write metadata - finalize_tfds_metadata_beamless(dataset_dir) + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) # Note: async_drq_sim will read from the dataset_dir you printed above. diff --git a/examples/async_drq_sim/.vscode/launch.json b/examples/async_drq_sim/.vscode/launch.json new file mode 100644 index 00000000..af352dfa --- /dev/null +++ b/examples/async_drq_sim/.vscode/launch.json @@ -0,0 +1,169 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_sim.py", + // Command-line arguments matching run_learner.sh + "args": [ + "--render", + "--env", "PandaReachSparseCubeGymEnv-v0", // Environment to use + "--agent", "drq", // Agent type (drq or sac) + "--exp_name", "PandaPickCubeVision-v0_sparse_001", // Experiment name for wandb logging + "--max_traj_length", "100", // Max episode length/ Max episode length + "--seed", "42", // Random seed for reproducibility + // "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--max_steps", "50_000 ", // Maximum training steps + "--replay_buffer_capacity", "200_000", // Replay buffer capacity + "--random_steps", "300", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--steps_per_update", "30", // Number of env steps per update + "--learner", // REQUIRED: Indicates this is a learner instance + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + + // Fractal + // "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + // "--branch_method", "constant", + // "--split_method", "constant", + // "--starting_branch_count", "3", // Start with 27 branches + // "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + "--preload_rlds_path", "/data/data/serl/demos/franka_reach_drq_demo_script/session_20250902_105400_num_demos_2/PandaPickCubeVision-v0/0.1.0", // Preload RLDS dataset for faster loading + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + //"--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_sim.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--render", + "--env", "PandaReachSparseCubeGymEnv-v0", // Environment to use + "--agent", "drq", // Agent type (drq or sac) + "--exp_name", "PandaPickCubeVision-v0_sparse_001", // Experiment name for wandb logging + "--max_traj_length", "100", // Max episode length/ Max episode length + "--seed", "42", // Random seed for reproducibility + // "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--max_steps", "50_000", // Maximum training steps + "--replay_buffer_capacity", "200_000", // Replay buffer capacity + "--random_steps", "300", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--steps_per_update", "30", // Number of env steps per update + "--actor", // REQUIRED: Indicates this is a learner instance + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + + // Fractal + // "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + // "--branch_method", "constant", + // "--split_method", "constant", + // "--starting_branch_count", "3", // Start with 27 branches + // "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + "--preload_rlds_path", "/data/data/serl/demos/franka_reach_drq_demo_script/session_20250902_105400_num_demos_2/PandaPickCubeVision-v0/0.1.0", // Preload RLDS dataset for faster loading + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + //"--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index 15a197d0..50728ecc 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -38,7 +38,7 @@ FLAGS = flags.FLAGS -flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("env", "PandaReachSparseCubeGymEnv-v0", "Name of environment.") flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") @@ -108,7 +108,7 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env == "PandaReachSparseCubeGymEnv-v0": eval_env = SERLObsWrapper(eval_env) eval_env = ChunkingWrapper(eval_env, obs_horizon=1, act_exec_horizon=None) eval_env = RecordEpisodeStatistics(eval_env) @@ -327,7 +327,7 @@ def main(_): if FLAGS.env == "PandaPickCube-v0": env = gym.wrappers.FlattenObservation(env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env == "PandaReachSparseCubeGymEnv-v0": env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) @@ -370,6 +370,17 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]: # NOTE: Create your own custom data transform function here if you # are loading this via with --preload_rlds_path with tf rlds data # This default does nothing + # See: https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=X1KXM8IGecRO + # https://www.tensorflow.org/guide/data + # https://github.com/google-research/rlds/blob/main/docs/transformations.md + # Batch: rlds.transformations.batch (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=TGT3YfzFOrBm) + # Reverb: rlds.transformations.pattern_map (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_dataset_patterns.ipynb ) + # Nested data set manipulation: rlds.transformations.episode_length/.sum_dataset/.final_step/.map_nested_steps + # Concatenation: rlds.transformations.concatenate / .concat_if_terminal (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_examples.ipynb#scrollTo=pWNhxwJzOUJv) + # Stats: rlds.transformations.mean_and_std (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=Z0TITfo_4oZr) + # Truncation: rlds.transformations.truncate_after_condition + # Alignment: rlds.transformations.shift_keys + # Zero Init: rlds.transformations.zeros_from_spec return data demo_buffer = make_replay_buffer( @@ -377,7 +388,7 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]: capacity=FLAGS.replay_buffer_capacity, type="memory_efficient_replay_buffer", image_keys=image_keys, - preload_rlds_path=FLAGS.preload_rlds_path, + preload_rlds_path=FLAGS.preload_rlds_path, # Folder that contains the dataset_info.json preload_data_transform=preload_data_transform, ) diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 28b97c12..93f6a2f3 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -24,6 +24,11 @@ entry_point="franka_sim.envs:PandaReachCubeGymEnv", max_episode_steps=100, ) +register( + id="PandaReachSparseCube-v0", + entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", + max_episode_steps=100, +) # register( # id="PandaReachCubeVision-v0", # entry_point="franka_sim.envs:PandaReachCubeGymEnv", diff --git a/serl_launcher/serl_launcher/networks/reward_classifier.py b/serl_launcher/serl_launcher/networks/reward_classifier.py index b05e196c..6b60f17c 100644 --- a/serl_launcher/serl_launcher/networks/reward_classifier.py +++ b/serl_launcher/serl_launcher/networks/reward_classifier.py @@ -67,7 +67,7 @@ def create_classifier( with open(pretrained_encoder_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index fbed0c0a..fb0b4440 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -286,6 +286,13 @@ def make_replay_buffer( else: raise ValueError(f"Unsupported replay_buffer_type: {type}") + # Load RLDS or oxe_envlogger recroded data with tfds.builder_from_directory. + # Choose number of episodes by passing split="train[:N%]" or split="test[:N%]" + # or ds = tfds.builder_from_directory(builder_dir).as_dataset(split="train").take(5) + # See more details: https://www.tensorflow.org/datasets/splits + # + # It's also possible to filter specirfic episodes, i.e. by time: + # ds = ds.filter(lambda ep: tf.strings.regex_full_match(ep['some/session_id'], "20250821_222412")) if preload_rlds_path: print(f" - Preloaded {preload_rlds_path} to replay buffer") dataset = tfds.builder_from_directory(preload_rlds_path).as_dataset(split="all") diff --git a/serl_launcher/serl_launcher/utils/train_utils.py b/serl_launcher/serl_launcher/utils/train_utils.py index 31037317..dd8816a7 100644 --- a/serl_launcher/serl_launcher/utils/train_utils.py +++ b/serl_launcher/serl_launcher/utils/train_utils.py @@ -108,7 +108,7 @@ def load_resnet10_params(agent, image_keys=("image",), public=True): with open(file_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) From 9b6e8a60b5023073179151c27b87008ab5d5bdb5 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 9 Sep 2025 11:07:03 -0500 Subject: [PATCH 129/141] - Number of max steps adjustement. --- demos/demos/franka_reach_drq_demo_script.py | 8 ++++---- franka_sim/franka_sim/__init__.py | 2 +- franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 0352a3f8..b052f5b7 100755 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -31,12 +31,12 @@ #------------------------------------------------------------------------------------------- flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") -flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") +flags.DEFINE_integer("max_traj_length", 120, "Maximum length of trajectory.") flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") -flags.DEFINE_string("output_dir", "~/serl/demos/franka_reach_drq_demo_script", +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", "Directory to save the output data. This is where the RLDS logs will be saved.") -flags.DEFINE_integer("num_demos", 1, "Number of episodes to log.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") @@ -365,7 +365,7 @@ def main(unused_argv): if FLAGS.env == "PandaPickCube-v0": env = gym.wrappers.FlattenObservation(env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env == "PandaReachSparseCube-v0": env = SERLObsWrapper( env, target_hw=(960, 960), diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 3a3cb9dc..38cb7ba4 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -27,7 +27,7 @@ register( id="PandaReachSparseCube-v0", entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", - max_episode_steps=100, + max_episode_steps=120, ) # register( # id="PandaReachCubeVision-v0", diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py index 832b89d5..e86335d8 100644 --- a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -36,7 +36,7 @@ def __init__( time_limit: float = 10.0, render_spec: GymRenderingSpec = GymRenderingSpec(), render_mode: Literal["rgb_array", "human"] = "rgb_array", - image_obs: bool = False, + image_obs: bool = True, demo: str = "None", ): self._action_scale = action_scale From 78b84ccdd094775f42109969a9fb959a2f441fe5 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Fri, 12 Sep 2025 12:59:15 -0500 Subject: [PATCH 130/141] Corrected sparse environment definition. - Reward is 1 for success cases based on the error norm proximity. - Slight change to step function. - Increased max-steps to 200 for this environment. --- franka_sim/franka_sim/__init__.py | 2 +- .../envs/panda_reach_sparse_gym_env.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 38cb7ba4..1a56890e 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -27,7 +27,7 @@ register( id="PandaReachSparseCube-v0", entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", - max_episode_steps=120, + max_episode_steps=200, ) # register( # id="PandaReachCubeVision-v0", diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py index e86335d8..016afd8e 100644 --- a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -214,7 +214,10 @@ def step( self._data.ctrl[self._panda_ctrl_ids] = tau mujoco.mj_step(self._model, self._data) + # Compute observation. obs = self._compute_observation() + + # For sparse reward we return 1 if the task is achieved, else 0. rew = self._compute_reward() # # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. @@ -224,7 +227,10 @@ def step( # terminated = True # else: # Check if the time limit is exceeded. - terminated = (rew >= 0.85) + + # Episode is terminated if the reward is 1.0 (i.e. the task is achieved). + terminated = (rew == 1.0) + if self._time_limit is not None: terminated = terminated or self.time_limit_exceeded() @@ -276,10 +282,12 @@ def _compute_reward(self) -> float: dist = np.linalg.norm(block_pos - tcp_pos) # Distance-based reward. Note at norm of 0.015, reward will be 0.5 - r_close = np.exp(-20 * dist) - r_close = np.clip(r_close, 0.0, 1.0) + if dist < 0.015: + reward = 1.0 + else: + reward = 0.0 - return (r_close > 0.85) + return reward if __name__ == "__main__": From b53280921405aee19b1becb94a9e0fdfcd3dd88d Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Sun, 14 Sep 2025 22:10:02 -0500 Subject: [PATCH 131/141] =?UTF-8?q?This=20file=20maintains=20latest=20chan?= =?UTF-8?q?ges=20where=202=20or=20more=20demos=20are=20recorded=20using=20?= =?UTF-8?q?a=20format=20similar=20to=20the=20following:=20=E2=94=94?= =?UTF-8?q?=E2=94=80=E2=94=80=20franka=5Freach=5Fdrq=5Fdemo=5Fscript=20=20?= =?UTF-8?q?=20=20=20=E2=94=9C=E2=94=80=E2=94=80=20session=5F20250914=5F210?= =?UTF-8?q?112=5Fnum=5Fdemos=5F2=20=20=20=20=20=E2=94=82=C2=A0=C2=A0=20?= =?UTF-8?q?=E2=94=94=E2=94=80=E2=94=80=20PandaReachSparseCube-v0=20=20=20?= =?UTF-8?q?=20=20=E2=94=82=C2=A0=C2=A0=20=20=20=20=20=E2=94=94=E2=94=80?= =?UTF-8?q?=E2=94=80=200.1.0=20=20=20=20=20=E2=94=82=C2=A0=C2=A0=20=20=20?= =?UTF-8?q?=20=20=20=20=20=20=E2=94=9C=E2=94=80=E2=94=80=20dataset=5Finfo.?= =?UTF-8?q?json=20=20=20=20=20=E2=94=82=C2=A0=C2=A0=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=E2=94=9C=E2=94=80=E2=94=80=20features.json=20=20=20=20?= =?UTF-8?q?=20=E2=94=82=C2=A0=C2=A0=20=20=20=20=20=20=20=20=20=E2=94=94?= =?UTF-8?q?=E2=94=80=E2=94=80=20PandaReachSparseCube-v0-train.tfrecord-000?= =?UTF-8?q?00?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Succeeded in recording 2,5,10 demos for Panda Reach Sparse with Cube. --- demos/demos/franka_reach_drq_demo_script.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) mode change 100755 => 100644 demos/demos/franka_reach_drq_demo_script.py diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py old mode 100755 new mode 100644 index b052f5b7..1856fbdf --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -31,12 +31,12 @@ #------------------------------------------------------------------------------------------- flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") -flags.DEFINE_integer("max_traj_length", 120, "Maximum length of trajectory.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", "Directory to save the output data. This is where the RLDS logs will be saved.") -flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_integer("num_demos", 10, "Number of episodes to log.") flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") @@ -164,7 +164,7 @@ def finalize_tfds_metadata_beamless(builder_dir: str): with open(info_path) as f: info = json.load(f) - ds_name = info["name"] # e.g. "PandaPickCubeVision-v0" + ds_name = info["name"] # e.g. "PandaReachSparseCube-v0" file_fmt = info.get("fileFormat", "tfrecord") tmpl_str = info["splits"][0]["filepathTemplate"] @@ -201,11 +201,11 @@ def ensure_dir_exists(): For oxe_envlogger + RLDS compatibility, data must be written in the following format: /data/data/serl/demos/franka_reach_drq_demo_script/ └── session_20250821_222412/ - └── PandaPickCubeVision-v0/ + └── PandaReachSparseCube-v0/ └── 0.1.0/ dataset_info.json features.json - PandaPickCubeVision-v0-train.tfrecord-00000-of-00001 + PandaReachSparseCube-v0-train.tfrecord-00000-of-00001 We can have a base path with customized sessions inside. Inside each session we have: env-version-files @@ -262,7 +262,7 @@ def getKey(settings): termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) return key -def get_kb_demo_action(env,speed=0.1): +def get_kb_demo_action(env,speed=0.05): """ Reads keyboard input and maps it to a 3D action vector for robot control or camera action. TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. @@ -368,7 +368,7 @@ def main(unused_argv): if FLAGS.env == "PandaReachSparseCube-v0": env = SERLObsWrapper( env, - target_hw=(960, 960), + target_hw=(128, 128), img_dtype=np.uint8, # or np.float32 normalize=False, # True if using float32 in [0,1] ) From 285fa977c60017d95f67a96ceac86bcbb8aef927 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 16 Sep 2025 10:34:13 -0500 Subject: [PATCH 132/141] small changes --- examples/async_drq_sim/async_drq_sim.py | 5 -- examples/async_drq_sim/automated_tests.sh | 96 ++++------------------- 2 files changed, 15 insertions(+), 86 deletions(-) diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index b894e4ae..b1968b18 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -92,11 +92,6 @@ flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") -# Load demonstation data -flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") -flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") -flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") - devices = jax.local_devices() num_devices = len(devices) sharding = jax.sharding.PositionalSharding(devices) diff --git a/examples/async_drq_sim/automated_tests.sh b/examples/async_drq_sim/automated_tests.sh index 42a53225..a1a32da2 100644 --- a/examples/async_drq_sim/automated_tests.sh +++ b/examples/async_drq_sim/automated_tests.sh @@ -4,15 +4,15 @@ SEEDS=$1 # WANDB_OUTPUT_DIR=~/wandb_logs TEST="async_sac_state_sim.py" CONDA_ENV="serl" -ENV="PandaPickCubeVision-v0" -MAX_STEPS=1000000 -TRAINING_STARTS=1000 -RANDOM_STEPS=1000 +ENV="PandaReachSparseCube-v0" +MAX_STEPS=100000 +TRAINING_STARTS=0 +RANDOM_STEPS=0 CRITIC_ACTOR_RATIO=8 -EXP_NAME="GENERAL-RETESTING-$ENV" +EXP_NAME="FIRST-TESTS-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" - -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS" +PRELOAD_RLDS="/data/data/serl/demos/franka_reach_drq_demo_script/10_demos_session_202500914_213515" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --preload_rlds_path $PRELOAD_RLDS" ARGS="" function run_test { @@ -42,89 +42,23 @@ function run_test { } # BASELINE TESTING -for CRITIC_ACTOR_RATIO in 8 +for replay_buffer_capacity in 1000000 do - for batch_size in 256 - do - for replay_buffer_capacity in 1000000 - do - ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" - run_test - done - done + ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" + run_test done # CONSTANT TESTING -for CRITIC_ACTOR_RATIO in 8 +for starting_branch_count in 1 27 do - for starting_branch_count in 27 + for workspace_width in 0.5 do - for batch_size in 256 + for replay_buffer_capacity in 1000000 do - for workspace_width in 0.5 - do - for replay_buffer_capacity in 1000000 - do - ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" - run_test - done - done + ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test done done done -# # FRACTAL TESTING -# for batch_size in 256 -# do -# for replay_buffer_capacity in 1000000 -# do -# for workspace_width in 0.5 -# do -# for alpha in 0.9 -# do -# for branching_factor in 3 9 -# do -# for max_depth in 2 4 -# do -# # Fractal Expansion -# ARGS="--run_name fractal_expansion-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" -# run_test - -# # Fractal Contraction -# ARGS="--run_name fractal_contraction-$branching_factor^$max_depth-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" -# run_test -# done -# done -# done -# done -# done -# done - -# # DISASSOCIATIVE TESTING -# for batch_size in 256 -# do -# for replay_buffer_capacity in 1000000 -# do -# for workspace_width in 0.5 -# do -# for alpha in 0.9 -# do -# for min_branch_count in 1 3 9 -# do -# for max_branch_count in 3 9 27 -# do -# # Disassociative (Hourglass) -# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" -# run_test - -# # Disassociative (Octahedron) -# ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count-alpha-$alpha-workspace_width-$workspace_width-batch-size-$batch_size-capacity-$replay_buffer_capacity --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" - -# done -# done -# done -# done -# done -# done - tmux kill-window -t serl_session:$SEED From b1d09524dc54473b993e99df1c1281ed7e8f0576 Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 16 Sep 2025 10:35:11 -0500 Subject: [PATCH 133/141] update naming scheme for rlds demos to :#_num_demos_session_yyyymmdd_hhmmss --- demos/demos/franka_reach_drq_demo_script.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 1856fbdf..40f1df04 100644 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -36,7 +36,7 @@ #flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", "Directory to save the output data. This is where the RLDS logs will be saved.") -flags.DEFINE_integer("num_demos", 10, "Number of episodes to log.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") @@ -219,7 +219,7 @@ def ensure_dir_exists(): # Customize the path root = FLAGS.output_dir session = datetime.now().strftime("session_%Y%m%d_%H%M%S") - session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") + session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") # Dataset details dataset_name = FLAGS.env @@ -262,7 +262,7 @@ def getKey(settings): termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) return key -def get_kb_demo_action(env,speed=0.05): +def get_kb_demo_action(env,speed=0.075): """ Reads keyboard input and maps it to a 3D action vector for robot control or camera action. TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. From 01b5ab6363c43a605bb31ce1ea2b6272cbfa845e Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 16 Sep 2025 10:36:49 -0500 Subject: [PATCH 134/141] fix for wrappers drq_sim --- examples/async_drq_sim/async_drq_sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index b1968b18..37af0fb5 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -350,7 +350,7 @@ def main(_): else: raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") - if "Vision" in FLAGS.env: + if FLAGS.env == "PandaReachSparseCube-v0": env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) else: From daded7d2cd943b34f6d5ef37be40519a99ac985f Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 16 Sep 2025 10:40:01 -0500 Subject: [PATCH 135/141] Updates to launch.json regarding rlds reach-sparse --- examples/async_drq_sim/.vscode/launch.json | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/async_drq_sim/.vscode/launch.json b/examples/async_drq_sim/.vscode/launch.json index af352dfa..4e3fbf20 100644 --- a/examples/async_drq_sim/.vscode/launch.json +++ b/examples/async_drq_sim/.vscode/launch.json @@ -12,17 +12,17 @@ // Command-line arguments matching run_learner.sh "args": [ "--render", - "--env", "PandaReachSparseCubeGymEnv-v0", // Environment to use + "--env", "PandaReachSparseCube-v0", // Environment to use "--agent", "drq", // Agent type (drq or sac) - "--exp_name", "PandaPickCubeVision-v0_sparse_001", // Experiment name for wandb logging + "--exp_name", "PandaReachCubeVision-v0_001", // Experiment name for wandb logging "--max_traj_length", "100", // Max episode length/ Max episode length "--seed", "42", // Random seed for reproducibility - // "--save_model", // Save model checkpoints + // "--save_model", // Save model checkpoints "--batch_size", "256", // Training batch size - "--critic_actor_ratio", "8", // Critic-to-actor update ratio - "--max_steps", "50_000 ", // Maximum training steps + "--critic_actor_ratio", "4", // Critic-to-actor update ratio + "--max_steps", "100_000 ", // Maximum training steps "--replay_buffer_capacity", "200_000", // Replay buffer capacity - "--random_steps", "300", // Number of random steps at beginning + "--random_steps", "0", // Number of random steps at beginning "--training_starts", "300", // Start training after buffer has this many samples "--steps_per_update", "30", // Number of env steps per update "--learner", // REQUIRED: Indicates this is a learner instance @@ -39,7 +39,7 @@ // "--load_demos", // Load demo dataset // "--demo_dir", "/data/data/serl/demos", // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load - "--preload_rlds_path", "/data/data/serl/demos/franka_reach_drq_demo_script/session_20250902_105400_num_demos_2/PandaPickCubeVision-v0/0.1.0", // Preload RLDS dataset for faster loading + "--preload_rlds_path", "/media/bison/hdd/data/serl/demos/franka_reach_drq_demo_script/2_demos_session_20250916_100650/PandaReachSparseCube-v0/0.1.0", // Preload RLDS dataset for faster loading // Dissasociated Fractals // "--branch_method", "disassociated", // branch method type @@ -73,7 +73,7 @@ // Command-line arguments matching run_actor.sh "args": [ "--render", - "--env", "PandaReachSparseCubeGymEnv-v0", // Environment to use + "--env", "PandaReachSparseCube-v0", // Environment to use "--agent", "drq", // Agent type (drq or sac) "--exp_name", "PandaPickCubeVision-v0_sparse_001", // Experiment name for wandb logging "--max_traj_length", "100", // Max episode length/ Max episode length @@ -100,7 +100,7 @@ // "--load_demos", // Load demo dataset // "--demo_dir", "/data/data/serl/demos", // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load - "--preload_rlds_path", "/data/data/serl/demos/franka_reach_drq_demo_script/session_20250902_105400_num_demos_2/PandaPickCubeVision-v0/0.1.0", // Preload RLDS dataset for faster loading + "--preload_rlds_path", "/media/bison/hdd/data/serl/demos/franka_reach_drq_demo_script/2_demos_session_20250916_100650/PandaReachSparseCube-v0/0.1.0", // Preload RLDS dataset for faster loading // Dissasociated Fractals // "--branch_method", "disassociated", // branch method type From d73649562930943e5d354dd02b936c07fa72d15c Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Thu, 18 Sep 2025 09:45:48 -0500 Subject: [PATCH 136/141] Need oxe_envlogger as a git package. --- demos/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/requirements.txt b/demos/requirements.txt index b20b6e1d..febfe029 100644 --- a/demos/requirements.txt +++ b/demos/requirements.txt @@ -1,4 +1,4 @@ tensorflow-metadata==1.17.2 apache-beam==2.67.0 protobuf>=4.21.6,<4.22 - +git+https://github.com/rail-berkeley/oxe_envlogger.git@main?? From 282b53cc8a150a2b27d3b4f840a5afab16415c2c Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Fri, 19 Sep 2025 11:24:00 -0500 Subject: [PATCH 137/141] added 4 dimensions to reach-sparse and begun work conserving memory --- examples/async_drq_sim/automated_tests.sh | 21 +++++++++---------- .../async_drq_sim/automated_tests_helper.sh | 1 + examples/async_drq_sim/run_actor.sh | 12 ++++------- examples/async_drq_sim/run_learner.sh | 14 ++++--------- .../envs/panda_reach_sparse_gym_env.py | 8 +++---- 5 files changed, 23 insertions(+), 33 deletions(-) diff --git a/examples/async_drq_sim/automated_tests.sh b/examples/async_drq_sim/automated_tests.sh index a1a32da2..261a1201 100644 --- a/examples/async_drq_sim/automated_tests.sh +++ b/examples/async_drq_sim/automated_tests.sh @@ -5,14 +5,13 @@ SEEDS=$1 TEST="async_sac_state_sim.py" CONDA_ENV="serl" ENV="PandaReachSparseCube-v0" -MAX_STEPS=100000 +MAX_STEPS=1000 TRAINING_STARTS=0 RANDOM_STEPS=0 -CRITIC_ACTOR_RATIO=8 EXP_NAME="FIRST-TESTS-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" -PRELOAD_RLDS="/data/data/serl/demos/franka_reach_drq_demo_script/10_demos_session_202500914_213515" -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --preload_rlds_path $PRELOAD_RLDS" +PRELOAD_RLDS="/data/data/serl/demos/franka_reach_drq_demo_script/10_demos_session_202500914_213515/PandaReachSparseCube-v0/0.1.0" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --preload_rlds_path $PRELOAD_RLDS --encoder_type resnet-pretrained" ARGS="" function run_test { @@ -29,22 +28,22 @@ function run_test { echo "Running constant with args: $ARGS" tmux respawn-pane -k -t serl_session:0.1 tmux respawn-pane -k -t serl_session:0.2 - tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m - tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_actor.sh --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash run_learner.sh --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m # Wait for learner to finish while ! tmux capture-pane -t serl_session:0.2 -p | grep "logout" > /dev/null; do - sleep 1 + sleep 100 done echo "Finished!" done } # BASELINE TESTING -for replay_buffer_capacity in 1000000 +for replay_buffer_capacity in 200000 do - ARGS="--run_name baseline --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type replay_buffer --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity" + ARGS="--run_name baseline --replay_buffer_type memory_efficient_replay_buffer --replay_buffer_capacity $replay_buffer_capacity" run_test done @@ -53,9 +52,9 @@ for starting_branch_count in 1 27 do for workspace_width in 0.5 do - for replay_buffer_capacity in 1000000 + for replay_buffer_capacity in 200000 do - ARGS="--run_name constant-$starting_branch_count^1 --steps_per_update $steps_per_update --critic_actor_ratio $CRITIC_ACTOR_RATIO --replay_buffer_type $REPLAY_BUFFER_TYPE --batch_size $batch_size --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + ARGS="--run_name constant-$starting_branch_count^1 --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" run_test done done diff --git a/examples/async_drq_sim/automated_tests_helper.sh b/examples/async_drq_sim/automated_tests_helper.sh index 4f6df155..3c02ecb6 100644 --- a/examples/async_drq_sim/automated_tests_helper.sh +++ b/examples/async_drq_sim/automated_tests_helper.sh @@ -2,5 +2,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export MUJOCO_GL=egl python async_drq_sim.py "$@" \ No newline at end of file diff --git a/examples/async_drq_sim/run_actor.sh b/examples/async_drq_sim/run_actor.sh index ebdb0514..86f8fff3 100644 --- a/examples/async_drq_sim/run_actor.sh +++ b/examples/async_drq_sim/run_actor.sh @@ -1,9 +1,5 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -python async_drq_sim.py "$@" \ - --actor \ - --exp_name=serl_dev_drq_sim_test_resnet \ - --seed 0 \ - --random_steps 1000 \ - --encoder_type resnet-pretrained \ - --debug +export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export MUJOCO_GL=egl && \ + +python async_drq_sim.py --actor "$@" diff --git a/examples/async_drq_sim/run_learner.sh b/examples/async_drq_sim/run_learner.sh index 49589428..d96ec089 100644 --- a/examples/async_drq_sim/run_learner.sh +++ b/examples/async_drq_sim/run_learner.sh @@ -1,11 +1,5 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ -python async_drq_sim.py "$@" \ - --learner \ - --exp_name=serl_dev_drq_sim_test_resnet \ - --seed 0 \ - --training_starts 1000 \ - --critic_actor_ratio 4 \ - --encoder_type resnet-pretrained \ -# --demo_path franka_lift_cube_image_20_trajs.pkl \ -# --debug # wandb is disabled when debug +export XLA_PYTHON_CLIENT_MEM_FRACTION=.4 && \ +export MUJOCO_GL=egl && \ + +python async_drq_sim.py --learner "$@" diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py index 016afd8e..416f4e24 100644 --- a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -130,8 +130,8 @@ def __init__( ) self.action_space = gym.spaces.Box( - low=np.asarray([-1.0, -1.0, -1.0]), - high=np.asarray([1.0, 1.0, 1.0]), + low=np.asarray([-1.0, -1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0, 1.0]), dtype=np.float32, ) @@ -142,8 +142,8 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, - width=960, - height=960, + width=128, + height=128, camera_id=0 ) self._viewer.render(self.render_mode) From 8687adaf9a3606fcb68b4626127e215579b27170 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 23 Sep 2025 12:27:16 -0500 Subject: [PATCH 138/141] better memory management for testing (also batch size down to 128) --- examples/async_drq_sim/automated_tests.sh | 13 +++++++------ examples/async_drq_sim/run_actor.sh | 1 + examples/async_drq_sim/run_learner.sh | 1 + 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/async_drq_sim/automated_tests.sh b/examples/async_drq_sim/automated_tests.sh index 261a1201..07a58f48 100644 --- a/examples/async_drq_sim/automated_tests.sh +++ b/examples/async_drq_sim/automated_tests.sh @@ -5,13 +5,14 @@ SEEDS=$1 TEST="async_sac_state_sim.py" CONDA_ENV="serl" ENV="PandaReachSparseCube-v0" -MAX_STEPS=1000 -TRAINING_STARTS=0 -RANDOM_STEPS=0 +MAX_STEPS=1000000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +BATCH_SIZE=128 EXP_NAME="FIRST-TESTS-$ENV" REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" PRELOAD_RLDS="/data/data/serl/demos/franka_reach_drq_demo_script/10_demos_session_202500914_213515/PandaReachSparseCube-v0/0.1.0" -BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --preload_rlds_path $PRELOAD_RLDS --encoder_type resnet-pretrained" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --batch_size $BATCH_SIZE --preload_rlds_path $PRELOAD_RLDS --encoder_type resnet-pretrained" ARGS="" function run_test { @@ -41,7 +42,7 @@ function run_test { } # BASELINE TESTING -for replay_buffer_capacity in 200000 +for replay_buffer_capacity in 1000000 do ARGS="--run_name baseline --replay_buffer_type memory_efficient_replay_buffer --replay_buffer_capacity $replay_buffer_capacity" run_test @@ -52,7 +53,7 @@ for starting_branch_count in 1 27 do for workspace_width in 0.5 do - for replay_buffer_capacity in 200000 + for replay_buffer_capacity in 1000000 do ARGS="--run_name constant-$starting_branch_count^1 --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" run_test diff --git a/examples/async_drq_sim/run_actor.sh b/examples/async_drq_sim/run_actor.sh index 86f8fff3..ece2d80e 100644 --- a/examples/async_drq_sim/run_actor.sh +++ b/examples/async_drq_sim/run_actor.sh @@ -1,5 +1,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ export MUJOCO_GL=egl && \ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ python async_drq_sim.py --actor "$@" diff --git a/examples/async_drq_sim/run_learner.sh b/examples/async_drq_sim/run_learner.sh index d96ec089..48deba7e 100644 --- a/examples/async_drq_sim/run_learner.sh +++ b/examples/async_drq_sim/run_learner.sh @@ -1,5 +1,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ export XLA_PYTHON_CLIENT_MEM_FRACTION=.4 && \ export MUJOCO_GL=egl && \ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ python async_drq_sim.py --learner "$@" From 53c85daaee284d6686829d89c263da0c03f20e0e Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 30 Sep 2025 10:50:55 -0500 Subject: [PATCH 139/141] Moving towards a pick-and-place rlds demo recording system with cube vision. --- .../franka_pick_n_place_drq_demo_script.py | 511 ++++++++++++++++++ demos/demos/franka_reach_drq_demo_script.py | 3 + franka_sim/franka_sim/__init__.py | 8 +- .../envs/panda_reach_sparse_gym_env.py | 13 +- 4 files changed, 523 insertions(+), 12 deletions(-) create mode 100644 demos/demos/franka_pick_n_place_drq_demo_script.py diff --git a/demos/demos/franka_pick_n_place_drq_demo_script.py b/demos/demos/franka_pick_n_place_drq_demo_script.py new file mode 100644 index 00000000..3e1a8458 --- /dev/null +++ b/demos/demos/franka_pick_n_place_drq_demo_script.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +import os +import time +from datetime import datetime + +import numpy as np + +# Logging +from absl import app, flags, logging +from oxe_envlogger.envlogger import AutoOXEEnvLogger + +# DRL +import gym +import mujoco +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper + +# Needed to create the franka environment +import franka_sim + +# Teleoperation imports +import sys, select, termios, tty + +# RLDS/TFDS +import json, glob, inspect +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets import folder_dataset +#------------------------------------------------------------------------------------------- +# Flags +#------------------------------------------------------------------------------------------- +flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") +flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging +#flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", + "Directory to save the output data. This is where the RLDS logs will be saved.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") +flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") + +FLAGS = flags.FLAGS + +#------------------------------------------------------------------------------------------- +## Telop Config Variables +#------------------------------------------------------------------------------------------- +ACTION_MAX = 10 # Maximum action value for clipping actions + +# Bind, xyz, gripper vals to keys +moveBindings = { + 'i':(1,0,0,0), + ',':(-1,0,0,0), + 'j':(0,1,0,0), + 'l':(0,-1,0,0), + 'u':(0,0,0,1), + 'o':(0,0,0,-1), + 'm':(0,0,1,0), + '.':(0,0,-1,0), + 'g':(0,0,0,0.1), # open gripper + 'h':(0,0,0,-0.1), # close gripper + + } + +# Extend bindings to include camera controls +camBindings = { + 'a': ("azimuth", -5), # rotate left + 'd': ("azimuth", 5), # rotate right + 'w': ("elevation", 2), # tilt up + 's': ("elevation", -2), # tilt down + 'q': ("distance", -0.1),# zoom in + 'e': ("distance", 0.1), # zoom out +} + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def update_camera(viewer,key): + """ + Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. + Controls: + 'a' : rotate left + 'd' : rotate right + 'w' : tilt up + 's' : tilt down + 'q' : zoom in + 'e' : zoom out + """ + if hasattr(viewer, 'cam'): + # Get current camera parameters + attr, delta = camBindings[key] + val = getattr(viewer.cam, attr) + + setattr(viewer.cam, attr, val + delta) + +def close_logger_and_env(env): + """ + Best-effort shutdown: + 1) close embedded envloggers/writers if present + 2) close dm_env (if you have a DeepMind-style env) + 3) close Gym/Gymnasium env + 4) close the Mujoco viewer + Also walks through common wrapper attributes (env, _env, unwrapped). + """ + import logging + + visited = set() + + def _safe_close(obj, label=""): + if obj is None or id(obj) in visited: + return + visited.add(id(obj)) + + # 1) Close common logger/writer attributes first (flush TFRecord) + for name in ("_envlogger", "envlogger", "logger", "_logger", "writer"): + try: + logger_obj = getattr(obj, name, None) + if logger_obj is not None and hasattr(logger_obj, "close"): + logger_obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.{name}.close() raised {e!r}") + + # 2) Close dm_env if present + try: + dm = getattr(obj, "dm_env", None) + if dm is not None and hasattr(dm, "close"): + dm.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.dm_env.close() raised {e!r}") + + # 3) Close Gym/Gymnasium env + try: + if hasattr(obj, "close"): + obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.close() raised {e!r}") + + # 4) Close Mujoco viewer if accessible + try: + # Some stacks keep viewer at env._viewer.viewer; others just env._viewer + viewer = None + vwrap = getattr(obj, "_viewer", None) + if vwrap is not None: + viewer = getattr(vwrap, "viewer", vwrap) + if viewer is not None and hasattr(viewer, "close"): + viewer.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label} viewer close raised {e!r}") + + # Recurse into common wrapper links + for child_name in ("env", "_env", "unwrapped", "environment", "base_env"): + child = getattr(obj, child_name, None) + if child is not None and child is not obj: + _safe_close(child, f"{label}.{child_name}" if label else child_name) + + _safe_close(env, "env") + +def finalize_tfds_metadata_beamless(builder_dir: str): + """ + Beam-free finalize: count TFRecord examples per shard and write + numShards/shardLengths into dataset_info.json so TFDS will load. + """ + import os, json, glob + import tensorflow as tf + + info_path = os.path.join(builder_dir, "dataset_info.json") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + + with open(info_path) as f: + info = json.load(f) + + ds_name = info["name"] # e.g. "PandaReachSparseCube-v0" + file_fmt = info.get("fileFormat", "tfrecord") + tmpl_str = info["splits"][0]["filepathTemplate"] + + # Prefer strict pattern "-.-" + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{ds_name}-*.{file_fmt}-*"))) + if not shard_paths: + # Fallback to any tfrecord-like file + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_fmt}*"))) + if not shard_paths: + raise FileNotFoundError( + f"No {file_fmt} shards found in {builder_dir}. " + f"Expected like '{ds_name}-train.{file_fmt}-00000' per template '{tmpl_str}'." + ) + + # Count episodes (1 Example = 1 episode with envlogger/RLDS) + shard_lengths = [sum(1 for _ in tf.data.TFRecordDataset(p)) for p in shard_paths] + + # Write lengths for each split using this template + for s in info["splits"]: + if s.get("filepathTemplate") == tmpl_str: + s["numShards"] = len(shard_paths) + s["shardLengths"] = shard_lengths + # Re-write dataset_info.json with updated shard info + with open(info_path, "w") as f: + json.dump(info, f, indent=2) + + # Sanity log + import tensorflow_datasets as tfds + b = tfds.builder_from_directory(builder_dir) + print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + +def ensure_dir_exists(): + """ + For oxe_envlogger + RLDS compatibility, data must be written in the following format: + /data/data/serl/demos/franka_reach_drq_demo_script/ + └── session_20250821_222412/ + └── PandaReachSparseCube-v0/ + └── 0.1.0/ + dataset_info.json + features.json + PandaReachSparseCube-v0-train.tfrecord-00000-of-00001 + + We can have a base path with customized sessions inside. + Inside each session we have: env-version-files + + + Returns + ------- + out_path : str + The path to the output directory. + """ + # Customize the path + root = FLAGS.output_dir + session = datetime.now().strftime("session_%Y%m%d_%H%M%S") + session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") + #session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") + + # Dataset details + dataset_name = FLAGS.env + version = "0.1.0" # PArt of RLDS format. needed. + + # Create output filename with configuration details + dataset_dir = os.path.join(session_root, dataset_name, version) + os.makedirs(dataset_dir, exist_ok=True) + logging.info(f"TFDS builder dir: {dataset_dir}") + + return dataset_dir + +def getKey(settings): + """ + Waits briefly for a keypress and returns the pressed key. + + Parameters + ---------- + settings : list + Original terminal settings so they can be restored after reading. + + Returns + ------- + key : str + The key pressed by the user, or '' (empty string) if no key was pressed. + """ + # Put terminal into raw mode so keypress is captured instantly + tty.setraw(sys.stdin.fileno()) + + # Wait for human to input action, we will not provide a timeout so it is blocking. + rlist, _, _ = select.select([sys.stdin], [], []) # select.select(rlist, wlist, xlist[, timeout]) + + if rlist: + # Read exactly one character if a key is pressed + key = sys.stdin.read(1) + else: + key = '' + + # Restore terminal to original settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + return key + +def get_kb_demo_action(env,speed=0.075): + """ + Reads keyboard input and maps it to a 3D action vector for robot control or camera action. + TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. + Otherwise, none actions are still considered steps in the loop. + + The function uses non-blocking keyboard input to allow interactive + teleoperation. Keys are mapped to directions in Cartesian space: + - 'i' : +x (forward) + - ',' : -x (backward) + - 'j' : +y (left) + - 'l' : -y (right) + - 'm' : +z (up) + - '.' : -z (down) + - 'k' : stop (zero vector) + + Parameters + ---------- + speed : float, optional + Step size for each key press (default 0.2). + + Returns + ------- + np.ndarray + Action vector of shape (3,), where each entry corresponds to + [x, y, z] translation command. Example: [0.2, 0.0, 0.0]. + """ + # Save current terminal settings so we can restore later + settings = termios.tcgetattr(sys.stdin) + + # Initialize action as a zero vector (no movement) + action = np.zeros(4, dtype=float) + + try: + # Capture the pressed key + key = getKey(settings) + + # Check keys for camera first, if so, update camera and get another key for action + if key in camBindings: + if hasattr(env.unwrapped, "_viewer"): + update_camera(env.unwrapped._viewer.viewer,key) + key = getKey(settings) # get another key for action + + elif key in moveBindings: + # Lookup (x, y, z) direction and scale by speed + dx, dy, dz, g= moveBindings[key] + action = np.array([dx, dy, dz, g], dtype=float) * speed + + elif key == 'k': + # 'k' means stop → zero vector + action = np.zeros(4, dtype=float) + + elif key == '\x03': # CTRL-C + raise KeyboardInterrupt + + finally: + # Restore terminal even if something goes wrong + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + + # Clip action values to prevent excessive commands + action = np.clip(action, -ACTION_MAX, ACTION_MAX) + + return action + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 2.0 # Camera distance + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +############################################################################## +def main(unused_argv): + logging.info(f'Creating gym environment...') + + # Render mode configuration based on debug flag + if FLAGS.debug: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' + else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + # Create the environment with the specified render mode and wrappers + env = gym.make(FLAGS.env, render_mode=_render_mode) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env == "PandaReachSparseCube-v0" or FLAGS.env == "PandaPickCubeVision-v0": + env = SERLObsWrapper( + env, + target_hw=(128, 128), + img_dtype=np.uint8, # or np.float32 + normalize=False, # True if using float32 in [0,1] + ) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + logging.info(f'Done creating {FLAGS.env} environment.') + + # Set camera to front view if viewer is available + if hasattr(env.unwrapped, '_viewer'): + viewer = set_front_cam_view(env) + if viewer: + logging.info('Camera view set to front-facing perspective.') + else: + logging.warning('Failed to set camera view. Viewer not available.') + + # Wrap with oxe_envlogger to record demos + dataset_dir = None + session_root = None + if FLAGS.enable_envlogger: + + dataset_dir = ensure_dir_exists() + + # Will save as many episodes as possible into files of 200MB each by default. + env = AutoOXEEnvLogger( + env=env, + dataset_name=FLAGS.env, + directory=dataset_dir, + #split_name="train", # "train", "test", or "validation" + ) + logging.info('Recording %r demos...', FLAGS.num_demos) + + #--- LOOP DEMOS/EPISODES --- + # Loop through the number of demos specified by the user to record demonstrations + try: + for i in range(FLAGS.num_demos): + + # Log custom metadata during new episode: language embeddings randomly. + if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) + + # Start a new episode + env.reset() + terminated = False + truncated = False + + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached + while not (terminated or truncated): + + # Get action from the demo function + action = get_kb_demo_action(env) + + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": np.float32(time.time())}) + + return_step = env.step(action) + + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False + + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 + + logging.info('Done recording %r demos.', FLAGS.num_demos) + + finally: + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: + + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") + + # closes logger(s) + dm_env + gym + viewer + # env.unwrapped.unwrapped.unwrapped.close() # close mujoco viewer + # env.env.env.env.close() + close_logger_and_env(env) + + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) + + # Note: async_drq_sim will read from the dataset_dir you printed above. + +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py index 40f1df04..b793cda3 100644 --- a/demos/demos/franka_reach_drq_demo_script.py +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -220,6 +220,7 @@ def ensure_dir_exists(): root = FLAGS.output_dir session = datetime.now().strftime("session_%Y%m%d_%H%M%S") session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") + #session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") # Dataset details dataset_name = FLAGS.env @@ -457,6 +458,8 @@ def main(unused_argv): logging.info("Closing environment to flush data to disk...") # closes logger(s) + dm_env + gym + viewer + # env.unwrapped.unwrapped.unwrapped.close() # close mujoco viewer + # env.env.env.env.close() close_logger_and_env(env) # Scan files and write metadata diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 1a56890e..ff6ecc2b 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -16,7 +16,7 @@ register( id="PandaPickCubeVision-v0", entry_point="franka_sim.envs:PandaPickCubeGymEnv", - max_episode_steps=100, + max_episode_steps=200, kwargs={"image_obs": True}, ) register( @@ -29,10 +29,4 @@ entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", max_episode_steps=200, ) -# register( -# id="PandaReachCubeVision-v0", -# entry_point="franka_sim.envs:PandaReachCubeGymEnv", -# max_episode_steps=100, -# kwargs={"image_obs": True}, -# ) diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py index 016afd8e..61d7295c 100644 --- a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -130,8 +130,8 @@ def __init__( ) self.action_space = gym.spaces.Box( - low=np.asarray([-1.0, -1.0, -1.0]), - high=np.asarray([1.0, 1.0, 1.0]), + low=np.asarray([-1.0, -1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0, 1.0]), dtype=np.float32, ) @@ -189,7 +189,7 @@ def step( truncated: bool, info: dict[str, Any] """ - x, y, z, _ = action + x, y, z, grasp = action # Set the mocap position. pos = self._data.mocap_pos[0].copy() @@ -198,7 +198,10 @@ def step( self._data.mocap_pos[0] = npos # Set gripper grasp. - self._data.ctrl[self._gripper_ctrl_id] = 0 # Fully open position + g = self._data.ctrl[self._gripper_ctrl_id] / 255 + dg = grasp * self._action_scale[1] + ng = np.clip(g + dg, 0.0, 1.0) + self._data.ctrl[self._gripper_ctrl_id] = ng * 255 for _ in range(self._n_substeps): tau = opspace( @@ -295,6 +298,6 @@ def _compute_reward(self) -> float: env = PandaReachSparseCubeGymEnv(render_mode="human") env.reset() for i in range(5000): - env.step(np.random.uniform(-1, 1, 3)) + env.step(np.random.uniform(-1, 1, 4)) env.render() env.close() From 9f2880931c45fe1737ab159d1a28939b8bd8e1fc Mon Sep 17 00:00:00 2001 From: Juan Rojas Date: Tue, 30 Sep 2025 10:51:22 -0500 Subject: [PATCH 140/141] Moving towards a pick-and-place rlds demo recording system with cube vision. --- franka_sim/franka_sim/envs/panda_pick_gym_env.py | 4 ++-- .../franka_sim/envs/panda_reach_sparse_gym_env.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index dcf4783f..c3d3324e 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,8 +144,8 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, - width=128, - height=128, + width=960, + height=960, camera_id=0 ) self._viewer.render(self.render_mode) diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py index 61d7295c..80c586ad 100644 --- a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -189,7 +189,7 @@ def step( truncated: bool, info: dict[str, Any] """ - x, y, z, grasp = action + x, y, z, _ = action # Set the mocap position. pos = self._data.mocap_pos[0].copy() @@ -198,10 +198,11 @@ def step( self._data.mocap_pos[0] = npos # Set gripper grasp. - g = self._data.ctrl[self._gripper_ctrl_id] / 255 - dg = grasp * self._action_scale[1] - ng = np.clip(g + dg, 0.0, 1.0) - self._data.ctrl[self._gripper_ctrl_id] = ng * 255 + self._data.ctrl[self._gripper_ctrl_id] = 0 # Keep gripper closed. + # g = self._data.ctrl[self._gripper_ctrl_id] / 255 + # dg = grasp * self._action_scale[1] + # ng = np.clip(g + dg, 0.0, 1.0) + # self._data.ctrl[self._gripper_ctrl_id] = ng * 255 for _ in range(self._n_substeps): tau = opspace( From 5cb7419addd4d4a149b8a5fbe11be5c32d548f31 Mon Sep 17 00:00:00 2001 From: ryanvanderstelt Date: Tue, 6 Jan 2026 18:08:51 -0600 Subject: [PATCH 141/141] fixed sampling bug --- .../serl_launcher/data/fractal_symmetry_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py index a55ef3a8..e3ccf68c 100644 --- a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -389,7 +389,7 @@ def sample( indx = self.np_random.randint(len(self), size=batch_size) for i in range(batch_size): - while not self._is_correct_index[indx[i]]: + while indx[i] >= self._size: if hasattr(self.np_random, "integers"): indx[i] = self.np_random.integers(len(self)) else: