Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bfa4b7f
Added Kuka env to start testing SERL
omeym Nov 5, 2024
c67fad6
Merge pull request #1 from omeym/omeym/kuka-env
omeym Nov 5, 2024
1a8ef43
Made changes in SERL for kuka execution
omeym Nov 6, 2024
b15aaca
Merge pull request #2 from omeym/omeym/robot-testing
omeym Nov 6, 2024
7aed98d
Completed Testing and running the training loop
rysabh Nov 12, 2024
4cd36b3
Completed setting up the training pipeline
rysabh Nov 13, 2024
4c327df
Merge pull request #3 from omeym/omeym/serl-training
omeym Nov 13, 2024
4e74bee
Pushing latest changes for benchmarking trials
omeym Jan 30, 2025
8182483
Removing the stale data_collector folder
nikitasarawgi Feb 7, 2025
bbaffae
Squashed 'data_collector/' content from commit 7d3b98f
nikitasarawgi Feb 7, 2025
8147c53
Merge commit 'bbaffae892d4bf99bcf6998e61655da060a2d90f' as 'data_coll…
nikitasarawgi Feb 7, 2025
4744ba1
Undoing the incorrect nesting of data_collector
nikitasarawgi Feb 7, 2025
ad9e61f
Squashed 'data_collector/' content from commit b8263e9
nikitasarawgi Feb 7, 2025
3c5c5ee
Merge commit 'ad9e61f5b0c5ddf0980c6f985aeea7c7313d0c37' as 'data_coll…
nikitasarawgi Feb 7, 2025
13f087d
Updating lbr FRI submodule to a new one
nikitasarawgi Mar 5, 2025
8a07d89
updating lbr_fri_ros2_stack submodule in gitmodules
nikitasarawgi Mar 5, 2025
e0a207f
Updating submodule for fri_idl to rros-lab
nikitasarawgi Mar 5, 2025
e34521f
Updating .gitmodules for lbr_fri_idl
nikitasarawgi Mar 5, 2025
2ef56c7
remove broken subtree reference
nikitasarawgi Mar 5, 2025
c8fb537
Added new submodule
nikitasarawgi Mar 5, 2025
a1c3add
Updating the data_collector submodule
nikitasarawgi Mar 5, 2025
ce21ae9
Adding serl files for kuka
nikitasarawgi Mar 5, 2025
f3ca745
Updating the data_collector submodule with changes to pkl file for SERL
nikitasarawgi Mar 13, 2025
af0cb6b
Updating submodule for pickle protocol changes
nikitasarawgi Mar 13, 2025
fce667e
Mostly changes to logging
nikitasarawgi Mar 13, 2025
c3687b1
Merge branch 'omeym/serl-training' of github.com:omeym/serl-rros into…
nikitasarawgi Mar 13, 2025
67d1875
Updating gitignore
nikitasarawgi Mar 18, 2025
3227fd5
Updating submodule data_collector with fixes
nikitasarawgi Mar 18, 2025
30eabfb
New system changes + IMPORTANT reward calc changes
nikitasarawgi Mar 18, 2025
4de15a9
project sync 8/31
nikitasarawgi Aug 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,29 @@ MUJOCO_LOG.TXT
_METADATA
checkpoint
wandb/
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace

# Local History for Visual Studio Code
.history/

# Built Visual Studio Code Extensions
*.vsix

realsense-ros/
.vscode/
data_collector/data_collector/**/*.csv
data_collector/data_collector/**/*.jpg
data_collector/data_collector/**/*.jpeg
data_collector/data_collector/**/*.png

.~lock.*

examples/async_peg_insert_drq/checkpoints*/**
examples/async_peg_insert_drq/log*/**
examples/async_peg_insert_drq/demo_data*/**/**
examples/async_peg_insert_drq/experiments*/**/**
15 changes: 15 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[submodule "fri"]
path = fri
url = git@github.com:lbr-stack/fri.git
[submodule "lbr_fri_idl"]
path = lbr_fri_idl
url = https://github.com/RROS-Lab/serl_lbr_fri_idl.git
branch=serl-kuka
[submodule "lbr_fri_ros2_stack"]
path = lbr_fri_ros2_stack
url = https://github.com/RROS-Lab/serl_lbr_fri_ros2_stack.git
branch=serl-kuka
[submodule "data_collector"]
path = data_collector
url = https://github.com/nikitasarawgi/expert_data_collector.git
branch = data-collector-only
1 change: 1 addition & 0 deletions data_collector
Submodule data_collector added at b511fd
121 changes: 95 additions & 26 deletions examples/async_peg_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
#!/usr/bin/env python3

import sys
import os
import getpass

username = getpass.getuser()
env = "serl"
print("Ensure to change anaconda3 or miniconda3 and change your environment name")
print("Conda Environment name is: ", env)
sys.path.append(
"/home/" + username + "/anaconda3/envs/" + env + "/lib/python3.10/site-packages"
)
import rclpy
import rclpy.duration
from rclpy.executors import MultiThreadedExecutor
import time
from functools import partial
import jax
Expand All @@ -26,27 +39,47 @@
make_trainer_config,
make_wandb_logger,
)
from kuka_server.robot_interface import RobotInterfaceNode
from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore
from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper
from franka_env.envs.relative_env import RelativeFrame
from franka_env.envs.wrappers import (
GripperCloseEnv,
SpacemouseIntervention,

import os

MAIN_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
sys.path.append(MAIN_DIR)
from serl_robot_infra.kuka_env.envs.relative_env import RelativeFrame
from serl_robot_infra.kuka_env.envs.wrappers import (
Quat2EulerWrapper,
)

import franka_env
import serl_robot_infra.franka_env
import serl_robot_infra.kuka_env

FLAGS = flags.FLAGS
# sys.argv = sys.argv[:1]

# # `app.run` calls `sys.exit`
# try:
# app.run(lambda argv: None)
# except:
# pass

flags.DEFINE_string("env", "FrankaEnv-Vision-v0", "Name of environment.")
FLAGS = flags.FLAGS
flags.DEFINE_string("env", "KukaPegInsert-Vision-v0", "Name of environment.")
flags.DEFINE_string("agent", "drq", "Name of agent.")
flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.")
flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.")

## TODO :: NISARA :: change max trajectory length to at least 300 (or see what the new data wants)
flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.")

flags.DEFINE_integer("seed", 42, "Random seed.")
flags.DEFINE_bool("save_model", False, "Whether to save model.")
flags.DEFINE_integer("critic_actor_ratio", 4, "critic to actor update ratio.")
flags.DEFINE_boolean(
"load_checkpoint", False, "Whether to start from previous checkpoint or not."
)
flags.DEFINE_string("load_checkpoint_path", None, "Checkpoint to start training from.")

flags.DEFINE_integer("batch_size", 256, "Batch size for training the policy.")
flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.")
flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.")

Expand All @@ -56,6 +89,7 @@

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
flags.DEFINE_integer("loaded_checkpoint_step", 1000, "Loaded Checkpoint Step.")

# flag to indicate if this is a leaner or a actor
flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.")
Expand Down Expand Up @@ -94,6 +128,7 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng):
This is the actor loop, which runs when "--actor" is set to True.
"""
if FLAGS.eval_checkpoint_step:
print("Evaluating the policy")
success_counter = 0
time_list = []

Expand All @@ -104,14 +139,17 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng):
)
agent = agent.replace(state=ckpt)

for episode in range(FLAGS.eval_n_trajs):
for episode in tqdm.tqdm(range(FLAGS.eval_n_trajs), "Evaluation Trajectory"):
obs, _ = env.reset()
done = False
start_time = time.time()
while not done:
sampling_rng, key = jax.random.split(sampling_rng)
actions = agent.sample_actions(
observations=jax.device_put(obs),
argmax=True,
# argmax=True,
seed=key,
# deterministic=False,
)
actions = np.asarray(jax.device_get(actions))

Expand Down Expand Up @@ -151,10 +189,11 @@ def update_params(params):
done = False

# training loop
print("Entering Training Loop")
timer = Timer()
running_return = 0.0

for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True):
for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="actor"):
timer.tick("total")

with timer.context("sample_actions"):
Expand All @@ -168,6 +207,8 @@ def update_params(params):
deterministic=False,
)
actions = np.asarray(jax.device_get(actions))
# nisara : Comment
# print("Actions in Training Loop: ", actions)

# Step environment
with timer.context("step_env"):
Expand Down Expand Up @@ -215,8 +256,9 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer):
The learner loop, which runs when "--learner" is set to True.
"""
# set up wandb and logging
# wandb_logger = None
wandb_logger = make_wandb_logger(
project="serl_dev",
project="serl_testing",
description=FLAGS.exp_name or FLAGS.env,
debug=FLAGS.debug,
)
Expand Down Expand Up @@ -304,7 +346,7 @@ def stats_callback(type: str, payload: dict) -> dict:
if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0:
assert FLAGS.checkpoint_path is not None
checkpoints.save_checkpoint(
FLAGS.checkpoint_path, agent.state, step=update_steps, keep=100
FLAGS.checkpoint_path, agent.state, step=update_steps, keep=200
)

update_steps += 1
Expand All @@ -314,28 +356,42 @@ def stats_callback(type: str, payload: dict) -> dict:


def main(_):
rclpy.init(args=_)
print("ROS2 initialized.")
robot_interface_node = RobotInterfaceNode()
executor = MultiThreadedExecutor()
executor.add_node(robot_interface_node)
robot_interface_node.get_logger().info("Robot interface node started.")
# executor.spin_once()

assert FLAGS.batch_size % num_devices == 0
# seed
rng = jax.random.PRNGKey(FLAGS.seed)

print("Initializing environment")
# create env and load dataset
print("Value of learner: ", FLAGS.learner)

env = gym.make(
FLAGS.env,
fake_env=FLAGS.learner,
save_video=FLAGS.eval_checkpoint_step,
robot_interface_node=robot_interface_node,
)
env = GripperCloseEnv(env)
if FLAGS.actor:
env = SpacemouseIntervention(env)
print("Environment initialized")
# env = GripperCloseEnv(env)
# if FLAGS.actor:
# env = SpacemouseIntervention(env)
env = RelativeFrame(env)
env = Quat2EulerWrapper(env)
env = SERLObsWrapper(env)
env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None)
env = RecordEpisodeStatistics(env)
print("Environment wrapped")

image_keys = [key for key in env.observation_space.keys() if key != "state"]

rng, sampling_rng = jax.random.split(rng)
print("Creating agent")
agent: DrQAgent = make_drq_agent(
seed=FLAGS.seed,
sample_obs=env.observation_space.sample(),
Expand All @@ -344,44 +400,55 @@ def main(_):
encoder_type=FLAGS.encoder_type,
)

if FLAGS.load_checkpoint:
print(f"Loading Checkpoint from Previous Run:{FLAGS.load_checkpoint_path}")

ckpt = checkpoints.restore_checkpoint(
FLAGS.load_checkpoint_path,
agent.state,
step=FLAGS.loaded_checkpoint_step,
)
agent = agent.replace(state=ckpt)

print("Agent created")
# replicate agent across devices
# need the jnp.array to avoid a bug where device_put doesn't recognize primitives
print("Replicating agent")
agent: DrQAgent = jax.device_put(
jax.tree_map(jnp.array, agent), sharding.replicate()
)

if FLAGS.learner:
print("Learner code executed")
sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate())
replay_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=FLAGS.replay_buffer_capacity,
image_keys=image_keys,
)
print("Replay buffer initialized")
print("Checkpoint Path: ", FLAGS.checkpoint_path)
demo_buffer = MemoryEfficientReplayBufferDataStore(
env.observation_space,
env.action_space,
capacity=10000,
image_keys=image_keys,
)
import pickle as pkl
import pickle5 as pkl

with open(FLAGS.demo_path, "rb") as f:
trajs = pkl.load(f)
trajs = pkl.load(f, fix_imports=True)
for traj in trajs:
demo_buffer.insert(traj)
print(f"demo buffer size: {len(demo_buffer)}")

# learner loop
print_green("starting learner loop")
learner(
sampling_rng,
agent,
replay_buffer,
demo_buffer=demo_buffer,
)
learner(sampling_rng, agent=agent, replay_buffer=replay_buffer, demo_buffer=demo_buffer)

elif FLAGS.actor:
print("Initializing Actor Node")
sampling_rng = jax.device_put(sampling_rng, sharding.replicate())
data_store = QueuedDataStore(2000) # the queue size on the actor

Expand All @@ -392,6 +459,8 @@ def main(_):
else:
raise NotImplementedError("Must be either a learner or an actor")

rclpy.shutdown()


if __name__ == "__main__":
app.run(main)
Binary file not shown.
19 changes: 13 additions & 6 deletions examples/async_peg_insert_drq/run_actor.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
export XLA_PYTHON_CLIENT_PREALLOCATE=false && \
export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \
python async_drq_randomized.py "$@" \
--actor \
export XLA_PYTHON_CLIENT_ALLOCATOR=platform &&\
python3 async_drq_randomized.py "$@" \
--actor 1\
--render \
--env FrankaPegInsert-Vision-v0 \
--env KukaPegInsert-Vision-v0 \
--exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet \
--seed 0 \
--random_steps 0 \
--training_starts 200 \
--random_steps 200 \
--training_starts 300 \
--encoder_type resnet-pretrained \
--demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \
--demo_path /home/rp/SERL/src/examples/async_peg_insert_drq/rect_peg/rect_peg_skips_sdf_action6_surface_relaxed.pkl \
--eval_checkpoint_step 16500 \
--loaded_checkpoint_step 16500 \
--eval_n_trajs 11 \
--load_checkpoint_path /home/rp/SERL/src/examples/async_peg_insert_drq/rect_peg/checkpoints/ \
--checkpoint_path /home/rp/SERL/src/examples/async_peg_insert_drq/rect_peg/checkpoints_new/ \
--load_checkpoint 1 \
21 changes: 13 additions & 8 deletions examples/async_peg_insert_drq/run_learner.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
export XLA_PYTHON_CLIENT_PREALLOCATE=false && \
export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \
python async_drq_randomized.py "$@" \
--learner \
--env FrankaPegInsert-Vision-v0 \
export XLA_PYTHON_CLIENT_ALLOCATOR=platform &&\
python3 async_drq_randomized.py "$@" \
--learner 1\
--env KukaPegInsert-Vision-v0 \
--exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet_097 \
--seed 0 \
--random_steps 1000 \
--training_starts 200 \
--random_steps 200 \
--training_starts 300 \
--critic_actor_ratio 4 \
--batch_size 256 \
--eval_period 2000 \
--encoder_type resnet-pretrained \
--demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \
--checkpoint_period 1000 \
--checkpoint_path /home/undergrad/code/serl_dev/examples/async_peg_insert_drq/5x5_20degs_20demos_rand_peg_insert_097
--demo_path /home/rp/SERL/src/examples/async_peg_insert_drq/rect_peg/rect_peg_skips_sdf_action6_surface_relaxed.pkl \
--checkpoint_period 300 \
--loaded_checkpoint_step 1200 \
--load_checkpoint_path /home/rp/SERL/src/examples/async_peg_insert_drq/rect_peg/checkpoints/ \
--checkpoint_path /home/rp/SERL/src/examples/async_peg_insert_drq/rect_peg/checkpoints_new/ \
--load_checkpoint 1 \

2 changes: 1 addition & 1 deletion examples/async_sac_state_sim/tmux_launch.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"}
CONDA_ENV=${CONDA_ENV:-"serl"}
CONDA_ENV=${CONDA_ENV:-"serl-docker"}

cd $EXAMPLE_DIR
echo "Running from $(pwd)"
Expand Down
1 change: 1 addition & 0 deletions fri
Submodule fri added at 581194
Empty file.
Loading