diff --git a/acme/wrappers/frame_stacking.py b/acme/wrappers/frame_stacking.py index ce06dd71d7..1312788af7 100644 --- a/acme/wrappers/frame_stacking.py +++ b/acme/wrappers/frame_stacking.py @@ -18,17 +18,25 @@ from acme import types from acme.wrappers import base +import enum import dm_env from dm_env import specs as dm_env_specs import numpy as np import tree +class FillBehavior(enum.Enum): + """Class to enumerate available options for frame-stacking behavior at episode starts.""" + ZEROS = 'zeros' + FIRST = 'first' + + class FrameStackingWrapper(base.EnvironmentWrapper): """Wrapper that stacks observations along a new final axis.""" def __init__(self, environment: dm_env.Environment, num_frames: int = 4, - flatten: bool = False): + flatten: bool = False, + fill_behavior: FillBehavior = FillBehavior.ZEROS): """Initializes a new FrameStackingWrapper. Args: @@ -39,7 +47,8 @@ def __init__(self, environment: dm_env.Environment, num_frames: int = 4, self._environment = environment original_spec = self._environment.observation_spec() self._stackers = tree.map_structure( - lambda _: FrameStacker(num_frames=num_frames, flatten=flatten), + lambda _: FrameStacker( + num_frames=num_frames, flatten=flatten, fill_behavior=fill_behavior), self._environment.observation_spec()) self._observation_spec = tree.map_structure( lambda stacker, spec: stacker.update_spec(spec), @@ -65,9 +74,16 @@ def observation_spec(self) -> types.NestedSpec: class FrameStacker: """Simple class for frame-stacking observations.""" - def __init__(self, num_frames: int, flatten: bool = False): + def __init__(self, + num_frames: int, + flatten: bool = False, + fill_behavior: FillBehavior = FillBehavior.ZEROS + ): self._num_frames = num_frames self._flatten = flatten + if not isinstance(fill_behavior, FillBehavior): + raise TypeError("Expect fill_behavior to be an FillBehavior enum") + self._fill_behavior = fill_behavior self.reset() @property @@ -80,8 +96,11 @@ def reset(self): def step(self, frame: np.ndarray) -> np.ndarray: """Append frame to stack and return the stack.""" if not self._stack: - # Fill stack with blank frames if empty. - self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1)) + if self._fill_behavior == FillBehavior.ZEROS: + # Fill stack with blank frames if empty. + self._stack.extend([np.zeros_like(frame)] * (self._num_frames - 1)) + elif self._fill_behavior == FillBehavior.FIRST: + self._stack.extend([frame] * (self._num_frames - 1)) self._stack.append(frame) stacked_frames = np.stack(self._stack, axis=-1) diff --git a/acme/wrappers/frame_stacking_test.py b/acme/wrappers/frame_stacking_test.py index ff21f47e2e..ff25ae9ba8 100644 --- a/acme/wrappers/frame_stacking_test.py +++ b/acme/wrappers/frame_stacking_test.py @@ -21,6 +21,7 @@ from absl.testing import absltest +from acme.wrappers import frame_stacking class FakeNonZeroObservationEnvironment(fakes.ContinuousEnvironment): """Fake environment with non-zero observations.""" @@ -76,6 +77,19 @@ def test_second_reset(self): timestep = env.reset() self.assertTrue(np.all(timestep.observation[..., 0] == 0)) + def test_fill_behavior(self): + original_env = FakeNonZeroObservationEnvironment() + env = wrappers.FrameStackingWrapper( + original_env, 2, + fill_behavior=frame_stacking.FillBehavior.FIRST, + ) + action_spec = env.action_spec() + + env.reset() + env.step(action_spec.generate_value()) + timestep = env.reset() + self.assertTrue(np.all(timestep.observation == 1)) + if __name__ == '__main__': absltest.main()