diff --git a/README.md b/README.md index 1b1ebc7..a2e5273 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ pip install \ tensorflow==2.6.0 \ keras==2.6.0 \ tf-agents==0.11.0rc0 \ - tqdm==4.62.2 + tqdm==4.62.2 \ + wandb==0.12.7 ``` (Optional): For Mujoco support, see [`docs/mujoco_setup.md`](docs/mujoco_setup.md). Recommended to skip it diff --git a/data/policy_eval.py b/data/policy_eval.py index 324bfa4..1afe4f0 100644 --- a/data/policy_eval.py +++ b/data/policy_eval.py @@ -31,6 +31,7 @@ from ibc.environments.collect.utils import get_oracle as get_oracle_module from ibc.environments.particle import particle # pylint: disable=unused-import from ibc.environments.particle import particle_oracles +from ibc.ibc.utils import make_video as video_module from tf_agents.drivers import py_driver from tf_agents.environments import suite_gym from tf_agents.environments import wrappers @@ -135,7 +136,6 @@ def evaluate(num_episodes, if static_policy: video_path = os.path.join(video_path, static_policy, 'vid.mp4') - if saved_model_path and static_policy: raise ValueError( 'Only pass in either a `saved_model_path` or a `static_policy`.') @@ -187,6 +187,9 @@ def evaluate(num_episodes, py_mode=True, compress_image=True)) + if video: + env = video_module.make_video_env(env, video_path) + driver = py_driver.PyDriver(env, policy, observers, max_episodes=num_episodes) time_step = env.reset() initial_policy_state = policy.get_initial_state(1) @@ -209,10 +212,12 @@ def main(_): raise ValueError( 'A dataset_path must be provided when replicas are specified.') dataset_split_path = os.path.splitext(flags.FLAGS.dataset_path) + output_split_path = os.path.splitext(flags.FLAGS.output_path) context = multiprocessing.get_context() for i in range(flags.FLAGS.replicas): dataset_path = dataset_split_path[0] + '_%d' % i + dataset_split_path[1] + output_path = output_split_path[0] + '_%d' % i + output_split_path[1] kwargs = dict( num_episodes=flags.FLAGS.num_episodes, task=flags.FLAGS.task, @@ -223,7 +228,9 @@ def main(_): checkpoint_path=flags.FLAGS.checkpoint_path, static_policy=flags.FLAGS.policy, dataset_path=dataset_path, - history_length=flags.FLAGS.history_length + history_length=flags.FLAGS.history_length, + video=flags.FLAGS.video, + output_path=output_path, ) job = context.Process(target=evaluate, kwargs=kwargs) job.start() diff --git a/ibc/agents/ibc_agent.py b/ibc/agents/ibc_agent.py index 85c546a..3c9e538 100644 --- a/ibc/agents/ibc_agent.py +++ b/ibc/agents/ibc_agent.py @@ -36,6 +36,13 @@ from tf_agents.utils import nest_utils +def add_tensor_summaries(d, prefix, xs): + # Much simpler version of generate_tensor_summaries. + d[f"{prefix}.min"] = tf.reduce_min(xs) + d[f"{prefix}.avg"] = tf.reduce_mean(xs) + d[f"{prefix}.max"] = tf.reduce_max(xs) + + @gin.configurable class ImplicitBCAgent(base_agent.BehavioralCloningAgent): """TFAgent, implementing implicit behavioral cloning.""" @@ -257,25 +264,19 @@ def _loss(self, if grad_loss is not None: losses_dict['grad_loss'] = tf.reduce_mean(grad_loss) if self._compute_mse: - losses_dict['mse_counter_examples'] = tf.reduce_mean( - mse_counter_examples) + add_tensor_summaries(losses_dict, "mse_counter_examples", mse_counter_examples) opt_dict = dict() - if chain_data is not None and chain_data.energies is not None: - energies = chain_data.energies - opt_dict['overall_energies_avg'] = tf.reduce_mean(energies) - first_energies = energies[0] - opt_dict['first_energies_avg'] = tf.reduce_mean(first_energies) - final_energies = energies[-1] - opt_dict['final_energies_avg'] = tf.reduce_mean(final_energies) - - if chain_data is not None and chain_data.grad_norms is not None: - grad_norms = chain_data.grad_norms - opt_dict['overall_grad_norms_avg'] = tf.reduce_mean(grad_norms) - first_grad_norms = grad_norms[0] - opt_dict['first_grad_norms_avg'] = tf.reduce_mean(first_grad_norms) - final_grad_norms = grad_norms[-1] - opt_dict['final_grad_norms_avg'] = tf.reduce_mean(final_grad_norms) + self._log_energy_info( + opt_dict, + observations, + expanded_actions, + fmt="EnergyStats/{}_pos") + self._log_energy_info( + opt_dict, + observations, + counter_example_actions, + fmt="EnergyStats/{}_neg") losses_dict.update(opt_dict) @@ -337,6 +338,30 @@ def _compute_ebm_loss( return per_example_loss, debug_dict + def _log_energy_info( + self, + opt_dict, + observations, + actions, + *, + fmt): + assert not self._late_fusion # TODO(eric): Support? + B, K, _ = actions.shape + reshape_actions = tf.reshape(actions, ((B * K, -1))) + tiled_obs = nest_utils.tile_batch(observations, K) + de_dact, energies = mcmc.gradient_wrt_act( + self.cloning_network, + tiled_obs, + reshape_actions, + training=False, + network_state=(), + tfa_step_type=(), + apply_exp=False, + obs_encoding=None) + grad_norms = mcmc.compute_grad_norm(self._grad_norm_type, de_dact) + add_tensor_summaries(opt_dict, fmt.format("energies"), energies) + add_tensor_summaries(opt_dict, fmt.format("grad_norms"), grad_norms) + def _make_counter_example_actions( self, observations, # B x obs_spec diff --git a/ibc/configs/particle/mlp_ebm_dfo.gin b/ibc/configs/particle/mlp_ebm_dfo.gin new file mode 100644 index 0000000..2e94906 --- /dev/null +++ b/ibc/configs/particle/mlp_ebm_dfo.gin @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2021 The Reach ML Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +train_eval.root_dir = '/tmp/ibc_logs/mlp_ebm_dfo' +train_eval.loss_type = 'ebm' # specifies we are training ebm. +train_eval.network = 'MLPEBM' +train_eval.batch_size = 512 +ImplicitBCAgent.num_counter_examples = 8 +train_eval.num_iterations = 100000 +train_eval.replay_capacity = 10000 +train_eval.eval_interval = 10000 +train_eval.eval_episodes = 20 +train_eval.learning_rate = 1e-3 +train_eval.goal_tolerance = 0.02 +train_eval.seed = 0 +train_eval.sequence_length = 2 +train_eval.dataset_eval_fraction = 0.0 + +IbcPolicy.num_action_samples = 512 +train_eval.uniform_boundary_buffer = 0.05 +get_normalizers.nested_obs = True # Particle has nested +get_normalizers.num_samples = 5000 +compute_dataset_statistics.min_max_actions = True + +IbcPolicy.use_dfo = True +IbcPolicy.use_langevin = False +IbcPolicy.optimize_again = False + +# Configs for cloning net. +# MLPEBM.layers = 'ResNetPreActivation' +MLPEBM.layers = 'MLPDropout' +MLPEBM.width = 256 +MLPEBM.depth = 2 +MLPEBM.rate = 0.0 +MLPEBM.dense_layer_type = 'regular' +MLPEBM.activation = 'relu' +ImplicitBCAgent.compute_mse = True +ImplicitBCAgent.add_grad_penalty = False +ResNetLayer.normalizer = None diff --git a/ibc/configs/particle/mlp_ebm_langevin.gin b/ibc/configs/particle/mlp_ebm_langevin.gin index 9bdbdb2..5d3ce71 100644 --- a/ibc/configs/particle/mlp_ebm_langevin.gin +++ b/ibc/configs/particle/mlp_ebm_langevin.gin @@ -25,7 +25,7 @@ train_eval.eval_episodes = 20 train_eval.learning_rate = 1e-3 train_eval.goal_tolerance = 0.02 train_eval.seed = 0 -train_eval.sequence_length = 2 +train_eval.sequence_length = 1 train_eval.dataset_eval_fraction = 0.0 IbcPolicy.num_action_samples = 512 @@ -39,7 +39,7 @@ IbcPolicy.use_langevin = True IbcPolicy.optimize_again = False # Configs for cloning net. -MLPEBM.layers = 'ResNetPreActivation' +MLPEBM.layers = 'MLPDropout' MLPEBM.width = 256 MLPEBM.depth = 2 MLPEBM.rate = 0.0 @@ -54,7 +54,7 @@ grad_penalty.only_apply_final_grad_penalty = True # Note: # we actually get slightly better results with this as False, # however we get an OOM with it as False. TODO(peteflorence,oars): # investigate if can lessen memory. -ImplicitBCAgent.run_full_chain_under_gradient = True +ImplicitBCAgent.run_full_chain_under_gradient = False langevin_actions_given_obs.num_iterations = 100 ImplicitBCAgent.fraction_langevin_samples = 1.0 diff --git a/ibc/configs/particle/run_mlp_ebm_dfo.sh b/ibc/configs/particle/run_mlp_ebm_dfo.sh new file mode 100755 index 0000000..23063b5 --- /dev/null +++ b/ibc/configs/particle/run_mlp_ebm_dfo.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +## Use "N" of the N-d environment as the arg + +python3 ibc/ibc/train_eval.py -- \ + --alsologtostderr \ + --gin_file=ibc/ibc/configs/particle/mlp_ebm_dfo.gin \ + --task=PARTICLE \ + --tag=dfo \ + --add_time=True \ + --gin_bindings="train_eval.dataset_path='ibc/data/particle/$1d_oracle_particle*.tfrecord'" \ + --gin_bindings="ParticleEnv.n_dim=$1" \ + --video diff --git a/ibc/configs/particle/show_vid_example.sh b/ibc/configs/particle/show_vid_example.sh new file mode 100755 index 0000000..5591191 --- /dev/null +++ b/ibc/configs/particle/show_vid_example.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -eu + +python3 ibc/data/policy_eval.py -- \ + --alsologtostderr \ + --num_episodes=3 \ + --replicas=2 \ + --policy=particle_green_then_blue \ + --task=PARTICLE \ + --use_image_obs=False \ + --dataset_path=/tmp/ibc_tmp/data \ + --output_path=/tmp/ibc_tmp/vid \ + --video diff --git a/ibc/train_eval.py b/ibc/train_eval.py index 71e9225..9ca9374 100644 --- a/ibc/train_eval.py +++ b/ibc/train_eval.py @@ -44,6 +44,7 @@ from tf_agents.train.utils import strategy_utils from tf_agents.train.utils import train_utils from tf_agents.utils import common +import wandb flags.DEFINE_string('tag', None, 'Tag for the experiment. Appended to the root_dir.') @@ -377,6 +378,15 @@ def main(_): # eval working. Remove it once we do. skip_unknown=True) + wandb.init( + project="google-research-ibc", + sync_tensorboard=True, + ) + # Print operative gin config to stdout so wandb can intercept. + # (it'd be nice for gin to provide a flat/nested dictionary of values so they + # can be used via wandb's aggregation...) + print(gin.config.config_str()) + # For TPU, FLAGS.tpu will be set with a TPU address and FLAGS.use_gpu # will be False. # For GPU, FLAGS.tpu will be None and FLAGS.use_gpu will be True. diff --git a/ibc/utils/debug.py b/ibc/utils/debug.py new file mode 100644 index 0000000..a88d818 --- /dev/null +++ b/ibc/utils/debug.py @@ -0,0 +1,39 @@ +# From anzu stuff. + +from contextlib import contextmanager +import pdb +import sys +import traceback + + +@contextmanager +def launch_pdb_on_exception(): + """ + Provides a context that will launch interactive pdb console automatically + if an exception is raised. + + Example usage with @iex decorator below: + + @iex + def my_bad_function(): + x = 1 + assert False + + my_bad_function() + # Should bring up debugger at `assert` statement. + """ + # Adapted from: + # https://github.com/gotcha/ipdb/blob/fc83b4f5f/ipdb/__main__.py#L219-L232 + + try: + yield + except Exception: + traceback.print_exc() + _, _, tb = sys.exc_info() + pdb.post_mortem(tb) + # Resume original execution. + raise + + +# See docs for `launch_pdb_on_exception()`. +iex = launch_pdb_on_exception() diff --git a/ibc/utils/make_video.py b/ibc/utils/make_video.py index 331cd57..d6eb6f7 100644 --- a/ibc/utils/make_video.py +++ b/ibc/utils/make_video.py @@ -25,19 +25,22 @@ from tf_agents.drivers import py_driver +def make_video_env(env, video_path): + # Use default control freq for d4rl envs, which don't have a control_frequency + # attr. + control_frequency = getattr(env, 'control_frequency', 30) + video_env = mp4_video_wrapper.Mp4VideoWrapper( + env, control_frequency, frame_interval=1, video_filepath=video_path) + video_env.batch_size = getattr(env, "batch_size", 1) + return video_env + + def make_video(agent, env, root_dir, step, strategy): """Creates a video of a single rollout from the current policy.""" policy = strategy_policy.StrategyPyTFEagerPolicy( agent.policy, strategy=strategy) video_path = os.path.join(root_dir, 'videos', 'ttl=7d', 'vid_%d.mp4' % step) - if not hasattr(env, 'control_frequency'): - # Use this control freq for d4rl envs, which don't have a control_frequency - # attr. - control_frequency = 30 - else: - control_frequency = env.control_frequency - video_env = mp4_video_wrapper.Mp4VideoWrapper( - env, control_frequency, frame_interval=1, video_filepath=video_path) + video_env = make_video_env(env, video_path) driver = py_driver.PyDriver(video_env, policy, observers=[], max_episodes=1) time_step = video_env.reset() initial_policy_state = policy.get_initial_state(1) diff --git a/requirements.freeze.txt b/requirements.freeze.txt new file mode 100644 index 0000000..b20ea55 --- /dev/null +++ b/requirements.freeze.txt @@ -0,0 +1,85 @@ +absl-py==0.12.0 +astunparse==1.6.3 +backcall==0.2.0 +cachetools==4.2.4 +certifi==2021.10.8 +charset-normalizer==2.0.7 +clang==5.0 +click==8.0.3 +cloudpickle==2.0.0 +configparser==5.2.0 +cycler==0.11.0 +decorator==5.1.0 +dm-tree==0.1.6 +docker-pycreds==0.4.0 +flatbuffers==1.12 +gast==0.4.0 +gin-config==0.4.0 +gitdb==4.0.9 +GitPython==3.1.24 +google-auth==2.3.3 +google-auth-oauthlib==0.4.6 +google-pasta==0.2.0 +grpcio==1.42.0 +gym==0.21.0 +h5py==3.1.0 +idna==3.3 +importlib-metadata==4.8.2 +ipython==7.29.0 +jedi==0.18.1 +keras==2.6.0 +Keras-Preprocessing==1.1.2 +kiwisolver==1.3.2 +Markdown==3.3.6 +matplotlib==3.4.3 +matplotlib-inline==0.1.3 +mediapy==1.0.3 +numpy==1.19.5 +oauthlib==3.1.1 +opencv-python==4.5.3.56 +opt-einsum==3.3.0 +parso==0.8.2 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.4.0 +pkg_resources==0.0.0 +promise==2.3 +prompt-toolkit==3.0.22 +protobuf==3.19.1 +psutil==5.8.0 +ptyprocess==0.7.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pybullet==3.1.6 +Pygments==2.10.0 +pyparsing==3.0.6 +python-dateutil==2.8.2 +PyYAML==6.0 +requests==2.26.0 +requests-oauthlib==1.3.0 +rsa==4.7.2 +scipy==1.7.1 +sentry-sdk==1.5.0 +shortuuid==1.0.8 +six==1.15.0 +smmap==5.0.0 +subprocess32==3.5.4 +tensorboard==2.7.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.0 +tensorflow==2.6.0 +tensorflow-estimator==2.7.0 +tensorflow-probability==0.14.1 +termcolor==1.1.0 +tf-agents==0.11.0rc0 +tqdm==4.62.2 +traitlets==5.1.1 +typing-extensions==3.7.4.3 +urllib3==1.26.7 +wandb==0.12.7 +wcwidth==0.2.5 +Werkzeug==2.0.2 +wrapt==1.12.1 +yaspin==2.1.0 +zipp==3.6.0