diff --git a/recurrent_buffer.ipynb b/recurrent_buffer.ipynb new file mode 100644 index 0000000..01e9264 --- /dev/null +++ b/recurrent_buffer.ipynb @@ -0,0 +1,1434 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple\n", + "\n", + "import numpy as np\n", + "import torch as th\n", + "import jax.numpy as jnp\n", + "from gymnasium import spaces\n", + "# TODO : see later how to enable DictRolloutBuffer\n", + "from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer\n", + "from stable_baselines3.common.vec_env import VecNormalize\n", + "from stable_baselines3.common.env_util import make_vec_env\n", + "\n", + "\n", + "# TODO : add type aliases for the NamedTuple\n", + "class LSTMStates(NamedTuple):\n", + " pi: Tuple\n", + " vf: Tuple\n", + "\n", + "# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 fns)\n", + "# Added lstm states but also dones because they are used in actor and critic\n", + "class RecurrentRolloutBufferSamples(NamedTuple):\n", + " observations: jnp.ndarray\n", + " actions: jnp.ndarray\n", + " old_values: jnp.ndarray\n", + " old_log_prob: jnp.ndarray\n", + " advantages: jnp.ndarray\n", + " returns: jnp.ndarray\n", + " dones: jnp.ndarray\n", + " lstm_states: LSTMStates\n", + "\n", + "# Add a recurrent buffer that also takes care of the lstm states and dones flags\n", + "class RecurrentRolloutBuffer(RolloutBuffer):\n", + " \"\"\"\n", + " Rollout buffer that also stores the LSTM cell and hidden states.\n", + "\n", + " :param buffer_size: Max number of element in the buffer\n", + " :param observation_space: Observation space\n", + " :param action_space: Action space\n", + " :param hidden_state_shape: Shape of the buffer that will collect lstm states\n", + " (n_steps, lstm.num_layers, n_envs, lstm.hidden_size)\n", + " :param device: PyTorch device\n", + " :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator\n", + " Equivalent to classic advantage when set to 1.\n", + " :param gamma: Discount factor\n", + " :param n_envs: Number of parallel environments\n", + " \"\"\"\n", + "\n", + " def __init__( \n", + " self,\n", + " buffer_size: int,\n", + " observation_space: spaces.Space,\n", + " action_space: spaces.Space,\n", + " # renamed this because I found hidden_state_shape confusing\n", + " lstm_state_buffer_shape: Tuple[int, int, int],\n", + " device: Union[th.device, str] = \"auto\",\n", + " gae_lambda: float = 1,\n", + " gamma: float = 0.99,\n", + " n_envs: int = 1,\n", + " ): \n", + " self.hidden_state_shape = lstm_state_buffer_shape\n", + " self.seq_start_indices, self.seq_end_indices = None, None\n", + " super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)\n", + "\n", + " # TODO : remove dones because already episode starts in the buffer\n", + " def reset(self):\n", + " super().reset()\n", + " # also add the dones and all lstm states\n", + " self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)\n", + " self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + "\n", + " # TODO : remove dones because already episode starts in the buffer\n", + " def add(self, *args, dones, lstm_states, **kwargs) -> None:\n", + " \"\"\"\n", + " :param hidden_states: LSTM cell and hidden state\n", + " \"\"\"\n", + " self.dones[self.pos] = np.array(dones)\n", + " self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0])\n", + " self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1])\n", + " self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0])\n", + " self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1])\n", + "\n", + " super().add(*args, **kwargs)\n", + "\n", + " def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:\n", + " assert self.full, \"Rollout buffer must be full before sampling from it\"\n", + "\n", + " # Prepare the data\n", + " if not self.generator_ready:\n", + " # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)\n", + " # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)\n", + " for tensor in [\"hidden_states_pi\", \"cell_states_pi\", \"hidden_states_vf\", \"cell_states_vf\"]:\n", + " self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)\n", + "\n", + " # flatten but keep the sequence order\n", + " # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)\n", + " # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)\n", + " for tensor in [\n", + " \"observations\",\n", + " \"actions\",\n", + " \"values\",\n", + " \"log_probs\",\n", + " \"advantages\",\n", + " \"returns\",\n", + " \"dones\",\n", + " \"hidden_states_pi\",\n", + " \"cell_states_pi\",\n", + " \"hidden_states_vf\",\n", + " \"cell_states_vf\",\n", + " \"episode_starts\",\n", + " ]:\n", + " self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])\n", + " self.generator_ready = True\n", + "\n", + " # Return everything, don't create minibatches\n", + " if batch_size is None:\n", + " batch_size = self.buffer_size * self.n_envs\n", + "\n", + " # TODO : See how to effectively use the indices to conserve temporal order in the batch data during updates\n", + " # TODO : I think the easisest way is to ensure the n_steps is a multiple of batch_size\n", + " # TODO : But still need to be fixed at the moment (I just made sure the returned shape was right)\n", + " indices = np.arange(self.buffer_size * self.n_envs)\n", + "\n", + " start_idx = 0\n", + " while start_idx < self.buffer_size * self.n_envs:\n", + " batch_inds = indices[start_idx : start_idx + batch_size]\n", + " yield self._get_samples(batch_inds)\n", + " start_idx += batch_size\n", + "\n", + " # return the lstm states as an LSTMStates tuple\n", + " def _get_samples(\n", + " self,\n", + " batch_inds: np.ndarray,\n", + " env: Optional[VecNormalize] = None,\n", + " ) -> RecurrentRolloutBufferSamples:\n", + " \n", + " lstm_states_pi = (\n", + " self.hidden_states_pi[batch_inds],\n", + " self.cell_states_pi[batch_inds]\n", + " )\n", + "\n", + " lstm_states_vf = (\n", + " self.hidden_states_vf[batch_inds],\n", + " self.cell_states_vf[batch_inds]\n", + " )\n", + "\n", + " data = (\n", + " self.observations[batch_inds],\n", + " self.actions[batch_inds],\n", + " self.values[batch_inds].flatten(),\n", + " self.log_probs[batch_inds].flatten(),\n", + " self.advantages[batch_inds].flatten(),\n", + " self.returns[batch_inds].flatten(),\n", + " self.dones[batch_inds],\n", + " LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf)\n", + " )\n", + " return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "n_envs = 8\n", + "n_steps = 128\n", + "batch_size = 32\n", + "buffer_size = n_envs * n_steps\n", + "gamma = 0.99\n", + "n_epochs = 4\n", + "gae_lambda = 0.95\n", + "hidden_size = 64\n", + "lstm_state_buffer_shape = (n_steps, n_envs, 64)\n", + "\n", + "env_id = \"CartPole-v1\"\n", + "vec_env = make_vec_env(env_id, n_envs=n_envs)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_buffer = RecurrentRolloutBuffer(\n", + " n_steps,\n", + " vec_env.observation_space,\n", + " vec_env.action_space,\n", + " gamma=gamma,\n", + " gae_lambda=gae_lambda,\n", + " n_envs=n_envs,\n", + " lstm_state_buffer_shape=lstm_state_buffer_shape,\n", + " device=\"cpu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(n_steps):\n", + "\n", + " lstm_states = (\n", + " np.full((n_envs, hidden_size), i, dtype=np.float32),\n", + " np.full((n_envs, hidden_size), i, dtype=np.float32),\n", + " )\n", + "\n", + " lstm_states = LSTMStates(pi=lstm_states, vf=lstm_states)\n", + "\n", + " act = np.array([i + 0.1 * idx for idx in range(n_envs)]).reshape(-1, 1)\n", + "\n", + " rollout_buffer.add(\n", + " obs=np.full((n_envs, 4), i, dtype=np.float32),\n", + " action=act,\n", + " # action=np.full((n_envs, 1), i, dtype=np.float32),\n", + " reward=np.full((n_envs, ), i, dtype=np.float32),\n", + " episode_start=np.zeros((n_envs,), dtype=np.float32),\n", + " value=th.ones((n_envs, 1), dtype=th.float32),\n", + " log_prob=th.ones((n_envs, ), dtype=th.float32),\n", + " dones=np.zeros((n_envs,), dtype=np.float32),\n", + " lstm_states=lstm_states,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rollout_buffer.actions.shape = (128, 8, 1)\n", + "rollout_buffer.actions[0] = array([[0. ],\n", + " [0.1],\n", + " [0.2],\n", + " [0.3],\n", + " [0.4],\n", + " [0.5],\n", + " [0.6],\n", + " [0.7]], dtype=float32)\n", + "rollout_buffer.actions[1] = array([[1. ],\n", + " [1.1],\n", + " [1.2],\n", + " [1.3],\n", + " [1.4],\n", + " [1.5],\n", + " [1.6],\n", + " [1.7]], dtype=float32)\n", + "rollout_buffer.actions[127] = array([[127. ],\n", + " [127.1],\n", + " [127.2],\n", + " [127.3],\n", + " [127.4],\n", + " [127.5],\n", + " [127.6],\n", + " [127.7]], dtype=float32)\n" + ] + } + ], + "source": [ + "print(f\"{rollout_buffer.actions.shape = }\")\n", + "print(f\"{rollout_buffer.actions[0] = }\")\n", + "print(f\"{rollout_buffer.actions[1] = }\")\n", + "print(f\"{rollout_buffer.actions[127] = }\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ok in fact at each iteration I need to differentiate between the current idx (i) and the idx of the environment (n_envs). How can I do that ? \n", + "\n", + "- pass environment idx as the first digit of the number and then \n" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rollout_data.actions = tensor([[ 0.],\n", + " [ 1.],\n", + " [ 2.],\n", + " [ 3.],\n", + " [ 4.],\n", + " [ 5.],\n", + " [ 6.],\n", + " [ 7.],\n", + " [ 8.],\n", + " [ 9.],\n", + " [10.],\n", + " [11.],\n", + " [12.],\n", + " [13.],\n", + " [14.],\n", + " [15.],\n", + " [16.],\n", + " [17.],\n", + " [18.],\n", + " [19.],\n", + " [20.],\n", + " [21.],\n", + " [22.],\n", + " [23.],\n", + " [24.],\n", + " [25.],\n", + " [26.],\n", + " [27.],\n", + " [28.],\n", + " [29.],\n", + " [30.],\n", + " [31.]])\n", + "rollout_data.actions = tensor([[32.],\n", + " [33.],\n", + " [34.],\n", + " [35.],\n", + " [36.],\n", + " [37.],\n", + " [38.],\n", + " [39.],\n", + " [40.],\n", + " [41.],\n", + " [42.],\n", + " [43.],\n", + " [44.],\n", + " [45.],\n", + " [46.],\n", + " [47.],\n", + " [48.],\n", + " [49.],\n", + " [50.],\n", + " [51.],\n", + " [52.],\n", + " [53.],\n", + " [54.],\n", + " [55.],\n", + " [56.],\n", + " [57.],\n", + " [58.],\n", + " [59.],\n", + " [60.],\n", + " [61.],\n", + " [62.],\n", + " [63.]])\n", + "rollout_data.actions = tensor([[64.],\n", + " [65.],\n", + " [66.],\n", + " [67.],\n", + " [68.],\n", + " [69.],\n", + " [70.],\n", + " [71.],\n", + " [72.],\n", + " [73.],\n", + " [74.],\n", + " [75.],\n", + " [76.],\n", + " [77.],\n", + " [78.],\n", + " [79.],\n", + " [80.],\n", + " [81.],\n", + " [82.],\n", + " [83.],\n", + " [84.],\n", + " [85.],\n", + " [86.],\n", + " [87.],\n", + " [88.],\n", + " [89.],\n", + " [90.],\n", + " [91.],\n", + " [92.],\n", + " [93.],\n", + " [94.],\n", + " [95.]])\n", + "rollout_data.actions = tensor([[ 96.],\n", + " [ 97.],\n", + " [ 98.],\n", + " [ 99.],\n", + " [100.],\n", + " [101.],\n", + " [102.],\n", + " [103.],\n", + " [104.],\n", + " [105.],\n", + " [106.],\n", + " [107.],\n", + " [108.],\n", + " [109.],\n", + " [110.],\n", + " [111.],\n", + " [112.],\n", + " [113.],\n", + " [114.],\n", + " [115.],\n", + " [116.],\n", + " [117.],\n", + " [118.],\n", + " [119.],\n", + " [120.],\n", + " [121.],\n", + " [122.],\n", + " [123.],\n", + " [124.],\n", + " [125.],\n", + " [126.],\n", + " [127.]])\n", + "rollout_data.actions = tensor([[ 0.1000],\n", + " [ 1.1000],\n", + " [ 2.1000],\n", + " [ 3.1000],\n", + " [ 4.1000],\n", + " [ 5.1000],\n", + " [ 6.1000],\n", + " [ 7.1000],\n", + " [ 8.1000],\n", + " [ 9.1000],\n", + " [10.1000],\n", + " [11.1000],\n", + " [12.1000],\n", + " [13.1000],\n", + " [14.1000],\n", + " [15.1000],\n", + " [16.1000],\n", + " [17.1000],\n", + " [18.1000],\n", + " [19.1000],\n", + " [20.1000],\n", + " [21.1000],\n", + " [22.1000],\n", + " [23.1000],\n", + " [24.1000],\n", + " [25.1000],\n", + " [26.1000],\n", + " [27.1000],\n", + " [28.1000],\n", + " [29.1000],\n", + " [30.1000],\n", + " [31.1000]])\n", + "rollout_data.actions = tensor([[32.1000],\n", + " [33.1000],\n", + " [34.1000],\n", + " [35.1000],\n", + " [36.1000],\n", + " [37.1000],\n", + " [38.1000],\n", + " [39.1000],\n", + " [40.1000],\n", + " [41.1000],\n", + " [42.1000],\n", + " [43.1000],\n", + " [44.1000],\n", + " [45.1000],\n", + " [46.1000],\n", + " [47.1000],\n", + " [48.1000],\n", + " [49.1000],\n", + " [50.1000],\n", + " [51.1000],\n", + " [52.1000],\n", + " [53.1000],\n", + " [54.1000],\n", + " [55.1000],\n", + " [56.1000],\n", + " [57.1000],\n", + " [58.1000],\n", + " [59.1000],\n", + " [60.1000],\n", + " [61.1000],\n", + " [62.1000],\n", + " [63.1000]])\n", + "rollout_data.actions = tensor([[64.1000],\n", + " [65.1000],\n", + " [66.1000],\n", + " [67.1000],\n", + " [68.1000],\n", + " [69.1000],\n", + " [70.1000],\n", + " [71.1000],\n", + " [72.1000],\n", + " [73.1000],\n", + " [74.1000],\n", + " [75.1000],\n", + " [76.1000],\n", + " [77.1000],\n", + " [78.1000],\n", + " [79.1000],\n", + " [80.1000],\n", + " [81.1000],\n", + " [82.1000],\n", + " [83.1000],\n", + " [84.1000],\n", + " [85.1000],\n", + " [86.1000],\n", + " [87.1000],\n", + " [88.1000],\n", + " [89.1000],\n", + " [90.1000],\n", + " [91.1000],\n", + " [92.1000],\n", + " [93.1000],\n", + " [94.1000],\n", + " [95.1000]])\n", + "rollout_data.actions = tensor([[ 96.1000],\n", + " [ 97.1000],\n", + " [ 98.1000],\n", + " [ 99.1000],\n", + " [100.1000],\n", + " [101.1000],\n", + " [102.1000],\n", + " [103.1000],\n", + " [104.1000],\n", + " [105.1000],\n", + " [106.1000],\n", + " [107.1000],\n", + " [108.1000],\n", + " [109.1000],\n", + " [110.1000],\n", + " [111.1000],\n", + " [112.1000],\n", + " [113.1000],\n", + " [114.1000],\n", + " [115.1000],\n", + " [116.1000],\n", + " [117.1000],\n", + " [118.1000],\n", + " [119.1000],\n", + " [120.1000],\n", + " [121.1000],\n", + " [122.1000],\n", + " [123.1000],\n", + " [124.1000],\n", + " [125.1000],\n", + " [126.1000],\n", + " [127.1000]])\n", + "rollout_data.actions = tensor([[ 0.2000],\n", + " [ 1.2000],\n", + " [ 2.2000],\n", + " [ 3.2000],\n", + " [ 4.2000],\n", + " [ 5.2000],\n", + " [ 6.2000],\n", + " [ 7.2000],\n", + " [ 8.2000],\n", + " [ 9.2000],\n", + " [10.2000],\n", + " [11.2000],\n", + " [12.2000],\n", + " [13.2000],\n", + " [14.2000],\n", + " [15.2000],\n", + " [16.2000],\n", + " [17.2000],\n", + " [18.2000],\n", + " [19.2000],\n", + " [20.2000],\n", + " [21.2000],\n", + " [22.2000],\n", + " [23.2000],\n", + " [24.2000],\n", + " [25.2000],\n", + " [26.2000],\n", + " [27.2000],\n", + " [28.2000],\n", + " [29.2000],\n", + " [30.2000],\n", + " [31.2000]])\n", + "rollout_data.actions = tensor([[32.2000],\n", + " [33.2000],\n", + " [34.2000],\n", + " [35.2000],\n", + " [36.2000],\n", + " [37.2000],\n", + " [38.2000],\n", + " [39.2000],\n", + " [40.2000],\n", + " [41.2000],\n", + " [42.2000],\n", + " [43.2000],\n", + " [44.2000],\n", + " [45.2000],\n", + " [46.2000],\n", + " [47.2000],\n", + " [48.2000],\n", + " [49.2000],\n", + " [50.2000],\n", + " [51.2000],\n", + " [52.2000],\n", + " [53.2000],\n", + " [54.2000],\n", + " [55.2000],\n", + " [56.2000],\n", + " [57.2000],\n", + " [58.2000],\n", + " [59.2000],\n", + " [60.2000],\n", + " [61.2000],\n", + " [62.2000],\n", + " [63.2000]])\n", + "rollout_data.actions = tensor([[64.2000],\n", + " [65.2000],\n", + " [66.2000],\n", + " [67.2000],\n", + " [68.2000],\n", + " [69.2000],\n", + " [70.2000],\n", + " [71.2000],\n", + " [72.2000],\n", + " [73.2000],\n", + " [74.2000],\n", + " [75.2000],\n", + " [76.2000],\n", + " [77.2000],\n", + " [78.2000],\n", + " [79.2000],\n", + " [80.2000],\n", + " [81.2000],\n", + " [82.2000],\n", + " [83.2000],\n", + " [84.2000],\n", + " [85.2000],\n", + " [86.2000],\n", + " [87.2000],\n", + " [88.2000],\n", + " [89.2000],\n", + " [90.2000],\n", + " [91.2000],\n", + " [92.2000],\n", + " [93.2000],\n", + " [94.2000],\n", + " [95.2000]])\n", + "rollout_data.actions = tensor([[ 96.2000],\n", + " [ 97.2000],\n", + " [ 98.2000],\n", + " [ 99.2000],\n", + " [100.2000],\n", + " [101.2000],\n", + " [102.2000],\n", + " [103.2000],\n", + " [104.2000],\n", + " [105.2000],\n", + " [106.2000],\n", + " [107.2000],\n", + " [108.2000],\n", + " [109.2000],\n", + " [110.2000],\n", + " [111.2000],\n", + " [112.2000],\n", + " [113.2000],\n", + " [114.2000],\n", + " [115.2000],\n", + " [116.2000],\n", + " [117.2000],\n", + " [118.2000],\n", + " [119.2000],\n", + " [120.2000],\n", + " [121.2000],\n", + " [122.2000],\n", + " [123.2000],\n", + " [124.2000],\n", + " [125.2000],\n", + " [126.2000],\n", + " [127.2000]])\n", + "rollout_data.actions = tensor([[ 0.3000],\n", + " [ 1.3000],\n", + " [ 2.3000],\n", + " [ 3.3000],\n", + " [ 4.3000],\n", + " [ 5.3000],\n", + " [ 6.3000],\n", + " [ 7.3000],\n", + " [ 8.3000],\n", + " [ 9.3000],\n", + " [10.3000],\n", + " [11.3000],\n", + " [12.3000],\n", + " [13.3000],\n", + " [14.3000],\n", + " [15.3000],\n", + " [16.3000],\n", + " [17.3000],\n", + " [18.3000],\n", + " [19.3000],\n", + " [20.3000],\n", + " [21.3000],\n", + " [22.3000],\n", + " [23.3000],\n", + " [24.3000],\n", + " [25.3000],\n", + " [26.3000],\n", + " [27.3000],\n", + " [28.3000],\n", + " [29.3000],\n", + " [30.3000],\n", + " [31.3000]])\n", + "rollout_data.actions = tensor([[32.3000],\n", + " [33.3000],\n", + " [34.3000],\n", + " [35.3000],\n", + " [36.3000],\n", + " [37.3000],\n", + " [38.3000],\n", + " [39.3000],\n", + " [40.3000],\n", + " [41.3000],\n", + " [42.3000],\n", + " [43.3000],\n", + " [44.3000],\n", + " [45.3000],\n", + " [46.3000],\n", + " [47.3000],\n", + " [48.3000],\n", + " [49.3000],\n", + " [50.3000],\n", + " [51.3000],\n", + " [52.3000],\n", + " [53.3000],\n", + " [54.3000],\n", + " [55.3000],\n", + " [56.3000],\n", + " [57.3000],\n", + " [58.3000],\n", + " [59.3000],\n", + " [60.3000],\n", + " [61.3000],\n", + " [62.3000],\n", + " [63.3000]])\n", + "rollout_data.actions = tensor([[64.3000],\n", + " [65.3000],\n", + " [66.3000],\n", + " [67.3000],\n", + " [68.3000],\n", + " [69.3000],\n", + " [70.3000],\n", + " [71.3000],\n", + " [72.3000],\n", + " [73.3000],\n", + " [74.3000],\n", + " [75.3000],\n", + " [76.3000],\n", + " [77.3000],\n", + " [78.3000],\n", + " [79.3000],\n", + " [80.3000],\n", + " [81.3000],\n", + " [82.3000],\n", + " [83.3000],\n", + " [84.3000],\n", + " [85.3000],\n", + " [86.3000],\n", + " [87.3000],\n", + " [88.3000],\n", + " [89.3000],\n", + " [90.3000],\n", + " [91.3000],\n", + " [92.3000],\n", + " [93.3000],\n", + " [94.3000],\n", + " [95.3000]])\n", + "rollout_data.actions = tensor([[ 96.3000],\n", + " [ 97.3000],\n", + " [ 98.3000],\n", + " [ 99.3000],\n", + " [100.3000],\n", + " [101.3000],\n", + " [102.3000],\n", + " [103.3000],\n", + " [104.3000],\n", + " [105.3000],\n", + " [106.3000],\n", + " [107.3000],\n", + " [108.3000],\n", + " [109.3000],\n", + " [110.3000],\n", + " [111.3000],\n", + " [112.3000],\n", + " [113.3000],\n", + " [114.3000],\n", + " [115.3000],\n", + " [116.3000],\n", + " [117.3000],\n", + " [118.3000],\n", + " [119.3000],\n", + " [120.3000],\n", + " [121.3000],\n", + " [122.3000],\n", + " [123.3000],\n", + " [124.3000],\n", + " [125.3000],\n", + " [126.3000],\n", + " [127.3000]])\n", + "rollout_data.actions = tensor([[ 0.4000],\n", + " [ 1.4000],\n", + " [ 2.4000],\n", + " [ 3.4000],\n", + " [ 4.4000],\n", + " [ 5.4000],\n", + " [ 6.4000],\n", + " [ 7.4000],\n", + " [ 8.4000],\n", + " [ 9.4000],\n", + " [10.4000],\n", + " [11.4000],\n", + " [12.4000],\n", + " [13.4000],\n", + " [14.4000],\n", + " [15.4000],\n", + " [16.4000],\n", + " [17.4000],\n", + " [18.4000],\n", + " [19.4000],\n", + " [20.4000],\n", + " [21.4000],\n", + " [22.4000],\n", + " [23.4000],\n", + " [24.4000],\n", + " [25.4000],\n", + " [26.4000],\n", + " [27.4000],\n", + " [28.4000],\n", + " [29.4000],\n", + " [30.4000],\n", + " [31.4000]])\n", + "rollout_data.actions = tensor([[32.4000],\n", + " [33.4000],\n", + " [34.4000],\n", + " [35.4000],\n", + " [36.4000],\n", + " [37.4000],\n", + " [38.4000],\n", + " [39.4000],\n", + " [40.4000],\n", + " [41.4000],\n", + " [42.4000],\n", + " [43.4000],\n", + " [44.4000],\n", + " [45.4000],\n", + " [46.4000],\n", + " [47.4000],\n", + " [48.4000],\n", + " [49.4000],\n", + " [50.4000],\n", + " [51.4000],\n", + " [52.4000],\n", + " [53.4000],\n", + " [54.4000],\n", + " [55.4000],\n", + " [56.4000],\n", + " [57.4000],\n", + " [58.4000],\n", + " [59.4000],\n", + " [60.4000],\n", + " [61.4000],\n", + " [62.4000],\n", + " [63.4000]])\n", + "rollout_data.actions = tensor([[64.4000],\n", + " [65.4000],\n", + " [66.4000],\n", + " [67.4000],\n", + " [68.4000],\n", + " [69.4000],\n", + " [70.4000],\n", + " [71.4000],\n", + " [72.4000],\n", + " [73.4000],\n", + " [74.4000],\n", + " [75.4000],\n", + " [76.4000],\n", + " [77.4000],\n", + " [78.4000],\n", + " [79.4000],\n", + " [80.4000],\n", + " [81.4000],\n", + " [82.4000],\n", + " [83.4000],\n", + " [84.4000],\n", + " [85.4000],\n", + " [86.4000],\n", + " [87.4000],\n", + " [88.4000],\n", + " [89.4000],\n", + " [90.4000],\n", + " [91.4000],\n", + " [92.4000],\n", + " [93.4000],\n", + " [94.4000],\n", + " [95.4000]])\n", + "rollout_data.actions = tensor([[ 96.4000],\n", + " [ 97.4000],\n", + " [ 98.4000],\n", + " [ 99.4000],\n", + " [100.4000],\n", + " [101.4000],\n", + " [102.4000],\n", + " [103.4000],\n", + " [104.4000],\n", + " [105.4000],\n", + " [106.4000],\n", + " [107.4000],\n", + " [108.4000],\n", + " [109.4000],\n", + " [110.4000],\n", + " [111.4000],\n", + " [112.4000],\n", + " [113.4000],\n", + " [114.4000],\n", + " [115.4000],\n", + " [116.4000],\n", + " [117.4000],\n", + " [118.4000],\n", + " [119.4000],\n", + " [120.4000],\n", + " [121.4000],\n", + " [122.4000],\n", + " [123.4000],\n", + " [124.4000],\n", + " [125.4000],\n", + " [126.4000],\n", + " [127.4000]])\n", + "rollout_data.actions = tensor([[ 0.5000],\n", + " [ 1.5000],\n", + " [ 2.5000],\n", + " [ 3.5000],\n", + " [ 4.5000],\n", + " [ 5.5000],\n", + " [ 6.5000],\n", + " [ 7.5000],\n", + " [ 8.5000],\n", + " [ 9.5000],\n", + " [10.5000],\n", + " [11.5000],\n", + " [12.5000],\n", + " [13.5000],\n", + " [14.5000],\n", + " [15.5000],\n", + " [16.5000],\n", + " [17.5000],\n", + " [18.5000],\n", + " [19.5000],\n", + " [20.5000],\n", + " [21.5000],\n", + " [22.5000],\n", + " [23.5000],\n", + " [24.5000],\n", + " [25.5000],\n", + " [26.5000],\n", + " [27.5000],\n", + " [28.5000],\n", + " [29.5000],\n", + " [30.5000],\n", + " [31.5000]])\n", + "rollout_data.actions = tensor([[32.5000],\n", + " [33.5000],\n", + " [34.5000],\n", + " [35.5000],\n", + " [36.5000],\n", + " [37.5000],\n", + " [38.5000],\n", + " [39.5000],\n", + " [40.5000],\n", + " [41.5000],\n", + " [42.5000],\n", + " [43.5000],\n", + " [44.5000],\n", + " [45.5000],\n", + " [46.5000],\n", + " [47.5000],\n", + " [48.5000],\n", + " [49.5000],\n", + " [50.5000],\n", + " [51.5000],\n", + " [52.5000],\n", + " [53.5000],\n", + " [54.5000],\n", + " [55.5000],\n", + " [56.5000],\n", + " [57.5000],\n", + " [58.5000],\n", + " [59.5000],\n", + " [60.5000],\n", + " [61.5000],\n", + " [62.5000],\n", + " [63.5000]])\n", + "rollout_data.actions = tensor([[64.5000],\n", + " [65.5000],\n", + " [66.5000],\n", + " [67.5000],\n", + " [68.5000],\n", + " [69.5000],\n", + " [70.5000],\n", + " [71.5000],\n", + " [72.5000],\n", + " [73.5000],\n", + " [74.5000],\n", + " [75.5000],\n", + " [76.5000],\n", + " [77.5000],\n", + " [78.5000],\n", + " [79.5000],\n", + " [80.5000],\n", + " [81.5000],\n", + " [82.5000],\n", + " [83.5000],\n", + " [84.5000],\n", + " [85.5000],\n", + " [86.5000],\n", + " [87.5000],\n", + " [88.5000],\n", + " [89.5000],\n", + " [90.5000],\n", + " [91.5000],\n", + " [92.5000],\n", + " [93.5000],\n", + " [94.5000],\n", + " [95.5000]])\n", + "rollout_data.actions = tensor([[ 96.5000],\n", + " [ 97.5000],\n", + " [ 98.5000],\n", + " [ 99.5000],\n", + " [100.5000],\n", + " [101.5000],\n", + " [102.5000],\n", + " [103.5000],\n", + " [104.5000],\n", + " [105.5000],\n", + " [106.5000],\n", + " [107.5000],\n", + " [108.5000],\n", + " [109.5000],\n", + " [110.5000],\n", + " [111.5000],\n", + " [112.5000],\n", + " [113.5000],\n", + " [114.5000],\n", + " [115.5000],\n", + " [116.5000],\n", + " [117.5000],\n", + " [118.5000],\n", + " [119.5000],\n", + " [120.5000],\n", + " [121.5000],\n", + " [122.5000],\n", + " [123.5000],\n", + " [124.5000],\n", + " [125.5000],\n", + " [126.5000],\n", + " [127.5000]])\n", + "rollout_data.actions = tensor([[ 0.6000],\n", + " [ 1.6000],\n", + " [ 2.6000],\n", + " [ 3.6000],\n", + " [ 4.6000],\n", + " [ 5.6000],\n", + " [ 6.6000],\n", + " [ 7.6000],\n", + " [ 8.6000],\n", + " [ 9.6000],\n", + " [10.6000],\n", + " [11.6000],\n", + " [12.6000],\n", + " [13.6000],\n", + " [14.6000],\n", + " [15.6000],\n", + " [16.6000],\n", + " [17.6000],\n", + " [18.6000],\n", + " [19.6000],\n", + " [20.6000],\n", + " [21.6000],\n", + " [22.6000],\n", + " [23.6000],\n", + " [24.6000],\n", + " [25.6000],\n", + " [26.6000],\n", + " [27.6000],\n", + " [28.6000],\n", + " [29.6000],\n", + " [30.6000],\n", + " [31.6000]])\n", + "rollout_data.actions = tensor([[32.6000],\n", + " [33.6000],\n", + " [34.6000],\n", + " [35.6000],\n", + " [36.6000],\n", + " [37.6000],\n", + " [38.6000],\n", + " [39.6000],\n", + " [40.6000],\n", + " [41.6000],\n", + " [42.6000],\n", + " [43.6000],\n", + " [44.6000],\n", + " [45.6000],\n", + " [46.6000],\n", + " [47.6000],\n", + " [48.6000],\n", + " [49.6000],\n", + " [50.6000],\n", + " [51.6000],\n", + " [52.6000],\n", + " [53.6000],\n", + " [54.6000],\n", + " [55.6000],\n", + " [56.6000],\n", + " [57.6000],\n", + " [58.6000],\n", + " [59.6000],\n", + " [60.6000],\n", + " [61.6000],\n", + " [62.6000],\n", + " [63.6000]])\n", + "rollout_data.actions = tensor([[64.6000],\n", + " [65.6000],\n", + " [66.6000],\n", + " [67.6000],\n", + " [68.6000],\n", + " [69.6000],\n", + " [70.6000],\n", + " [71.6000],\n", + " [72.6000],\n", + " [73.6000],\n", + " [74.6000],\n", + " [75.6000],\n", + " [76.6000],\n", + " [77.6000],\n", + " [78.6000],\n", + " [79.6000],\n", + " [80.6000],\n", + " [81.6000],\n", + " [82.6000],\n", + " [83.6000],\n", + " [84.6000],\n", + " [85.6000],\n", + " [86.6000],\n", + " [87.6000],\n", + " [88.6000],\n", + " [89.6000],\n", + " [90.6000],\n", + " [91.6000],\n", + " [92.6000],\n", + " [93.6000],\n", + " [94.6000],\n", + " [95.6000]])\n", + "rollout_data.actions = tensor([[ 96.6000],\n", + " [ 97.6000],\n", + " [ 98.6000],\n", + " [ 99.6000],\n", + " [100.6000],\n", + " [101.6000],\n", + " [102.6000],\n", + " [103.6000],\n", + " [104.6000],\n", + " [105.6000],\n", + " [106.6000],\n", + " [107.6000],\n", + " [108.6000],\n", + " [109.6000],\n", + " [110.6000],\n", + " [111.6000],\n", + " [112.6000],\n", + " [113.6000],\n", + " [114.6000],\n", + " [115.6000],\n", + " [116.6000],\n", + " [117.6000],\n", + " [118.6000],\n", + " [119.6000],\n", + " [120.6000],\n", + " [121.6000],\n", + " [122.6000],\n", + " [123.6000],\n", + " [124.6000],\n", + " [125.6000],\n", + " [126.6000],\n", + " [127.6000]])\n", + "rollout_data.actions = tensor([[ 0.7000],\n", + " [ 1.7000],\n", + " [ 2.7000],\n", + " [ 3.7000],\n", + " [ 4.7000],\n", + " [ 5.7000],\n", + " [ 6.7000],\n", + " [ 7.7000],\n", + " [ 8.7000],\n", + " [ 9.7000],\n", + " [10.7000],\n", + " [11.7000],\n", + " [12.7000],\n", + " [13.7000],\n", + " [14.7000],\n", + " [15.7000],\n", + " [16.7000],\n", + " [17.7000],\n", + " [18.7000],\n", + " [19.7000],\n", + " [20.7000],\n", + " [21.7000],\n", + " [22.7000],\n", + " [23.7000],\n", + " [24.7000],\n", + " [25.7000],\n", + " [26.7000],\n", + " [27.7000],\n", + " [28.7000],\n", + " [29.7000],\n", + " [30.7000],\n", + " [31.7000]])\n", + "rollout_data.actions = tensor([[32.7000],\n", + " [33.7000],\n", + " [34.7000],\n", + " [35.7000],\n", + " [36.7000],\n", + " [37.7000],\n", + " [38.7000],\n", + " [39.7000],\n", + " [40.7000],\n", + " [41.7000],\n", + " [42.7000],\n", + " [43.7000],\n", + " [44.7000],\n", + " [45.7000],\n", + " [46.7000],\n", + " [47.7000],\n", + " [48.7000],\n", + " [49.7000],\n", + " [50.7000],\n", + " [51.7000],\n", + " [52.7000],\n", + " [53.7000],\n", + " [54.7000],\n", + " [55.7000],\n", + " [56.7000],\n", + " [57.7000],\n", + " [58.7000],\n", + " [59.7000],\n", + " [60.7000],\n", + " [61.7000],\n", + " [62.7000],\n", + " [63.7000]])\n", + "rollout_data.actions = tensor([[64.7000],\n", + " [65.7000],\n", + " [66.7000],\n", + " [67.7000],\n", + " [68.7000],\n", + " [69.7000],\n", + " [70.7000],\n", + " [71.7000],\n", + " [72.7000],\n", + " [73.7000],\n", + " [74.7000],\n", + " [75.7000],\n", + " [76.7000],\n", + " [77.7000],\n", + " [78.7000],\n", + " [79.7000],\n", + " [80.7000],\n", + " [81.7000],\n", + " [82.7000],\n", + " [83.7000],\n", + " [84.7000],\n", + " [85.7000],\n", + " [86.7000],\n", + " [87.7000],\n", + " [88.7000],\n", + " [89.7000],\n", + " [90.7000],\n", + " [91.7000],\n", + " [92.7000],\n", + " [93.7000],\n", + " [94.7000],\n", + " [95.7000]])\n", + "rollout_data.actions = tensor([[ 96.7000],\n", + " [ 97.7000],\n", + " [ 98.7000],\n", + " [ 99.7000],\n", + " [100.7000],\n", + " [101.7000],\n", + " [102.7000],\n", + " [103.7000],\n", + " [104.7000],\n", + " [105.7000],\n", + " [106.7000],\n", + " [107.7000],\n", + " [108.7000],\n", + " [109.7000],\n", + " [110.7000],\n", + " [111.7000],\n", + " [112.7000],\n", + " [113.7000],\n", + " [114.7000],\n", + " [115.7000],\n", + " [116.7000],\n", + " [117.7000],\n", + " [118.7000],\n", + " [119.7000],\n", + " [120.7000],\n", + " [121.7000],\n", + " [122.7000],\n", + " [123.7000],\n", + " [124.7000],\n", + " [125.7000],\n", + " [126.7000],\n", + " [127.7000]])\n", + "count = 32\n" + ] + } + ], + "source": [ + "# train for n_epochs epochs\n", + "# for j in range(n_epochs):\n", + "count = 0\n", + "for rollout_data in rollout_buffer.get(batch_size): # type: ignore[attr-defined]\n", + " print(f\"{rollout_data.actions = }\")\n", + " count += 1\n", + "\n", + "print(f\"{count = }\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Seems like the batches are quite ok with the custom actions I give \n", + "- Now need to check if it is also the case for the lstm components" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "count = 32\n" + ] + } + ], + "source": [ + "# train for n_epochs epochs\n", + "# for j in range(n_epochs):\n", + "count = 0\n", + "for rollout_data in rollout_buffer.get(batch_size): # type: ignore[attr-defined]\n", + " count += 1\n", + " print(f\"{rollout_data.lstm_states[0][0].shape = }\")\n", + "\n", + "print(f\"{count = }\")\n", + "\n", + "# lstm_states_pi = (\n", + "# rollout_data.lstm_states[0][0].numpy().reshape(batch_size, hidden_state_size),\n", + "# rollout_data.lstm_states[0][1].numpy().reshape(batch_size, hidden_state_size)\n", + "# )\n", + "\n", + "# lstm_states_vf = (\n", + "# rollout_data.lstm_states[1][0].numpy().reshape(batch_size, hidden_state_size),\n", + "# rollout_data.lstm_states[1][1].numpy().reshape(batch_size, hidden_state_size)\n", + "# )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sbx/__init__.py b/sbx/__init__.py index c2762bc..2376e5a 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -4,6 +4,7 @@ from sbx.ddpg import DDPG from sbx.dqn import DQN from sbx.ppo import PPO +from sbx.recurrent_ppo import RecurrentPPO from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC @@ -26,6 +27,7 @@ def DroQ(*args, **kwargs): "DDPG", "DQN", "PPO", + "RecurrentPPO" "SAC", "TD3", "TQC", diff --git a/sbx/common/recurrent.py b/sbx/common/recurrent.py new file mode 100644 index 0000000..16e47dd --- /dev/null +++ b/sbx/common/recurrent.py @@ -0,0 +1,154 @@ +from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple + +import numpy as np +import torch as th +import jax.numpy as jnp +from gymnasium import spaces +# TODO : see later how to enable DictRolloutBuffer +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.vec_env import VecNormalize + +# TODO : add type aliases for the NamedTuple +class LSTMStates(NamedTuple): + pi: Tuple + vf: Tuple + +# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 fns) +# Added lstm states but also dones because they are used in actor and critic +class RecurrentRolloutBufferSamples(NamedTuple): + observations: jnp.ndarray + actions: jnp.ndarray + old_values: jnp.ndarray + old_log_prob: jnp.ndarray + advantages: jnp.ndarray + returns: jnp.ndarray + dones: jnp.ndarray + lstm_states: LSTMStates + +# Add a recurrent buffer that also takes care of the lstm states and dones flags +class RecurrentRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the LSTM cell and hidden states. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + (n_steps, lstm.num_layers, n_envs, lstm.hidden_size) + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + # renamed this because I found hidden_state_shape confusing + lstm_state_buffer_shape: Tuple[int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + self.hidden_state_shape = lstm_state_buffer_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + + def reset(self): + super().reset() + # also add the dones and all lstm states + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + + def add(self, *args, dones, lstm_states, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + self.dones[self.pos] = np.array(dones) + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0]) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1]) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0]) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1]) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "dones", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # TODO : See how to effectively use the indices to conserve temporal order in the batch data during updates + # TODO : I think the easisest way is to ensure the n_steps is a multiple of batch_size + # TODO : But still need to be fixed at the moment (I just made sure the returned shape was right) + indices = np.arange(self.buffer_size * self.n_envs) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds) + start_idx += batch_size + + # return the lstm states as an LSTMStates tuple + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentRolloutBufferSamples: + + lstm_states_pi = ( + self.hidden_states_pi[batch_inds], + self.cell_states_pi[batch_inds] + ) + + lstm_states_vf = ( + self.hidden_states_vf[batch_inds], + self.cell_states_vf[batch_inds] + ) + + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + self.dones[batch_inds], + LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf) + ) + return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/sbx/recurrent_ppo/__init__.py b/sbx/recurrent_ppo/__init__.py new file mode 100644 index 0000000..d7cd5b7 --- /dev/null +++ b/sbx/recurrent_ppo/__init__.py @@ -0,0 +1,3 @@ +from sbx.recurrent_ppo.recurrent_ppo import RecurrentPPO + +__all__ = ["RecurrentPPO"] diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py new file mode 100644 index 0000000..79e7c94 --- /dev/null +++ b/sbx/recurrent_ppo/policies.py @@ -0,0 +1,399 @@ +import functools + +from dataclasses import field +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple + +import flax.linen as nn +import gymnasium as gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_probability.substrates.jax as tfp +from flax.linen.initializers import constant +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.type_aliases import Schedule + +from sbx.common.policies import BaseJaxPolicy +from sbx.common.recurrent import LSTMStates + +tfd = tfp.distributions + + +# Added a ScanLSTM Module that automatically handles the reset of LSTM states +# inspired from the ScanRNN in purejaxrl : https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo_rnn.py +class ScanLSTM(nn.Module): + @functools.partial( + nn.scan, + variable_broadcast='params', + in_axes=0, + out_axes=0, + split_rngs={'params': False} + ) + @nn.compact + def __call__(self, lstm_states, obs_and_resets): + # pass the pi and vf lstm states, as well as the obs and the resets + obs, resets = obs_and_resets + hidden_state, cell_state = lstm_states + + # create new lstm states to replace the old ones if reset is True + batch_size, hidden_size = hidden_state.shape + reset_lstm_states = self.initialize_carry(batch_size, hidden_size) + + # handle the reset of the hidden lstm states + hidden_state = jnp.where( + resets[:, np.newaxis], + reset_lstm_states[0], + hidden_state + ) + # handle the reset of the cell lstm states + cell_state = jnp.where( + resets[:, np.newaxis], + reset_lstm_states[1], + cell_state + ) + + lstm_states = (hidden_state, cell_state) + hidden_size = lstm_states[0].shape[-1] + + new_lstm_states, output = nn.LSTMCell(features=hidden_size)(lstm_states, obs) + return new_lstm_states, output + + @staticmethod + def initialize_carry(batch_size, hidden_size): + # Returns a tuple of lstm states (hidden and cell states) + return nn.LSTMCell(features=hidden_size).initialize_carry( + rng=jax.random.PRNGKey(0), input_shape=(batch_size, hidden_size) + ) + +# Add ScanLSTM as first element of the Critic architecture +class Critic(nn.Module): + n_units: int = 256 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh + + # return hidden state + val + @nn.compact + def __call__(self, lstm_states, obs_dones) -> jnp.ndarray: + lstm_states, out = ScanLSTM()(lstm_states, obs_dones) + x = nn.Dense(self.n_units)(out) + x = self.activation_fn(x) + x = nn.Dense(self.n_units)(x) + x = self.activation_fn(x) + x = nn.Dense(1)(x) + return lstm_states, x + +# Add ScanLSTM as first element of the Actor architecture +class Actor(nn.Module): + action_dim: int + n_units: int = 256 + log_std_init: float = 0.0 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh + # For Discrete, MultiDiscrete and MultiBinary actions + num_discrete_choices: Optional[Union[int, Sequence[int]]] = None + # For MultiDiscrete + max_num_choices: int = 0 + split_indices: np.ndarray = field(default_factory=lambda: np.array([])) + + def get_std(self) -> jnp.ndarray: + # Make it work with gSDE + return jnp.array(0.0) + + def __post_init__(self) -> None: + # For MultiDiscrete + if isinstance(self.num_discrete_choices, np.ndarray): + self.max_num_choices = max(self.num_discrete_choices) + # np.cumsum(...) gives the correct indices at which to split the flatten logits + self.split_indices = np.cumsum(self.num_discrete_choices[:-1]) + super().__post_init__() + + # return hidden state + action dist + @nn.compact + def __call__(self, hidden, obs_dones) -> tfd.Distribution: # type: ignore[name-defined] + hidden, out = ScanLSTM()(hidden, obs_dones) + out = jnp.squeeze(out, axis=0) + + x = nn.Dense(self.n_units)(out) + x = self.activation_fn(x) + x = nn.Dense(self.n_units)(x) + x = self.activation_fn(x) + action_logits = nn.Dense(self.action_dim)(x) + if self.num_discrete_choices is None: + # Continuous actions + log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) + dist = tfd.MultivariateNormalDiag(loc=action_logits, scale_diag=jnp.exp(log_std)) + elif isinstance(self.num_discrete_choices, int): + # Discrete actions + dist = tfd.Categorical(logits=action_logits) + else: + raise ValueError("Invalid action space. Only Discrete and Continuous are supported at the moment.") + return hidden, dist + + +# TODO Later : at the moment custom net_architectures are not supported for the LSTM +class RecurrentPPOPolicy(BaseJaxPolicy): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + ortho_init: bool = False, + log_std_init: float = 0.0, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, + use_sde: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class=None, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = False, + ): + if optimizer_kwargs is None: + # Small values to avoid NaN in Adam optimizer + optimizer_kwargs = {} + if optimizer_class == optax.adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + self.log_std_init = log_std_init + self.activation_fn = activation_fn + if net_arch is not None: + if isinstance(net_arch, list): + self.n_units = net_arch[0] + else: + assert isinstance(net_arch, dict) + self.n_units = net_arch["pi"][0] + else: + self.n_units = 64 + # self.use_sde = use_sde + + self.key = self.noise_key = jax.random.PRNGKey(0) + + def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> jax.Array: + key, actor_key, vf_key = jax.random.split(key, 3) + # Keep a key for the actor + key, self.key = jax.random.split(key, 2) + # Initialize noise + self.reset_noise() + + if isinstance(self.action_space, spaces.Box): + actor_kwargs = { + "action_dim": int(np.prod(self.action_space.shape)), + } + elif isinstance(self.action_space, spaces.Discrete): + actor_kwargs = { + "action_dim": int(self.action_space.n), + "num_discrete_choices": int(self.action_space.n), + } + elif isinstance(self.action_space, spaces.MultiDiscrete): + assert self.action_space.nvec.ndim == 1, ( + f"Only one-dimensional MultiDiscrete action spaces are supported, " + f"but found MultiDiscrete({(self.action_space.nvec).tolist()})." + ) + actor_kwargs = { + "action_dim": int(np.sum(self.action_space.nvec)), + "num_discrete_choices": self.action_space.nvec, # type: ignore[dict-item] + } + elif isinstance(self.action_space, spaces.MultiBinary): + assert isinstance(self.action_space.n, int), ( + f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. " + "You can flatten it instead." + ) + # Handle binary action spaces as discrete action spaces with two choices. + actor_kwargs = { + "action_dim": 2 * self.action_space.n, + "num_discrete_choices": 2 * np.ones(self.action_space.n, dtype=int), + } + else: + raise NotImplementedError(f"{self.action_space}") + + + self.actor = Actor( + n_units=self.n_units, + log_std_init=self.log_std_init, + activation_fn=self.activation_fn, + **actor_kwargs, # type: ignore[arg-type] + ) + + # Initialize a dummy input for the LSTM layer (obs, dones) + init_obs = jnp.array([self.observation_space.sample()]) + init_dones = jnp.zeros((init_obs.shape[0],)) + # at the moment use this trick of adding a dimension to obs and dones to pass them to the LSTM + init_x = (init_obs[np.newaxis, :], init_dones[np.newaxis, :]) + + # hardcode the number of envs to 1 for the initialization of the lstm states + n_envs = 1 + self.hidden_size = self.n_units + init_lstm_states = ScanLSTM.initialize_carry(n_envs, self.hidden_size) + + # Hack to make gSDE work without modifying internal SB3 code + self.actor.reset_noise = self.reset_noise + + # pass the init lstm states as argument to the actor train state + self.actor_state = TrainState.create( + apply_fn=self.actor.apply, + params=self.actor.init(actor_key, init_lstm_states, init_x), + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, # , eps=1e-5 + ), + ), + ) + + self.vf = Critic(n_units=self.n_units, activation_fn=self.activation_fn) + + # pass the init lstm states as argument to the critic train state + self.vf_state = TrainState.create( + apply_fn=self.vf.apply, + params=self.vf.init({"params": vf_key}, init_lstm_states, init_x), + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, # , eps=1e-5 + ), + ), + ) + + self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] + self.vf.apply = jax.jit(self.vf.apply) # type: ignore[method-assign] + + return key + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + """ + self.key, self.noise_key = jax.random.split(self.key, 2) + + def forward(self, obs: np.ndarray, lstm_states, deterministic: bool = False, key = None) -> np.ndarray: + return self._predict(obs, deterministic=deterministic) + + # Overrided the _predict function with a new one taking the lstm states as arguments + def _predict( + self, + observation: np.ndarray, + lstm_states: LSTMStates, + episode_start: np.ndarray, + deterministic: bool = False + ) -> Tuple[np.ndarray, LSTMStates]: + # TODO : could do a helper fn to transform the obs, dones and return lstm states and action / value + # because it is used in several parts of the code and quite verbose + lstm_in = (observation[np.newaxis, :], episode_start[np.newaxis, :]) + new_pi_lstm_states, dist = self.actor_state.apply_fn(self.actor_state.params, lstm_states.pi, lstm_in) + + if deterministic: + actions = dist.mode() + else: + actions = dist.sample(seed=self.noise_key) + + # Trick to use gSDE: repeat sampled noise by using the same noise key + # if not self.use_sde: + # self.reset_noise() + + self.reset_noise() + + # add the new actor and old critic lstm states to the lstm states tuple + lstm_states = LSTMStates( + pi=new_pi_lstm_states, + vf=lstm_states.vf + ) + + return actions, lstm_states + + # Overrided the predict function with a new one taking the lstm states as arguments + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + lstm_states: Optional[Tuple[jnp.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + # TODO : see if still need that + # observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[next(iter(observation.keys()))].shape[0] + else: + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if lstm_states is None: + # Initialize hidden states to zeros + init_lstm_states = ScanLSTM.initialize_carry(n_envs, self.hidden_size) + lstm_states = LSTMStates( + pi=init_lstm_states, + vf=init_lstm_states + ) + + if episode_start is None: + episode_start = jnp.array([False for _ in range(n_envs)]) + + actions, lsmt_states = self._predict( + observation, lstm_states=lstm_states, episode_start=episode_start, deterministic=deterministic + ) + + # Convert to numpy + actions = np.array(actions) + + if isinstance(self.action_space, spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # TODO : see if still need that + # Remove batch dimension if needed + # if not vectorized_env: + # actions = actions.squeeze(axis=0) + + return actions, lsmt_states + + + # Added the lstm states to the predict_all method (maybe also the dones but I don't remember) + def predict_all(self, observation: np.ndarray, done, lstm_states, key: jax.Array) -> np.ndarray: + return self._predict_all(self.actor_state, self.vf_state, observation, done, lstm_states, key) + + @staticmethod + @jax.jit + def _predict_all(actor_state, vf_state, observations, dones, lstm_states, key): + # separate the lstm states for the actor and the critic, and prepare the input for the lstm + act_lstm_states, vf_lstm_states = lstm_states + lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + + # pass the actor lstm states and the input to the actor + act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) + actions = dist.sample(seed=key) + log_probs = dist.log_prob(actions) + + # pass the critic lstm states and the input to the critic + vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) + values = values.flatten() + + # add the actor and critic lstm states to the lstm states tuple + lstm_states = LSTMStates( + pi=act_lstm_states, + vf=vf_lstm_states + ) + + return actions, log_probs, values, lstm_states diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py new file mode 100644 index 0000000..9e1820d --- /dev/null +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -0,0 +1,558 @@ +import warnings +from functools import partial +from copy import deepcopy +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union + +import jax +import jax.numpy as jnp +import numpy as np +import torch as th +import gymnasium as gym +from flax.training.train_state import TrainState +from gymnasium import spaces + +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.vec_env import VecEnv + +from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax +from sbx.common.recurrent import RecurrentRolloutBuffer, LSTMStates +from sbx.recurrent_ppo.policies import RecurrentPPOPolicy as PPOPolicy +from sbx.recurrent_ppo.policies import ScanLSTM + +RPPOSelf = TypeVar("RPPOSelf", bound="RecurrentPPO") + + +class RecurrentPPO(OnPolicyAlgorithmJax): + # TODO : Update documentation + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + + Paper: https://arxiv.org/abs/1707.06347 + Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) + https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and + Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) + NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) + See https://github.com/pytorch/pytorch/issues/29372 + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: ClassVar[Dict[str, Type[PPOPolicy]]] = { # type: ignore[assignment] + "MlpPolicy": PPOPolicy, + # "CnnPolicy": ActorCriticCnnPolicy, + # "MultiInputPolicy": MultiInputActorCriticPolicy, + } + policy: PPOPolicy # type: ignore[assignment] + + def __init__( + self, + policy: Union[str, Type[PPOPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + # Note: gSDE is not properly implemented, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + # super()._setup_model() + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, + ) + + self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm) + + # TODO : what is ent_key ? + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor + self.vf = self.policy.vf + + # added the lstm states hidden_size (used to initialize the lstm states and the replay buffer) + # at the moment just match the n_units of the policy + # TODO : change this to enable more complex architectures (fix the n_lstm layers to 1 for now) + self.hidden_state_size = self.policy.actor.n_units + num_lstm_layers = 1 + lstm_state_buffer_shape = (self.n_steps, num_lstm_layers, self.n_envs, self.hidden_state_size) + + # use dummy lstm states to init the pi and the vf states + cell_hidden_lstm_states = ScanLSTM.initialize_carry(self.n_envs, self.hidden_state_size) + # add them to the global LSTMStates + init_lstm_states = LSTMStates( + pi=cell_hidden_lstm_states, + vf=cell_hidden_lstm_states, + ) + # update the last lstm states (like in sb3 contrib) + self._last_lstm_states = init_lstm_states + + # Initialize the rollout buffer (it also encompasses the dones now as well as the lstm states) + self.rollout_buffer = RecurrentRolloutBuffer( + self.n_steps, + self.observation_space, + self.action_space, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + lstm_state_buffer_shape=lstm_state_buffer_shape, + device="cpu", + ) + + # Initialize schedules for policy/value clipping + self.clip_range_schedule = get_schedule_fn(self.clip_range) + # if self.clip_range_vf is not None: + # if isinstance(self.clip_range_vf, (float, int)): + # assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + # + # self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RecurrentRolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_rollout_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" # type: ignore[has-type] + # Switch to eval mode (this affects batch norm / dropout) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise() + + callback.on_rollout_start() + + # copied that from sb3 contrib + lstm_states = deepcopy(self._last_lstm_states) + dones = self._last_episode_starts + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise() + + if not self.use_sde or isinstance(self.action_space, gym.spaces.Discrete): + # Always sample new stochastic action + self.policy.reset_noise() + + obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type] + # use the predict_all method with the lstm states and the dones + actions, log_probs, values, lstm_states = self.policy.predict_all(obs_tensor, dones, lstm_states, self.policy.noise_key) + + actions = np.array(actions) + log_probs = np.array(log_probs) + values = np.array(values) + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, gym.spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, gym.spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # will be used to boostrap with the value function if need (need to to a critic pass) + vf_lstm_states = lstm_states.vf + lstm_in = (obs_tensor[np.newaxis, :], dones[np.newaxis, :]) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + # TODO : See how we can handle the lstm states here (bc we iterate on the dones) + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] + + # TODO Normally should only give the obs and dones for current idx + # TODO Should maybe pre-compute the lstm states and values before and then just iterate over the idx when needed + # TODO This is surely slowing everything for no reason (and seems false) + vf_lstm_states, values = self.vf.apply( + self.policy.vf_state.params, + vf_lstm_states, + lstm_in + ) + terminal_value = values.flatten().item() + rewards[idx] += self.gamma * terminal_value + + # add the dones and the lstm states to the rollout buffer + rollout_buffer.add( + self._last_obs, # type: ignore + actions, + rewards, + self._last_episode_starts, # type: ignore + th.as_tensor(values), + th.as_tensor(log_probs), + dones=dones, + lstm_states=self._last_lstm_states, + ) + + # update the lstm states for the next iteration + self._last_obs = new_obs # type: ignore[assignment] + self._last_episode_starts = dones + self._last_lstm_states = lstm_states + + # Get the last values when the rollout ends to compute the advantages + vf_lstm_states = lstm_states.vf + lstm_in = (self.policy.prepare_obs(new_obs)[0][np.newaxis, :], dones[np.newaxis, :]) + vf_lstm_states, values = self.vf.apply( + self.policy.vf_state.params, + vf_lstm_states, + lstm_in + ) + values = np.array(values).flatten() + + rollout_buffer.compute_returns_and_advantage(last_values=th.as_tensor(values), dones=dones) + + callback.on_rollout_end() + + return True + + @staticmethod + @partial(jax.jit, static_argnames=["normalize_advantage"]) + def _one_update( + actor_state: TrainState, + vf_state: TrainState, + lstm_states: LSTMStates, + observations: np.ndarray, + dones: np.ndarray, + actions: np.ndarray, + advantages: np.ndarray, + returns: np.ndarray, + old_log_prob: np.ndarray, + clip_range: float, + ent_coef: float, + vf_coef: float, + normalize_advantage: bool = True, + ): + # Normalize advantage + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if normalize_advantage and len(advantages) > 1: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # TODO : something weird here because the params argument isn't used and only actor_state.params instead + def actor_loss(params): + # TODO : see why I need to flatten dones here (otherwise error in the shapes given to the lstm) + lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) + act_lstm_states, _ = lstm_states + + # TODO + # act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) + act_lstm_states, dist = actor_state.apply_fn(params, act_lstm_states, lstm_in) + log_prob = dist.log_prob(actions) + entropy = dist.entropy() + + # ratio between old and new policy, should be one at the first iteration + ratio = jnp.exp(log_prob - old_log_prob) + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * jnp.clip(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -jnp.minimum(policy_loss_1, policy_loss_2).mean() + + # Entropy loss favor exploration + # Approximate entropy when no analytical form + # entropy_loss = -jnp.mean(-log_prob) + # analytical form + entropy_loss = -jnp.mean(entropy) + + total_policy_loss = policy_loss + ent_coef * entropy_loss + return total_policy_loss + + pg_loss_value, pg_grads = jax.value_and_grad(actor_loss, has_aux=False)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=pg_grads) + + # TODO : same observation as above + def critic_loss(params): + lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) + _, vf_lstm_states = lstm_states + # Value loss using the TD(gae_lambda) target + # TODO + # vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) + vf_lstm_states, values = vf_state.apply_fn(params, vf_lstm_states, lstm_in) + vf_values = values.flatten() + return ((returns - vf_values) ** 2).mean() + + vf_loss_value, vf_grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) + vf_state = vf_state.apply_gradients(grads=vf_grads) + + # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss + return (actor_state, vf_state), (pg_loss_value, vf_loss_value) + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Update optimizer learning rate + # self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range_schedule(self._current_progress_remaining) + + # train for n_epochs epochs + for _ in range(self.n_epochs): + # JIT only one update + for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to int + actions = rollout_data.actions.flatten().numpy().astype(np.int32) + else: + actions = rollout_data.actions.numpy() + + # TODO : fix the values of the lstm states (at the moment they dot not follow the right temporal order I think) + # TODO : also fix this mechanism where I need to reshape the lstm states here + lstm_states_pi = ( + rollout_data.lstm_states[0][0].numpy().reshape(self.batch_size, self.hidden_state_size), + rollout_data.lstm_states[0][1].numpy().reshape(self.batch_size, self.hidden_state_size) + ) + + lstm_states_vf = ( + rollout_data.lstm_states[1][0].numpy().reshape(self.batch_size, self.hidden_state_size), + rollout_data.lstm_states[1][1].numpy().reshape(self.batch_size, self.hidden_state_size) + ) + + lstm_states = LSTMStates( + pi=lstm_states_pi, + vf=lstm_states_vf, + ) + + (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss) = self._one_update( + actor_state=self.policy.actor_state, + vf_state=self.policy.vf_state, + observations=rollout_data.observations.numpy(), + actions=actions, + # added the dones here + dones=rollout_data.dones.numpy(), + advantages=rollout_data.advantages.numpy(), + returns=rollout_data.returns.numpy(), + old_log_prob=rollout_data.old_log_prob.numpy(), + # added the lstm states here + lstm_states=lstm_states, + clip_range=clip_range, + ent_coef=self.ent_coef, + vf_coef=self.vf_coef, + normalize_advantage=self.normalize_advantage, + ) + + self._n_updates += self.n_epochs + explained_var = explained_variance( + self.rollout_buffer.values.flatten(), # type: ignore[attr-defined] + self.rollout_buffer.returns.flatten(), # type: ignore[attr-defined] + ) + + # Logs + # self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + # TODO: use mean instead of one point + self.logger.record("train/value_loss", value_loss.item()) + # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/pg_loss", pg_loss.item()) + self.logger.record("train/explained_variance", explained_var) + # if hasattr(self.policy, "log_std"): + # self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + # if self.clip_range_vf is not None: + # self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self: RPPOSelf, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "PPO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> RPPOSelf: + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + +if __name__ == "__main__": + import gymnasium as gym + from stable_baselines3.common.env_util import make_vec_env + + n_steps = 128 + batch_size = 32 + train_steps = 10_000 + test_steps = 10 + n_envs = 4 + env_id = "CartPole-v1" + + # create vec env and train algo + vec_env = make_vec_env(env_id, n_envs=n_envs) + model = RecurrentPPO("MlpPolicy", vec_env, n_steps=n_steps, batch_size=batch_size, verbose=1) + model.learn(total_timesteps=train_steps, progress_bar=True) + + vec_env = model.get_env() + obs = vec_env.reset() + lstm_states = None + + for _ in range(test_steps): + action, lstm_states = model.predict(obs, state=lstm_states, deterministic=True) + obs, reward, done, info = vec_env.step(action) + + vec_env.close() diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..bc605e9 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,480 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple\n", + "\n", + "import numpy as np\n", + "import torch as th\n", + "import jax.numpy as jnp\n", + "from gymnasium import spaces\n", + "from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer\n", + "from stable_baselines3.common.vec_env import VecNormalize\n", + "\n", + "# TODO : see if I add jax info\n", + "class LSTMStates(NamedTuple):\n", + " pi: Tuple\n", + " vf: Tuple\n", + "\n", + "# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 functions)\n", + "# Added lstm states but also dones because they are used in actor and critic\n", + "class RecurrentRolloutBufferSamples(NamedTuple):\n", + " observations: jnp.ndarray\n", + " actions: jnp.ndarray\n", + " old_values: jnp.ndarray\n", + " old_log_prob: jnp.ndarray\n", + " advantages: jnp.ndarray\n", + " returns: jnp.ndarray\n", + " dones: jnp.ndarray\n", + " lstm_states: LSTMStates" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class RecurrentRolloutBuffer(RolloutBuffer):\n", + " \"\"\"\n", + " Rollout buffer that also stores the LSTM cell and hidden states.\n", + "\n", + " :param buffer_size: Max number of element in the buffer\n", + " :param observation_space: Observation space\n", + " :param action_space: Action space\n", + " :param hidden_state_shape: Shape of the buffer that will collect lstm states\n", + " (n_steps, lstm.num_layers, n_envs, lstm.hidden_size)\n", + " :param device: PyTorch device\n", + " :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator\n", + " Equivalent to classic advantage when set to 1.\n", + " :param gamma: Discount factor\n", + " :param n_envs: Number of parallel environments\n", + " \"\"\"\n", + "\n", + " def __init__( \n", + " self,\n", + " buffer_size: int,\n", + " observation_space: spaces.Space,\n", + " action_space: spaces.Space,\n", + " # renamed this because I found hidden_state_shape confusing\n", + " lstm_state_buffer_shape: Tuple[int, int, int],\n", + " device: Union[th.device, str] = \"auto\",\n", + " gae_lambda: float = 1,\n", + " gamma: float = 0.99,\n", + " n_envs: int = 1,\n", + " ): \n", + " # TODO : see if I rename this in all the code\n", + " self.hidden_state_shape = lstm_state_buffer_shape\n", + " self.seq_start_indices, self.seq_end_indices = None, None\n", + " super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)\n", + "\n", + " def reset(self):\n", + " super().reset()\n", + " self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)\n", + " self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + "\n", + " # def add(self, *args, lstm_states: LSTMStates, **kwargs) -> None:\n", + " # \"\"\"\n", + " # :param hidden_states: LSTM cell and hidden state\n", + " # \"\"\"\n", + " # # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states\n", + " # self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())\n", + " # self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())\n", + " # self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())\n", + " # self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())\n", + "\n", + " def add(self, *args, dones, lstm_states, **kwargs) -> None:\n", + " \"\"\"\n", + " :param hidden_states: LSTM cell and hidden state\n", + " \"\"\"\n", + " # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states\n", + " self.hidden_states_pi[self.pos] = np.array(lstm_states[0][0])\n", + " self.cell_states_pi[self.pos] = np.array(lstm_states[0][1])\n", + " self.hidden_states_vf[self.pos] = np.array(lstm_states[1][0])\n", + " self.cell_states_vf[self.pos] = np.array(lstm_states[1][1])\n", + " self.dones[self.pos] = np.array(dones)\n", + "\n", + " super().add(*args, **kwargs)\n", + "\n", + " def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:\n", + " assert self.full, \"Rollout buffer must be full before sampling from it\"\n", + "\n", + " # Prepare the data\n", + " if not self.generator_ready:\n", + " # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)\n", + " # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)\n", + " for tensor in [\"hidden_states_pi\", \"cell_states_pi\", \"hidden_states_vf\", \"cell_states_vf\"]:\n", + " self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)\n", + "\n", + " # flatten but keep the sequence order\n", + " # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)\n", + " # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)\n", + " for tensor in [\n", + " \"observations\",\n", + " \"actions\",\n", + " \"values\",\n", + " \"log_probs\",\n", + " \"advantages\",\n", + " \"returns\",\n", + " \"dones\",\n", + " \"hidden_states_pi\",\n", + " \"cell_states_pi\",\n", + " \"hidden_states_vf\",\n", + " \"cell_states_vf\",\n", + " \"episode_starts\",\n", + " ]:\n", + " self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])\n", + " self.generator_ready = True\n", + "\n", + " # Return everything, don't create minibatches\n", + " if batch_size is None:\n", + " batch_size = self.buffer_size * self.n_envs\n", + "\n", + " # TODO : Check if this works well \n", + " # TODO : Sampling strategy that doesn't allow any mini batch size (must be a multiple of n_envs)\n", + " indices = np.arange(self.buffer_size * self.n_envs)\n", + "\n", + " start_idx = 0\n", + " while start_idx < self.buffer_size * self.n_envs:\n", + " batch_inds = indices[start_idx : start_idx + batch_size]\n", + " yield self._get_samples(batch_inds)\n", + " start_idx += batch_size\n", + "\n", + "\n", + " def _get_samples(\n", + " self,\n", + " batch_inds: np.ndarray,\n", + " env: Optional[VecNormalize] = None,\n", + " ) -> RecurrentRolloutBufferSamples:\n", + " \n", + " lstm_states_pi = (\n", + " self.hidden_states_pi[batch_inds],\n", + " self.cell_states_pi[batch_inds]\n", + " )\n", + "\n", + " lstm_states_vf = (\n", + " self.hidden_states_vf[batch_inds],\n", + " self.cell_states_vf[batch_inds]\n", + " )\n", + "\n", + " data = (\n", + " self.observations[batch_inds],\n", + " self.actions[batch_inds],\n", + " self.values[batch_inds].flatten(),\n", + " self.log_probs[batch_inds].flatten(),\n", + " self.advantages[batch_inds].flatten(),\n", + " self.returns[batch_inds].flatten(),\n", + " # TODO : Check that\n", + " self.dones[batch_inds],\n", + " LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf)\n", + " )\n", + " return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import flax.linen as nn \n", + "\n", + "import functools\n", + "import numpy as np\n", + "import gymnasium as gym\n", + "from stable_baselines3.common.env_util import make_vec_env\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "n_envs = 8\n", + "n_steps = 128\n", + "batch_size = 32\n", + "buffer_size = n_envs * n_steps\n", + "gamma = 0.99\n", + "gae_lambda = 0.95\n", + "hidden_size = 64\n", + "lstm_state_buffer_shape = (n_steps, n_envs, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "env_id = \"CartPole-v1\"\n", + "vec_env = make_vec_env(env_id, n_envs=n_envs)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_buffer = RecurrentRolloutBuffer(\n", + " n_steps,\n", + " vec_env.observation_space,\n", + " vec_env.action_space,\n", + " gamma=gamma,\n", + " gae_lambda=gae_lambda,\n", + " n_envs=n_envs,\n", + " lstm_state_buffer_shape=lstm_state_buffer_shape,\n", + " device=\"cpu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 8, 64)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rollout_buffer.cell_states_pi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 8, 1)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rollout_buffer.actions.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 8, 4)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rollout_buffer.observations.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch 1: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n", + "Batch 2: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n", + "Batch 3: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n", + "Batch 4: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "def split_into_batches(a, b, c, batch_size):\n", + " n_steps = a.shape[0]\n", + " assert n_steps % batch_size == 0, \"n_steps must be a multiple of batch_size\"\n", + " \n", + " num_batches = n_steps // batch_size\n", + " \n", + " a_batches = np.split(a, num_batches, axis=0)\n", + " b_batches = np.split(b, num_batches, axis=0)\n", + " c_batches = np.split(c, num_batches, axis=0)\n", + " \n", + " return a_batches, b_batches, c_batches\n", + "\n", + "# Example usage\n", + "n_steps = 128\n", + "n_envs = 8\n", + "hidden_size = 64\n", + "batch_size = 32\n", + "\n", + "a = np.zeros((n_steps, n_envs, 1))\n", + "b = np.zeros((n_steps, n_envs, 4))\n", + "c = np.zeros((n_steps, n_envs, hidden_size))\n", + "\n", + "a_batches, b_batches, c_batches = split_into_batches(a, b, c, batch_size)\n", + "\n", + "# Check the shapes of the batches\n", + "for i in range(len(a_batches)):\n", + " print(f\"Batch {i+1}: a shape = {a_batches[i].shape}, b shape = {b_batches[i].shape}, c shape = {c_batches[i].shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.zeros((n_steps, n_envs, 1))\n", + "b = np.zeros((n_steps, n_envs, 4))\n", + "c = np.zeros((n_steps, n_envs, hidden_size))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_buffer.full = True" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "def debug_get(buffer: RecurrentRolloutBuffer, batch_size):\n", + " data = buffer.get(batch_size)\n", + "\n", + " # Print the name and shape of each item in the data\n", + " for name, value in data.items():\n", + " print(f\"Name: {name}, Shape: {value.shape}\")\n", + "\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'generator' object has no attribute 'items'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[28], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdebug_get\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrollout_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[27], line 5\u001b[0m, in \u001b[0;36mdebug_get\u001b[0;34m(buffer, batch_size)\u001b[0m\n\u001b[1;32m 2\u001b[0m data \u001b[38;5;241m=\u001b[39m buffer\u001b[38;5;241m.\u001b[39mget(batch_size)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Print the name and shape of each item in the data\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, value \u001b[38;5;129;01min\u001b[39;00m \u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m():\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mName: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvalue\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "\u001b[0;31mAttributeError\u001b[0m: 'generator' object has no attribute 'items'" + ] + } + ], + "source": [ + "debug_get(rollout_buffer, batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n" + ] + } + ], + "source": [ + "for rollout_data in rollout_buffer.get(batch_size):\n", + " print(rollout_data.actions.shape)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test.py b/test.py new file mode 100644 index 0000000..9066d5f --- /dev/null +++ b/test.py @@ -0,0 +1,43 @@ +import argparse +import gymnasium as gym + +from stable_baselines3.common.env_util import make_vec_env +from sbx import PPO, RecurrentPPO + + +def main(): + algo = "ppo" + algo = "rppo" + ALGO = RecurrentPPO if algo == "rppo" else PPO + print(f"Using {algo}") + + n_steps = 128 + batch_size = 32 + train_steps = 20_000 + test_steps = 10 + n_envs = 8 + + n_steps = 64 + batch_size = 16 + train_steps = 20_000 + test_steps = 10 + n_envs = 2 + + env_id = "CartPole-v1" + + # create vec env and train algo + vec_env = make_vec_env(env_id, n_envs=n_envs) + model = ALGO("MlpPolicy", vec_env, n_steps=n_steps, batch_size=batch_size, verbose=1) + model.learn(total_timesteps=train_steps, progress_bar=True) + + # test if trained algo works + vec_env = model.get_env() + obs = vec_env.reset() + for _ in range(test_steps): + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = vec_env.step(action) + + vec_env.close() + +if __name__ == "__main__": + main() \ No newline at end of file