Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions acme/wrappers/frame_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@

from acme import types
from acme.wrappers import base
import enum
import dm_env
from dm_env import specs as dm_env_specs
import numpy as np
import tree


class FillBehavior(enum.Enum):
"""Class to enumerate available options for frame-stacking behavior at episode starts."""
ZEROS = 'zeros'
FIRST = 'first'


class FrameStackingWrapper(base.EnvironmentWrapper):
"""Wrapper that stacks observations along a new final axis."""

def __init__(self, environment: dm_env.Environment, num_frames: int = 4,
flatten: bool = False):
flatten: bool = False,
fill_behavior: FillBehavior = FillBehavior.ZEROS):
"""Initializes a new FrameStackingWrapper.

Args:
Expand All @@ -39,7 +47,8 @@ def __init__(self, environment: dm_env.Environment, num_frames: int = 4,
self._environment = environment
original_spec = self._environment.observation_spec()
self._stackers = tree.map_structure(
lambda _: FrameStacker(num_frames=num_frames, flatten=flatten),
lambda _: FrameStacker(
num_frames=num_frames, flatten=flatten, fill_behavior=fill_behavior),
self._environment.observation_spec())
self._observation_spec = tree.map_structure(
lambda stacker, spec: stacker.update_spec(spec),
Expand All @@ -65,9 +74,16 @@ def observation_spec(self) -> types.NestedSpec:
class FrameStacker:
"""Simple class for frame-stacking observations."""

def __init__(self, num_frames: int, flatten: bool = False):
def __init__(self,
num_frames: int,
flatten: bool = False,
fill_behavior: FillBehavior = FillBehavior.ZEROS
):
self._num_frames = num_frames
self._flatten = flatten
if not isinstance(fill_behavior, FillBehavior):
raise TypeError("Expect fill_behavior to be an FillBehavior enum")
self._fill_behavior = fill_behavior
self.reset()

@property
Expand All @@ -80,8 +96,11 @@ def reset(self):
def step(self, frame: np.ndarray) -> np.ndarray:
"""Append frame to stack and return the stack."""
if not self._stack:
# Fill stack with blank frames if empty.
self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1))
if self._fill_behavior == FillBehavior.ZEROS:
# Fill stack with blank frames if empty.
self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1))
elif self._fill_behavior == FillBehavior.FIRST:
self._stack.extend([frame] * (self._num_frames - 1))
self._stack.append(frame)
stacked_frames = np.stack(self._stack, axis=-1)

Expand Down
14 changes: 14 additions & 0 deletions acme/wrappers/frame_stacking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from absl.testing import absltest

from acme.wrappers import frame_stacking

class FakeNonZeroObservationEnvironment(fakes.ContinuousEnvironment):
"""Fake environment with non-zero observations."""
Expand Down Expand Up @@ -76,6 +77,19 @@ def test_second_reset(self):
timestep = env.reset()
self.assertTrue(np.all(timestep.observation[..., 0] == 0))

def test_fill_behavior(self):
original_env = FakeNonZeroObservationEnvironment()
env = wrappers.FrameStackingWrapper(
original_env, 2,
fill_behavior=frame_stacking.FillBehavior.FIRST,
)
action_spec = env.action_spec()

env.reset()
env.step(action_spec.generate_value())
timestep = env.reset()
self.assertTrue(np.all(timestep.observation == 1))


if __name__ == '__main__':
absltest.main()