From 892402ac7f2967c05983a703e44789cf5e71e02e Mon Sep 17 00:00:00 2001 From: natinew77-creator Date: Sun, 7 Dec 2025 19:07:51 -0500 Subject: [PATCH] Support nested observation structures in MCTS agent The MCTSActor._forward() method previously hard-coded tf.expand_dims() directly on the observation, which only works for array-like observations (np.ndarray). This prevented using nested structures (dicts, tuples) as observations. Changes: - Modified acting.py to use tf2_utils.add_batch_dim() which internally uses tree.map_structure() to apply tf.expand_dims to each leaf of the observation structure - Updated types.py Observation type from np.ndarray to Any to allow nested structures This follows the pattern used elsewhere in the codebase (see tf/utils.py) and allows MCTS to work with environments that have complex observation spaces. Fixes #341 --- acme/agents/tf/mcts/acting.py | 5 ++++- acme/agents/tf/mcts/types.py | 7 ++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/acme/agents/tf/mcts/acting.py b/acme/agents/tf/mcts/acting.py index 887d7c25e5..3a8673efbf 100644 --- a/acme/agents/tf/mcts/acting.py +++ b/acme/agents/tf/mcts/acting.py @@ -22,6 +22,7 @@ from acme.agents.tf.mcts import models from acme.agents.tf.mcts import search from acme.agents.tf.mcts import types +from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils import dm_env @@ -66,7 +67,9 @@ def __init__( def _forward( self, observation: types.Observation) -> Tuple[types.Probs, types.Value]: """Performs a forward pass of the policy-value network.""" - logits, value = self._network(tf.expand_dims(observation, axis=0)) + # Use tree.map_structure to support nested observation structures. + batched_observation = tf2_utils.add_batch_dim(observation) + logits, value = self._network(batched_observation) # Convert to numpy & take softmax. logits = logits.numpy().squeeze(axis=0) diff --git a/acme/agents/tf/mcts/types.py b/acme/agents/tf/mcts/types.py index 93d8d8cd11..3dcaf7479c 100644 --- a/acme/agents/tf/mcts/types.py +++ b/acme/agents/tf/mcts/types.py @@ -14,7 +14,7 @@ """Type aliases and assumptions that are specific to the MCTS agent.""" -from typing import Callable, Tuple, Union +from typing import Any, Callable, Tuple, Union import numpy as np # pylint: disable=invalid-name @@ -22,8 +22,9 @@ # Assumption: actions are scalar and discrete (integral). Action = Union[int, np.int32, np.int64] -# Assumption: observations are array-like. -Observation = np.ndarray +# Observations can be array-like or nested structures (e.g., dicts, tuples). +# This allows MCTS to work with environments that have complex observation spaces. +Observation = Any # Assumption: rewards and discounts are scalar. Reward = Union[float, np.float32, np.float64]