From 6a50cbf1d7c88461f4514f6766d6189b7bc6f7b6 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Mon, 26 Aug 2024 14:38:44 +0200 Subject: [PATCH 1/9] Add first rnn elements in ppo_rnn networks --- sbx/__init__.py | 2 + sbx/r_ppo/__init__.py | 3 + sbx/r_ppo/policies.py | 326 +++++++++++++++++++++++++++++++++++++ sbx/r_ppo/r_ppo.py | 367 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 698 insertions(+) create mode 100644 sbx/r_ppo/__init__.py create mode 100644 sbx/r_ppo/policies.py create mode 100644 sbx/r_ppo/r_ppo.py diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..9f194ad 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -4,6 +4,7 @@ from sbx.ddpg import DDPG from sbx.dqn import DQN from sbx.ppo import PPO +from sbx.r_ppo import RPPO from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC @@ -27,6 +28,7 @@ def DroQ(*args, **kwargs): "DDPG", "DQN", "PPO", + "RPPO" "SAC", "TD3", "TQC", diff --git a/sbx/r_ppo/__init__.py b/sbx/r_ppo/__init__.py new file mode 100644 index 0000000..ed2e4a9 --- /dev/null +++ b/sbx/r_ppo/__init__.py @@ -0,0 +1,3 @@ +from sbx.r_ppo.r_ppo import RPPO + +__all__ = ["RPPO"] diff --git a/sbx/r_ppo/policies.py b/sbx/r_ppo/policies.py new file mode 100644 index 0000000..978a0b9 --- /dev/null +++ b/sbx/r_ppo/policies.py @@ -0,0 +1,326 @@ +import functools + +from dataclasses import field +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import flax.linen as nn +import gymnasium as gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tensorflow_probability.substrates.jax as tfp +# DONE : Added orthogonal to the imports +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.type_aliases import Schedule + +from sbx.common.policies import BaseJaxPolicy, Flatten + +tfd = tfp.distributions + + +# TODO : Add LSTM class as a ScanRNN Module (see PureJaxRL) code from https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo_rnn.py +# TODO : at the moment take exactly the same model with GruCell + embedding space in the actor and critic before giving obs to the RNN +class ScanRNN(nn.Module): + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=0, + out_axes=0, + split_rngs={"params": False}, + ) + @nn.compact + def __call__(self, carry, x): + rnn_state = carry + ins, resets = x + # Handle the reset logic of rnn states here + lstm_states = jnp.where( + resets[:, np.newaxis], + self.initialize_carry(ins.shape[0], ins.shape[1]), + rnn_state + ) + hidden_size = rnn_state[0].shape[0] + new_lstm_states, out = nn.GRUCell(features=hidden_size)(lstm_states, ins) + return new_lstm_states, out + + + @staticmethod + def initialize_carry(batch_size, hidden_size): + # like in purejaxrl, use a dummy key because default state init fn is just zeros + return nn.GRUCell(features=hidden_size).initialize_carry( + rng=jax.random.PRNGKey(0), input_shape=(batch_size, hidden_size) + ) + +# Add scanned rnn in the critic +class Critic(nn.Module): + n_units: int = 256 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh + + # return hidden state + val + @nn.compact + def __call__(self, hidden, x) -> jnp.ndarray: + # Add embedding like in purejaxrl atm + obs, dones = x + # TODO : replace hardcoded 64 later + embedding = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(obs) + embedding = nn.relu(embedding) + + rnn_in = (embedding, dones) + hidden, out = ScanRNN()(hidden, rnn_in) + x = nn.Dense(self.n_units)(out) + x = self.activation_fn(x) + x = nn.Dense(self.n_units)(x) + x = self.activation_fn(x) + x = nn.Dense(1)(x) + return hidden, x + +# Add scanned lstm in the actor +class Actor(nn.Module): + action_dim: int + n_units: int = 256 + log_std_init: float = 0.0 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh + # For Discrete, MultiDiscrete and MultiBinary actions + num_discrete_choices: Optional[Union[int, Sequence[int]]] = None + # For MultiDiscrete + max_num_choices: int = 0 + split_indices: np.ndarray = field(default_factory=lambda: np.array([])) + + def get_std(self) -> jnp.ndarray: + # Make it work with gSDE + return jnp.array(0.0) + + def __post_init__(self) -> None: + # For MultiDiscrete + if isinstance(self.num_discrete_choices, np.ndarray): + self.max_num_choices = max(self.num_discrete_choices) + # np.cumsum(...) gives the correct indices at which to split the flatten logits + self.split_indices = np.cumsum(self.num_discrete_choices[:-1]) + super().__post_init__() + + # return hidden state + dist + @nn.compact + def __call__(self, hidden, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + # Add embedding like in purejaxrl atm + obs, dones = x + embedding = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(obs) + embedding = nn.relu(embedding) + + rnn_in = (embedding, dones) + + hidden, out = ScanRNN()(hidden, rnn_in) + x = nn.Dense(self.n_units)(out) + x = self.activation_fn(x) + x = nn.Dense(self.n_units)(x) + x = self.activation_fn(x) + action_logits = nn.Dense(self.action_dim)(x) + if self.num_discrete_choices is None: + # Continuous actions + log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) + dist = tfd.MultivariateNormalDiag(loc=action_logits, scale_diag=jnp.exp(log_std)) + elif isinstance(self.num_discrete_choices, int): + dist = tfd.Categorical(logits=action_logits) + else: + # Split action_logits = (batch_size, total_choices=sum(self.num_discrete_choices)) + action_logits = jnp.split(action_logits, self.split_indices, axis=1) + # Pad to the maximum number of choices (required by tfp.distributions.Categorical). + # Pad by -inf, so that the probability of these invalid actions is 0. + logits_padded = jnp.stack( + [ + jnp.pad( + logit, + # logit is of shape (batch_size, n) + # only pad after dim=1, to max_num_choices - n + # pad_width=((before_dim_0, after_0), (before_dim_1, after_1)) + pad_width=((0, 0), (0, self.max_num_choices - logit.shape[1])), + constant_values=-np.inf, + ) + for logit in action_logits + ], + axis=1, + ) + dist = tfp.distributions.Independent( + tfp.distributions.Categorical(logits=logits_padded), reinterpreted_batch_ndims=1 + ) + return hidden, dist + + +class RPPOPolicy(BaseJaxPolicy): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + ortho_init: bool = False, + log_std_init: float = 0.0, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, + use_sde: bool = False, + # Note: most gSDE parameters are not used + # this is to keep API consistent with SB3 + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class=None, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + share_features_extractor: bool = False, + ): + if optimizer_kwargs is None: + # Small values to avoid NaN in Adam optimizer + optimizer_kwargs = {} + if optimizer_class == optax.adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) + self.log_std_init = log_std_init + self.activation_fn = activation_fn + if net_arch is not None: + if isinstance(net_arch, list): + self.n_units = net_arch[0] + else: + assert isinstance(net_arch, dict) + self.n_units = net_arch["pi"][0] + else: + self.n_units = 64 + self.use_sde = use_sde + + self.key = self.noise_key = jax.random.PRNGKey(0) + + def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> jax.Array: + key, actor_key, vf_key = jax.random.split(key, 4) + # Keep a key for the actor + key, self.key = jax.random.split(key, 2) + # Initialize noise + self.reset_noise() + + if isinstance(self.action_space, spaces.Box): + actor_kwargs = { + "action_dim": int(np.prod(self.action_space.shape)), + } + elif isinstance(self.action_space, spaces.Discrete): + actor_kwargs = { + "action_dim": int(self.action_space.n), + "num_discrete_choices": int(self.action_space.n), + } + elif isinstance(self.action_space, spaces.MultiDiscrete): + assert self.action_space.nvec.ndim == 1, ( + f"Only one-dimensional MultiDiscrete action spaces are supported, " + f"but found MultiDiscrete({(self.action_space.nvec).tolist()})." + ) + actor_kwargs = { + "action_dim": int(np.sum(self.action_space.nvec)), + "num_discrete_choices": self.action_space.nvec, # type: ignore[dict-item] + } + elif isinstance(self.action_space, spaces.MultiBinary): + assert isinstance(self.action_space.n, int), ( + f"Multi-dimensional MultiBinary({self.action_space.n}) action space is not supported. " + "You can flatten it instead." + ) + # Handle binary action spaces as discrete action spaces with two choices. + actor_kwargs = { + "action_dim": 2 * self.action_space.n, + "num_discrete_choices": 2 * np.ones(self.action_space.n, dtype=int), + } + else: + raise NotImplementedError(f"{self.action_space}") + + + self.actor = Actor( + n_units=self.n_units, + log_std_init=self.log_std_init, + activation_fn=self.activation_fn, + **actor_kwargs, # type: ignore[arg-type] + ) + + # Initialize a dummy x input (obs, dones) + init_obs = jnp.array([self.observation_space.sample()]) + # create an array of dones to create the good x (obs, dones) + init_dones = jnp.zeros((init_obs.shape[0],)) + init_x = (init_obs[np.newaxis, :], init_dones[np.newaxis, :]) + + # TODO : See how to get the actual batch size (the number of vectorized envs) + batch_size = 1 + # give same hidden size than n_units so constant shapes in the layers + hidden_size = self.n_units + init_hstate = ScanRNN.initialize_carry(batch_size, hidden_size) + + # Hack to make gSDE work without modifying internal SB3 code + self.actor.reset_noise = self.reset_noise + + self.actor_state = TrainState.create( + apply_fn=self.actor.apply, + params=self.actor.init(actor_key, init_hstate, init_x), + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, # , eps=1e-5 + ), + ), + ) + + self.vf = Critic(n_units=self.n_units, activation_fn=self.activation_fn) + + self.vf_state = TrainState.create( + apply_fn=self.vf.apply, + # TODO : Why difference w params of actor state + params=self.vf.init({"params": vf_key}, init_hstate, init_x), + tx=optax.chain( + optax.clip_by_global_norm(max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, # , eps=1e-5 + ), + ), + ) + + self.actor.apply = jax.jit(self.actor.apply) # type: ignore[method-assign] + self.vf.apply = jax.jit(self.vf.apply) # type: ignore[method-assign] + + return key + + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + """ + self.key, self.noise_key = jax.random.split(self.key, 2) + + def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: + return self._predict(obs, deterministic=deterministic) + + # TODO : Add the lstm state to the thing ? Maybe not here + def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] + if deterministic: + return BaseJaxPolicy.select_action(self.actor_state, observation) + # Trick to use gSDE: repeat sampled noise by using the same noise key + if not self.use_sde: + self.reset_noise() + return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) + + def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray: + return self._predict_all(self.actor_state, self.vf_state, observation, key) + + @staticmethod + @jax.jit + def _predict_all(actor_state, vf_state, obervations, key): + dist = actor_state.apply_fn(actor_state.params, obervations) + actions = dist.sample(seed=key) + log_probs = dist.log_prob(actions) + values = vf_state.apply_fn(vf_state.params, obervations).flatten() + return actions, log_probs, values diff --git a/sbx/r_ppo/r_ppo.py b/sbx/r_ppo/r_ppo.py new file mode 100644 index 0000000..c6ca16a --- /dev/null +++ b/sbx/r_ppo/r_ppo.py @@ -0,0 +1,367 @@ +import warnings +from functools import partial +from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.training.train_state import TrainState +from gymnasium import spaces +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance, get_schedule_fn + +from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax +# TODO : Fix this import +from sbx.r_ppo.policies import RPPOPolicy as PPOPolicy + +RPPOSelf = TypeVar("RPPOSelf", bound="RPPO") + + +class RPPO(OnPolicyAlgorithmJax): + """ + Proximal Policy Optimization algorithm (PPO) (clip version) + + Paper: https://arxiv.org/abs/1707.06347 + Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) + https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and + Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) + + Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) + NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) + See https://github.com/pytorch/pytorch/issues/29372 + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param clip_range: Clipping parameter, it can be a function of the current progress + remaining (from 1 to 0). + :param clip_range_vf: Clipping parameter for the value function, + it can be a function of the current progress remaining (from 1 to 0). + This is a parameter specific to the OpenAI implementation. If None is passed (default), + no clipping will be done on the value function. + IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for + debug messages + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + policy_aliases: ClassVar[Dict[str, Type[PPOPolicy]]] = { # type: ignore[assignment] + "MlpPolicy": PPOPolicy, + # "CnnPolicy": ActorCriticCnnPolicy, + # "MultiInputPolicy": MultiInputActorCriticPolicy, + } + policy: PPOPolicy # type: ignore[assignment] + + def __init__( + self, + policy: Union[str, Type[PPOPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ): + super().__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + # Note: gSDE is not properly implemented, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + super()._setup_model() + + if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] + self.policy = self.policy_class( # type: ignore[assignment] + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, + ) + + self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm) + + self.key, ent_key = jax.random.split(self.key, 2) + + self.actor = self.policy.actor + self.vf = self.policy.vf + + # Initialize schedules for policy/value clipping + self.clip_range_schedule = get_schedule_fn(self.clip_range) + # if self.clip_range_vf is not None: + # if isinstance(self.clip_range_vf, (float, int)): + # assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + # + # self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + # TODO : use lstm train state + @staticmethod + @partial(jax.jit, static_argnames=["normalize_advantage"]) + def _one_update( + actor_state: TrainState, + vf_state: TrainState, + lstm_train_state: TrainState, + observations: np.ndarray, + actions: np.ndarray, + advantages: np.ndarray, + returns: np.ndarray, + old_log_prob: np.ndarray, + clip_range: float, + ent_coef: float, + vf_coef: float, + normalize_advantage: bool = True, + ): + # Normalize advantage + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if normalize_advantage and len(advantages) > 1: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # TODO : Maybe adding an lstm inside is easier for the gradients + def actor_loss(params): + dist = actor_state.apply_fn(params, observations) + log_prob = dist.log_prob(actions) + entropy = dist.entropy() + + # ratio between old and new policy, should be one at the first iteration + ratio = jnp.exp(log_prob - old_log_prob) + # clipped surrogate loss + policy_loss_1 = advantages * ratio + policy_loss_2 = advantages * jnp.clip(ratio, 1 - clip_range, 1 + clip_range) + policy_loss = -jnp.minimum(policy_loss_1, policy_loss_2).mean() + + # Entropy loss favor exploration + # Approximate entropy when no analytical form + # entropy_loss = -jnp.mean(-log_prob) + # analytical form + entropy_loss = -jnp.mean(entropy) + + total_policy_loss = policy_loss + ent_coef * entropy_loss + return total_policy_loss + + pg_loss_value, pg_grads = jax.value_and_grad(actor_loss, has_aux=False)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=pg_grads) + + def critic_loss(params): + # Value loss using the TD(gae_lambda) target + vf_values = vf_state.apply_fn(params, observations).flatten() + return ((returns - vf_values) ** 2).mean() + + vf_loss_value, vf_grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) + vf_state = vf_state.apply_gradients(grads=vf_grads) + + # TODO ? What should be lstm loss ?? Atm just give as a loss the sum of losses for actor and critic + lstm_grads = pg_grads + vf_grads + vf_state = vf_state.apply_gradients(grads=vf_grads) + + # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss + return (actor_state, vf_state), (pg_loss_value, vf_loss_value) + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Update optimizer learning rate + # self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range_schedule(self._current_progress_remaining) + + # train for n_epochs epochs + for _ in range(self.n_epochs): + # JIT only one update + # TODO : Fix the buffer here because we don't want to do permutations in it + for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to int + actions = rollout_data.actions.flatten().numpy().astype(np.int32) + else: + actions = rollout_data.actions.numpy() + + (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss) = self._one_update( + actor_state=self.policy.actor_state, + vf_state=self.policy.vf_state, + observations=rollout_data.observations.numpy(), + actions=actions, + advantages=rollout_data.advantages.numpy(), + returns=rollout_data.returns.numpy(), + old_log_prob=rollout_data.old_log_prob.numpy(), + clip_range=clip_range, + ent_coef=self.ent_coef, + vf_coef=self.vf_coef, + normalize_advantage=self.normalize_advantage, + ) + + self._n_updates += self.n_epochs + explained_var = explained_variance( + self.rollout_buffer.values.flatten(), # type: ignore[attr-defined] + self.rollout_buffer.returns.flatten(), # type: ignore[attr-defined] + ) + + # Logs + # self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) + # TODO: use mean instead of one point + self.logger.record("train/value_loss", value_loss.item()) + # self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + # self.logger.record("train/clip_fraction", np.mean(clip_fractions)) + self.logger.record("train/pg_loss", pg_loss.item()) + self.logger.record("train/explained_variance", explained_var) + # if hasattr(self.policy, "log_std"): + # self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + # if self.clip_range_vf is not None: + # self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self: RPPOSelf, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "PPO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> RPPOSelf: + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + +if __name__ == "__main__": + import gymnasium as gym + from sbx import PPO + + # env = gym.make("CartPole-v1", render_mode="human") + n_steps = 2048 + batch_size = 32 + train_steps = 5_000 + env = gym.make("CartPole-v1") + + model = PPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) + vec_env = model.get_env() + print("") + print(f"{vec_env = }") + obs = vec_env.reset() + print(f"{obs = }") + print(f"{obs.shape = }") + + + model = RPPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) + model.learn(total_timesteps=train_steps, progress_bar=True) + + vec_env = model.get_env() + print(f"\n{vec_env = }") + print("AA") + obs = vec_env.reset() + test_steps = 10 + for _ in range(test_steps): + # vec_env.render() + action, _states = model.predict(obs, deterministic=True) + print(f"\n{action.shape = }") + print(f"{obs.shape = }") + print(f"{_states = }") + obs, reward, done, info = vec_env.step(action) + + vec_env.close() \ No newline at end of file From c4d189111b18d2cfaae0fae2f9737f7de00055bd Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 16 Oct 2024 17:06:06 +0200 Subject: [PATCH 2/9] Rename files and replace gru by lstm --- sbx/__init__.py | 4 +- sbx/r_ppo/__init__.py | 3 - sbx/recurrent_ppo/__init__.py | 3 + sbx/{r_ppo => recurrent_ppo}/policies.py | 111 +++++------ .../recurrent_ppo.py} | 176 ++++++++++++++++-- 5 files changed, 220 insertions(+), 77 deletions(-) delete mode 100644 sbx/r_ppo/__init__.py create mode 100644 sbx/recurrent_ppo/__init__.py rename sbx/{r_ppo => recurrent_ppo}/policies.py (77%) rename sbx/{r_ppo/r_ppo.py => recurrent_ppo/recurrent_ppo.py} (68%) diff --git a/sbx/__init__.py b/sbx/__init__.py index 9f194ad..591a6e5 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -4,7 +4,7 @@ from sbx.ddpg import DDPG from sbx.dqn import DQN from sbx.ppo import PPO -from sbx.r_ppo import RPPO +from sbx.recurrent_ppo import RecurrentPPO from sbx.sac import SAC from sbx.td3 import TD3 from sbx.tqc import TQC @@ -28,7 +28,7 @@ def DroQ(*args, **kwargs): "DDPG", "DQN", "PPO", - "RPPO" + "RecurrentPPO" "SAC", "TD3", "TQC", diff --git a/sbx/r_ppo/__init__.py b/sbx/r_ppo/__init__.py deleted file mode 100644 index ed2e4a9..0000000 --- a/sbx/r_ppo/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from sbx.r_ppo.r_ppo import RPPO - -__all__ = ["RPPO"] diff --git a/sbx/recurrent_ppo/__init__.py b/sbx/recurrent_ppo/__init__.py new file mode 100644 index 0000000..d7cd5b7 --- /dev/null +++ b/sbx/recurrent_ppo/__init__.py @@ -0,0 +1,3 @@ +from sbx.recurrent_ppo.recurrent_ppo import RecurrentPPO + +__all__ = ["RecurrentPPO"] diff --git a/sbx/r_ppo/policies.py b/sbx/recurrent_ppo/policies.py similarity index 77% rename from sbx/r_ppo/policies.py rename to sbx/recurrent_ppo/policies.py index 978a0b9..3910d0d 100644 --- a/sbx/r_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -21,35 +21,45 @@ tfd = tfp.distributions -# TODO : Add LSTM class as a ScanRNN Module (see PureJaxRL) code from https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo_rnn.py -# TODO : at the moment take exactly the same model with GruCell + embedding space in the actor and critic before giving obs to the RNN -class ScanRNN(nn.Module): +class ScanLSTM(nn.Module): @functools.partial( nn.scan, - variable_broadcast="params", + variable_broadcast='params', in_axes=0, out_axes=0, - split_rngs={"params": False}, + split_rngs={'params': False} ) @nn.compact - def __call__(self, carry, x): - rnn_state = carry - ins, resets = x - # Handle the reset logic of rnn states here - lstm_states = jnp.where( + def __call__(self, lstm_states, inputs_and_resets): + input, resets = inputs_and_resets + hidden_state, cell_state = lstm_states + # create new lstm states to replace the old ones if reset is True + reset_lstm_states = self.initialize_carry(hidden_state.shape[0], hidden_state.shape[1]) + + # handle the reset of the hidden lstm states + hidden_state = jnp.where( resets[:, np.newaxis], - self.initialize_carry(ins.shape[0], ins.shape[1]), - rnn_state + reset_lstm_states[0], + hidden_state ) - hidden_size = rnn_state[0].shape[0] - new_lstm_states, out = nn.GRUCell(features=hidden_size)(lstm_states, ins) - return new_lstm_states, out + # handle the reset of the cell lstm states + cell_state = jnp.where( + resets[:, np.newaxis], + reset_lstm_states[1], + cell_state + ) + + lstm_states = (hidden_state, cell_state) + hidden_size = lstm_states[0].shape[-1] + + new_lstm_states, output = nn.LSTMCell(features=hidden_size)(lstm_states, input) + return new_lstm_states, output @staticmethod def initialize_carry(batch_size, hidden_size): - # like in purejaxrl, use a dummy key because default state init fn is just zeros - return nn.GRUCell(features=hidden_size).initialize_carry( + # Return the tuple of hidden and cell states as a tuple + return nn.LSTMCell(features=hidden_size).initialize_carry( rng=jax.random.PRNGKey(0), input_shape=(batch_size, hidden_size) ) @@ -60,23 +70,14 @@ class Critic(nn.Module): # return hidden state + val @nn.compact - def __call__(self, hidden, x) -> jnp.ndarray: - # Add embedding like in purejaxrl atm - obs, dones = x - # TODO : replace hardcoded 64 later - embedding = nn.Dense( - 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) - )(obs) - embedding = nn.relu(embedding) - - rnn_in = (embedding, dones) - hidden, out = ScanRNN()(hidden, rnn_in) + def __call__(self, lstm_states, obs_dones) -> jnp.ndarray: + lstm_states, out = ScanLSTM()(lstm_states, obs_dones) x = nn.Dense(self.n_units)(out) x = self.activation_fn(x) x = nn.Dense(self.n_units)(x) x = self.activation_fn(x) x = nn.Dense(1)(x) - return hidden, x + return lstm_states, x # Add scanned lstm in the actor class Actor(nn.Module): @@ -102,19 +103,12 @@ def __post_init__(self) -> None: self.split_indices = np.cumsum(self.num_discrete_choices[:-1]) super().__post_init__() - # return hidden state + dist + # return hidden state + action dist @nn.compact - def __call__(self, hidden, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + def __call__(self, hidden, obs_dones) -> tfd.Distribution: # type: ignore[name-defined] # Add embedding like in purejaxrl atm - obs, dones = x - embedding = nn.Dense( - 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) - )(obs) - embedding = nn.relu(embedding) - - rnn_in = (embedding, dones) - hidden, out = ScanRNN()(hidden, rnn_in) + hidden, out = ScanLSTM()(hidden, obs_dones) x = nn.Dense(self.n_units)(out) x = self.activation_fn(x) x = nn.Dense(self.n_units)(x) @@ -151,7 +145,7 @@ def __call__(self, hidden, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[ return hidden, dist -class RPPOPolicy(BaseJaxPolicy): +class RecurrentPPOPolicy(BaseJaxPolicy): def __init__( self, observation_space: gym.spaces.Space, @@ -203,7 +197,7 @@ def __init__( self.key = self.noise_key = jax.random.PRNGKey(0) def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> jax.Array: - key, actor_key, vf_key = jax.random.split(key, 4) + key, actor_key, vf_key = jax.random.split(key, 3) # Keep a key for the actor key, self.key = jax.random.split(key, 2) # Initialize noise @@ -254,18 +248,17 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> init_dones = jnp.zeros((init_obs.shape[0],)) init_x = (init_obs[np.newaxis, :], init_dones[np.newaxis, :]) - # TODO : See how to get the actual batch size (the number of vectorized envs) - batch_size = 1 - # give same hidden size than n_units so constant shapes in the layers + # TODO : HERE HARD CODE THE NUMBER OF ENVS (but find a way to see how to actually get it from recurrent_ppo model) + n_envs = 1 hidden_size = self.n_units - init_hstate = ScanRNN.initialize_carry(batch_size, hidden_size) + init_lstm_states = ScanLSTM.initialize_carry(n_envs, hidden_size) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise self.actor_state = TrainState.create( apply_fn=self.actor.apply, - params=self.actor.init(actor_key, init_hstate, init_x), + params=self.actor.init(actor_key, init_lstm_states, init_x), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), self.optimizer_class( @@ -280,7 +273,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> self.vf_state = TrainState.create( apply_fn=self.vf.apply, # TODO : Why difference w params of actor state - params=self.vf.init({"params": vf_key}, init_hstate, init_x), + params=self.vf.init({"params": vf_key}, init_lstm_states, init_x), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), self.optimizer_class( @@ -301,26 +294,34 @@ def reset_noise(self, batch_size: int = 1) -> None: """ self.key, self.noise_key = jax.random.split(self.key, 2) - def forward(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray: + def forward(self, obs: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: return self._predict(obs, deterministic=deterministic) # TODO : Add the lstm state to the thing ? Maybe not here - def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.ndarray: # type: ignore[override] + def _predict(self, observation: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: # type: ignore[override] if deterministic: + # TODO : include lstm_states here return BaseJaxPolicy.select_action(self.actor_state, observation) # Trick to use gSDE: repeat sampled noise by using the same noise key if not self.use_sde: self.reset_noise() + # TODO : include lstm state here return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) - def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray: - return self._predict_all(self.actor_state, self.vf_state, observation, key) + def predict_all(self, observation: np.ndarray, done, lstm_states, key: jax.Array) -> np.ndarray: + return self._predict_all(self.actor_state, self.vf_state, observation, done, lstm_states, key) @staticmethod @jax.jit - def _predict_all(actor_state, vf_state, obervations, key): - dist = actor_state.apply_fn(actor_state.params, obervations) + def _predict_all(actor_state, vf_state, observations, dones, lstm_states, key): + # TODO : check if really need to add this dimension to obs and dones + ac_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + # actor pass + act_lstm_states, dist = actor_state.apply_fn(actor_state.params, lstm_states, ac_in) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) - values = vf_state.apply_fn(vf_state.params, obervations).flatten() - return actions, log_probs, values + # value pass + vf_lstm_states, values = vf_state.apply_fn(vf_state.params, lstm_states, ac_in) + values = values.flatten() + lstm_states = (act_lstm_states, vf_lstm_states) + return actions, log_probs, values, lstm_states diff --git a/sbx/r_ppo/r_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py similarity index 68% rename from sbx/r_ppo/r_ppo.py rename to sbx/recurrent_ppo/recurrent_ppo.py index c6ca16a..dffde36 100644 --- a/sbx/r_ppo/r_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -1,23 +1,31 @@ import warnings +from copy import deepcopy from functools import partial from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import jax import jax.numpy as jnp import numpy as np +import torch as th +import gymnasium as gym from flax.training.train_state import TrainState from gymnasium import spaces +# from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.vec_env import VecEnv from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax +from sbx.common.recurrent import RecurrentRolloutBuffer # TODO : Fix this import -from sbx.r_ppo.policies import RPPOPolicy as PPOPolicy +from sbx.recurrent_ppo.policies import RecurrentPPOPolicy as PPOPolicy +from sbx.recurrent_ppo.policies import ScanLSTM -RPPOSelf = TypeVar("RPPOSelf", bound="RPPO") +RPPOSelf = TypeVar("RPPOSelf", bound="RecurrentPPO") -class RPPO(OnPolicyAlgorithmJax): +class RecurrentPPO(OnPolicyAlgorithmJax): """ Proximal Policy Optimization algorithm (PPO) (clip version) @@ -165,8 +173,11 @@ def __init__( if _init_setup_model: self._setup_model() + # TODO : Update the setup model function to add the lstm info ... (maybe not necessary because all in actor and value nets) def _setup_model(self) -> None: - super()._setup_model() + # super()._setup_model() + self._setup_lr_schedule() + self.set_random_seed(self.seed) if not hasattr(self, "policy") or self.policy is None: # type: ignore[has-type] self.policy = self.policy_class( # type: ignore[assignment] @@ -178,11 +189,31 @@ def _setup_model(self) -> None: self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm) + # TODO : what is ent_key ? self.key, ent_key = jax.random.split(self.key, 2) self.actor = self.policy.actor self.vf = self.policy.vf + hidden_state_shape = self.policy.actor.n_units + # TODO : create the last lstm states (dummy atm) --> should surely use the init carry method + lstm_states = ScanLSTM.initialize_carry(self.n_envs, hidden_state_shape) # (1, 64) because 1 env + self._last_lstm_states = lstm_states + # self._last_lstm_states = jnp.zeros((2, hidden_state_shape)) # (2, 64) because two lstm states of 1 env --> should surely be of shape (2, 1, 64) + + # TODO : Make this a recurrent rollout buffer --> add lstm states in it + self.rollout_buffer = RecurrentRolloutBuffer( + self.n_steps, + self.observation_space, + self.action_space, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + # TODO : Add the good hidden state shape + hidden_state_shape=hidden_state_shape, + device="cpu", + ) + # Initialize schedules for policy/value clipping self.clip_range_schedule = get_schedule_fn(self.clip_range) # if self.clip_range_vf is not None: @@ -191,6 +222,128 @@ def _setup_model(self) -> None: # # self.clip_range_vf = get_schedule_fn(self.clip_range_vf) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RecurrentRolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``RolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_rollout_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" # type: ignore[has-type] + # Switch to eval mode (this affects batch norm / dropout) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise() + + callback.on_rollout_start() + + # TODO : initialize the dones and lstm states + lstm_states = deepcopy(self._last_lstm_states) + dones = jnp.zeros(self.n_envs) # Check it is the right shape for dones + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise() + + if not self.use_sde or isinstance(self.action_space, gym.spaces.Discrete): + # Always sample new stochastic action + self.policy.reset_noise() + + obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type] + # TODO : check why I get a wrong shape for actions (1, 2), wheras I should only have 1 action here because only 1 env + actions, log_probs, values, lstm_states = self.policy.predict_all(obs_tensor, dones, lstm_states, self.policy.noise_key) + + actions = np.array(actions) + log_probs = np.array(log_probs) + values = np.array(values) + + # Rescale and perform action + clipped_actions = actions + # Clip the actions to avoid out of bound error + if isinstance(self.action_space, gym.spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + new_obs, rewards, dones, infos = env.step(clipped_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if callback.on_step() is False: + return False + + self._update_info_buffer(infos) + n_steps += 1 + + if isinstance(self.action_space, gym.spaces.Discrete): + # Reshape in case of discrete action + actions = actions.reshape(-1, 1) + + # Handle timeout by bootstraping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] + terminal_value = np.array( + self.vf.apply( # type: ignore[union-attr] + self.policy.vf_state.params, + # TODO : might need to also give the lstm_states and the dones here + terminal_obs, + ).flatten() + ).item() + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, # type: ignore + actions, + rewards, + self._last_episode_starts, # type: ignore + th.as_tensor(values), + th.as_tensor(log_probs), + lstm_states=self._last_lstm_states, # Should it be last lstm states ? or lstm states + ) + + + self._last_obs = new_obs # type: ignore[assignment] + self._last_episode_starts = dones + self._last_lstm_states = lstm_states + + # TODO : Compute the value by also giving the dones and the lstm states values + values = np.array( + self.vf.apply( # type: ignore[union-attr] + self.policy.vf_state.params, + self.policy.prepare_obs(new_obs)[0], # type: ignore[arg-type] + ).flatten() + ) + + rollout_buffer.compute_returns_and_advantage(last_values=th.as_tensor(values), dones=dones) + + callback.on_rollout_end() + + return True + # TODO : use lstm train state @staticmethod @partial(jax.jit, static_argnames=["normalize_advantage"]) @@ -246,8 +399,6 @@ def critic_loss(params): vf_loss_value, vf_grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) vf_state = vf_state.apply_gradients(grads=vf_grads) - # TODO ? What should be lstm loss ?? Atm just give as a loss the sum of losses for actor and critic - lstm_grads = pg_grads + vf_grads vf_state = vf_state.apply_gradients(grads=vf_grads) # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss @@ -319,6 +470,7 @@ def learn( reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> RPPOSelf: + # TODO : stuck because need to use a custom replay buffer now --> See how it is done in Sb3-Contrib return super().learn( total_timesteps=total_timesteps, callback=callback, @@ -333,7 +485,6 @@ def learn( import gymnasium as gym from sbx import PPO - # env = gym.make("CartPole-v1", render_mode="human") n_steps = 2048 batch_size = 32 train_steps = 5_000 @@ -341,27 +492,18 @@ def learn( model = PPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) vec_env = model.get_env() - print("") - print(f"{vec_env = }") obs = vec_env.reset() - print(f"{obs = }") - print(f"{obs.shape = }") - model = RPPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) + model = RecurrentPPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) model.learn(total_timesteps=train_steps, progress_bar=True) vec_env = model.get_env() - print(f"\n{vec_env = }") - print("AA") obs = vec_env.reset() test_steps = 10 for _ in range(test_steps): # vec_env.render() action, _states = model.predict(obs, deterministic=True) - print(f"\n{action.shape = }") - print(f"{obs.shape = }") - print(f"{_states = }") obs, reward, done, info = vec_env.step(action) vec_env.close() \ No newline at end of file From 58c9e03d6c9ca361ea43d666b0077f41d48978b0 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Wed, 16 Oct 2024 19:50:17 +0200 Subject: [PATCH 3/9] Fix actor output shape and add recurrent rollout buffer --- sbx/common/recurrent.py | 164 +++++++++++++++++++++++++++++ sbx/ppo/policies.py | 6 +- sbx/recurrent_ppo/policies.py | 17 +-- sbx/recurrent_ppo/recurrent_ppo.py | 20 ++-- 4 files changed, 191 insertions(+), 16 deletions(-) create mode 100644 sbx/common/recurrent.py diff --git a/sbx/common/recurrent.py b/sbx/common/recurrent.py new file mode 100644 index 0000000..a690d1e --- /dev/null +++ b/sbx/common/recurrent.py @@ -0,0 +1,164 @@ +from functools import partial +from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple + +import numpy as np +import torch as th +import jax.numpy as jnp +from gymnasium import spaces +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from stable_baselines3.common.vec_env import VecNormalize + +# class LSTMStates(NamedTuple): +# pi: Tuple[th.Tensor, ...] +# vf: Tuple[th.Tensor, ...] + +# TODO : see if I add jax info +class LSTMStates(NamedTuple): + pi: Tuple + vf: Tuple + +# Replaced th.Tensor with jnp.ndarray +class RecurrentRolloutBufferSamples(NamedTuple): + observations: jnp.ndarray + actions: jnp.ndarray + old_values: jnp.ndarray + old_log_prob: jnp.ndarray + advantages: jnp.ndarray + returns: jnp.ndarray + lstm_states: LSTMStates + +class RecurrentRolloutBuffer(RolloutBuffer): + """ + Rollout buffer that also stores the LSTM cell and hidden states. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param hidden_state_shape: Shape of the buffer that will collect lstm states + (n_steps, lstm.num_layers, n_envs, lstm.hidden_size) + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + # renamed this because I found hidden_state_shape confusing + lstm_state_buffer_shape: Tuple[int, int, int], + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + # TODO : see if I rename this in all the code + self.hidden_state_shape = lstm_state_buffer_shape + self.seq_start_indices, self.seq_end_indices = None, None + super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) + + def reset(self): + super().reset() + self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) + + # def add(self, *args, lstm_states: LSTMStates, **kwargs) -> None: + # """ + # :param hidden_states: LSTM cell and hidden state + # """ + # # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states + # self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) + # self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) + # self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) + # self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) + + def add(self, *args, lstm_states, **kwargs) -> None: + """ + :param hidden_states: LSTM cell and hidden state + """ + # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states + print(lstm_states[0][0]) + print(np.array(lstm_states[0][0])) + print(np.array(lstm_states[0][0]).shape) + print(self.hidden_states_pi[self.pos].shape) + self.hidden_states_pi[self.pos] = np.array(lstm_states[0][0]) + self.cell_states_pi[self.pos] = np.array(lstm_states[0][1]) + self.hidden_states_vf[self.pos] = np.array(lstm_states[1][0]) + self.cell_states_vf[self.pos] = np.array(lstm_states[1][1]) + + super().add(*args, **kwargs) + + def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]: + assert self.full, "Rollout buffer must be full before sampling from it" + + # Prepare the data + if not self.generator_ready: + # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size) + # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size) + for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]: + self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2) + + # flatten but keep the sequence order + # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape) + # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape) + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + "hidden_states_pi", + "cell_states_pi", + "hidden_states_vf", + "cell_states_vf", + "episode_starts", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + # Sampling strategy that doesn't allow any mini batch size (must be a product of n_envs) + indices = np.arange(self.buffer_size * self.n_envs) + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + batch_inds = indices[start_idx : start_idx + batch_size] + yield self._get_samples(batch_inds) + start_idx += batch_size + + + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RecurrentRolloutBufferSamples: + + lstm_states_pi = ( + self.hidden_states_pi[batch_inds], + self.cell_states_pi[batch_inds] + ) + + lstm_states_vf = ( + self.hidden_states_vf[batch_inds], + self.cell_states_vf[batch_inds] + ) + + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf) + ) + return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 54915c8..ca4e386 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -248,9 +248,9 @@ def predict_all(self, observation: np.ndarray, key: jax.Array) -> np.ndarray: @staticmethod @jax.jit - def _predict_all(actor_state, vf_state, obervations, key): - dist = actor_state.apply_fn(actor_state.params, obervations) + def _predict_all(actor_state, vf_state, observations, key): + dist = actor_state.apply_fn(actor_state.params, observations) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) - values = vf_state.apply_fn(vf_state.params, obervations).flatten() + values = vf_state.apply_fn(vf_state.params, observations).flatten() return actions, log_probs, values diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py index 3910d0d..d38eeed 100644 --- a/sbx/recurrent_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -109,6 +109,9 @@ def __call__(self, hidden, obs_dones) -> tfd.Distribution: # type: ignore[name- # Add embedding like in purejaxrl atm hidden, out = ScanLSTM()(hidden, obs_dones) + # TODO : check if that still works well (had a problem with a new axis=0 that shouldn't be there) + out = jnp.squeeze(out, axis=0) + x = nn.Dense(self.n_units)(out) x = self.activation_fn(x) x = nn.Dense(self.n_units)(x) @@ -297,10 +300,10 @@ def reset_noise(self, batch_size: int = 1) -> None: def forward(self, obs: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: return self._predict(obs, deterministic=deterministic) - # TODO : Add the lstm state to the thing ? Maybe not here + # TODO : Add the lstm state to the _predict_method def _predict(self, observation: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: # type: ignore[override] if deterministic: - # TODO : include lstm_states here + # TODO : pass the lstm state here (see how to do it cleanly) return BaseJaxPolicy.select_action(self.actor_state, observation) # Trick to use gSDE: repeat sampled noise by using the same noise key if not self.use_sde: @@ -314,14 +317,16 @@ def predict_all(self, observation: np.ndarray, done, lstm_states, key: jax.Array @staticmethod @jax.jit def _predict_all(actor_state, vf_state, observations, dones, lstm_states, key): + # get the lstm states for the actor and the critic + act_lstm_states, vf_lstm_states = lstm_states + # TODO : check if really need to add this dimension to obs and dones ac_in = (observations[np.newaxis, :], dones[np.newaxis, :]) - # actor pass - act_lstm_states, dist = actor_state.apply_fn(actor_state.params, lstm_states, ac_in) + act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, ac_in) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) - # value pass - vf_lstm_states, values = vf_state.apply_fn(vf_state.params, lstm_states, ac_in) + + vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, ac_in) values = values.flatten() lstm_states = (act_lstm_states, vf_lstm_states) return actions, log_probs, values, lstm_states diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py index dffde36..6b6bf06 100644 --- a/sbx/recurrent_ppo/recurrent_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env import VecEnv from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax -from sbx.common.recurrent import RecurrentRolloutBuffer +from sbx.common.recurrent import RecurrentRolloutBuffer, LSTMStates # TODO : Fix this import from sbx.recurrent_ppo.policies import RecurrentPPOPolicy as PPOPolicy from sbx.recurrent_ppo.policies import ScanLSTM @@ -173,7 +173,6 @@ def __init__( if _init_setup_model: self._setup_model() - # TODO : Update the setup model function to add the lstm info ... (maybe not necessary because all in actor and value nets) def _setup_model(self) -> None: # super()._setup_model() self._setup_lr_schedule() @@ -196,10 +195,17 @@ def _setup_model(self) -> None: self.vf = self.policy.vf hidden_state_shape = self.policy.actor.n_units - # TODO : create the last lstm states (dummy atm) --> should surely use the init carry method - lstm_states = ScanLSTM.initialize_carry(self.n_envs, hidden_state_shape) # (1, 64) because 1 env - self._last_lstm_states = lstm_states - # self._last_lstm_states = jnp.zeros((2, hidden_state_shape)) # (2, 64) because two lstm states of 1 env --> should surely be of shape (2, 1, 64) + # init one lstm state + lstm_states = ScanLSTM.initialize_carry(self.n_envs, hidden_state_shape) + # use it to initialize the lstm states of the actor and the critic + # TODO : check if I need to do a copy or not + tuple_lstm_states = LSTMStates( + pi=lstm_states, + vf=lstm_states, + ) + self._last_lstm_states = tuple_lstm_states + + lstm_state_buffer_shape = (self.n_steps, self.n_envs, hidden_state_shape) # TODO : Make this a recurrent rollout buffer --> add lstm states in it self.rollout_buffer = RecurrentRolloutBuffer( @@ -210,7 +216,7 @@ def _setup_model(self) -> None: gae_lambda=self.gae_lambda, n_envs=self.n_envs, # TODO : Add the good hidden state shape - hidden_state_shape=hidden_state_shape, + lstm_state_buffer_shape=lstm_state_buffer_shape, device="cpu", ) From 6ea0afcde3fa654f765e6c4d0cacc80987a2cda6 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 17 Oct 2024 19:40:33 +0200 Subject: [PATCH 4/9] Fix errors in collect rollouts --- sbx/common/recurrent.py | 20 ++++--- sbx/recurrent_ppo/policies.py | 15 ++--- sbx/recurrent_ppo/recurrent_ppo.py | 92 ++++++++++++++++++++---------- 3 files changed, 79 insertions(+), 48 deletions(-) diff --git a/sbx/common/recurrent.py b/sbx/common/recurrent.py index a690d1e..6d191b4 100644 --- a/sbx/common/recurrent.py +++ b/sbx/common/recurrent.py @@ -17,7 +17,8 @@ class LSTMStates(NamedTuple): pi: Tuple vf: Tuple -# Replaced th.Tensor with jnp.ndarray +# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 functions) +# Added lstm states but also dones because they are used in actor and critic class RecurrentRolloutBufferSamples(NamedTuple): observations: jnp.ndarray actions: jnp.ndarray @@ -25,6 +26,7 @@ class RecurrentRolloutBufferSamples(NamedTuple): old_log_prob: jnp.ndarray advantages: jnp.ndarray returns: jnp.ndarray + dones: jnp.ndarray lstm_states: LSTMStates class RecurrentRolloutBuffer(RolloutBuffer): @@ -43,7 +45,7 @@ class RecurrentRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ - def __init__( + def __init__( self, buffer_size: int, observation_space: spaces.Space, @@ -62,6 +64,7 @@ def __init__( def reset(self): super().reset() + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) @@ -77,19 +80,16 @@ def reset(self): # self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) # self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) - def add(self, *args, lstm_states, **kwargs) -> None: + def add(self, *args, dones, lstm_states, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states - print(lstm_states[0][0]) - print(np.array(lstm_states[0][0])) - print(np.array(lstm_states[0][0]).shape) - print(self.hidden_states_pi[self.pos].shape) self.hidden_states_pi[self.pos] = np.array(lstm_states[0][0]) self.cell_states_pi[self.pos] = np.array(lstm_states[0][1]) self.hidden_states_vf[self.pos] = np.array(lstm_states[1][0]) self.cell_states_vf[self.pos] = np.array(lstm_states[1][1]) + self.dones[self.pos] = np.array(dones) super().add(*args, **kwargs) @@ -113,6 +113,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf "log_probs", "advantages", "returns", + "dones", "hidden_states_pi", "cell_states_pi", "hidden_states_vf", @@ -126,7 +127,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - # Sampling strategy that doesn't allow any mini batch size (must be a product of n_envs) + # TODO : Check if this works well + # TODO : Sampling strategy that doesn't allow any mini batch size (must be a multiple of n_envs) indices = np.arange(self.buffer_size * self.n_envs) start_idx = 0 @@ -159,6 +161,8 @@ def _get_samples( self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), + # TODO : Check that + self.dones[batch_inds], LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf) ) return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py index d38eeed..a1998ac 100644 --- a/sbx/recurrent_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -10,8 +10,7 @@ import numpy as np import optax import tensorflow_probability.substrates.jax as tfp -# DONE : Added orthogonal to the imports -from flax.linen.initializers import constant, orthogonal +from flax.linen.initializers import constant from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule @@ -106,10 +105,7 @@ def __post_init__(self) -> None: # return hidden state + action dist @nn.compact def __call__(self, hidden, obs_dones) -> tfd.Distribution: # type: ignore[name-defined] - # Add embedding like in purejaxrl atm - hidden, out = ScanLSTM()(hidden, obs_dones) - # TODO : check if that still works well (had a problem with a new axis=0 that shouldn't be there) out = jnp.squeeze(out, axis=0) x = nn.Dense(self.n_units)(out) @@ -251,7 +247,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> init_dones = jnp.zeros((init_obs.shape[0],)) init_x = (init_obs[np.newaxis, :], init_dones[np.newaxis, :]) - # TODO : HERE HARD CODE THE NUMBER OF ENVS (but find a way to see how to actually get it from recurrent_ppo model) + # hardcode the number of envs to 1 for the initialization of the lstm states n_envs = 1 hidden_size = self.n_units init_lstm_states = ScanLSTM.initialize_carry(n_envs, hidden_size) @@ -320,13 +316,12 @@ def _predict_all(actor_state, vf_state, observations, dones, lstm_states, key): # get the lstm states for the actor and the critic act_lstm_states, vf_lstm_states = lstm_states - # TODO : check if really need to add this dimension to obs and dones - ac_in = (observations[np.newaxis, :], dones[np.newaxis, :]) - act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, ac_in) + lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) - vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, ac_in) + vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) values = values.flatten() lstm_states = (act_lstm_states, vf_lstm_states) return actions, log_probs, values, lstm_states diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py index 6b6bf06..f3b8911 100644 --- a/sbx/recurrent_ppo/recurrent_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -207,7 +207,7 @@ def _setup_model(self) -> None: lstm_state_buffer_shape = (self.n_steps, self.n_envs, hidden_state_shape) - # TODO : Make this a recurrent rollout buffer --> add lstm states in it + # TODO : Check why buffer size is n_steps instead of buffer size self.rollout_buffer = RecurrentRolloutBuffer( self.n_steps, self.observation_space, @@ -260,9 +260,8 @@ def collect_rollouts( callback.on_rollout_start() - # TODO : initialize the dones and lstm states lstm_states = deepcopy(self._last_lstm_states) - dones = jnp.zeros(self.n_envs) # Check it is the right shape for dones + dones = self._last_episode_starts while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: @@ -274,7 +273,6 @@ def collect_rollouts( self.policy.reset_noise() obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type] - # TODO : check why I get a wrong shape for actions (1, 2), wheras I should only have 1 action here because only 1 env actions, log_probs, values, lstm_states = self.policy.predict_all(obs_tensor, dones, lstm_states, self.policy.noise_key) actions = np.array(actions) @@ -302,9 +300,14 @@ def collect_rollouts( if isinstance(self.action_space, gym.spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) + + # will be used to boostreap with the value function if need (need to to a critic pass) + vf_lstm_states = lstm_states[0] + lstm_in = (obs_tensor[np.newaxis, :], dones[np.newaxis, :]) # Handle timeout by bootstraping with value function # see GitHub issue #633 + # TODO : See how we can handle the lstm states here (bc we iterate on the dones) for idx, done in enumerate(dones): if ( done @@ -312,13 +315,15 @@ def collect_rollouts( and infos[idx].get("TimeLimit.truncated", False) ): terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] - terminal_value = np.array( - self.vf.apply( # type: ignore[union-attr] - self.policy.vf_state.params, - # TODO : might need to also give the lstm_states and the dones here - terminal_obs, - ).flatten() - ).item() + + # TODO Normally should only give the obs and dones for current idx + # TODO Should maybe pre-compute this before and then just iterate over the idx when needed + vf_lstm_states, values = self.vf.apply( + self.policy.vf_state.params, + vf_lstm_states, + lstm_in + ) + terminal_value = values.flatten().item() rewards[idx] += self.gamma * terminal_value rollout_buffer.add( @@ -326,23 +331,27 @@ def collect_rollouts( actions, rewards, self._last_episode_starts, # type: ignore + # TODO : Let the th.Tensors because other sb3 functions that depend on it later th.as_tensor(values), th.as_tensor(log_probs), + dones=dones, lstm_states=self._last_lstm_states, # Should it be last lstm states ? or lstm states ) - self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones self._last_lstm_states = lstm_states - # TODO : Compute the value by also giving the dones and the lstm states values - values = np.array( - self.vf.apply( # type: ignore[union-attr] - self.policy.vf_state.params, - self.policy.prepare_obs(new_obs)[0], # type: ignore[arg-type] - ).flatten() - ) + # Compute the last values when the rollout ends to compute the advantage + vf_lstm_states = lstm_states[0] + lstm_in = (self.policy.prepare_obs(new_obs)[0][np.newaxis, :], dones[np.newaxis, :]) + + vf_lstm_states, values = self.vf.apply( + self.policy.vf_state.params, + vf_lstm_states, + lstm_in + ) + values = np.array(values).flatten() rollout_buffer.compute_returns_and_advantage(last_values=th.as_tensor(values), dones=dones) @@ -350,14 +359,15 @@ def collect_rollouts( return True - # TODO : use lstm train state + # TODO : Unjit the function to debug it @staticmethod - @partial(jax.jit, static_argnames=["normalize_advantage"]) + # @partial(jax.jit, static_argnames=["normalize_advantage"]) def _one_update( actor_state: TrainState, vf_state: TrainState, - lstm_train_state: TrainState, + lstm_states: LSTMStates, observations: np.ndarray, + dones: np.ndarray, actions: np.ndarray, advantages: np.ndarray, returns: np.ndarray, @@ -371,10 +381,14 @@ def _one_update( # Normalization does not make sense if mini batchsize == 1, see GH issue #325 if normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - - # TODO : Maybe adding an lstm inside is easier for the gradients + + # TODO : something weird here because the params aren't used and only actor_state.params instead def actor_loss(params): - dist = actor_state.apply_fn(params, observations) + lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + act_lstm_states, _ = lstm_states + + act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) + # dist = actor_state.apply_fn(params, observations) log_prob = dist.log_prob(actions) entropy = dist.entropy() @@ -397,16 +411,18 @@ def actor_loss(params): pg_loss_value, pg_grads = jax.value_and_grad(actor_loss, has_aux=False)(actor_state.params) actor_state = actor_state.apply_gradients(grads=pg_grads) - def critic_loss(params): + def critic_loss(params): + lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + _, vf_lstm_states = lstm_states # Value loss using the TD(gae_lambda) target - vf_values = vf_state.apply_fn(params, observations).flatten() + vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) + vf_values = values.flatten() + # vf_values = vf_state.apply_fn(params, observations).flatten() return ((returns - vf_values) ** 2).mean() vf_loss_value, vf_grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) vf_state = vf_state.apply_gradients(grads=vf_grads) - vf_state = vf_state.apply_gradients(grads=vf_grads) - # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss return (actor_state, vf_state), (pg_loss_value, vf_loss_value) @@ -422,7 +438,7 @@ def train(self) -> None: # train for n_epochs epochs for _ in range(self.n_epochs): # JIT only one update - # TODO : Fix the buffer here because we don't want to do permutations in it + # TODO : Fix the recurrent buffer here because we don't want to do permutations in it to get our observations, returns, advantages, etc. for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to int @@ -430,14 +446,30 @@ def train(self) -> None: else: actions = rollout_data.actions.numpy() + # TODO : 32 is the batch size + # TODO : 8 is the n_envs + # TODO : goal is to have a batch_size that is a multiple of n_envs (good here) + # TODO : assert it + # TODO : transform the LSTMStates to give them the correct shape --> + # TODO : tuple[pi, vf] + # TODO : with pi and vf tuple[hidden_state, cell_state] + # TODO : hidden_state.shape = (batch_size, hidden_size) = (32, 64) + # TODO : is normally of size during rollouts (n_envs, hidden_size) = (8, 64) + dones = rollout_data.dones.numpy() + lstm_states = LSTMStates( + pi=rollout_data.lstm_states[0].numpy(), + vf=rollout_data.lstm_states[1].numpy(), + ) (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss) = self._one_update( actor_state=self.policy.actor_state, vf_state=self.policy.vf_state, observations=rollout_data.observations.numpy(), actions=actions, + dones=dones, advantages=rollout_data.advantages.numpy(), returns=rollout_data.returns.numpy(), old_log_prob=rollout_data.old_log_prob.numpy(), + lstm_states=lstm_states, clip_range=clip_range, ent_coef=self.ent_coef, vf_coef=self.vf_coef, From 8c16259b5bedddd5c4a180b382da2e3d905253dd Mon Sep 17 00:00:00 2001 From: corentinlger Date: Fri, 18 Oct 2024 17:55:45 +0200 Subject: [PATCH 5/9] First runnable version of LSTM-PPO --- sbx/common/recurrent.py | 25 ++------ sbx/recurrent_ppo/policies.py | 8 +-- sbx/recurrent_ppo/recurrent_ppo.py | 97 ++++++++++++++---------------- 3 files changed, 55 insertions(+), 75 deletions(-) diff --git a/sbx/common/recurrent.py b/sbx/common/recurrent.py index 6d191b4..5141c80 100644 --- a/sbx/common/recurrent.py +++ b/sbx/common/recurrent.py @@ -1,18 +1,14 @@ -from functools import partial from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple import numpy as np import torch as th import jax.numpy as jnp from gymnasium import spaces +# TODO : see later how to enable DictRolloutBuffer from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize -# class LSTMStates(NamedTuple): -# pi: Tuple[th.Tensor, ...] -# vf: Tuple[th.Tensor, ...] - -# TODO : see if I add jax info +# TODO : see if I add type aliases for the NamedTuple class LSTMStates(NamedTuple): pi: Tuple vf: Tuple @@ -70,21 +66,11 @@ def reset(self): self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32) - # def add(self, *args, lstm_states: LSTMStates, **kwargs) -> None: - # """ - # :param hidden_states: LSTM cell and hidden state - # """ - # # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states - # self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy()) - # self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy()) - # self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy()) - # self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy()) - def add(self, *args, dones, lstm_states, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states + # TODO :Replace idx [0] and [1] by named tuples (pi and vf) self.hidden_states_pi[self.pos] = np.array(lstm_states[0][0]) self.cell_states_pi[self.pos] = np.array(lstm_states[0][1]) self.hidden_states_vf[self.pos] = np.array(lstm_states[1][0]) @@ -127,8 +113,8 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf if batch_size is None: batch_size = self.buffer_size * self.n_envs - # TODO : Check if this works well - # TODO : Sampling strategy that doesn't allow any mini batch size (must be a multiple of n_envs) + # TODO : See how to effectively use the indices to conserve temporal order in the batch data during updates + # TODO : I think the easisest way is to ensure the n_steps is a multiple of batch_size indices = np.arange(self.buffer_size * self.n_envs) start_idx = 0 @@ -161,7 +147,6 @@ def _get_samples( self.log_probs[batch_inds].flatten(), self.advantages[batch_inds].flatten(), self.returns[batch_inds].flatten(), - # TODO : Check that self.dones[batch_inds], LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf) ) diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py index a1998ac..8bb4c79 100644 --- a/sbx/recurrent_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -271,7 +271,6 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> self.vf_state = TrainState.create( apply_fn=self.vf.apply, - # TODO : Why difference w params of actor state params=self.vf.init({"params": vf_key}, init_lstm_states, init_x), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), @@ -296,15 +295,16 @@ def reset_noise(self, batch_size: int = 1) -> None: def forward(self, obs: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: return self._predict(obs, deterministic=deterministic) - # TODO : Add the lstm state to the _predict_method + # TODO : Add the lstm state to the _predict_method (Might also need to return them) + # Like in this recurrent ppo ex in sb3 contrib : https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html def _predict(self, observation: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: # type: ignore[override] if deterministic: - # TODO : pass the lstm state here (see how to do it cleanly) + # TODO : pass the lstm state here (see how to do it cleanly because uses a function from parent class) return BaseJaxPolicy.select_action(self.actor_state, observation) # Trick to use gSDE: repeat sampled noise by using the same noise key if not self.use_sde: self.reset_noise() - # TODO : include lstm state here + # TODO : also include lstm state here return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) def predict_all(self, observation: np.ndarray, done, lstm_states, key: jax.Array) -> np.ndarray: diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py index f3b8911..bbe4461 100644 --- a/sbx/recurrent_ppo/recurrent_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -1,6 +1,6 @@ import warnings -from copy import deepcopy from functools import partial +from copy import deepcopy from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union import jax @@ -10,7 +10,7 @@ import gymnasium as gym from flax.training.train_state import TrainState from gymnasium import spaces -# from stable_baselines3.common.buffers import RolloutBuffer + from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn @@ -18,7 +18,6 @@ from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax from sbx.common.recurrent import RecurrentRolloutBuffer, LSTMStates -# TODO : Fix this import from sbx.recurrent_ppo.policies import RecurrentPPOPolicy as PPOPolicy from sbx.recurrent_ppo.policies import ScanLSTM @@ -194,20 +193,22 @@ def _setup_model(self) -> None: self.actor = self.policy.actor self.vf = self.policy.vf - hidden_state_shape = self.policy.actor.n_units - # init one lstm state - lstm_states = ScanLSTM.initialize_carry(self.n_envs, hidden_state_shape) - # use it to initialize the lstm states of the actor and the critic - # TODO : check if I need to do a copy or not - tuple_lstm_states = LSTMStates( + # added the lstm states hidden_size + self.hidden_state_size = self.policy.actor.n_units + + # TODO : change this hardcoded value and see how to add more layers in the policy + num_lstm_layers = 1 + + # use a dummy lstm state to init the + lstm_states = ScanLSTM.initialize_carry(self.n_envs, self.hidden_state_size) + init_lstm_states = LSTMStates( pi=lstm_states, vf=lstm_states, ) - self._last_lstm_states = tuple_lstm_states + self._last_lstm_states = init_lstm_states - lstm_state_buffer_shape = (self.n_steps, self.n_envs, hidden_state_shape) + lstm_state_buffer_shape = (self.n_steps, num_lstm_layers, self.n_envs, self.hidden_state_size) - # TODO : Check why buffer size is n_steps instead of buffer size self.rollout_buffer = RecurrentRolloutBuffer( self.n_steps, self.observation_space, @@ -215,7 +216,6 @@ def _setup_model(self) -> None: gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, - # TODO : Add the good hidden state shape lstm_state_buffer_shape=lstm_state_buffer_shape, device="cpu", ) @@ -318,6 +318,7 @@ def collect_rollouts( # TODO Normally should only give the obs and dones for current idx # TODO Should maybe pre-compute this before and then just iterate over the idx when needed + # TODO This is surely slowing everything for no reason vf_lstm_states, values = self.vf.apply( self.policy.vf_state.params, vf_lstm_states, @@ -331,11 +332,10 @@ def collect_rollouts( actions, rewards, self._last_episode_starts, # type: ignore - # TODO : Let the th.Tensors because other sb3 functions that depend on it later th.as_tensor(values), th.as_tensor(log_probs), dones=dones, - lstm_states=self._last_lstm_states, # Should it be last lstm states ? or lstm states + lstm_states=self._last_lstm_states, ) self._last_obs = new_obs # type: ignore[assignment] @@ -359,9 +359,8 @@ def collect_rollouts( return True - # TODO : Unjit the function to debug it @staticmethod - # @partial(jax.jit, static_argnames=["normalize_advantage"]) + @partial(jax.jit, static_argnames=["normalize_advantage"]) def _one_update( actor_state: TrainState, vf_state: TrainState, @@ -384,7 +383,8 @@ def _one_update( # TODO : something weird here because the params aren't used and only actor_state.params instead def actor_loss(params): - lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + # TODO : see why I need to flatten dones here (otherwise error in the shapes given to the lstm) + lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) act_lstm_states, _ = lstm_states act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) @@ -412,7 +412,7 @@ def actor_loss(params): actor_state = actor_state.apply_gradients(grads=pg_grads) def critic_loss(params): - lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) _, vf_lstm_states = lstm_states # Value loss using the TD(gae_lambda) target vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) @@ -438,7 +438,6 @@ def train(self) -> None: # train for n_epochs epochs for _ in range(self.n_epochs): # JIT only one update - # TODO : Fix the recurrent buffer here because we don't want to do permutations in it to get our observations, returns, advantages, etc. for rollout_data in self.rollout_buffer.get(self.batch_size): # type: ignore[attr-defined] if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to int @@ -446,20 +445,27 @@ def train(self) -> None: else: actions = rollout_data.actions.numpy() - # TODO : 32 is the batch size - # TODO : 8 is the n_envs - # TODO : goal is to have a batch_size that is a multiple of n_envs (good here) - # TODO : assert it - # TODO : transform the LSTMStates to give them the correct shape --> - # TODO : tuple[pi, vf] - # TODO : with pi and vf tuple[hidden_state, cell_state] - # TODO : hidden_state.shape = (batch_size, hidden_size) = (32, 64) - # TODO : is normally of size during rollouts (n_envs, hidden_size) = (8, 64) dones = rollout_data.dones.numpy() + + # TODO : fix this reshape somewhere else + # in sb3 contrib, shape = (n_steps, n_lstm_layers, n_envs, hidden_size) + # here same shape in the rollout buffer + # but give a shape of (batch_size, hidden_size) to the lstm layer + lstm_states_pi = ( + rollout_data.lstm_states[0][0].numpy().reshape(self.batch_size, self.hidden_state_size), + rollout_data.lstm_states[0][1].numpy().reshape(self.batch_size, self.hidden_state_size) + ) + + lstm_states_vf = ( + rollout_data.lstm_states[1][0].numpy().reshape(self.batch_size, self.hidden_state_size), + rollout_data.lstm_states[1][1].numpy().reshape(self.batch_size, self.hidden_state_size) + ) + lstm_states = LSTMStates( - pi=rollout_data.lstm_states[0].numpy(), - vf=rollout_data.lstm_states[1].numpy(), + pi=lstm_states_pi, + vf=lstm_states_vf, ) + (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss) = self._one_update( actor_state=self.policy.actor_state, vf_state=self.policy.vf_state, @@ -508,7 +514,6 @@ def learn( reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> RPPOSelf: - # TODO : stuck because need to use a custom replay buffer now --> See how it is done in Sb3-Contrib return super().learn( total_timesteps=total_timesteps, callback=callback, @@ -521,27 +526,17 @@ def learn( if __name__ == "__main__": import gymnasium as gym - from sbx import PPO + from stable_baselines3.common.env_util import make_vec_env - n_steps = 2048 - batch_size = 32 - train_steps = 5_000 - env = gym.make("CartPole-v1") + n_steps = 128 + batch_size = 32 + train_steps = 20_000 + n_envs = 4 + env_id = "CartPole-v1" - model = PPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) - vec_env = model.get_env() - obs = vec_env.reset() - - - model = RecurrentPPO("MlpPolicy", env, n_steps=n_steps, batch_size=batch_size, verbose=1) + # create vec env and train algo + vec_env = make_vec_env(env_id, n_envs=n_envs) + model = RecurrentPPO("MlpPolicy", vec_env, n_steps=n_steps, batch_size=batch_size, verbose=1) model.learn(total_timesteps=train_steps, progress_bar=True) - vec_env = model.get_env() - obs = vec_env.reset() - test_steps = 10 - for _ in range(test_steps): - # vec_env.render() - action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = vec_env.step(action) - vec_env.close() \ No newline at end of file From 0195e8c7c989dce932194a69086aea8577706e3c Mon Sep 17 00:00:00 2001 From: corentinlger Date: Sun, 20 Oct 2024 10:10:16 +0200 Subject: [PATCH 6/9] Fix typos and update comments --- sbx/common/recurrent.py | 19 ++++----- sbx/recurrent_ppo/policies.py | 34 +++++++++++----- sbx/recurrent_ppo/recurrent_ppo.py | 63 +++++++++++++++++------------- 3 files changed, 71 insertions(+), 45 deletions(-) diff --git a/sbx/common/recurrent.py b/sbx/common/recurrent.py index 5141c80..16e47dd 100644 --- a/sbx/common/recurrent.py +++ b/sbx/common/recurrent.py @@ -8,12 +8,12 @@ from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.vec_env import VecNormalize -# TODO : see if I add type aliases for the NamedTuple +# TODO : add type aliases for the NamedTuple class LSTMStates(NamedTuple): pi: Tuple vf: Tuple -# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 functions) +# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 fns) # Added lstm states but also dones because they are used in actor and critic class RecurrentRolloutBufferSamples(NamedTuple): observations: jnp.ndarray @@ -25,6 +25,7 @@ class RecurrentRolloutBufferSamples(NamedTuple): dones: jnp.ndarray lstm_states: LSTMStates +# Add a recurrent buffer that also takes care of the lstm states and dones flags class RecurrentRolloutBuffer(RolloutBuffer): """ Rollout buffer that also stores the LSTM cell and hidden states. @@ -53,13 +54,13 @@ def __init__( gamma: float = 0.99, n_envs: int = 1, ): - # TODO : see if I rename this in all the code self.hidden_state_shape = lstm_state_buffer_shape self.seq_start_indices, self.seq_end_indices = None, None super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs) def reset(self): super().reset() + # also add the dones and all lstm states self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32) @@ -70,12 +71,11 @@ def add(self, *args, dones, lstm_states, **kwargs) -> None: """ :param hidden_states: LSTM cell and hidden state """ - # TODO :Replace idx [0] and [1] by named tuples (pi and vf) - self.hidden_states_pi[self.pos] = np.array(lstm_states[0][0]) - self.cell_states_pi[self.pos] = np.array(lstm_states[0][1]) - self.hidden_states_vf[self.pos] = np.array(lstm_states[1][0]) - self.cell_states_vf[self.pos] = np.array(lstm_states[1][1]) self.dones[self.pos] = np.array(dones) + self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0]) + self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1]) + self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0]) + self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1]) super().add(*args, **kwargs) @@ -115,6 +115,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf # TODO : See how to effectively use the indices to conserve temporal order in the batch data during updates # TODO : I think the easisest way is to ensure the n_steps is a multiple of batch_size + # TODO : But still need to be fixed at the moment (I just made sure the returned shape was right) indices = np.arange(self.buffer_size * self.n_envs) start_idx = 0 @@ -123,7 +124,7 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBuf yield self._get_samples(batch_inds) start_idx += batch_size - + # return the lstm states as an LSTMStates tuple def _get_samples( self, batch_inds: np.ndarray, diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py index 8bb4c79..afc790b 100644 --- a/sbx/recurrent_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -16,10 +16,13 @@ from stable_baselines3.common.type_aliases import Schedule from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.recurrent import LSTMStates tfd = tfp.distributions +# Added a ScanLSTM Module that automatically handles the reset of LSTM states +# inspired from the ScanRNN in purejaxrl : https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo_rnn.py class ScanLSTM(nn.Module): @functools.partial( nn.scan, @@ -30,8 +33,10 @@ class ScanLSTM(nn.Module): ) @nn.compact def __call__(self, lstm_states, inputs_and_resets): + # pass the pi and vf lstm states, as well as the obs and the resets input, resets = inputs_and_resets hidden_state, cell_state = lstm_states + # create new lstm states to replace the old ones if reset is True reset_lstm_states = self.initialize_carry(hidden_state.shape[0], hidden_state.shape[1]) @@ -54,15 +59,14 @@ def __call__(self, lstm_states, inputs_and_resets): new_lstm_states, output = nn.LSTMCell(features=hidden_size)(lstm_states, input) return new_lstm_states, output - @staticmethod def initialize_carry(batch_size, hidden_size): - # Return the tuple of hidden and cell states as a tuple + # Returns a tuple of lstm states (hidden and cell states) return nn.LSTMCell(features=hidden_size).initialize_carry( rng=jax.random.PRNGKey(0), input_shape=(batch_size, hidden_size) ) -# Add scanned rnn in the critic +# Add ScanLSTM as first element of the Critic architecture class Critic(nn.Module): n_units: int = 256 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh @@ -78,7 +82,7 @@ def __call__(self, lstm_states, obs_dones) -> jnp.ndarray: x = nn.Dense(1)(x) return lstm_states, x -# Add scanned lstm in the actor +# Add ScanLSTM as first element of the Actor architecture class Actor(nn.Module): action_dim: int n_units: int = 256 @@ -144,6 +148,7 @@ def __call__(self, hidden, obs_dones) -> tfd.Distribution: # type: ignore[name- return hidden, dist +# TODO Later : at the moment custom net_architectures are not supported for the LSTM class RecurrentPPOPolicy(BaseJaxPolicy): def __init__( self, @@ -241,10 +246,10 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> **actor_kwargs, # type: ignore[arg-type] ) - # Initialize a dummy x input (obs, dones) + # Initialize a dummy input for the LSTM layer (obs, dones) init_obs = jnp.array([self.observation_space.sample()]) - # create an array of dones to create the good x (obs, dones) init_dones = jnp.zeros((init_obs.shape[0],)) + # at the moment use this trick of adding a dimension to obs and dones to pass them to the LSTM init_x = (init_obs[np.newaxis, :], init_dones[np.newaxis, :]) # hardcode the number of envs to 1 for the initialization of the lstm states @@ -255,6 +260,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise + # pass the init lstm states as argument to the actor train state self.actor_state = TrainState.create( apply_fn=self.actor.apply, params=self.actor.init(actor_key, init_lstm_states, init_x), @@ -269,6 +275,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> self.vf = Critic(n_units=self.n_units, activation_fn=self.activation_fn) + # pass the init lstm states as argument to the critic train state self.vf_state = TrainState.create( apply_fn=self.vf.apply, params=self.vf.init({"params": vf_key}, init_lstm_states, init_x), @@ -307,21 +314,30 @@ def _predict(self, observation: np.ndarray, lstm_states, deterministic: bool = F # TODO : also include lstm state here return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) + # Added the lstm states to the predict_all method (maybe also the dones but I don't remember) def predict_all(self, observation: np.ndarray, done, lstm_states, key: jax.Array) -> np.ndarray: return self._predict_all(self.actor_state, self.vf_state, observation, done, lstm_states, key) @staticmethod @jax.jit def _predict_all(actor_state, vf_state, observations, dones, lstm_states, key): - # get the lstm states for the actor and the critic + # separate the lstm states for the actor and the critic, and prepare the input for the lstm act_lstm_states, vf_lstm_states = lstm_states - lstm_in = (observations[np.newaxis, :], dones[np.newaxis, :]) + + # pass the actor lstm states and the input to the actor act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) actions = dist.sample(seed=key) log_probs = dist.log_prob(actions) + # pass the critic lstm states and the input to the critic vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) values = values.flatten() - lstm_states = (act_lstm_states, vf_lstm_states) + + # add the actor and critic lstm states to the lstm states tuple + lstm_states = LSTMStates( + pi=act_lstm_states, + vf=vf_lstm_states + ) + return actions, log_probs, values, lstm_states diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py index bbe4461..3dcb272 100644 --- a/sbx/recurrent_ppo/recurrent_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -25,6 +25,7 @@ class RecurrentPPO(OnPolicyAlgorithmJax): + # TODO : Update documentation """ Proximal Policy Optimization algorithm (PPO) (clip version) @@ -193,22 +194,24 @@ def _setup_model(self) -> None: self.actor = self.policy.actor self.vf = self.policy.vf - # added the lstm states hidden_size + # added the lstm states hidden_size (used to initialize the lstm states and the replay buffer) + # at the moment just match the n_units of the policy + # TODO : change this to enable more complex architectures (fix the n_lstm layers to 1 for now) self.hidden_state_size = self.policy.actor.n_units - - # TODO : change this hardcoded value and see how to add more layers in the policy num_lstm_layers = 1 + lstm_state_buffer_shape = (self.n_steps, num_lstm_layers, self.n_envs, self.hidden_state_size) - # use a dummy lstm state to init the - lstm_states = ScanLSTM.initialize_carry(self.n_envs, self.hidden_state_size) + # use dummy lstm states to init the pi and the vf states + cell_hidden_lstm_states = ScanLSTM.initialize_carry(self.n_envs, self.hidden_state_size) + # add them to the global LSTMStates init_lstm_states = LSTMStates( - pi=lstm_states, - vf=lstm_states, + pi=cell_hidden_lstm_states, + vf=cell_hidden_lstm_states, ) + # update the last lstm states (like in sb3 contrib) self._last_lstm_states = init_lstm_states - lstm_state_buffer_shape = (self.n_steps, num_lstm_layers, self.n_envs, self.hidden_state_size) - + # Initialize the rollout buffer (it also encompasses the dones now as well as the lstm states) self.rollout_buffer = RecurrentRolloutBuffer( self.n_steps, self.observation_space, @@ -260,6 +263,7 @@ def collect_rollouts( callback.on_rollout_start() + # copied that from sb3 contrib lstm_states = deepcopy(self._last_lstm_states) dones = self._last_episode_starts @@ -273,6 +277,7 @@ def collect_rollouts( self.policy.reset_noise() obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type] + # use the predict_all method with the lstm states and the dones actions, log_probs, values, lstm_states = self.policy.predict_all(obs_tensor, dones, lstm_states, self.policy.noise_key) actions = np.array(actions) @@ -301,8 +306,8 @@ def collect_rollouts( # Reshape in case of discrete action actions = actions.reshape(-1, 1) - # will be used to boostreap with the value function if need (need to to a critic pass) - vf_lstm_states = lstm_states[0] + # will be used to boostrap with the value function if need (need to to a critic pass) + vf_lstm_states = lstm_states.vf lstm_in = (obs_tensor[np.newaxis, :], dones[np.newaxis, :]) # Handle timeout by bootstraping with value function @@ -317,8 +322,8 @@ def collect_rollouts( terminal_obs = self.policy.prepare_obs(infos[idx]["terminal_observation"])[0] # TODO Normally should only give the obs and dones for current idx - # TODO Should maybe pre-compute this before and then just iterate over the idx when needed - # TODO This is surely slowing everything for no reason + # TODO Should maybe pre-compute the lstm states and values before and then just iterate over the idx when needed + # TODO This is surely slowing everything for no reason (and seems false) vf_lstm_states, values = self.vf.apply( self.policy.vf_state.params, vf_lstm_states, @@ -327,6 +332,7 @@ def collect_rollouts( terminal_value = values.flatten().item() rewards[idx] += self.gamma * terminal_value + # add the dones and the lstm states to the rollout buffer rollout_buffer.add( self._last_obs, # type: ignore actions, @@ -338,14 +344,14 @@ def collect_rollouts( lstm_states=self._last_lstm_states, ) + # update the lstm states for the next iteration self._last_obs = new_obs # type: ignore[assignment] self._last_episode_starts = dones self._last_lstm_states = lstm_states - # Compute the last values when the rollout ends to compute the advantage - vf_lstm_states = lstm_states[0] + # Get the last values when the rollout ends to compute the advantages + vf_lstm_states = lstm_states.vf lstm_in = (self.policy.prepare_obs(new_obs)[0][np.newaxis, :], dones[np.newaxis, :]) - vf_lstm_states, values = self.vf.apply( self.policy.vf_state.params, vf_lstm_states, @@ -381,14 +387,13 @@ def _one_update( if normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - # TODO : something weird here because the params aren't used and only actor_state.params instead + # TODO : something weird here because the params argument isn't used and only actor_state.params instead def actor_loss(params): # TODO : see why I need to flatten dones here (otherwise error in the shapes given to the lstm) lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) act_lstm_states, _ = lstm_states act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) - # dist = actor_state.apply_fn(params, observations) log_prob = dist.log_prob(actions) entropy = dist.entropy() @@ -411,13 +416,13 @@ def actor_loss(params): pg_loss_value, pg_grads = jax.value_and_grad(actor_loss, has_aux=False)(actor_state.params) actor_state = actor_state.apply_gradients(grads=pg_grads) + # TODO : same observation as above def critic_loss(params): lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) _, vf_lstm_states = lstm_states # Value loss using the TD(gae_lambda) target vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) vf_values = values.flatten() - # vf_values = vf_state.apply_fn(params, observations).flatten() return ((returns - vf_values) ** 2).mean() vf_loss_value, vf_grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) @@ -445,12 +450,8 @@ def train(self) -> None: else: actions = rollout_data.actions.numpy() - dones = rollout_data.dones.numpy() - - # TODO : fix this reshape somewhere else - # in sb3 contrib, shape = (n_steps, n_lstm_layers, n_envs, hidden_size) - # here same shape in the rollout buffer - # but give a shape of (batch_size, hidden_size) to the lstm layer + # TODO : fix the values of the lstm states (at the moment they dot not follow the right temporal order I think) + # TODO : also fix this mechanism where I need to reshape the lstm states here lstm_states_pi = ( rollout_data.lstm_states[0][0].numpy().reshape(self.batch_size, self.hidden_state_size), rollout_data.lstm_states[0][1].numpy().reshape(self.batch_size, self.hidden_state_size) @@ -471,10 +472,12 @@ def train(self) -> None: vf_state=self.policy.vf_state, observations=rollout_data.observations.numpy(), actions=actions, - dones=dones, + # added the dones here + dones=rollout_data.dones.numpy(), advantages=rollout_data.advantages.numpy(), returns=rollout_data.returns.numpy(), old_log_prob=rollout_data.old_log_prob.numpy(), + # added the lstm states here lstm_states=lstm_states, clip_range=clip_range, ent_coef=self.ent_coef, @@ -530,7 +533,7 @@ def learn( n_steps = 128 batch_size = 32 - train_steps = 20_000 + train_steps = 10_000 n_envs = 4 env_id = "CartPole-v1" @@ -539,4 +542,10 @@ def learn( model = RecurrentPPO("MlpPolicy", vec_env, n_steps=n_steps, batch_size=batch_size, verbose=1) model.learn(total_timesteps=train_steps, progress_bar=True) + # vec_env = model.get_env() + # obs = vec_env.reset() + # for _ in range(10): + # action, lstm_states = model.predict(obs, deterministic=True) + # obs, reward, done, info = vec_env.step(action) + vec_env.close() \ No newline at end of file From f520f47e34cedbdfee8588668bbc4614c5e02978 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Thu, 24 Oct 2024 17:33:38 +0200 Subject: [PATCH 7/9] Implement predict method --- sbx/recurrent_ppo/policies.py | 96 ++++++++++++++++++++++++++---- sbx/recurrent_ppo/recurrent_ppo.py | 15 +++-- 2 files changed, 93 insertions(+), 18 deletions(-) diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py index afc790b..89708a2 100644 --- a/sbx/recurrent_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -1,7 +1,7 @@ import functools from dataclasses import field -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple import flax.linen as nn import gymnasium as gym @@ -15,7 +15,7 @@ from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule -from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.policies import BaseJaxPolicy from sbx.common.recurrent import LSTMStates tfd = tfp.distributions @@ -254,8 +254,8 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> # hardcode the number of envs to 1 for the initialization of the lstm states n_envs = 1 - hidden_size = self.n_units - init_lstm_states = ScanLSTM.initialize_carry(n_envs, hidden_size) + self.hidden_size = self.n_units + init_lstm_states = ScanLSTM.initialize_carry(n_envs, self.hidden_size) # Hack to make gSDE work without modifying internal SB3 code self.actor.reset_noise = self.reset_noise @@ -299,20 +299,92 @@ def reset_noise(self, batch_size: int = 1) -> None: """ self.key, self.noise_key = jax.random.split(self.key, 2) - def forward(self, obs: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: + def forward(self, obs: np.ndarray, lstm_states, deterministic: bool = False, key = None) -> np.ndarray: return self._predict(obs, deterministic=deterministic) - # TODO : Add the lstm state to the _predict_method (Might also need to return them) - # Like in this recurrent ppo ex in sb3 contrib : https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html - def _predict(self, observation: np.ndarray, lstm_states, deterministic: bool = False) -> np.ndarray: # type: ignore[override] + # Overrided the _predict function with a new one taking the lstm states as arguments + def _predict( + self, + observation: np.ndarray, + lstm_states: LSTMStates, + episode_start: np.ndarray, + deterministic: bool = False + ) -> Tuple[np.ndarray, LSTMStates]: + # TODO : could do a helper fn to transform the obs, dones and return lstm states and action / value + # because it is used in several parts of the code and quite verbose + lstm_in = (observation[np.newaxis, :], episode_start[np.newaxis, :]) + new_pi_lstm_states, dist = self.actor_state.apply_fn(self.actor_state.params, lstm_states.pi, lstm_in) + if deterministic: - # TODO : pass the lstm state here (see how to do it cleanly because uses a function from parent class) - return BaseJaxPolicy.select_action(self.actor_state, observation) + actions = dist.mode() + else: + actions = dist.sample(seed=self.noise_key) + # Trick to use gSDE: repeat sampled noise by using the same noise key if not self.use_sde: self.reset_noise() - # TODO : also include lstm state here - return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) + + # add the new actor and old critic lstm states to the lstm states tuple + lstm_states = LSTMStates( + pi=new_pi_lstm_states, + vf=lstm_states.vf + ) + + return actions, lstm_states + + # Overrided the predict function with a new one taking the lstm states as arguments + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + lstm_states: Optional[Tuple[jnp.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + # TODO : see if still need that + # observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[next(iter(observation.keys()))].shape[0] + else: + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if lstm_states is None: + # Initialize hidden states to zeros + init_lstm_states = ScanLSTM.initialize_carry(n_envs, self.hidden_size) + lstm_states = LSTMStates( + pi=init_lstm_states, + vf=init_lstm_states + ) + + if episode_start is None: + episode_start = jnp.array([False for _ in range(n_envs)]) + + actions, lsmt_states = self._predict( + observation, lstm_states=lstm_states, episode_start=episode_start, deterministic=deterministic + ) + + # Convert to numpy + actions = np.array(actions) + + if isinstance(self.action_space, spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # TODO : see if still need that + # Remove batch dimension if needed + # if not vectorized_env: + # actions = actions.squeeze(axis=0) + + return actions, lsmt_states + # Added the lstm states to the predict_all method (maybe also the dones but I don't remember) def predict_all(self, observation: np.ndarray, done, lstm_states, key: jax.Array) -> np.ndarray: diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py index 3dcb272..0fa65a9 100644 --- a/sbx/recurrent_ppo/recurrent_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -534,6 +534,7 @@ def learn( n_steps = 128 batch_size = 32 train_steps = 10_000 + test_steps = 10 n_envs = 4 env_id = "CartPole-v1" @@ -542,10 +543,12 @@ def learn( model = RecurrentPPO("MlpPolicy", vec_env, n_steps=n_steps, batch_size=batch_size, verbose=1) model.learn(total_timesteps=train_steps, progress_bar=True) - # vec_env = model.get_env() - # obs = vec_env.reset() - # for _ in range(10): - # action, lstm_states = model.predict(obs, deterministic=True) - # obs, reward, done, info = vec_env.step(action) + vec_env = model.get_env() + obs = vec_env.reset() + lstm_states = None - vec_env.close() \ No newline at end of file + for _ in range(test_steps): + action, lstm_states = model.predict(obs, state=lstm_states, deterministic=True) + obs, reward, done, info = vec_env.step(action) + + vec_env.close() From 8f0b95f106b6c59cb54e499ff58dd9857b9ea278 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Fri, 1 Nov 2024 22:28:07 +0100 Subject: [PATCH 8/9] Clean and simplify code --- sbx/recurrent_ppo/policies.py | 40 +++++++++--------------------- sbx/recurrent_ppo/recurrent_ppo.py | 8 ++++-- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/sbx/recurrent_ppo/policies.py b/sbx/recurrent_ppo/policies.py index 89708a2..79e7c94 100644 --- a/sbx/recurrent_ppo/policies.py +++ b/sbx/recurrent_ppo/policies.py @@ -32,13 +32,14 @@ class ScanLSTM(nn.Module): split_rngs={'params': False} ) @nn.compact - def __call__(self, lstm_states, inputs_and_resets): + def __call__(self, lstm_states, obs_and_resets): # pass the pi and vf lstm states, as well as the obs and the resets - input, resets = inputs_and_resets + obs, resets = obs_and_resets hidden_state, cell_state = lstm_states # create new lstm states to replace the old ones if reset is True - reset_lstm_states = self.initialize_carry(hidden_state.shape[0], hidden_state.shape[1]) + batch_size, hidden_size = hidden_state.shape + reset_lstm_states = self.initialize_carry(batch_size, hidden_size) # handle the reset of the hidden lstm states hidden_state = jnp.where( @@ -56,7 +57,7 @@ def __call__(self, lstm_states, inputs_and_resets): lstm_states = (hidden_state, cell_state) hidden_size = lstm_states[0].shape[-1] - new_lstm_states, output = nn.LSTMCell(features=hidden_size)(lstm_states, input) + new_lstm_states, output = nn.LSTMCell(features=hidden_size)(lstm_states, obs) return new_lstm_states, output @staticmethod @@ -122,29 +123,10 @@ def __call__(self, hidden, obs_dones) -> tfd.Distribution: # type: ignore[name- log_std = self.param("log_std", constant(self.log_std_init), (self.action_dim,)) dist = tfd.MultivariateNormalDiag(loc=action_logits, scale_diag=jnp.exp(log_std)) elif isinstance(self.num_discrete_choices, int): + # Discrete actions dist = tfd.Categorical(logits=action_logits) else: - # Split action_logits = (batch_size, total_choices=sum(self.num_discrete_choices)) - action_logits = jnp.split(action_logits, self.split_indices, axis=1) - # Pad to the maximum number of choices (required by tfp.distributions.Categorical). - # Pad by -inf, so that the probability of these invalid actions is 0. - logits_padded = jnp.stack( - [ - jnp.pad( - logit, - # logit is of shape (batch_size, n) - # only pad after dim=1, to max_num_choices - n - # pad_width=((before_dim_0, after_0), (before_dim_1, after_1)) - pad_width=((0, 0), (0, self.max_num_choices - logit.shape[1])), - constant_values=-np.inf, - ) - for logit in action_logits - ], - axis=1, - ) - dist = tfp.distributions.Independent( - tfp.distributions.Categorical(logits=logits_padded), reinterpreted_batch_ndims=1 - ) + raise ValueError("Invalid action space. Only Discrete and Continuous are supported at the moment.") return hidden, dist @@ -196,7 +178,7 @@ def __init__( self.n_units = net_arch["pi"][0] else: self.n_units = 64 - self.use_sde = use_sde + # self.use_sde = use_sde self.key = self.noise_key = jax.random.PRNGKey(0) @@ -321,8 +303,10 @@ def _predict( actions = dist.sample(seed=self.noise_key) # Trick to use gSDE: repeat sampled noise by using the same noise key - if not self.use_sde: - self.reset_noise() + # if not self.use_sde: + # self.reset_noise() + + self.reset_noise() # add the new actor and old critic lstm states to the lstm states tuple lstm_states = LSTMStates( diff --git a/sbx/recurrent_ppo/recurrent_ppo.py b/sbx/recurrent_ppo/recurrent_ppo.py index 0fa65a9..9e1820d 100644 --- a/sbx/recurrent_ppo/recurrent_ppo.py +++ b/sbx/recurrent_ppo/recurrent_ppo.py @@ -393,7 +393,9 @@ def actor_loss(params): lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) act_lstm_states, _ = lstm_states - act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) + # TODO + # act_lstm_states, dist = actor_state.apply_fn(actor_state.params, act_lstm_states, lstm_in) + act_lstm_states, dist = actor_state.apply_fn(params, act_lstm_states, lstm_in) log_prob = dist.log_prob(actions) entropy = dist.entropy() @@ -421,7 +423,9 @@ def critic_loss(params): lstm_in = (observations[np.newaxis, :], dones.flatten()[np.newaxis, :]) _, vf_lstm_states = lstm_states # Value loss using the TD(gae_lambda) target - vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) + # TODO + # vf_lstm_states, values = vf_state.apply_fn(vf_state.params, vf_lstm_states, lstm_in) + vf_lstm_states, values = vf_state.apply_fn(params, vf_lstm_states, lstm_in) vf_values = values.flatten() return ((returns - vf_values) ** 2).mean() From fb52332765d04d3639d589b2c12f1a95f32000f3 Mon Sep 17 00:00:00 2001 From: corentinlger Date: Mon, 3 Mar 2025 11:51:19 +0100 Subject: [PATCH 9/9] add testing files --- recurrent_buffer.ipynb | 1434 ++++++++++++++++++++++++++++++++++++++++ test.ipynb | 480 ++++++++++++++ test.py | 43 ++ 3 files changed, 1957 insertions(+) create mode 100644 recurrent_buffer.ipynb create mode 100644 test.ipynb create mode 100644 test.py diff --git a/recurrent_buffer.ipynb b/recurrent_buffer.ipynb new file mode 100644 index 0000000..01e9264 --- /dev/null +++ b/recurrent_buffer.ipynb @@ -0,0 +1,1434 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple\n", + "\n", + "import numpy as np\n", + "import torch as th\n", + "import jax.numpy as jnp\n", + "from gymnasium import spaces\n", + "# TODO : see later how to enable DictRolloutBuffer\n", + "from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer\n", + "from stable_baselines3.common.vec_env import VecNormalize\n", + "from stable_baselines3.common.env_util import make_vec_env\n", + "\n", + "\n", + "# TODO : add type aliases for the NamedTuple\n", + "class LSTMStates(NamedTuple):\n", + " pi: Tuple\n", + " vf: Tuple\n", + "\n", + "# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 fns)\n", + "# Added lstm states but also dones because they are used in actor and critic\n", + "class RecurrentRolloutBufferSamples(NamedTuple):\n", + " observations: jnp.ndarray\n", + " actions: jnp.ndarray\n", + " old_values: jnp.ndarray\n", + " old_log_prob: jnp.ndarray\n", + " advantages: jnp.ndarray\n", + " returns: jnp.ndarray\n", + " dones: jnp.ndarray\n", + " lstm_states: LSTMStates\n", + "\n", + "# Add a recurrent buffer that also takes care of the lstm states and dones flags\n", + "class RecurrentRolloutBuffer(RolloutBuffer):\n", + " \"\"\"\n", + " Rollout buffer that also stores the LSTM cell and hidden states.\n", + "\n", + " :param buffer_size: Max number of element in the buffer\n", + " :param observation_space: Observation space\n", + " :param action_space: Action space\n", + " :param hidden_state_shape: Shape of the buffer that will collect lstm states\n", + " (n_steps, lstm.num_layers, n_envs, lstm.hidden_size)\n", + " :param device: PyTorch device\n", + " :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator\n", + " Equivalent to classic advantage when set to 1.\n", + " :param gamma: Discount factor\n", + " :param n_envs: Number of parallel environments\n", + " \"\"\"\n", + "\n", + " def __init__( \n", + " self,\n", + " buffer_size: int,\n", + " observation_space: spaces.Space,\n", + " action_space: spaces.Space,\n", + " # renamed this because I found hidden_state_shape confusing\n", + " lstm_state_buffer_shape: Tuple[int, int, int],\n", + " device: Union[th.device, str] = \"auto\",\n", + " gae_lambda: float = 1,\n", + " gamma: float = 0.99,\n", + " n_envs: int = 1,\n", + " ): \n", + " self.hidden_state_shape = lstm_state_buffer_shape\n", + " self.seq_start_indices, self.seq_end_indices = None, None\n", + " super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)\n", + "\n", + " # TODO : remove dones because already episode starts in the buffer\n", + " def reset(self):\n", + " super().reset()\n", + " # also add the dones and all lstm states\n", + " self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)\n", + " self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + "\n", + " # TODO : remove dones because already episode starts in the buffer\n", + " def add(self, *args, dones, lstm_states, **kwargs) -> None:\n", + " \"\"\"\n", + " :param hidden_states: LSTM cell and hidden state\n", + " \"\"\"\n", + " self.dones[self.pos] = np.array(dones)\n", + " self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0])\n", + " self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1])\n", + " self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0])\n", + " self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1])\n", + "\n", + " super().add(*args, **kwargs)\n", + "\n", + " def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:\n", + " assert self.full, \"Rollout buffer must be full before sampling from it\"\n", + "\n", + " # Prepare the data\n", + " if not self.generator_ready:\n", + " # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)\n", + " # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)\n", + " for tensor in [\"hidden_states_pi\", \"cell_states_pi\", \"hidden_states_vf\", \"cell_states_vf\"]:\n", + " self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)\n", + "\n", + " # flatten but keep the sequence order\n", + " # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)\n", + " # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)\n", + " for tensor in [\n", + " \"observations\",\n", + " \"actions\",\n", + " \"values\",\n", + " \"log_probs\",\n", + " \"advantages\",\n", + " \"returns\",\n", + " \"dones\",\n", + " \"hidden_states_pi\",\n", + " \"cell_states_pi\",\n", + " \"hidden_states_vf\",\n", + " \"cell_states_vf\",\n", + " \"episode_starts\",\n", + " ]:\n", + " self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])\n", + " self.generator_ready = True\n", + "\n", + " # Return everything, don't create minibatches\n", + " if batch_size is None:\n", + " batch_size = self.buffer_size * self.n_envs\n", + "\n", + " # TODO : See how to effectively use the indices to conserve temporal order in the batch data during updates\n", + " # TODO : I think the easisest way is to ensure the n_steps is a multiple of batch_size\n", + " # TODO : But still need to be fixed at the moment (I just made sure the returned shape was right)\n", + " indices = np.arange(self.buffer_size * self.n_envs)\n", + "\n", + " start_idx = 0\n", + " while start_idx < self.buffer_size * self.n_envs:\n", + " batch_inds = indices[start_idx : start_idx + batch_size]\n", + " yield self._get_samples(batch_inds)\n", + " start_idx += batch_size\n", + "\n", + " # return the lstm states as an LSTMStates tuple\n", + " def _get_samples(\n", + " self,\n", + " batch_inds: np.ndarray,\n", + " env: Optional[VecNormalize] = None,\n", + " ) -> RecurrentRolloutBufferSamples:\n", + " \n", + " lstm_states_pi = (\n", + " self.hidden_states_pi[batch_inds],\n", + " self.cell_states_pi[batch_inds]\n", + " )\n", + "\n", + " lstm_states_vf = (\n", + " self.hidden_states_vf[batch_inds],\n", + " self.cell_states_vf[batch_inds]\n", + " )\n", + "\n", + " data = (\n", + " self.observations[batch_inds],\n", + " self.actions[batch_inds],\n", + " self.values[batch_inds].flatten(),\n", + " self.log_probs[batch_inds].flatten(),\n", + " self.advantages[batch_inds].flatten(),\n", + " self.returns[batch_inds].flatten(),\n", + " self.dones[batch_inds],\n", + " LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf)\n", + " )\n", + " return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "n_envs = 8\n", + "n_steps = 128\n", + "batch_size = 32\n", + "buffer_size = n_envs * n_steps\n", + "gamma = 0.99\n", + "n_epochs = 4\n", + "gae_lambda = 0.95\n", + "hidden_size = 64\n", + "lstm_state_buffer_shape = (n_steps, n_envs, 64)\n", + "\n", + "env_id = \"CartPole-v1\"\n", + "vec_env = make_vec_env(env_id, n_envs=n_envs)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_buffer = RecurrentRolloutBuffer(\n", + " n_steps,\n", + " vec_env.observation_space,\n", + " vec_env.action_space,\n", + " gamma=gamma,\n", + " gae_lambda=gae_lambda,\n", + " n_envs=n_envs,\n", + " lstm_state_buffer_shape=lstm_state_buffer_shape,\n", + " device=\"cpu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(n_steps):\n", + "\n", + " lstm_states = (\n", + " np.full((n_envs, hidden_size), i, dtype=np.float32),\n", + " np.full((n_envs, hidden_size), i, dtype=np.float32),\n", + " )\n", + "\n", + " lstm_states = LSTMStates(pi=lstm_states, vf=lstm_states)\n", + "\n", + " act = np.array([i + 0.1 * idx for idx in range(n_envs)]).reshape(-1, 1)\n", + "\n", + " rollout_buffer.add(\n", + " obs=np.full((n_envs, 4), i, dtype=np.float32),\n", + " action=act,\n", + " # action=np.full((n_envs, 1), i, dtype=np.float32),\n", + " reward=np.full((n_envs, ), i, dtype=np.float32),\n", + " episode_start=np.zeros((n_envs,), dtype=np.float32),\n", + " value=th.ones((n_envs, 1), dtype=th.float32),\n", + " log_prob=th.ones((n_envs, ), dtype=th.float32),\n", + " dones=np.zeros((n_envs,), dtype=np.float32),\n", + " lstm_states=lstm_states,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rollout_buffer.actions.shape = (128, 8, 1)\n", + "rollout_buffer.actions[0] = array([[0. ],\n", + " [0.1],\n", + " [0.2],\n", + " [0.3],\n", + " [0.4],\n", + " [0.5],\n", + " [0.6],\n", + " [0.7]], dtype=float32)\n", + "rollout_buffer.actions[1] = array([[1. ],\n", + " [1.1],\n", + " [1.2],\n", + " [1.3],\n", + " [1.4],\n", + " [1.5],\n", + " [1.6],\n", + " [1.7]], dtype=float32)\n", + "rollout_buffer.actions[127] = array([[127. ],\n", + " [127.1],\n", + " [127.2],\n", + " [127.3],\n", + " [127.4],\n", + " [127.5],\n", + " [127.6],\n", + " [127.7]], dtype=float32)\n" + ] + } + ], + "source": [ + "print(f\"{rollout_buffer.actions.shape = }\")\n", + "print(f\"{rollout_buffer.actions[0] = }\")\n", + "print(f\"{rollout_buffer.actions[1] = }\")\n", + "print(f\"{rollout_buffer.actions[127] = }\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ok in fact at each iteration I need to differentiate between the current idx (i) and the idx of the environment (n_envs). How can I do that ? \n", + "\n", + "- pass environment idx as the first digit of the number and then \n" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rollout_data.actions = tensor([[ 0.],\n", + " [ 1.],\n", + " [ 2.],\n", + " [ 3.],\n", + " [ 4.],\n", + " [ 5.],\n", + " [ 6.],\n", + " [ 7.],\n", + " [ 8.],\n", + " [ 9.],\n", + " [10.],\n", + " [11.],\n", + " [12.],\n", + " [13.],\n", + " [14.],\n", + " [15.],\n", + " [16.],\n", + " [17.],\n", + " [18.],\n", + " [19.],\n", + " [20.],\n", + " [21.],\n", + " [22.],\n", + " [23.],\n", + " [24.],\n", + " [25.],\n", + " [26.],\n", + " [27.],\n", + " [28.],\n", + " [29.],\n", + " [30.],\n", + " [31.]])\n", + "rollout_data.actions = tensor([[32.],\n", + " [33.],\n", + " [34.],\n", + " [35.],\n", + " [36.],\n", + " [37.],\n", + " [38.],\n", + " [39.],\n", + " [40.],\n", + " [41.],\n", + " [42.],\n", + " [43.],\n", + " [44.],\n", + " [45.],\n", + " [46.],\n", + " [47.],\n", + " [48.],\n", + " [49.],\n", + " [50.],\n", + " [51.],\n", + " [52.],\n", + " [53.],\n", + " [54.],\n", + " [55.],\n", + " [56.],\n", + " [57.],\n", + " [58.],\n", + " [59.],\n", + " [60.],\n", + " [61.],\n", + " [62.],\n", + " [63.]])\n", + "rollout_data.actions = tensor([[64.],\n", + " [65.],\n", + " [66.],\n", + " [67.],\n", + " [68.],\n", + " [69.],\n", + " [70.],\n", + " [71.],\n", + " [72.],\n", + " [73.],\n", + " [74.],\n", + " [75.],\n", + " [76.],\n", + " [77.],\n", + " [78.],\n", + " [79.],\n", + " [80.],\n", + " [81.],\n", + " [82.],\n", + " [83.],\n", + " [84.],\n", + " [85.],\n", + " [86.],\n", + " [87.],\n", + " [88.],\n", + " [89.],\n", + " [90.],\n", + " [91.],\n", + " [92.],\n", + " [93.],\n", + " [94.],\n", + " [95.]])\n", + "rollout_data.actions = tensor([[ 96.],\n", + " [ 97.],\n", + " [ 98.],\n", + " [ 99.],\n", + " [100.],\n", + " [101.],\n", + " [102.],\n", + " [103.],\n", + " [104.],\n", + " [105.],\n", + " [106.],\n", + " [107.],\n", + " [108.],\n", + " [109.],\n", + " [110.],\n", + " [111.],\n", + " [112.],\n", + " [113.],\n", + " [114.],\n", + " [115.],\n", + " [116.],\n", + " [117.],\n", + " [118.],\n", + " [119.],\n", + " [120.],\n", + " [121.],\n", + " [122.],\n", + " [123.],\n", + " [124.],\n", + " [125.],\n", + " [126.],\n", + " [127.]])\n", + "rollout_data.actions = tensor([[ 0.1000],\n", + " [ 1.1000],\n", + " [ 2.1000],\n", + " [ 3.1000],\n", + " [ 4.1000],\n", + " [ 5.1000],\n", + " [ 6.1000],\n", + " [ 7.1000],\n", + " [ 8.1000],\n", + " [ 9.1000],\n", + " [10.1000],\n", + " [11.1000],\n", + " [12.1000],\n", + " [13.1000],\n", + " [14.1000],\n", + " [15.1000],\n", + " [16.1000],\n", + " [17.1000],\n", + " [18.1000],\n", + " [19.1000],\n", + " [20.1000],\n", + " [21.1000],\n", + " [22.1000],\n", + " [23.1000],\n", + " [24.1000],\n", + " [25.1000],\n", + " [26.1000],\n", + " [27.1000],\n", + " [28.1000],\n", + " [29.1000],\n", + " [30.1000],\n", + " [31.1000]])\n", + "rollout_data.actions = tensor([[32.1000],\n", + " [33.1000],\n", + " [34.1000],\n", + " [35.1000],\n", + " [36.1000],\n", + " [37.1000],\n", + " [38.1000],\n", + " [39.1000],\n", + " [40.1000],\n", + " [41.1000],\n", + " [42.1000],\n", + " [43.1000],\n", + " [44.1000],\n", + " [45.1000],\n", + " [46.1000],\n", + " [47.1000],\n", + " [48.1000],\n", + " [49.1000],\n", + " [50.1000],\n", + " [51.1000],\n", + " [52.1000],\n", + " [53.1000],\n", + " [54.1000],\n", + " [55.1000],\n", + " [56.1000],\n", + " [57.1000],\n", + " [58.1000],\n", + " [59.1000],\n", + " [60.1000],\n", + " [61.1000],\n", + " [62.1000],\n", + " [63.1000]])\n", + "rollout_data.actions = tensor([[64.1000],\n", + " [65.1000],\n", + " [66.1000],\n", + " [67.1000],\n", + " [68.1000],\n", + " [69.1000],\n", + " [70.1000],\n", + " [71.1000],\n", + " [72.1000],\n", + " [73.1000],\n", + " [74.1000],\n", + " [75.1000],\n", + " [76.1000],\n", + " [77.1000],\n", + " [78.1000],\n", + " [79.1000],\n", + " [80.1000],\n", + " [81.1000],\n", + " [82.1000],\n", + " [83.1000],\n", + " [84.1000],\n", + " [85.1000],\n", + " [86.1000],\n", + " [87.1000],\n", + " [88.1000],\n", + " [89.1000],\n", + " [90.1000],\n", + " [91.1000],\n", + " [92.1000],\n", + " [93.1000],\n", + " [94.1000],\n", + " [95.1000]])\n", + "rollout_data.actions = tensor([[ 96.1000],\n", + " [ 97.1000],\n", + " [ 98.1000],\n", + " [ 99.1000],\n", + " [100.1000],\n", + " [101.1000],\n", + " [102.1000],\n", + " [103.1000],\n", + " [104.1000],\n", + " [105.1000],\n", + " [106.1000],\n", + " [107.1000],\n", + " [108.1000],\n", + " [109.1000],\n", + " [110.1000],\n", + " [111.1000],\n", + " [112.1000],\n", + " [113.1000],\n", + " [114.1000],\n", + " [115.1000],\n", + " [116.1000],\n", + " [117.1000],\n", + " [118.1000],\n", + " [119.1000],\n", + " [120.1000],\n", + " [121.1000],\n", + " [122.1000],\n", + " [123.1000],\n", + " [124.1000],\n", + " [125.1000],\n", + " [126.1000],\n", + " [127.1000]])\n", + "rollout_data.actions = tensor([[ 0.2000],\n", + " [ 1.2000],\n", + " [ 2.2000],\n", + " [ 3.2000],\n", + " [ 4.2000],\n", + " [ 5.2000],\n", + " [ 6.2000],\n", + " [ 7.2000],\n", + " [ 8.2000],\n", + " [ 9.2000],\n", + " [10.2000],\n", + " [11.2000],\n", + " [12.2000],\n", + " [13.2000],\n", + " [14.2000],\n", + " [15.2000],\n", + " [16.2000],\n", + " [17.2000],\n", + " [18.2000],\n", + " [19.2000],\n", + " [20.2000],\n", + " [21.2000],\n", + " [22.2000],\n", + " [23.2000],\n", + " [24.2000],\n", + " [25.2000],\n", + " [26.2000],\n", + " [27.2000],\n", + " [28.2000],\n", + " [29.2000],\n", + " [30.2000],\n", + " [31.2000]])\n", + "rollout_data.actions = tensor([[32.2000],\n", + " [33.2000],\n", + " [34.2000],\n", + " [35.2000],\n", + " [36.2000],\n", + " [37.2000],\n", + " [38.2000],\n", + " [39.2000],\n", + " [40.2000],\n", + " [41.2000],\n", + " [42.2000],\n", + " [43.2000],\n", + " [44.2000],\n", + " [45.2000],\n", + " [46.2000],\n", + " [47.2000],\n", + " [48.2000],\n", + " [49.2000],\n", + " [50.2000],\n", + " [51.2000],\n", + " [52.2000],\n", + " [53.2000],\n", + " [54.2000],\n", + " [55.2000],\n", + " [56.2000],\n", + " [57.2000],\n", + " [58.2000],\n", + " [59.2000],\n", + " [60.2000],\n", + " [61.2000],\n", + " [62.2000],\n", + " [63.2000]])\n", + "rollout_data.actions = tensor([[64.2000],\n", + " [65.2000],\n", + " [66.2000],\n", + " [67.2000],\n", + " [68.2000],\n", + " [69.2000],\n", + " [70.2000],\n", + " [71.2000],\n", + " [72.2000],\n", + " [73.2000],\n", + " [74.2000],\n", + " [75.2000],\n", + " [76.2000],\n", + " [77.2000],\n", + " [78.2000],\n", + " [79.2000],\n", + " [80.2000],\n", + " [81.2000],\n", + " [82.2000],\n", + " [83.2000],\n", + " [84.2000],\n", + " [85.2000],\n", + " [86.2000],\n", + " [87.2000],\n", + " [88.2000],\n", + " [89.2000],\n", + " [90.2000],\n", + " [91.2000],\n", + " [92.2000],\n", + " [93.2000],\n", + " [94.2000],\n", + " [95.2000]])\n", + "rollout_data.actions = tensor([[ 96.2000],\n", + " [ 97.2000],\n", + " [ 98.2000],\n", + " [ 99.2000],\n", + " [100.2000],\n", + " [101.2000],\n", + " [102.2000],\n", + " [103.2000],\n", + " [104.2000],\n", + " [105.2000],\n", + " [106.2000],\n", + " [107.2000],\n", + " [108.2000],\n", + " [109.2000],\n", + " [110.2000],\n", + " [111.2000],\n", + " [112.2000],\n", + " [113.2000],\n", + " [114.2000],\n", + " [115.2000],\n", + " [116.2000],\n", + " [117.2000],\n", + " [118.2000],\n", + " [119.2000],\n", + " [120.2000],\n", + " [121.2000],\n", + " [122.2000],\n", + " [123.2000],\n", + " [124.2000],\n", + " [125.2000],\n", + " [126.2000],\n", + " [127.2000]])\n", + "rollout_data.actions = tensor([[ 0.3000],\n", + " [ 1.3000],\n", + " [ 2.3000],\n", + " [ 3.3000],\n", + " [ 4.3000],\n", + " [ 5.3000],\n", + " [ 6.3000],\n", + " [ 7.3000],\n", + " [ 8.3000],\n", + " [ 9.3000],\n", + " [10.3000],\n", + " [11.3000],\n", + " [12.3000],\n", + " [13.3000],\n", + " [14.3000],\n", + " [15.3000],\n", + " [16.3000],\n", + " [17.3000],\n", + " [18.3000],\n", + " [19.3000],\n", + " [20.3000],\n", + " [21.3000],\n", + " [22.3000],\n", + " [23.3000],\n", + " [24.3000],\n", + " [25.3000],\n", + " [26.3000],\n", + " [27.3000],\n", + " [28.3000],\n", + " [29.3000],\n", + " [30.3000],\n", + " [31.3000]])\n", + "rollout_data.actions = tensor([[32.3000],\n", + " [33.3000],\n", + " [34.3000],\n", + " [35.3000],\n", + " [36.3000],\n", + " [37.3000],\n", + " [38.3000],\n", + " [39.3000],\n", + " [40.3000],\n", + " [41.3000],\n", + " [42.3000],\n", + " [43.3000],\n", + " [44.3000],\n", + " [45.3000],\n", + " [46.3000],\n", + " [47.3000],\n", + " [48.3000],\n", + " [49.3000],\n", + " [50.3000],\n", + " [51.3000],\n", + " [52.3000],\n", + " [53.3000],\n", + " [54.3000],\n", + " [55.3000],\n", + " [56.3000],\n", + " [57.3000],\n", + " [58.3000],\n", + " [59.3000],\n", + " [60.3000],\n", + " [61.3000],\n", + " [62.3000],\n", + " [63.3000]])\n", + "rollout_data.actions = tensor([[64.3000],\n", + " [65.3000],\n", + " [66.3000],\n", + " [67.3000],\n", + " [68.3000],\n", + " [69.3000],\n", + " [70.3000],\n", + " [71.3000],\n", + " [72.3000],\n", + " [73.3000],\n", + " [74.3000],\n", + " [75.3000],\n", + " [76.3000],\n", + " [77.3000],\n", + " [78.3000],\n", + " [79.3000],\n", + " [80.3000],\n", + " [81.3000],\n", + " [82.3000],\n", + " [83.3000],\n", + " [84.3000],\n", + " [85.3000],\n", + " [86.3000],\n", + " [87.3000],\n", + " [88.3000],\n", + " [89.3000],\n", + " [90.3000],\n", + " [91.3000],\n", + " [92.3000],\n", + " [93.3000],\n", + " [94.3000],\n", + " [95.3000]])\n", + "rollout_data.actions = tensor([[ 96.3000],\n", + " [ 97.3000],\n", + " [ 98.3000],\n", + " [ 99.3000],\n", + " [100.3000],\n", + " [101.3000],\n", + " [102.3000],\n", + " [103.3000],\n", + " [104.3000],\n", + " [105.3000],\n", + " [106.3000],\n", + " [107.3000],\n", + " [108.3000],\n", + " [109.3000],\n", + " [110.3000],\n", + " [111.3000],\n", + " [112.3000],\n", + " [113.3000],\n", + " [114.3000],\n", + " [115.3000],\n", + " [116.3000],\n", + " [117.3000],\n", + " [118.3000],\n", + " [119.3000],\n", + " [120.3000],\n", + " [121.3000],\n", + " [122.3000],\n", + " [123.3000],\n", + " [124.3000],\n", + " [125.3000],\n", + " [126.3000],\n", + " [127.3000]])\n", + "rollout_data.actions = tensor([[ 0.4000],\n", + " [ 1.4000],\n", + " [ 2.4000],\n", + " [ 3.4000],\n", + " [ 4.4000],\n", + " [ 5.4000],\n", + " [ 6.4000],\n", + " [ 7.4000],\n", + " [ 8.4000],\n", + " [ 9.4000],\n", + " [10.4000],\n", + " [11.4000],\n", + " [12.4000],\n", + " [13.4000],\n", + " [14.4000],\n", + " [15.4000],\n", + " [16.4000],\n", + " [17.4000],\n", + " [18.4000],\n", + " [19.4000],\n", + " [20.4000],\n", + " [21.4000],\n", + " [22.4000],\n", + " [23.4000],\n", + " [24.4000],\n", + " [25.4000],\n", + " [26.4000],\n", + " [27.4000],\n", + " [28.4000],\n", + " [29.4000],\n", + " [30.4000],\n", + " [31.4000]])\n", + "rollout_data.actions = tensor([[32.4000],\n", + " [33.4000],\n", + " [34.4000],\n", + " [35.4000],\n", + " [36.4000],\n", + " [37.4000],\n", + " [38.4000],\n", + " [39.4000],\n", + " [40.4000],\n", + " [41.4000],\n", + " [42.4000],\n", + " [43.4000],\n", + " [44.4000],\n", + " [45.4000],\n", + " [46.4000],\n", + " [47.4000],\n", + " [48.4000],\n", + " [49.4000],\n", + " [50.4000],\n", + " [51.4000],\n", + " [52.4000],\n", + " [53.4000],\n", + " [54.4000],\n", + " [55.4000],\n", + " [56.4000],\n", + " [57.4000],\n", + " [58.4000],\n", + " [59.4000],\n", + " [60.4000],\n", + " [61.4000],\n", + " [62.4000],\n", + " [63.4000]])\n", + "rollout_data.actions = tensor([[64.4000],\n", + " [65.4000],\n", + " [66.4000],\n", + " [67.4000],\n", + " [68.4000],\n", + " [69.4000],\n", + " [70.4000],\n", + " [71.4000],\n", + " [72.4000],\n", + " [73.4000],\n", + " [74.4000],\n", + " [75.4000],\n", + " [76.4000],\n", + " [77.4000],\n", + " [78.4000],\n", + " [79.4000],\n", + " [80.4000],\n", + " [81.4000],\n", + " [82.4000],\n", + " [83.4000],\n", + " [84.4000],\n", + " [85.4000],\n", + " [86.4000],\n", + " [87.4000],\n", + " [88.4000],\n", + " [89.4000],\n", + " [90.4000],\n", + " [91.4000],\n", + " [92.4000],\n", + " [93.4000],\n", + " [94.4000],\n", + " [95.4000]])\n", + "rollout_data.actions = tensor([[ 96.4000],\n", + " [ 97.4000],\n", + " [ 98.4000],\n", + " [ 99.4000],\n", + " [100.4000],\n", + " [101.4000],\n", + " [102.4000],\n", + " [103.4000],\n", + " [104.4000],\n", + " [105.4000],\n", + " [106.4000],\n", + " [107.4000],\n", + " [108.4000],\n", + " [109.4000],\n", + " [110.4000],\n", + " [111.4000],\n", + " [112.4000],\n", + " [113.4000],\n", + " [114.4000],\n", + " [115.4000],\n", + " [116.4000],\n", + " [117.4000],\n", + " [118.4000],\n", + " [119.4000],\n", + " [120.4000],\n", + " [121.4000],\n", + " [122.4000],\n", + " [123.4000],\n", + " [124.4000],\n", + " [125.4000],\n", + " [126.4000],\n", + " [127.4000]])\n", + "rollout_data.actions = tensor([[ 0.5000],\n", + " [ 1.5000],\n", + " [ 2.5000],\n", + " [ 3.5000],\n", + " [ 4.5000],\n", + " [ 5.5000],\n", + " [ 6.5000],\n", + " [ 7.5000],\n", + " [ 8.5000],\n", + " [ 9.5000],\n", + " [10.5000],\n", + " [11.5000],\n", + " [12.5000],\n", + " [13.5000],\n", + " [14.5000],\n", + " [15.5000],\n", + " [16.5000],\n", + " [17.5000],\n", + " [18.5000],\n", + " [19.5000],\n", + " [20.5000],\n", + " [21.5000],\n", + " [22.5000],\n", + " [23.5000],\n", + " [24.5000],\n", + " [25.5000],\n", + " [26.5000],\n", + " [27.5000],\n", + " [28.5000],\n", + " [29.5000],\n", + " [30.5000],\n", + " [31.5000]])\n", + "rollout_data.actions = tensor([[32.5000],\n", + " [33.5000],\n", + " [34.5000],\n", + " [35.5000],\n", + " [36.5000],\n", + " [37.5000],\n", + " [38.5000],\n", + " [39.5000],\n", + " [40.5000],\n", + " [41.5000],\n", + " [42.5000],\n", + " [43.5000],\n", + " [44.5000],\n", + " [45.5000],\n", + " [46.5000],\n", + " [47.5000],\n", + " [48.5000],\n", + " [49.5000],\n", + " [50.5000],\n", + " [51.5000],\n", + " [52.5000],\n", + " [53.5000],\n", + " [54.5000],\n", + " [55.5000],\n", + " [56.5000],\n", + " [57.5000],\n", + " [58.5000],\n", + " [59.5000],\n", + " [60.5000],\n", + " [61.5000],\n", + " [62.5000],\n", + " [63.5000]])\n", + "rollout_data.actions = tensor([[64.5000],\n", + " [65.5000],\n", + " [66.5000],\n", + " [67.5000],\n", + " [68.5000],\n", + " [69.5000],\n", + " [70.5000],\n", + " [71.5000],\n", + " [72.5000],\n", + " [73.5000],\n", + " [74.5000],\n", + " [75.5000],\n", + " [76.5000],\n", + " [77.5000],\n", + " [78.5000],\n", + " [79.5000],\n", + " [80.5000],\n", + " [81.5000],\n", + " [82.5000],\n", + " [83.5000],\n", + " [84.5000],\n", + " [85.5000],\n", + " [86.5000],\n", + " [87.5000],\n", + " [88.5000],\n", + " [89.5000],\n", + " [90.5000],\n", + " [91.5000],\n", + " [92.5000],\n", + " [93.5000],\n", + " [94.5000],\n", + " [95.5000]])\n", + "rollout_data.actions = tensor([[ 96.5000],\n", + " [ 97.5000],\n", + " [ 98.5000],\n", + " [ 99.5000],\n", + " [100.5000],\n", + " [101.5000],\n", + " [102.5000],\n", + " [103.5000],\n", + " [104.5000],\n", + " [105.5000],\n", + " [106.5000],\n", + " [107.5000],\n", + " [108.5000],\n", + " [109.5000],\n", + " [110.5000],\n", + " [111.5000],\n", + " [112.5000],\n", + " [113.5000],\n", + " [114.5000],\n", + " [115.5000],\n", + " [116.5000],\n", + " [117.5000],\n", + " [118.5000],\n", + " [119.5000],\n", + " [120.5000],\n", + " [121.5000],\n", + " [122.5000],\n", + " [123.5000],\n", + " [124.5000],\n", + " [125.5000],\n", + " [126.5000],\n", + " [127.5000]])\n", + "rollout_data.actions = tensor([[ 0.6000],\n", + " [ 1.6000],\n", + " [ 2.6000],\n", + " [ 3.6000],\n", + " [ 4.6000],\n", + " [ 5.6000],\n", + " [ 6.6000],\n", + " [ 7.6000],\n", + " [ 8.6000],\n", + " [ 9.6000],\n", + " [10.6000],\n", + " [11.6000],\n", + " [12.6000],\n", + " [13.6000],\n", + " [14.6000],\n", + " [15.6000],\n", + " [16.6000],\n", + " [17.6000],\n", + " [18.6000],\n", + " [19.6000],\n", + " [20.6000],\n", + " [21.6000],\n", + " [22.6000],\n", + " [23.6000],\n", + " [24.6000],\n", + " [25.6000],\n", + " [26.6000],\n", + " [27.6000],\n", + " [28.6000],\n", + " [29.6000],\n", + " [30.6000],\n", + " [31.6000]])\n", + "rollout_data.actions = tensor([[32.6000],\n", + " [33.6000],\n", + " [34.6000],\n", + " [35.6000],\n", + " [36.6000],\n", + " [37.6000],\n", + " [38.6000],\n", + " [39.6000],\n", + " [40.6000],\n", + " [41.6000],\n", + " [42.6000],\n", + " [43.6000],\n", + " [44.6000],\n", + " [45.6000],\n", + " [46.6000],\n", + " [47.6000],\n", + " [48.6000],\n", + " [49.6000],\n", + " [50.6000],\n", + " [51.6000],\n", + " [52.6000],\n", + " [53.6000],\n", + " [54.6000],\n", + " [55.6000],\n", + " [56.6000],\n", + " [57.6000],\n", + " [58.6000],\n", + " [59.6000],\n", + " [60.6000],\n", + " [61.6000],\n", + " [62.6000],\n", + " [63.6000]])\n", + "rollout_data.actions = tensor([[64.6000],\n", + " [65.6000],\n", + " [66.6000],\n", + " [67.6000],\n", + " [68.6000],\n", + " [69.6000],\n", + " [70.6000],\n", + " [71.6000],\n", + " [72.6000],\n", + " [73.6000],\n", + " [74.6000],\n", + " [75.6000],\n", + " [76.6000],\n", + " [77.6000],\n", + " [78.6000],\n", + " [79.6000],\n", + " [80.6000],\n", + " [81.6000],\n", + " [82.6000],\n", + " [83.6000],\n", + " [84.6000],\n", + " [85.6000],\n", + " [86.6000],\n", + " [87.6000],\n", + " [88.6000],\n", + " [89.6000],\n", + " [90.6000],\n", + " [91.6000],\n", + " [92.6000],\n", + " [93.6000],\n", + " [94.6000],\n", + " [95.6000]])\n", + "rollout_data.actions = tensor([[ 96.6000],\n", + " [ 97.6000],\n", + " [ 98.6000],\n", + " [ 99.6000],\n", + " [100.6000],\n", + " [101.6000],\n", + " [102.6000],\n", + " [103.6000],\n", + " [104.6000],\n", + " [105.6000],\n", + " [106.6000],\n", + " [107.6000],\n", + " [108.6000],\n", + " [109.6000],\n", + " [110.6000],\n", + " [111.6000],\n", + " [112.6000],\n", + " [113.6000],\n", + " [114.6000],\n", + " [115.6000],\n", + " [116.6000],\n", + " [117.6000],\n", + " [118.6000],\n", + " [119.6000],\n", + " [120.6000],\n", + " [121.6000],\n", + " [122.6000],\n", + " [123.6000],\n", + " [124.6000],\n", + " [125.6000],\n", + " [126.6000],\n", + " [127.6000]])\n", + "rollout_data.actions = tensor([[ 0.7000],\n", + " [ 1.7000],\n", + " [ 2.7000],\n", + " [ 3.7000],\n", + " [ 4.7000],\n", + " [ 5.7000],\n", + " [ 6.7000],\n", + " [ 7.7000],\n", + " [ 8.7000],\n", + " [ 9.7000],\n", + " [10.7000],\n", + " [11.7000],\n", + " [12.7000],\n", + " [13.7000],\n", + " [14.7000],\n", + " [15.7000],\n", + " [16.7000],\n", + " [17.7000],\n", + " [18.7000],\n", + " [19.7000],\n", + " [20.7000],\n", + " [21.7000],\n", + " [22.7000],\n", + " [23.7000],\n", + " [24.7000],\n", + " [25.7000],\n", + " [26.7000],\n", + " [27.7000],\n", + " [28.7000],\n", + " [29.7000],\n", + " [30.7000],\n", + " [31.7000]])\n", + "rollout_data.actions = tensor([[32.7000],\n", + " [33.7000],\n", + " [34.7000],\n", + " [35.7000],\n", + " [36.7000],\n", + " [37.7000],\n", + " [38.7000],\n", + " [39.7000],\n", + " [40.7000],\n", + " [41.7000],\n", + " [42.7000],\n", + " [43.7000],\n", + " [44.7000],\n", + " [45.7000],\n", + " [46.7000],\n", + " [47.7000],\n", + " [48.7000],\n", + " [49.7000],\n", + " [50.7000],\n", + " [51.7000],\n", + " [52.7000],\n", + " [53.7000],\n", + " [54.7000],\n", + " [55.7000],\n", + " [56.7000],\n", + " [57.7000],\n", + " [58.7000],\n", + " [59.7000],\n", + " [60.7000],\n", + " [61.7000],\n", + " [62.7000],\n", + " [63.7000]])\n", + "rollout_data.actions = tensor([[64.7000],\n", + " [65.7000],\n", + " [66.7000],\n", + " [67.7000],\n", + " [68.7000],\n", + " [69.7000],\n", + " [70.7000],\n", + " [71.7000],\n", + " [72.7000],\n", + " [73.7000],\n", + " [74.7000],\n", + " [75.7000],\n", + " [76.7000],\n", + " [77.7000],\n", + " [78.7000],\n", + " [79.7000],\n", + " [80.7000],\n", + " [81.7000],\n", + " [82.7000],\n", + " [83.7000],\n", + " [84.7000],\n", + " [85.7000],\n", + " [86.7000],\n", + " [87.7000],\n", + " [88.7000],\n", + " [89.7000],\n", + " [90.7000],\n", + " [91.7000],\n", + " [92.7000],\n", + " [93.7000],\n", + " [94.7000],\n", + " [95.7000]])\n", + "rollout_data.actions = tensor([[ 96.7000],\n", + " [ 97.7000],\n", + " [ 98.7000],\n", + " [ 99.7000],\n", + " [100.7000],\n", + " [101.7000],\n", + " [102.7000],\n", + " [103.7000],\n", + " [104.7000],\n", + " [105.7000],\n", + " [106.7000],\n", + " [107.7000],\n", + " [108.7000],\n", + " [109.7000],\n", + " [110.7000],\n", + " [111.7000],\n", + " [112.7000],\n", + " [113.7000],\n", + " [114.7000],\n", + " [115.7000],\n", + " [116.7000],\n", + " [117.7000],\n", + " [118.7000],\n", + " [119.7000],\n", + " [120.7000],\n", + " [121.7000],\n", + " [122.7000],\n", + " [123.7000],\n", + " [124.7000],\n", + " [125.7000],\n", + " [126.7000],\n", + " [127.7000]])\n", + "count = 32\n" + ] + } + ], + "source": [ + "# train for n_epochs epochs\n", + "# for j in range(n_epochs):\n", + "count = 0\n", + "for rollout_data in rollout_buffer.get(batch_size): # type: ignore[attr-defined]\n", + " print(f\"{rollout_data.actions = }\")\n", + " count += 1\n", + "\n", + "print(f\"{count = }\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Seems like the batches are quite ok with the custom actions I give \n", + "- Now need to check if it is also the case for the lstm components" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])\n", + "count = 32\n" + ] + } + ], + "source": [ + "# train for n_epochs epochs\n", + "# for j in range(n_epochs):\n", + "count = 0\n", + "for rollout_data in rollout_buffer.get(batch_size): # type: ignore[attr-defined]\n", + " count += 1\n", + " print(f\"{rollout_data.lstm_states[0][0].shape = }\")\n", + "\n", + "print(f\"{count = }\")\n", + "\n", + "# lstm_states_pi = (\n", + "# rollout_data.lstm_states[0][0].numpy().reshape(batch_size, hidden_state_size),\n", + "# rollout_data.lstm_states[0][1].numpy().reshape(batch_size, hidden_state_size)\n", + "# )\n", + "\n", + "# lstm_states_vf = (\n", + "# rollout_data.lstm_states[1][0].numpy().reshape(batch_size, hidden_state_size),\n", + "# rollout_data.lstm_states[1][1].numpy().reshape(batch_size, hidden_state_size)\n", + "# )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..bc605e9 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,480 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple\n", + "\n", + "import numpy as np\n", + "import torch as th\n", + "import jax.numpy as jnp\n", + "from gymnasium import spaces\n", + "from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer\n", + "from stable_baselines3.common.vec_env import VecNormalize\n", + "\n", + "# TODO : see if I add jax info\n", + "class LSTMStates(NamedTuple):\n", + " pi: Tuple\n", + " vf: Tuple\n", + "\n", + "# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 functions)\n", + "# Added lstm states but also dones because they are used in actor and critic\n", + "class RecurrentRolloutBufferSamples(NamedTuple):\n", + " observations: jnp.ndarray\n", + " actions: jnp.ndarray\n", + " old_values: jnp.ndarray\n", + " old_log_prob: jnp.ndarray\n", + " advantages: jnp.ndarray\n", + " returns: jnp.ndarray\n", + " dones: jnp.ndarray\n", + " lstm_states: LSTMStates" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class RecurrentRolloutBuffer(RolloutBuffer):\n", + " \"\"\"\n", + " Rollout buffer that also stores the LSTM cell and hidden states.\n", + "\n", + " :param buffer_size: Max number of element in the buffer\n", + " :param observation_space: Observation space\n", + " :param action_space: Action space\n", + " :param hidden_state_shape: Shape of the buffer that will collect lstm states\n", + " (n_steps, lstm.num_layers, n_envs, lstm.hidden_size)\n", + " :param device: PyTorch device\n", + " :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator\n", + " Equivalent to classic advantage when set to 1.\n", + " :param gamma: Discount factor\n", + " :param n_envs: Number of parallel environments\n", + " \"\"\"\n", + "\n", + " def __init__( \n", + " self,\n", + " buffer_size: int,\n", + " observation_space: spaces.Space,\n", + " action_space: spaces.Space,\n", + " # renamed this because I found hidden_state_shape confusing\n", + " lstm_state_buffer_shape: Tuple[int, int, int],\n", + " device: Union[th.device, str] = \"auto\",\n", + " gae_lambda: float = 1,\n", + " gamma: float = 0.99,\n", + " n_envs: int = 1,\n", + " ): \n", + " # TODO : see if I rename this in all the code\n", + " self.hidden_state_shape = lstm_state_buffer_shape\n", + " self.seq_start_indices, self.seq_end_indices = None, None\n", + " super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)\n", + "\n", + " def reset(self):\n", + " super().reset()\n", + " self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)\n", + " self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + " self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)\n", + "\n", + " # def add(self, *args, lstm_states: LSTMStates, **kwargs) -> None:\n", + " # \"\"\"\n", + " # :param hidden_states: LSTM cell and hidden state\n", + " # \"\"\"\n", + " # # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states\n", + " # self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0].cpu().numpy())\n", + " # self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1].cpu().numpy())\n", + " # self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0].cpu().numpy())\n", + " # self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1].cpu().numpy())\n", + "\n", + " def add(self, *args, dones, lstm_states, **kwargs) -> None:\n", + " \"\"\"\n", + " :param hidden_states: LSTM cell and hidden state\n", + " \"\"\"\n", + " # TODO : at the moment doesn't work because I didn't create a named tuple for lstm states\n", + " self.hidden_states_pi[self.pos] = np.array(lstm_states[0][0])\n", + " self.cell_states_pi[self.pos] = np.array(lstm_states[0][1])\n", + " self.hidden_states_vf[self.pos] = np.array(lstm_states[1][0])\n", + " self.cell_states_vf[self.pos] = np.array(lstm_states[1][1])\n", + " self.dones[self.pos] = np.array(dones)\n", + "\n", + " super().add(*args, **kwargs)\n", + "\n", + " def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:\n", + " assert self.full, \"Rollout buffer must be full before sampling from it\"\n", + "\n", + " # Prepare the data\n", + " if not self.generator_ready:\n", + " # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)\n", + " # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)\n", + " for tensor in [\"hidden_states_pi\", \"cell_states_pi\", \"hidden_states_vf\", \"cell_states_vf\"]:\n", + " self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)\n", + "\n", + " # flatten but keep the sequence order\n", + " # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)\n", + " # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)\n", + " for tensor in [\n", + " \"observations\",\n", + " \"actions\",\n", + " \"values\",\n", + " \"log_probs\",\n", + " \"advantages\",\n", + " \"returns\",\n", + " \"dones\",\n", + " \"hidden_states_pi\",\n", + " \"cell_states_pi\",\n", + " \"hidden_states_vf\",\n", + " \"cell_states_vf\",\n", + " \"episode_starts\",\n", + " ]:\n", + " self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])\n", + " self.generator_ready = True\n", + "\n", + " # Return everything, don't create minibatches\n", + " if batch_size is None:\n", + " batch_size = self.buffer_size * self.n_envs\n", + "\n", + " # TODO : Check if this works well \n", + " # TODO : Sampling strategy that doesn't allow any mini batch size (must be a multiple of n_envs)\n", + " indices = np.arange(self.buffer_size * self.n_envs)\n", + "\n", + " start_idx = 0\n", + " while start_idx < self.buffer_size * self.n_envs:\n", + " batch_inds = indices[start_idx : start_idx + batch_size]\n", + " yield self._get_samples(batch_inds)\n", + " start_idx += batch_size\n", + "\n", + "\n", + " def _get_samples(\n", + " self,\n", + " batch_inds: np.ndarray,\n", + " env: Optional[VecNormalize] = None,\n", + " ) -> RecurrentRolloutBufferSamples:\n", + " \n", + " lstm_states_pi = (\n", + " self.hidden_states_pi[batch_inds],\n", + " self.cell_states_pi[batch_inds]\n", + " )\n", + "\n", + " lstm_states_vf = (\n", + " self.hidden_states_vf[batch_inds],\n", + " self.cell_states_vf[batch_inds]\n", + " )\n", + "\n", + " data = (\n", + " self.observations[batch_inds],\n", + " self.actions[batch_inds],\n", + " self.values[batch_inds].flatten(),\n", + " self.log_probs[batch_inds].flatten(),\n", + " self.advantages[batch_inds].flatten(),\n", + " self.returns[batch_inds].flatten(),\n", + " # TODO : Check that\n", + " self.dones[batch_inds],\n", + " LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf)\n", + " )\n", + " return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import flax.linen as nn \n", + "\n", + "import functools\n", + "import numpy as np\n", + "import gymnasium as gym\n", + "from stable_baselines3.common.env_util import make_vec_env\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "n_envs = 8\n", + "n_steps = 128\n", + "batch_size = 32\n", + "buffer_size = n_envs * n_steps\n", + "gamma = 0.99\n", + "gae_lambda = 0.95\n", + "hidden_size = 64\n", + "lstm_state_buffer_shape = (n_steps, n_envs, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "env_id = \"CartPole-v1\"\n", + "vec_env = make_vec_env(env_id, n_envs=n_envs)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_buffer = RecurrentRolloutBuffer(\n", + " n_steps,\n", + " vec_env.observation_space,\n", + " vec_env.action_space,\n", + " gamma=gamma,\n", + " gae_lambda=gae_lambda,\n", + " n_envs=n_envs,\n", + " lstm_state_buffer_shape=lstm_state_buffer_shape,\n", + " device=\"cpu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 8, 64)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rollout_buffer.cell_states_pi.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 8, 1)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rollout_buffer.actions.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(128, 8, 4)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rollout_buffer.observations.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch 1: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n", + "Batch 2: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n", + "Batch 3: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n", + "Batch 4: a shape = (32, 8, 1), b shape = (32, 8, 4), c shape = (32, 8, 64)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "def split_into_batches(a, b, c, batch_size):\n", + " n_steps = a.shape[0]\n", + " assert n_steps % batch_size == 0, \"n_steps must be a multiple of batch_size\"\n", + " \n", + " num_batches = n_steps // batch_size\n", + " \n", + " a_batches = np.split(a, num_batches, axis=0)\n", + " b_batches = np.split(b, num_batches, axis=0)\n", + " c_batches = np.split(c, num_batches, axis=0)\n", + " \n", + " return a_batches, b_batches, c_batches\n", + "\n", + "# Example usage\n", + "n_steps = 128\n", + "n_envs = 8\n", + "hidden_size = 64\n", + "batch_size = 32\n", + "\n", + "a = np.zeros((n_steps, n_envs, 1))\n", + "b = np.zeros((n_steps, n_envs, 4))\n", + "c = np.zeros((n_steps, n_envs, hidden_size))\n", + "\n", + "a_batches, b_batches, c_batches = split_into_batches(a, b, c, batch_size)\n", + "\n", + "# Check the shapes of the batches\n", + "for i in range(len(a_batches)):\n", + " print(f\"Batch {i+1}: a shape = {a_batches[i].shape}, b shape = {b_batches[i].shape}, c shape = {c_batches[i].shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.zeros((n_steps, n_envs, 1))\n", + "b = np.zeros((n_steps, n_envs, 4))\n", + "c = np.zeros((n_steps, n_envs, hidden_size))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "rollout_buffer.full = True" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "def debug_get(buffer: RecurrentRolloutBuffer, batch_size):\n", + " data = buffer.get(batch_size)\n", + "\n", + " # Print the name and shape of each item in the data\n", + " for name, value in data.items():\n", + " print(f\"Name: {name}, Shape: {value.shape}\")\n", + "\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'generator' object has no attribute 'items'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[28], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdebug_get\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrollout_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[27], line 5\u001b[0m, in \u001b[0;36mdebug_get\u001b[0;34m(buffer, batch_size)\u001b[0m\n\u001b[1;32m 2\u001b[0m data \u001b[38;5;241m=\u001b[39m buffer\u001b[38;5;241m.\u001b[39mget(batch_size)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Print the name and shape of each item in the data\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, value \u001b[38;5;129;01min\u001b[39;00m \u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m():\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mName: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvalue\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "\u001b[0;31mAttributeError\u001b[0m: 'generator' object has no attribute 'items'" + ] + } + ], + "source": [ + "debug_get(rollout_buffer, batch_size=batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n", + "torch.Size([32, 1])\n" + ] + } + ], + "source": [ + "for rollout_data in rollout_buffer.get(batch_size):\n", + " print(rollout_data.actions.shape)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test.py b/test.py new file mode 100644 index 0000000..9066d5f --- /dev/null +++ b/test.py @@ -0,0 +1,43 @@ +import argparse +import gymnasium as gym + +from stable_baselines3.common.env_util import make_vec_env +from sbx import PPO, RecurrentPPO + + +def main(): + algo = "ppo" + algo = "rppo" + ALGO = RecurrentPPO if algo == "rppo" else PPO + print(f"Using {algo}") + + n_steps = 128 + batch_size = 32 + train_steps = 20_000 + test_steps = 10 + n_envs = 8 + + n_steps = 64 + batch_size = 16 + train_steps = 20_000 + test_steps = 10 + n_envs = 2 + + env_id = "CartPole-v1" + + # create vec env and train algo + vec_env = make_vec_env(env_id, n_envs=n_envs) + model = ALGO("MlpPolicy", vec_env, n_steps=n_steps, batch_size=batch_size, verbose=1) + model.learn(total_timesteps=train_steps, progress_bar=True) + + # test if trained algo works + vec_env = model.get_env() + obs = vec_env.reset() + for _ in range(test_steps): + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = vec_env.step(action) + + vec_env.close() + +if __name__ == "__main__": + main() \ No newline at end of file