diff --git a/acme/agents/tf/mcts/acting.py b/acme/agents/tf/mcts/acting.py index 887d7c25e5..3a8673efbf 100644 --- a/acme/agents/tf/mcts/acting.py +++ b/acme/agents/tf/mcts/acting.py @@ -22,6 +22,7 @@ from acme.agents.tf.mcts import models from acme.agents.tf.mcts import search from acme.agents.tf.mcts import types +from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils import dm_env @@ -66,7 +67,9 @@ def __init__( def _forward( self, observation: types.Observation) -> Tuple[types.Probs, types.Value]: """Performs a forward pass of the policy-value network.""" - logits, value = self._network(tf.expand_dims(observation, axis=0)) + # Use tree.map_structure to support nested observation structures. + batched_observation = tf2_utils.add_batch_dim(observation) + logits, value = self._network(batched_observation) # Convert to numpy & take softmax. logits = logits.numpy().squeeze(axis=0) diff --git a/acme/agents/tf/mcts/types.py b/acme/agents/tf/mcts/types.py index 93d8d8cd11..3dcaf7479c 100644 --- a/acme/agents/tf/mcts/types.py +++ b/acme/agents/tf/mcts/types.py @@ -14,7 +14,7 @@ """Type aliases and assumptions that are specific to the MCTS agent.""" -from typing import Callable, Tuple, Union +from typing import Any, Callable, Tuple, Union import numpy as np # pylint: disable=invalid-name @@ -22,8 +22,9 @@ # Assumption: actions are scalar and discrete (integral). Action = Union[int, np.int32, np.int64] -# Assumption: observations are array-like. -Observation = np.ndarray +# Observations can be array-like or nested structures (e.g., dicts, tuples). +# This allows MCTS to work with environments that have complex observation spaces. +Observation = Any # Assumption: rewards and discounts are scalar. Reward = Union[float, np.float32, np.float64]