Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ pip install \
tensorflow==2.6.0 \
keras==2.6.0 \
tf-agents==0.11.0rc0 \
tqdm==4.62.2
tqdm==4.62.2 \
wandb==0.12.7
```

(Optional): For Mujoco support, see [`docs/mujoco_setup.md`](docs/mujoco_setup.md). Recommended to skip it
Expand Down
11 changes: 9 additions & 2 deletions data/policy_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ibc.environments.collect.utils import get_oracle as get_oracle_module
from ibc.environments.particle import particle # pylint: disable=unused-import
from ibc.environments.particle import particle_oracles
from ibc.ibc.utils import make_video as video_module
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import wrappers
Expand Down Expand Up @@ -135,7 +136,6 @@ def evaluate(num_episodes,
if static_policy:
video_path = os.path.join(video_path, static_policy, 'vid.mp4')


if saved_model_path and static_policy:
raise ValueError(
'Only pass in either a `saved_model_path` or a `static_policy`.')
Expand Down Expand Up @@ -187,6 +187,9 @@ def evaluate(num_episodes,
py_mode=True,
compress_image=True))

if video:
env = video_module.make_video_env(env, video_path)

driver = py_driver.PyDriver(env, policy, observers, max_episodes=num_episodes)
time_step = env.reset()
initial_policy_state = policy.get_initial_state(1)
Expand All @@ -209,10 +212,12 @@ def main(_):
raise ValueError(
'A dataset_path must be provided when replicas are specified.')
dataset_split_path = os.path.splitext(flags.FLAGS.dataset_path)
output_split_path = os.path.splitext(flags.FLAGS.output_path)
context = multiprocessing.get_context()

for i in range(flags.FLAGS.replicas):
dataset_path = dataset_split_path[0] + '_%d' % i + dataset_split_path[1]
output_path = output_split_path[0] + '_%d' % i + output_split_path[1]
kwargs = dict(
num_episodes=flags.FLAGS.num_episodes,
task=flags.FLAGS.task,
Expand All @@ -223,7 +228,9 @@ def main(_):
checkpoint_path=flags.FLAGS.checkpoint_path,
static_policy=flags.FLAGS.policy,
dataset_path=dataset_path,
history_length=flags.FLAGS.history_length
history_length=flags.FLAGS.history_length,
video=flags.FLAGS.video,
output_path=output_path,
)
job = context.Process(target=evaluate, kwargs=kwargs)
job.start()
Expand Down
59 changes: 42 additions & 17 deletions ibc/agents/ibc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
from tf_agents.utils import nest_utils


def add_tensor_summaries(d, prefix, xs):
# Much simpler version of generate_tensor_summaries.
d[f"{prefix}.min"] = tf.reduce_min(xs)
d[f"{prefix}.avg"] = tf.reduce_mean(xs)
d[f"{prefix}.max"] = tf.reduce_max(xs)


@gin.configurable
class ImplicitBCAgent(base_agent.BehavioralCloningAgent):
"""TFAgent, implementing implicit behavioral cloning."""
Expand Down Expand Up @@ -257,25 +264,19 @@ def _loss(self,
if grad_loss is not None:
losses_dict['grad_loss'] = tf.reduce_mean(grad_loss)
if self._compute_mse:
losses_dict['mse_counter_examples'] = tf.reduce_mean(
mse_counter_examples)
add_tensor_summaries(losses_dict, "mse_counter_examples", mse_counter_examples)

opt_dict = dict()
if chain_data is not None and chain_data.energies is not None:
energies = chain_data.energies
opt_dict['overall_energies_avg'] = tf.reduce_mean(energies)
first_energies = energies[0]
opt_dict['first_energies_avg'] = tf.reduce_mean(first_energies)
final_energies = energies[-1]
opt_dict['final_energies_avg'] = tf.reduce_mean(final_energies)

if chain_data is not None and chain_data.grad_norms is not None:
grad_norms = chain_data.grad_norms
opt_dict['overall_grad_norms_avg'] = tf.reduce_mean(grad_norms)
first_grad_norms = grad_norms[0]
opt_dict['first_grad_norms_avg'] = tf.reduce_mean(first_grad_norms)
final_grad_norms = grad_norms[-1]
opt_dict['final_grad_norms_avg'] = tf.reduce_mean(final_grad_norms)
self._log_energy_info(
opt_dict,
observations,
expanded_actions,
fmt="EnergyStats/{}_pos")
self._log_energy_info(
opt_dict,
observations,
counter_example_actions,
fmt="EnergyStats/{}_neg")

losses_dict.update(opt_dict)

Expand Down Expand Up @@ -337,6 +338,30 @@ def _compute_ebm_loss(

return per_example_loss, debug_dict

def _log_energy_info(
self,
opt_dict,
observations,
actions,
*,
fmt):
assert not self._late_fusion # TODO(eric): Support?
B, K, _ = actions.shape
reshape_actions = tf.reshape(actions, ((B * K, -1)))
tiled_obs = nest_utils.tile_batch(observations, K)
de_dact, energies = mcmc.gradient_wrt_act(
self.cloning_network,
tiled_obs,
reshape_actions,
training=False,
network_state=(),
tfa_step_type=(),
apply_exp=False,
obs_encoding=None)
grad_norms = mcmc.compute_grad_norm(self._grad_norm_type, de_dact)
add_tensor_summaries(opt_dict, fmt.format("energies"), energies)
add_tensor_summaries(opt_dict, fmt.format("grad_norms"), grad_norms)

def _make_counter_example_actions(
self,
observations, # B x obs_spec
Expand Down
51 changes: 51 additions & 0 deletions ibc/configs/particle/mlp_ebm_dfo.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# coding=utf-8
# Copyright 2021 The Reach ML Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

train_eval.root_dir = '/tmp/ibc_logs/mlp_ebm_dfo'
train_eval.loss_type = 'ebm' # specifies we are training ebm.
train_eval.network = 'MLPEBM'
train_eval.batch_size = 512
ImplicitBCAgent.num_counter_examples = 8
train_eval.num_iterations = 100000
train_eval.replay_capacity = 10000
train_eval.eval_interval = 10000
train_eval.eval_episodes = 20
train_eval.learning_rate = 1e-3
train_eval.goal_tolerance = 0.02
train_eval.seed = 0
train_eval.sequence_length = 2
train_eval.dataset_eval_fraction = 0.0

IbcPolicy.num_action_samples = 512
train_eval.uniform_boundary_buffer = 0.05
get_normalizers.nested_obs = True # Particle has nested
get_normalizers.num_samples = 5000
compute_dataset_statistics.min_max_actions = True

IbcPolicy.use_dfo = True
IbcPolicy.use_langevin = False
IbcPolicy.optimize_again = False

# Configs for cloning net.
# MLPEBM.layers = 'ResNetPreActivation'
MLPEBM.layers = 'MLPDropout'
MLPEBM.width = 256
MLPEBM.depth = 2
MLPEBM.rate = 0.0
MLPEBM.dense_layer_type = 'regular'
MLPEBM.activation = 'relu'
ImplicitBCAgent.compute_mse = True
ImplicitBCAgent.add_grad_penalty = False
ResNetLayer.normalizer = None
6 changes: 3 additions & 3 deletions ibc/configs/particle/mlp_ebm_langevin.gin
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ train_eval.eval_episodes = 20
train_eval.learning_rate = 1e-3
train_eval.goal_tolerance = 0.02
train_eval.seed = 0
train_eval.sequence_length = 2
train_eval.sequence_length = 1
train_eval.dataset_eval_fraction = 0.0

IbcPolicy.num_action_samples = 512
Expand All @@ -39,7 +39,7 @@ IbcPolicy.use_langevin = True
IbcPolicy.optimize_again = False

# Configs for cloning net.
MLPEBM.layers = 'ResNetPreActivation'
MLPEBM.layers = 'MLPDropout'
MLPEBM.width = 256
MLPEBM.depth = 2
MLPEBM.rate = 0.0
Expand All @@ -54,7 +54,7 @@ grad_penalty.only_apply_final_grad_penalty = True # Note:
# we actually get slightly better results with this as False,
# however we get an OOM with it as False. TODO(peteflorence,oars):
# investigate if can lessen memory.
ImplicitBCAgent.run_full_chain_under_gradient = True
ImplicitBCAgent.run_full_chain_under_gradient = False

langevin_actions_given_obs.num_iterations = 100
ImplicitBCAgent.fraction_langevin_samples = 1.0
13 changes: 13 additions & 0 deletions ibc/configs/particle/run_mlp_ebm_dfo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

## Use "N" of the N-d environment as the arg

python3 ibc/ibc/train_eval.py -- \
--alsologtostderr \
--gin_file=ibc/ibc/configs/particle/mlp_ebm_dfo.gin \
--task=PARTICLE \
--tag=dfo \
--add_time=True \
--gin_bindings="train_eval.dataset_path='ibc/data/particle/$1d_oracle_particle*.tfrecord'" \
--gin_bindings="ParticleEnv.n_dim=$1" \
--video
14 changes: 14 additions & 0 deletions ibc/configs/particle/show_vid_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

set -eu

python3 ibc/data/policy_eval.py -- \
--alsologtostderr \
--num_episodes=3 \
--replicas=2 \
--policy=particle_green_then_blue \
--task=PARTICLE \
--use_image_obs=False \
--dataset_path=/tmp/ibc_tmp/data \
--output_path=/tmp/ibc_tmp/vid \
--video
10 changes: 10 additions & 0 deletions ibc/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tf_agents.train.utils import strategy_utils
from tf_agents.train.utils import train_utils
from tf_agents.utils import common
import wandb

flags.DEFINE_string('tag', None,
'Tag for the experiment. Appended to the root_dir.')
Expand Down Expand Up @@ -377,6 +378,15 @@ def main(_):
# eval working. Remove it once we do.
skip_unknown=True)

wandb.init(
project="google-research-ibc",
sync_tensorboard=True,
)
# Print operative gin config to stdout so wandb can intercept.
# (it'd be nice for gin to provide a flat/nested dictionary of values so they
# can be used via wandb's aggregation...)
print(gin.config.config_str())

# For TPU, FLAGS.tpu will be set with a TPU address and FLAGS.use_gpu
# will be False.
# For GPU, FLAGS.tpu will be None and FLAGS.use_gpu will be True.
Expand Down
39 changes: 39 additions & 0 deletions ibc/utils/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# From anzu stuff.

from contextlib import contextmanager
import pdb
import sys
import traceback


@contextmanager
def launch_pdb_on_exception():
"""
Provides a context that will launch interactive pdb console automatically
if an exception is raised.

Example usage with @iex decorator below:

@iex
def my_bad_function():
x = 1
assert False

my_bad_function()
# Should bring up debugger at `assert` statement.
"""
# Adapted from:
# https://github.com/gotcha/ipdb/blob/fc83b4f5f/ipdb/__main__.py#L219-L232

try:
yield
except Exception:
traceback.print_exc()
_, _, tb = sys.exc_info()
pdb.post_mortem(tb)
# Resume original execution.
raise


# See docs for `launch_pdb_on_exception()`.
iex = launch_pdb_on_exception()
19 changes: 11 additions & 8 deletions ibc/utils/make_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@
from tf_agents.drivers import py_driver


def make_video_env(env, video_path):
# Use default control freq for d4rl envs, which don't have a control_frequency
# attr.
control_frequency = getattr(env, 'control_frequency', 30)
video_env = mp4_video_wrapper.Mp4VideoWrapper(
env, control_frequency, frame_interval=1, video_filepath=video_path)
video_env.batch_size = getattr(env, "batch_size", 1)
return video_env


def make_video(agent, env, root_dir, step, strategy):
"""Creates a video of a single rollout from the current policy."""
policy = strategy_policy.StrategyPyTFEagerPolicy(
agent.policy, strategy=strategy)
video_path = os.path.join(root_dir, 'videos', 'ttl=7d', 'vid_%d.mp4' % step)
if not hasattr(env, 'control_frequency'):
# Use this control freq for d4rl envs, which don't have a control_frequency
# attr.
control_frequency = 30
else:
control_frequency = env.control_frequency
video_env = mp4_video_wrapper.Mp4VideoWrapper(
env, control_frequency, frame_interval=1, video_filepath=video_path)
video_env = make_video_env(env, video_path)
driver = py_driver.PyDriver(video_env, policy, observers=[], max_episodes=1)
time_step = video_env.reset()
initial_policy_state = policy.get_initial_state(1)
Expand Down
Loading