diff --git a/acme/adders/reverb/episode.py b/acme/adders/reverb/episode.py index 2b5f644857..6da5f1984a 100644 --- a/acme/adders/reverb/episode.py +++ b/acme/adders/reverb/episode.py @@ -34,7 +34,26 @@ class EpisodeAdder(base.ReverbAdder): - """Adder which adds entire episodes as trajectories.""" + """Adder which adds entire episodes as trajectories. + + This adder accumulates all steps of an episode and inserts them as a single + trajectory item into Reverb at the end of the episode. It is useful for + algorithms that require full episodes (e.g., for offline learning or MCTS). + + Args: + client: The Reverb client to use for data insertion. + max_sequence_length: The maximum length of an episode. Episodes longer + than this will raise a ValueError. If padding_fn is provided, episodes + shorter than this will be padded to this length. + delta_encoded: Whether to use delta encoding for the trajectory. + priority_fns: A mapping from table names to priority functions. + max_in_flight_items: The maximum number of items allowed to be in flight + (being sent to Reverb) at the same time. + padding_fn: An optional callable that takes a shape and dtype and returns + a zero-filled (or otherwise equivalent 'empty') array of that shape and + dtype. If provided, episodes shorter than max_sequence_length will be + padded. + """ def __init__( self, diff --git a/acme/tf/losses/distributional.py b/acme/tf/losses/distributional.py index 54c0560c92..124470e378 100644 --- a/acme/tf/losses/distributional.py +++ b/acme/tf/losses/distributional.py @@ -18,10 +18,19 @@ import tensorflow as tf +def _validate_distribution(dist, name): + if not hasattr(dist, 'values') or not hasattr(dist, 'logits'): + raise TypeError( + f"Argument '{name}' must be a distribution with 'values' and 'logits' " + f"properties (e.g. from DiscreteValuedHead), but got {type(dist)}.") + + def categorical(q_tm1: networks.DiscreteValuedDistribution, r_t: tf.Tensor, d_t: tf.Tensor, q_t: networks.DiscreteValuedDistribution) -> tf.Tensor: """Implements the Categorical Distributional TD(0)-learning loss.""" + _validate_distribution(q_tm1, 'q_tm1') + _validate_distribution(q_t, 'q_t') z_t = tf.reshape(r_t, (-1, 1)) + tf.reshape(d_t, (-1, 1)) * q_t.values p_t = tf.nn.softmax(q_t.logits) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 84273905bf..72f0cb2e01 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -65,21 +65,8 @@ def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: return outputs -# FIXME: Add functionality to support decaying epsilon parameter. -# FIXME: This is a modified version of trfl's epsilon_greedy() which -# incorporates code from the bug fix described here -# https://github.com/deepmind/trfl/pull/28 -class EpsilonGreedy(snt.Module): - """Computes an epsilon-greedy distribution over actions. - - This policy does the following: - - With probability 1 - epsilon, take the action corresponding to the highest - action value, breaking ties uniformly at random. - - With probability epsilon, take an action uniformly at random. - """ - def __init__(self, - epsilon: Union[tf.Tensor, float], + epsilon: Union[tf.Tensor, float, tf.Variable], threshold: float, name: str = 'EpsilonGreedy'): """Initialize the policy. @@ -95,7 +82,10 @@ def __init__(self, policy. """ super().__init__(name=name) - self._epsilon = tf.Variable(epsilon, trainable=False) + if isinstance(epsilon, tf.Variable): + self._epsilon = epsilon + else: + self._epsilon = tf.Variable(epsilon, trainable=False) self._threshold = threshold def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: diff --git a/setup.py b/setup.py index bd46a62cb2..7836f5f850 100755 --- a/setup.py +++ b/setup.py @@ -37,11 +37,11 @@ # sure this constraint is upheld. tensorflow = [ - 'tensorflow==2.8.0', - 'tensorflow_probability==0.15.0', - 'tensorflow_datasets==4.6.0', - 'dm-reverb==0.7.2', - 'dm-launchpad==0.5.2', + 'tensorflow>=2.8.0', + 'tensorflow_probability>=0.15.0', + 'tensorflow_datasets>=4.6.0', + 'dm-reverb>=0.7.2', + 'dm-launchpad>=0.5.2', ] core_requirements = [ @@ -54,8 +54,8 @@ ] jax_requirements = [ - 'jax==0.4.3', - 'jaxlib==0.4.3', + 'jax>=0.4.3', + 'jaxlib>=0.4.3', 'chex', 'dm-haiku', 'flax', @@ -77,9 +77,9 @@ 'atari-py', 'bsuite', 'dm-control', - 'gym==0.25.0', + 'gym>=0.25.0,<0.26.0', 'gym[atari]', - 'pygame==2.1.0', + 'pygame>=2.1.0', 'rlds', ]