Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion acme/agents/tf/mcts/acting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from acme.agents.tf.mcts import models
from acme.agents.tf.mcts import search
from acme.agents.tf.mcts import types
from acme.tf import utils as tf2_utils
from acme.tf import variable_utils as tf2_variable_utils

import dm_env
Expand Down Expand Up @@ -66,7 +67,9 @@ def __init__(
def _forward(
self, observation: types.Observation) -> Tuple[types.Probs, types.Value]:
"""Performs a forward pass of the policy-value network."""
logits, value = self._network(tf.expand_dims(observation, axis=0))
# Use tree.map_structure to support nested observation structures.
batched_observation = tf2_utils.add_batch_dim(observation)
logits, value = self._network(batched_observation)

# Convert to numpy & take softmax.
logits = logits.numpy().squeeze(axis=0)
Expand Down
7 changes: 4 additions & 3 deletions acme/agents/tf/mcts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@

"""Type aliases and assumptions that are specific to the MCTS agent."""

from typing import Callable, Tuple, Union
from typing import Any, Callable, Tuple, Union
import numpy as np

# pylint: disable=invalid-name

# Assumption: actions are scalar and discrete (integral).
Action = Union[int, np.int32, np.int64]

# Assumption: observations are array-like.
Observation = np.ndarray
# Observations can be array-like or nested structures (e.g., dicts, tuples).
# This allows MCTS to work with environments that have complex observation spaces.
Observation = Any

# Assumption: rewards and discounts are scalar.
Reward = Union[float, np.float32, np.float64]
Expand Down