From 701da103e8b9c12922540037dabb057644081fb2 Mon Sep 17 00:00:00 2001 From: natinew77-creator Date: Fri, 19 Dec 2025 21:19:40 -0500 Subject: [PATCH] Fix dependencies, improve EpisodeAdder docs, and enable decaying epsilon --- acme/adders/reverb/episode.py | 21 ++++++++++++++++++++- acme/tf/networks/legal_actions.py | 20 +++++--------------- setup.py | 18 +++++++++--------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/acme/adders/reverb/episode.py b/acme/adders/reverb/episode.py index 2b5f644857..6da5f1984a 100644 --- a/acme/adders/reverb/episode.py +++ b/acme/adders/reverb/episode.py @@ -34,7 +34,26 @@ class EpisodeAdder(base.ReverbAdder): - """Adder which adds entire episodes as trajectories.""" + """Adder which adds entire episodes as trajectories. + + This adder accumulates all steps of an episode and inserts them as a single + trajectory item into Reverb at the end of the episode. It is useful for + algorithms that require full episodes (e.g., for offline learning or MCTS). + + Args: + client: The Reverb client to use for data insertion. + max_sequence_length: The maximum length of an episode. Episodes longer + than this will raise a ValueError. If padding_fn is provided, episodes + shorter than this will be padded to this length. + delta_encoded: Whether to use delta encoding for the trajectory. + priority_fns: A mapping from table names to priority functions. + max_in_flight_items: The maximum number of items allowed to be in flight + (being sent to Reverb) at the same time. + padding_fn: An optional callable that takes a shape and dtype and returns + a zero-filled (or otherwise equivalent 'empty') array of that shape and + dtype. If provided, episodes shorter than max_sequence_length will be + padded. + """ def __init__( self, diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 84273905bf..72f0cb2e01 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -65,21 +65,8 @@ def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: return outputs -# FIXME: Add functionality to support decaying epsilon parameter. -# FIXME: This is a modified version of trfl's epsilon_greedy() which -# incorporates code from the bug fix described here -# https://github.com/deepmind/trfl/pull/28 -class EpsilonGreedy(snt.Module): - """Computes an epsilon-greedy distribution over actions. - - This policy does the following: - - With probability 1 - epsilon, take the action corresponding to the highest - action value, breaking ties uniformly at random. - - With probability epsilon, take an action uniformly at random. - """ - def __init__(self, - epsilon: Union[tf.Tensor, float], + epsilon: Union[tf.Tensor, float, tf.Variable], threshold: float, name: str = 'EpsilonGreedy'): """Initialize the policy. @@ -95,7 +82,10 @@ def __init__(self, policy. """ super().__init__(name=name) - self._epsilon = tf.Variable(epsilon, trainable=False) + if isinstance(epsilon, tf.Variable): + self._epsilon = epsilon + else: + self._epsilon = tf.Variable(epsilon, trainable=False) self._threshold = threshold def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: diff --git a/setup.py b/setup.py index bd46a62cb2..7836f5f850 100755 --- a/setup.py +++ b/setup.py @@ -37,11 +37,11 @@ # sure this constraint is upheld. tensorflow = [ - 'tensorflow==2.8.0', - 'tensorflow_probability==0.15.0', - 'tensorflow_datasets==4.6.0', - 'dm-reverb==0.7.2', - 'dm-launchpad==0.5.2', + 'tensorflow>=2.8.0', + 'tensorflow_probability>=0.15.0', + 'tensorflow_datasets>=4.6.0', + 'dm-reverb>=0.7.2', + 'dm-launchpad>=0.5.2', ] core_requirements = [ @@ -54,8 +54,8 @@ ] jax_requirements = [ - 'jax==0.4.3', - 'jaxlib==0.4.3', + 'jax>=0.4.3', + 'jaxlib>=0.4.3', 'chex', 'dm-haiku', 'flax', @@ -77,9 +77,9 @@ 'atari-py', 'bsuite', 'dm-control', - 'gym==0.25.0', + 'gym>=0.25.0,<0.26.0', 'gym[atari]', - 'pygame==2.1.0', + 'pygame>=2.1.0', 'rlds', ]