From 19099ae023296b5cdf2d63220581fe5a4f7ea395 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 13:41:58 +0200 Subject: [PATCH 01/15] Add PER implementation --- Makefile | 2 +- sbx/__init__.py | 2 + sbx/common/type_aliases.py | 4 +- sbx/dqn/dqn.py | 5 + sbx/per_dqn/__init__.py | 3 + sbx/per_dqn/per_dqn.py | 215 +++++++++++++++++++++++++++++++++++++ sbx/version.txt | 2 +- tests/test_run.py | 13 +-- 8 files changed, 237 insertions(+), 9 deletions(-) create mode 100644 sbx/per_dqn/__init__.py create mode 100644 sbx/per_dqn/per_dqn.py diff --git a/Makefile b/Makefile index 0177d5a..a8bdbf4 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint: # see https://www.flake8rules.com/ ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff check ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..62dd9be 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -3,6 +3,7 @@ from sbx.crossq import CrossQ from sbx.ddpg import DDPG from sbx.dqn import DQN +from sbx.per_dqn import PERDQN from sbx.ppo import PPO from sbx.sac import SAC from sbx.td3 import TD3 @@ -26,6 +27,7 @@ def DroQ(*args, **kwargs): "CrossQ", "DDPG", "DQN", + "PERDQN", "PPO", "SAC", "TD3", diff --git a/sbx/common/type_aliases.py b/sbx/common/type_aliases.py index f014a1f..bf23356 100644 --- a/sbx/common/type_aliases.py +++ b/sbx/common/type_aliases.py @@ -1,4 +1,4 @@ -from typing import NamedTuple +from typing import NamedTuple, Optional, Union import flax import numpy as np @@ -19,3 +19,5 @@ class ReplayBufferSamplesNp(NamedTuple): next_observations: np.ndarray dones: np.ndarray rewards: np.ndarray + weights: Union[np.ndarray, float] = 1.0 + leaf_nodes_indices: Optional[np.ndarray] = None diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 852b823..06eb40e 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import numpy as np import optax +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn @@ -41,6 +42,8 @@ def __init__( # max_grad_norm: float = 10, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, @@ -59,6 +62,8 @@ def __init__( gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, diff --git a/sbx/per_dqn/__init__.py b/sbx/per_dqn/__init__.py new file mode 100644 index 0000000..188179e --- /dev/null +++ b/sbx/per_dqn/__init__.py @@ -0,0 +1,3 @@ +from sbx.per_dqn.per_dqn import PERDQN + +__all__ = ["PERDQN"] diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py new file mode 100644 index 0000000..dab07e3 --- /dev/null +++ b/sbx/per_dqn/per_dqn.py @@ -0,0 +1,215 @@ +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule + +from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState +from sbx.dqn import DQN +from sbx.dqn.policies import CNNPolicy, DQNPolicy + + +class PERDQN(DQN): + """ + DQN with Prioritized Experience Replay (PER). + """ + + policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment] + "MlpPolicy": DQNPolicy, + "CnnPolicy": CNNPolicy, + } + # Linear schedule will be defined in `_setup_model()` + exploration_schedule: Schedule + policy: DQNPolicy + + def __init__( + self, + policy, + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 1e-4, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 32, + tau: float = 1.0, + gamma: float = 0.99, + target_update_interval: int = 1000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.05, + optimize_memory_usage: bool = False, # Note: unused but to match SB3 API + # max_grad_norm: float = 10, + train_freq: Union[int, Tuple[int, str]] = 4, + gradient_steps: int = 1, + # replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ) -> None: + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + target_update_interval=target_update_interval, + exploration_fraction=exploration_fraction, + exploration_initial_eps=exploration_initial_eps, + exploration_final_eps=exploration_final_eps, + optimize_memory_usage=optimize_memory_usage, + train_freq=train_freq, + gradient_steps=gradient_steps, + replay_buffer_class=PrioritizedReplayBuffer, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + seed=seed, + _init_setup_model=_init_setup_model, + ) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "PERDQN", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ): + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def train(self, batch_size, gradient_steps): + # Sample all at once for efficiency (so we can jit the for loop) + data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + # Convert to numpy + data = ReplayBufferSamplesNp( + data.observations.numpy(), + # Convert to int64 + data.actions.long().numpy(), + data.next_observations.numpy(), + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + data.weights.numpy().flatten(), + data.leaf_nodes_indices, + ) + # Pre compute the slice indices + # otherwise jax will complain + indices = jnp.arange(len(data.dones)).reshape(gradient_steps, batch_size) + + update_carry = { + "qf_state": self.policy.qf_state, + "gamma": self.gamma, + "data": data, + "indices": indices, + "info": { + "critic_loss": jnp.array([0.0]), + "qf_mean_value": jnp.array([0.0]), + "td_error": jnp.zeros_like(data.rewards), + }, + } + + # jit the loop similar to https://github.com/Howuhh/sac-n-jax + # we use scan to be able to play with unroll parameter + update_carry, _ = jax.lax.scan( + self._train, + update_carry, + indices, + unroll=1, + ) + + self.policy.qf_state = update_carry["qf_state"] + qf_loss_value = update_carry["info"]["critic_loss"] + qf_mean_value = update_carry["info"]["qf_mean_value"] / gradient_steps + td_error = update_carry["info"]["td_error"] + + # Update priorities, they will be proportional to the td error + # Note: compared to the original implementation, we update + # the priorities after all the gradient steps + self.replay_buffer.update_priorities(data.leaf_nodes_indices, td_error, self._current_progress_remaining) + + self._n_updates += gradient_steps + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/critic_loss", qf_loss_value.item()) + self.logger.record("train/qf_mean_value", qf_mean_value.item()) + + @staticmethod + @jax.jit + def update_qnetwork( + gamma: float, + qf_state: RLTrainState, + observations: np.ndarray, + replay_actions: np.ndarray, + next_observations: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + sampling_weights: np.ndarray, + ): + # Compute the next Q-values using the target network + qf_next_values = qf_state.apply_fn(qf_state.target_params, next_observations) + + # Follow greedy policy: use the one with the highest value + next_q_values = qf_next_values.max(axis=1) + # Avoid potential broadcast issue + next_q_values = next_q_values.reshape(-1, 1) + + # shape is (batch_size, 1) + target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + + # Special case when using PrioritizedReplayBuffer (PER) + def weighted_huber_loss(params): + # Get current Q-values estimates + current_q_values = qf_state.apply_fn(params, observations) + # Retrieve the q-values for the actions from the replay buffer + current_q_values = jnp.take_along_axis(current_q_values, replay_actions, axis=1) + # TD error in absolute value, to update priorities + td_error = jnp.abs(current_q_values - target_q_values) + # Weighted Huber loss using importance sampling weights + loss = (sampling_weights * optax.huber_loss(current_q_values, target_q_values)).mean() + return loss, (current_q_values.mean(), td_error.flatten()) + + (qf_loss_value, (qf_mean_value, td_error)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)( + qf_state.params + ) + qf_state = qf_state.apply_gradients(grads=grads) + + return qf_state, (qf_loss_value, qf_mean_value, td_error) + + @staticmethod + @jax.jit + def _train(carry, indices): + data = carry["data"] + + qf_state, (qf_loss_value, qf_mean_value, td_error) = PERDQN.update_qnetwork( + carry["gamma"], + carry["qf_state"], + observations=data.observations[indices], + replay_actions=data.actions[indices], + next_observations=data.next_observations[indices], + rewards=data.rewards[indices], + dones=data.dones[indices], + sampling_weights=data.weights[indices], + ) + + carry["qf_state"] = qf_state + carry["info"]["critic_loss"] += qf_loss_value + carry["info"]["qf_mean_value"] += qf_mean_value + carry["info"]["td_error"] = carry["info"]["td_error"].at[indices].set(td_error) + + return carry, None diff --git a/sbx/version.txt b/sbx/version.txt index c5523bd..6633391 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.17.0 +0.18.0 diff --git a/tests/test_run.py b/tests/test_run.py index 18d6dec..b7272c7 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ +from sbx import DDPG, DQN, PERDQN, PPO, SAC, TD3, TQC, CrossQ, DroQ def check_save_load(model, model_class, tmp_path): @@ -116,9 +116,9 @@ def test_dropout(model_class): model.learn(110) -@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ]) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, PERDQN, CrossQ]) def test_policy_kwargs(model_class) -> None: - env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1" + env_id = "CartPole-v1" if model_class in [DQN, PERDQN] else "Pendulum-v1" model = model_class( "MlpPolicy", @@ -147,8 +147,9 @@ def test_ppo(tmp_path, env_id: str) -> None: check_save_load(model, PPO, tmp_path) -def test_dqn(tmp_path) -> None: - model = DQN( +@pytest.mark.parametrize("model_class", [DQN, PERDQN]) +def test_dqn(tmp_path, model_class) -> None: + model = model_class( "MlpPolicy", "CartPole-v1", verbose=1, @@ -156,7 +157,7 @@ def test_dqn(tmp_path) -> None: target_update_interval=10, ) model.learn(128) - check_save_load(model, DQN, tmp_path) + check_save_load(model, model_class, tmp_path) @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer]) From 473e264097f66733f7185234192519a77a7e5ded Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 17:48:36 +0200 Subject: [PATCH 02/15] Convert SumTree to jax --- sbx/common/prioritized_replay_buffer.py | 346 ++++++++++++++++++++++++ sbx/per_dqn/per_dqn.py | 2 +- 2 files changed, 347 insertions(+), 1 deletion(-) create mode 100644 sbx/common/prioritized_replay_buffer.py diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py new file mode 100644 index 0000000..cf77313 --- /dev/null +++ b/sbx/common/prioritized_replay_buffer.py @@ -0,0 +1,346 @@ +# from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import ReplayBufferSamples +from stable_baselines3.common.utils import get_linear_fn +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize + + +class SumTree: + """ + SumTree: a binary tree data structure where the parent's value is the sum of its children. + """ + + def __init__(self, buffer_size: int, rng_key: jax.Array): + self.buffer_size = buffer_size + self.tree = jnp.zeros(2 * buffer_size - 1) + self.data = jnp.zeros(buffer_size, dtype=jnp.float32) + self.size = 0 + self.key = rng_key + + @staticmethod + # TODO: try forcing on cpu + # partial(jax.jit, backend="cpu") + @jax.jit + def _add( + tree: jnp.ndarray, + data: jnp.ndarray, + size: int, + capacity: int, + priority: float, + new_data: int, + ) -> Tuple[jnp.ndarray, jnp.ndarray, int]: + index = size + capacity - 1 + data = data.at[size].set(new_data) + tree = SumTree._update(tree, index, priority) + size += 1 + return tree, data, size + + def add(self, priority: float, new_data: int) -> None: + """ + Add a new transition with priority value, + it adds a new leaf node and update cumulative sum. + + :param priority: Priority value. + :param new_data: Data for the new leaf node, storing transition index + in the case of the prioritized replay buffer. + """ + self.tree, self.data, self.size = self._add(self.tree, self.data, self.size, self.buffer_size, priority, new_data) + + @staticmethod + @jax.jit + def _update(tree: jnp.ndarray, index: int, priority: float) -> jnp.ndarray: + change = priority - tree[index] + tree = tree.at[index].set(priority) + tree = SumTree._propagate(tree, index, change) + return tree + + def update(self, leaf_node_idx: int, priority: float) -> None: + self.tree = self._update(self.tree, leaf_node_idx, priority) + + @staticmethod + @jax.jit + def _propagate(tree: jnp.ndarray, index: int, change: float) -> jnp.ndarray: + def cond_fun(val) -> bool: + idx, _, _ = val + return idx > 0 + + def body_fun(val) -> Tuple[int, float, jnp.ndarray]: + idx, change, tree = val + parent = (idx - 1) // 2 + tree = tree.at[parent].add(change) + return parent, change, tree + + _, _, tree = jax.lax.while_loop(cond_fun, body_fun, (index, change, tree)) + return tree + + @property + def total_sum(self) -> float: + return self.tree[0].item() + + @staticmethod + @jax.jit + def _get( + tree: jnp.ndarray, + data: jnp.ndarray, + capacity: int, + priority_sum: float, + ) -> Tuple[int, jnp.ndarray, jnp.ndarray]: + index = SumTree._retrieve(tree, priority_sum) + data_index = index - capacity + 1 + return index, tree[index], data[data_index] + + def get(self, cumulative_sum: float) -> Tuple[int, float, int]: + """ + Get a leaf node index, its priority value and transition index by cumulative_sum value. + + :param cumulative_sum: Cumulative sum value. + :return: Leaf node index, its priority value and transition index. + """ + leaf_tree_index, priority, transition_index = self._get(self.tree, self.data, self.buffer_size, cumulative_sum) + return leaf_tree_index, priority.item(), transition_index.item() + + @staticmethod + @jax.jit + def _retrieve(tree: jnp.ndarray, priority_sum: float) -> int: + def cond_fun(args) -> bool: + idx, _ = args + left = 2 * idx + 1 + return left < len(tree) + + def body_fun(args) -> Tuple[int, float]: + idx, priority_sum = args + left = 2 * idx + 1 + right = left + 1 + + def left_branch(_) -> Tuple[int, float]: + return left, priority_sum + + def right_branch(_) -> Tuple[int, float]: + return right, priority_sum - tree[left] + + idx, priority_sum = jax.lax.cond(priority_sum <= tree[left], left_branch, right_branch, None) + return idx, priority_sum + + index, _ = jax.lax.while_loop(cond_fun, body_fun, (0, priority_sum)) + return index + + # FIXME: not working yet + # @staticmethod + # def _stratified_sampling( + # tree: jnp.ndarray, + # data: jnp.ndarray, + # buffer_size: int, + # batch_size: int, + # total_sum: float, + # rng_key: jax.Array, + # ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jax.Array]: + # sample_indices = jnp.zeros(batch_size, dtype=jnp.uint32) + # priorities = jnp.zeros((batch_size, 1)) + # leaf_nodes_indices = jnp.zeros(batch_size, dtype=jnp.uint32) + + # # Using jax.lax.fori_loop to parallelize the sampling + # def body_fun(batch_idx, args): + # sample_indices, priorities, leaf_nodes_indices, rng_key = args + # segment_size = total_sum / batch_size + # start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) + # cumulative_sum = jax.random.uniform(rng_key, minval=start, maxval=end) + + # leaf_node_idx, priority, sample_idx = SumTree._get(tree, data, buffer_size, cumulative_sum) + + # leaf_nodes_indices = leaf_nodes_indices.at[batch_idx].set(leaf_node_idx) + # priorities = priorities.at[batch_idx].set(priority.item()) + # sample_indices = sample_indices.at[batch_idx].set(sample_idx.item()) + # return sample_indices, priorities, leaf_nodes_indices, rng_key + + # sample_indices, priorities, leaf_nodes_indices, rng_key = jax.lax.fori_loop( + # 0, batch_size, body_fun, (sample_indices, priorities, leaf_nodes_indices, rng_key) + # ) + # return sample_indices, priorities, leaf_nodes_indices, rng_key + + # def stratified_sampling(self, batch_size: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + # sample_indices, priorities, leaf_nodes_indices, self.key = self._stratified_sampling( + # self.tree, + # self.data, + # self.buffer_size, + # batch_size, + # self.total_sum, + # self.key, + # ) + # return sample_indices, priorities, leaf_nodes_indices + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized Replay Buffer (proportional priorities version). + Paper: https://arxiv.org/abs/1511.05952 + This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param n_envs: Number of parallel environments + :param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization) + :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) + :param final_beta: Value of beta at the end of training. + Linear annealing is used to interpolate between initial value of beta and final beta. + :param min_priority: Minimum priority, prevents zero probabilities, so that all samples + always have a non-zero probability to be sampled. + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + alpha: float = 0.5, + beta: float = 0.4, + final_beta: float = 1.0, + optimize_memory_usage: bool = False, + min_priority: float = 1e-6, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs) + + assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" + + self.min_priority = min_priority + self.alpha = alpha + self.max_priority = self.min_priority # priority for new samples, init as eps + # Track the training progress remaining (from 1 to 0) + # this is used to update beta + self._current_progress_remaining = 1.0 + self.inital_beta = beta + self.final_beta = final_beta + self.beta_schedule = get_linear_fn( + self.inital_beta, + self.final_beta, + end_fraction=1.0, + ) + # SumTree: data structure to store priorities + self.tree = SumTree(buffer_size=buffer_size, rng_key=jax.random.PRNGKey(0)) + + @property + def beta(self) -> float: + # Linear schedule + return self.beta_schedule(self._current_progress_remaining) + + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + """ + Add a new transition to the buffer. + + :param obs: Starting observation of the transition to be stored. + :param next_obs: Destination observation of the transition to be stored. + :param action: Action performed in the transition to be stored. + :param reward: Reward received in the transition to be stored. + :param done: Whether the episode was finished after the transition to be stored. + :param infos: Eventual information given by the environment. + """ + # store transition index with maximum priority in sum tree + self.tree.add(self.max_priority, self.pos) + + # store transition in the buffer + super().add(obs, next_obs, action, reward, done, infos) + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + """ + Sample elements from the prioritized replay buffer. + + :param batch_size: Number of element to sample + :param env:associated gym VecEnv + to normalize the observations/rewards when sampling + :return: a batch of sampled experiences from the buffer. + """ + assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." + + leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32) + priorities = np.zeros((batch_size, 1)) + sample_indices = np.zeros(batch_size, dtype=np.uint32) + + # To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. + # Next, a value is uniformly sampled from each range. Finally the transitions that correspond + # to each of these sampled values are retrieved from the tree. + segment_size = self.tree.total_sum / batch_size + for batch_idx in range(batch_size): + # extremes of the current segment + start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) + + # uniformely sample a value from the current segment + cumulative_sum = np.random.uniform(start, end) + + # leaf_node_idx is a index of a sample in the tree, needed further to update priorities + # sample_idx is a sample index in buffer, needed further to sample actual transitions + leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum) + + leaf_nodes_indices[batch_idx] = leaf_node_idx + priorities[batch_idx] = priority + sample_indices[batch_idx] = sample_idx + + # sample_indices, priorities, leaf_nodes_indices = self.tree.stratified_sampling(batch_size) + + # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha + # where p_i > 0 is the priority of transition i. + probs = priorities / self.tree.total_sum + + # Importance sampling weights. + # All weights w_i were scaled so that max_i w_i = 1. + weights = (self.size() * probs + 1e-7) ** -self.beta + weights = weights / weights.max() + + # TODO: add proper support for multi env + # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) + env_indices = np.zeros(batch_size, dtype=np.uint32) + + if self.optimize_memory_usage: + next_obs = self._normalize_obs(self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env) + else: + next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) + + batch = ( + self._normalize_obs(self.observations[sample_indices, env_indices, :], env), + self.actions[sample_indices, env_indices, :], + next_obs, + self.dones[sample_indices], + self.rewards[sample_indices], + weights, + ) + return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] + + def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: np.ndarray, progress_remaining: float) -> None: + """ + Update transition priorities. + + :param leaf_nodes_indices: Indices for the leaf nodes to update + (correponding to the transitions) + :param td_errors: New priorities, td error in the case of + proportional prioritized replay buffer. + :param progress_remaining: Current progress remaining (starts from 1 and ends to 0) + to linearly anneal beta from its start value to 1.0 at the end of training + """ + # Update beta schedule + self._current_progress_remaining = progress_remaining + + for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): + # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha + # where eps is a small positive constant that prevents the edge-case of transitions not being + # revisited once their error is zero. (Section 3.3) + priority = (abs(td_error) + self.min_priority) ** self.alpha + self.tree.update(leaf_node_idx, priority) + # Update max priority for new samples + self.max_priority = max(self.max_priority, priority) diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py index dab07e3..d45299b 100644 --- a/sbx/per_dqn/per_dqn.py +++ b/sbx/per_dqn/per_dqn.py @@ -4,9 +4,9 @@ import jax.numpy as jnp import numpy as np import optax -from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from sbx.common.prioritized_replay_buffer import PrioritizedReplayBuffer from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState from sbx.dqn import DQN from sbx.dqn.policies import CNNPolicy, DQNPolicy From 79e271026a14ed88cf5cb37b1079549ce495b476 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 18:23:34 +0200 Subject: [PATCH 03/15] Try to add batch update --- sbx/common/prioritized_replay_buffer.py | 155 ++++++++++++------------ 1 file changed, 76 insertions(+), 79 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index cf77313..c2d2477 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -1,4 +1,4 @@ -# from functools import partial +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import jax @@ -17,12 +17,11 @@ class SumTree: SumTree: a binary tree data structure where the parent's value is the sum of its children. """ - def __init__(self, buffer_size: int, rng_key: jax.Array): + def __init__(self, buffer_size: int): self.buffer_size = buffer_size self.tree = jnp.zeros(2 * buffer_size - 1) self.data = jnp.zeros(buffer_size, dtype=jnp.float32) self.size = 0 - self.key = rng_key @staticmethod # TODO: try forcing on cpu @@ -131,49 +130,73 @@ def right_branch(_) -> Tuple[int, float]: index, _ = jax.lax.while_loop(cond_fun, body_fun, (0, priority_sum)) return index - # FIXME: not working yet - # @staticmethod - # def _stratified_sampling( - # tree: jnp.ndarray, - # data: jnp.ndarray, - # buffer_size: int, - # batch_size: int, - # total_sum: float, - # rng_key: jax.Array, - # ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jax.Array]: - # sample_indices = jnp.zeros(batch_size, dtype=jnp.uint32) - # priorities = jnp.zeros((batch_size, 1)) - # leaf_nodes_indices = jnp.zeros(batch_size, dtype=jnp.uint32) - - # # Using jax.lax.fori_loop to parallelize the sampling - # def body_fun(batch_idx, args): - # sample_indices, priorities, leaf_nodes_indices, rng_key = args - # segment_size = total_sum / batch_size - # start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) - # cumulative_sum = jax.random.uniform(rng_key, minval=start, maxval=end) - - # leaf_node_idx, priority, sample_idx = SumTree._get(tree, data, buffer_size, cumulative_sum) - - # leaf_nodes_indices = leaf_nodes_indices.at[batch_idx].set(leaf_node_idx) - # priorities = priorities.at[batch_idx].set(priority.item()) - # sample_indices = sample_indices.at[batch_idx].set(sample_idx.item()) - # return sample_indices, priorities, leaf_nodes_indices, rng_key - - # sample_indices, priorities, leaf_nodes_indices, rng_key = jax.lax.fori_loop( - # 0, batch_size, body_fun, (sample_indices, priorities, leaf_nodes_indices, rng_key) - # ) - # return sample_indices, priorities, leaf_nodes_indices, rng_key - - # def stratified_sampling(self, batch_size: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - # sample_indices, priorities, leaf_nodes_indices, self.key = self._stratified_sampling( - # self.tree, - # self.data, - # self.buffer_size, - # batch_size, - # self.total_sum, - # self.key, - # ) - # return sample_indices, priorities, leaf_nodes_indices + @staticmethod + @jax.jit + def _batch_update( + tree: jnp.ndarray, + leaf_nodes_indices: jnp.ndarray, + priorities: jnp.ndarray, + ) -> jnp.ndarray: + for leaf_node_idx, priority in zip(leaf_nodes_indices, priorities): + tree = SumTree._update(tree, leaf_node_idx, priority) + return tree + + def batch_update(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None: + """ + Batch update transition priorities. + + :param leaf_nodes_indices: Indices for the leaf nodes to update + (correponding to the transitions) + :param priorities: New priorities, td error in the case of + proportional prioritized replay buffer. + """ + self.tree = self._batch_update(self.tree, leaf_nodes_indices, priorities) + + partial(jax.jit, backend="cpu", static_argnums=(4,)) + @staticmethod + def _stratified_sampling( + tree: jnp.ndarray, + data: jnp.ndarray, + capacity: int, + cumulative_sums: jnp.ndarray, + batch_size: int, + ): + leaf_nodes_indices = jnp.zeros(batch_size, dtype=jnp.uint32) + priorities = jnp.zeros(batch_size, dtype=jnp.float32) + sample_indices = jnp.zeros(batch_size, dtype=jnp.uint32) + + # Using jax.lax.fori_loop to avoid the need for a static loop + def body_fun(i, val): + leaf_nodes_indices, priorities, sample_indices, cumulative_sums = val + leaf_node_idx, priority, transition_index = SumTree._get(tree, data, capacity, cumulative_sums[i]) + leaf_nodes_indices = leaf_nodes_indices.at[i].set(leaf_node_idx) + priorities = priorities.at[i].set(priority) + sample_indices = sample_indices.at[i].set(transition_index) + return leaf_nodes_indices, priorities, sample_indices, cumulative_sums + + leaf_nodes_indices, priorities, sample_indices, _ = jax.lax.fori_loop( + 0, + batch_size, + body_fun, + (leaf_nodes_indices, priorities, sample_indices, cumulative_sums), + ) + return leaf_nodes_indices, priorities, sample_indices + + def stratified_sampling(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Batch stratified sampling of transitions. + + :param batch_size: Number of transitions to sample + :return: Tuple of leaf nodes indices, priorities and sample indices. + """ + segment_size = self.total_sum / batch_size + starts = np.arange(batch_size) * segment_size + ends = starts + segment_size + cumulative_sums = np.random.uniform(starts, ends) + leaf_nodes_indices, priorities, sample_indices = self._stratified_sampling( + self.tree, self.data, self.buffer_size, cumulative_sums, batch_size + ) + return leaf_nodes_indices, priorities, sample_indices class PrioritizedReplayBuffer(ReplayBuffer): @@ -226,7 +249,7 @@ def __init__( end_fraction=1.0, ) # SumTree: data structure to store priorities - self.tree = SumTree(buffer_size=buffer_size, rng_key=jax.random.PRNGKey(0)) + self.tree = SumTree(buffer_size=buffer_size) @property def beta(self) -> float: @@ -269,30 +292,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB """ assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." - leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32) - priorities = np.zeros((batch_size, 1)) - sample_indices = np.zeros(batch_size, dtype=np.uint32) - - # To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. - # Next, a value is uniformly sampled from each range. Finally the transitions that correspond - # to each of these sampled values are retrieved from the tree. - segment_size = self.tree.total_sum / batch_size - for batch_idx in range(batch_size): - # extremes of the current segment - start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) - - # uniformely sample a value from the current segment - cumulative_sum = np.random.uniform(start, end) - - # leaf_node_idx is a index of a sample in the tree, needed further to update priorities - # sample_idx is a sample index in buffer, needed further to sample actual transitions - leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum) - - leaf_nodes_indices[batch_idx] = leaf_node_idx - priorities[batch_idx] = priority - sample_indices[batch_idx] = sample_idx - - # sample_indices, priorities, leaf_nodes_indices = self.tree.stratified_sampling(batch_size) + sample_indices, priorities, leaf_nodes_indices = self.tree.stratified_sampling(batch_size) # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. @@ -336,11 +336,8 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: np.ndarra # Update beta schedule self._current_progress_remaining = progress_remaining - for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): - # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha - # where eps is a small positive constant that prevents the edge-case of transitions not being - # revisited once their error is zero. (Section 3.3) - priority = (abs(td_error) + self.min_priority) ** self.alpha - self.tree.update(leaf_node_idx, priority) - # Update max priority for new samples - self.max_priority = max(self.max_priority, priority) + # Batch update + priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha + self.tree.batch_update(leaf_nodes_indices, priorities) + # Update max priority for new samples + self.max_priority = max(self.max_priority, priorities.max()) From 20bc97dae4621899225d6feef480bf19838a6445 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 18:25:20 +0200 Subject: [PATCH 04/15] Revert batch sampling --- sbx/common/prioritized_replay_buffer.py | 78 +++++++++---------------- 1 file changed, 28 insertions(+), 50 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index c2d2477..1e7b6e7 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -1,4 +1,4 @@ -from functools import partial +# from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import jax @@ -17,11 +17,12 @@ class SumTree: SumTree: a binary tree data structure where the parent's value is the sum of its children. """ - def __init__(self, buffer_size: int): + def __init__(self, buffer_size: int, rng_key: jax.Array): self.buffer_size = buffer_size self.tree = jnp.zeros(2 * buffer_size - 1) self.data = jnp.zeros(buffer_size, dtype=jnp.float32) self.size = 0 + self.key = rng_key @staticmethod # TODO: try forcing on cpu @@ -152,52 +153,6 @@ def batch_update(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) - """ self.tree = self._batch_update(self.tree, leaf_nodes_indices, priorities) - partial(jax.jit, backend="cpu", static_argnums=(4,)) - @staticmethod - def _stratified_sampling( - tree: jnp.ndarray, - data: jnp.ndarray, - capacity: int, - cumulative_sums: jnp.ndarray, - batch_size: int, - ): - leaf_nodes_indices = jnp.zeros(batch_size, dtype=jnp.uint32) - priorities = jnp.zeros(batch_size, dtype=jnp.float32) - sample_indices = jnp.zeros(batch_size, dtype=jnp.uint32) - - # Using jax.lax.fori_loop to avoid the need for a static loop - def body_fun(i, val): - leaf_nodes_indices, priorities, sample_indices, cumulative_sums = val - leaf_node_idx, priority, transition_index = SumTree._get(tree, data, capacity, cumulative_sums[i]) - leaf_nodes_indices = leaf_nodes_indices.at[i].set(leaf_node_idx) - priorities = priorities.at[i].set(priority) - sample_indices = sample_indices.at[i].set(transition_index) - return leaf_nodes_indices, priorities, sample_indices, cumulative_sums - - leaf_nodes_indices, priorities, sample_indices, _ = jax.lax.fori_loop( - 0, - batch_size, - body_fun, - (leaf_nodes_indices, priorities, sample_indices, cumulative_sums), - ) - return leaf_nodes_indices, priorities, sample_indices - - def stratified_sampling(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Batch stratified sampling of transitions. - - :param batch_size: Number of transitions to sample - :return: Tuple of leaf nodes indices, priorities and sample indices. - """ - segment_size = self.total_sum / batch_size - starts = np.arange(batch_size) * segment_size - ends = starts + segment_size - cumulative_sums = np.random.uniform(starts, ends) - leaf_nodes_indices, priorities, sample_indices = self._stratified_sampling( - self.tree, self.data, self.buffer_size, cumulative_sums, batch_size - ) - return leaf_nodes_indices, priorities, sample_indices - class PrioritizedReplayBuffer(ReplayBuffer): """ @@ -249,7 +204,7 @@ def __init__( end_fraction=1.0, ) # SumTree: data structure to store priorities - self.tree = SumTree(buffer_size=buffer_size) + self.tree = SumTree(buffer_size=buffer_size, rng_key=jax.random.PRNGKey(0)) @property def beta(self) -> float: @@ -292,7 +247,30 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB """ assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." - sample_indices, priorities, leaf_nodes_indices = self.tree.stratified_sampling(batch_size) + leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32) + priorities = np.zeros((batch_size, 1)) + sample_indices = np.zeros(batch_size, dtype=np.uint32) + + # To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. + # Next, a value is uniformly sampled from each range. Finally the transitions that correspond + # to each of these sampled values are retrieved from the tree. + segment_size = self.tree.total_sum / batch_size + for batch_idx in range(batch_size): + # extremes of the current segment + start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) + + # uniformely sample a value from the current segment + cumulative_sum = np.random.uniform(start, end) + + # leaf_node_idx is a index of a sample in the tree, needed further to update priorities + # sample_idx is a sample index in buffer, needed further to sample actual transitions + leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum) + + leaf_nodes_indices[batch_idx] = leaf_node_idx + priorities[batch_idx] = priority + sample_indices[batch_idx] = sample_idx + + # sample_indices, priorities, leaf_nodes_indices = self.tree.stratified_sampling(batch_size) # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. From b643dc13ee493d88028ee3bf34f7179af136c65f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 15 Jul 2024 22:17:21 +0200 Subject: [PATCH 05/15] Add more type hints --- sbx/dqn/dqn.py | 2 +- sbx/per_dqn/per_dqn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 06eb40e..cbc5fea 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -135,7 +135,7 @@ def learn( progress_bar=progress_bar, ) - def train(self, batch_size, gradient_steps): + def train(self, batch_size: int, gradient_steps: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # Convert to numpy diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py index d45299b..7f5ae9a 100644 --- a/sbx/per_dqn/per_dqn.py +++ b/sbx/per_dqn/per_dqn.py @@ -95,7 +95,7 @@ def learn( progress_bar=progress_bar, ) - def train(self, batch_size, gradient_steps): + def train(self, batch_size: int, gradient_steps: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # Convert to numpy From 07d6f27c871b0f7bfcf15c01382061ba8a9923e8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 15 Jul 2024 22:17:29 +0200 Subject: [PATCH 06/15] Use segment tree implementation from SB2 --- sbx/common/prioritized_replay_buffer.py | 419 ++++++++++++++---------- 1 file changed, 240 insertions(+), 179 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index 1e7b6e7..ab81231 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -1,8 +1,14 @@ -# from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +""" +Segment tree implementation taken from Stable Baselines 2: +https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/segment_tree.py + +Notable differences: +- This implementation uses numpy arrays to store the values (faster initialization) +- We don't use a special function to have unique indices (no significant performance difference found) +""" + +from typing import Any, Callable, Dict, List, Optional, Union -import jax -import jax.numpy as jnp import numpy as np import torch as th from gymnasium import spaces @@ -12,146 +18,186 @@ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize -class SumTree: +class SegmentTree: + def __init__(self, capacity: int, reduce_op: Callable, neutral_element: float) -> None: + """ + Build a Segment Tree data structure. + + https://en.wikipedia.org/wiki/Segment_tree + + Can be used as regular array that supports Index arrays, but with two + important differences: + + a) setting item's value is slightly slower. + It is O(log capacity) instead of O(1). + b) user has access to an efficient ( O(log segment size) ) + `reduce` operation which reduces `operation` over + a contiguous subsequence of items in the array. + + :param capacity: Total size of the array - must be a power of two. + :param reduce_op: Operation for combining elements (eg. sum, max) must form a + mathematical group together with the set of possible values for array elements (i.e. be associative) + :param neutral_element: Neutral element for the operation above. eg. float('-inf') for max and 0 for sum. + """ + assert capacity > 0 and capacity & (capacity - 1) == 0, f"Capacity must be positive and a power of 2, not {capacity}" + self._capacity = capacity + self._values = np.full(2 * capacity, neutral_element) + self._reduce_op = reduce_op + self.neutral_element = neutral_element + + def _reduce_helper(self, start: int, end: int, node: int, node_start: int, node_end: int) -> float: + """ + Query the value of the segment tree for the given range + + :param start: start of the range + :param end: end of the range + :param node: current node in the segment tree + :param node_start: start of the range represented by the current node + :param node_end: end of the range represented by the current node + :return: result of reducing ``self.reduce_op`` over the specified range of array elements. + """ + if start == node_start and end == node_end: + return self._values[node] + mid = (node_start + node_end) // 2 + if end <= mid: + return self._reduce_helper(start, end, 2 * node, node_start, mid) + else: + if mid + 1 <= start: + return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) + else: + return self._reduce_op( + self._reduce_helper(start, mid, 2 * node, node_start, mid), + self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end), + ) + + def reduce(self, start: int = 0, end: Optional[int] = None) -> float: + """ + Returns result of applying ``self.reduce_op`` + to a contiguous subsequence of the array. + + .. code-block:: python + + self.reduce_op(arr[start], operation(arr[start+1], operation(... arr[end]))) + + :param start: beginning of the subsequence + :param end: end of the subsequences + :return: result of reducing ``self.reduce_op`` over the specified range of array elements. + """ + if end is None: + end = self._capacity + if end < 0: + end += self._capacity + end -= 1 + return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + + def __setitem__(self, idx: int, val: float) -> None: + """ + Set the value at index `idx` to `val` + + :param idx: index of the value to be updated + :param val: new value + """ + # Indices of the leafs + indices = idx + self._capacity + self._values[indices] = val + if isinstance(indices, int): + indices = np.array([indices]) + # Go up one level in the tree and remove duplicate indices + indices = np.unique(indices // 2) + while len(indices) > 1 or indices[0] > 0: + # As long as there are non-zero indices, update the corresponding values + self._values[indices] = self._reduce_op(self._values[2 * indices], self._values[2 * indices + 1]) + # Go up one level in the tree and remove duplicate indices + indices = np.unique(indices // 2) + + def __getitem__(self, idx: np.ndarray) -> np.ndarray: + """ + Get the value(s) at index `idx` + """ + assert np.max(idx) < self._capacity, f"Index must be less than capacity, got {np.max(idx)} >= {self._capacity}" + assert 0 <= np.min(idx) + return self._values[self._capacity + idx] + + +class SumSegmentTree(SegmentTree): """ - SumTree: a binary tree data structure where the parent's value is the sum of its children. + A Segment Tree data structure where each node contains the sum of the + values in its leaf nodes. Can be used as a Sum Tree for priorities. """ - def __init__(self, buffer_size: int, rng_key: jax.Array): - self.buffer_size = buffer_size - self.tree = jnp.zeros(2 * buffer_size - 1) - self.data = jnp.zeros(buffer_size, dtype=jnp.float32) - self.size = 0 - self.key = rng_key - - @staticmethod - # TODO: try forcing on cpu - # partial(jax.jit, backend="cpu") - @jax.jit - def _add( - tree: jnp.ndarray, - data: jnp.ndarray, - size: int, - capacity: int, - priority: float, - new_data: int, - ) -> Tuple[jnp.ndarray, jnp.ndarray, int]: - index = size + capacity - 1 - data = data.at[size].set(new_data) - tree = SumTree._update(tree, index, priority) - size += 1 - return tree, data, size - - def add(self, priority: float, new_data: int) -> None: + def __init__(self, capacity: int) -> None: + super().__init__(capacity=capacity, reduce_op=np.add, neutral_element=0.0) + + def sum(self, start: int = 0, end: Optional[int] = None) -> float: """ - Add a new transition with priority value, - it adds a new leaf node and update cumulative sum. + Returns arr[start] + ... + arr[end] - :param priority: Priority value. - :param new_data: Data for the new leaf node, storing transition index - in the case of the prioritized replay buffer. + :param start: start position of the reduction (must be >= 0) + :param end: end position of the reduction (must be < len(arr), can be None for len(arr) - 1) + :return: reduction of SumSegmentTree """ - self.tree, self.data, self.size = self._add(self.tree, self.data, self.size, self.buffer_size, priority, new_data) - - @staticmethod - @jax.jit - def _update(tree: jnp.ndarray, index: int, priority: float) -> jnp.ndarray: - change = priority - tree[index] - tree = tree.at[index].set(priority) - tree = SumTree._propagate(tree, index, change) - return tree - - def update(self, leaf_node_idx: int, priority: float) -> None: - self.tree = self._update(self.tree, leaf_node_idx, priority) - - @staticmethod - @jax.jit - def _propagate(tree: jnp.ndarray, index: int, change: float) -> jnp.ndarray: - def cond_fun(val) -> bool: - idx, _, _ = val - return idx > 0 - - def body_fun(val) -> Tuple[int, float, jnp.ndarray]: - idx, change, tree = val - parent = (idx - 1) // 2 - tree = tree.at[parent].add(change) - return parent, change, tree - - _, _, tree = jax.lax.while_loop(cond_fun, body_fun, (index, change, tree)) - return tree + return super().reduce(start, end) - @property - def total_sum(self) -> float: - return self.tree[0].item() - - @staticmethod - @jax.jit - def _get( - tree: jnp.ndarray, - data: jnp.ndarray, - capacity: int, - priority_sum: float, - ) -> Tuple[int, jnp.ndarray, jnp.ndarray]: - index = SumTree._retrieve(tree, priority_sum) - data_index = index - capacity + 1 - return index, tree[index], data[data_index] - - def get(self, cumulative_sum: float) -> Tuple[int, float, int]: + def find_prefixsum_idx(self, prefixsum: np.ndarray) -> np.ndarray: """ - Get a leaf node index, its priority value and transition index by cumulative_sum value. + Find the highest index `i` in the array such that + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum + + if array values are probabilities, this function + allows to sample indices according to the discrete + probability efficiently. - :param cumulative_sum: Cumulative sum value. - :return: Leaf node index, its priority value and transition index. + :param prefixsum: float upper bounds on the sum of array prefix + :return: highest indices satisfying the prefixsum constraint """ - leaf_tree_index, priority, transition_index = self._get(self.tree, self.data, self.buffer_size, cumulative_sum) - return leaf_tree_index, priority.item(), transition_index.item() - - @staticmethod - @jax.jit - def _retrieve(tree: jnp.ndarray, priority_sum: float) -> int: - def cond_fun(args) -> bool: - idx, _ = args - left = 2 * idx + 1 - return left < len(tree) - - def body_fun(args) -> Tuple[int, float]: - idx, priority_sum = args - left = 2 * idx + 1 - right = left + 1 - - def left_branch(_) -> Tuple[int, float]: - return left, priority_sum - - def right_branch(_) -> Tuple[int, float]: - return right, priority_sum - tree[left] - - idx, priority_sum = jax.lax.cond(priority_sum <= tree[left], left_branch, right_branch, None) - return idx, priority_sum - - index, _ = jax.lax.while_loop(cond_fun, body_fun, (0, priority_sum)) - return index - - @staticmethod - @jax.jit - def _batch_update( - tree: jnp.ndarray, - leaf_nodes_indices: jnp.ndarray, - priorities: jnp.ndarray, - ) -> jnp.ndarray: - for leaf_node_idx, priority in zip(leaf_nodes_indices, priorities): - tree = SumTree._update(tree, leaf_node_idx, priority) - return tree - - def batch_update(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None: + if isinstance(prefixsum, float): + prefixsum = np.array([prefixsum]) + assert 0 <= np.min(prefixsum) + assert np.max(prefixsum) <= self.sum() + 1e-5 + assert isinstance(prefixsum[0], float) + + indices = np.ones(len(prefixsum), dtype=int) + should_continue = np.ones(len(prefixsum), dtype=bool) + + while np.any(should_continue): # while not all nodes are leafs + indices[should_continue] = 2 * indices[should_continue] + prefixsum_new = np.where( + self._values[indices] <= prefixsum, + prefixsum - self._values[indices], + prefixsum, + ) + # Prepare update of prefixsum for all right children + indices = np.where( + np.logical_or(self._values[indices] > prefixsum, np.logical_not(should_continue)), + indices, + indices + 1, + ) + # Select child node for non-leaf nodes + prefixsum = prefixsum_new + # Update prefixsum + should_continue = indices < self._capacity + # Collect leafs + return indices - self._capacity + + +class MinSegmentTree(SegmentTree): + """ + A Segment Tree data structure where each node contains the minimum of the + values in its leaf nodes. Can be used as a Min Tree for priorities. + """ + + def __init__(self, capacity: int) -> None: + super().__init__(capacity=capacity, reduce_op=np.minimum, neutral_element=float("inf")) + + def min(self, start=0, end=None): """ - Batch update transition priorities. + Returns min(arr[start], ..., arr[end]) - :param leaf_nodes_indices: Indices for the leaf nodes to update - (correponding to the transitions) - :param priorities: New priorities, td error in the case of - proportional prioritized replay buffer. + :param start: start position of the reduction (must be >= 0) + :param end: end position of the reduction (must be < len(arr), can be None for len(arr) - 1) + :return: reduction of MinSegmentTree """ - self.tree = self._batch_update(self.tree, leaf_nodes_indices, priorities) + return super().reduce(start, end) class PrioritizedReplayBuffer(ReplayBuffer): @@ -188,23 +234,36 @@ def __init__( ): super().__init__(buffer_size, observation_space, action_space, device, n_envs) + # TODO: check if we can support optimize_memory_usage assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" + # TODO: add support for multi env + assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1" + + # Find the next power of 2 for the buffer size + power_of_two = int(np.ceil(np.log2(buffer_size))) + tree_capacity = 2**power_of_two + self.min_priority = min_priority - self.alpha = alpha - self.max_priority = self.min_priority # priority for new samples, init as eps + self._max_priority = 1.0 + + self._alpha = alpha + # Track the training progress remaining (from 1 to 0) # this is used to update beta self._current_progress_remaining = 1.0 - self.inital_beta = beta - self.final_beta = final_beta + + # TODO: move beta schedule to the DQN algorithm + self._inital_beta = beta + self._final_beta = final_beta self.beta_schedule = get_linear_fn( - self.inital_beta, - self.final_beta, + self._inital_beta, + self._final_beta, end_fraction=1.0, ) - # SumTree: data structure to store priorities - self.tree = SumTree(buffer_size=buffer_size, rng_key=jax.random.PRNGKey(0)) + + self._sum_tree = SumSegmentTree(tree_capacity) + self._min_tree = MinSegmentTree(tree_capacity) @property def beta(self) -> float: @@ -231,7 +290,8 @@ def add( :param infos: Eventual information given by the environment. """ # store transition index with maximum priority in sum tree - self.tree.add(self.max_priority, self.pos) + self._sum_tree[self.pos] = self._max_priority**self._alpha + self._min_tree[self.pos] = self._max_priority**self._alpha # store transition in the buffer super().add(obs, next_obs, action, reward, done, infos) @@ -247,48 +307,32 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB """ assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." - leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32) - priorities = np.zeros((batch_size, 1)) - sample_indices = np.zeros(batch_size, dtype=np.uint32) - - # To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. - # Next, a value is uniformly sampled from each range. Finally the transitions that correspond - # to each of these sampled values are retrieved from the tree. - segment_size = self.tree.total_sum / batch_size - for batch_idx in range(batch_size): - # extremes of the current segment - start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) + # priorities = np.zeros((batch_size, 1)) + # sample_indices = np.zeros(batch_size, dtype=np.uint32) - # uniformely sample a value from the current segment - cumulative_sum = np.random.uniform(start, end) + # TODO: check how things are sampled in the original implementation - # leaf_node_idx is a index of a sample in the tree, needed further to update priorities - # sample_idx is a sample index in buffer, needed further to sample actual transitions - leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum) - - leaf_nodes_indices[batch_idx] = leaf_node_idx - priorities[batch_idx] = priority - sample_indices[batch_idx] = sample_idx - - # sample_indices, priorities, leaf_nodes_indices = self.tree.stratified_sampling(batch_size) + sample_indices = self._sample_proportional(batch_size) + leaf_nodes_indices = sample_indices # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. - probs = priorities / self.tree.total_sum + # probs = priorities / self.tree.total_sum + probabilities = self._sum_tree[sample_indices] / self._sum_tree.sum() # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. - weights = (self.size() * probs + 1e-7) ** -self.beta + # weights = (self.size() * probs + 1e-7) ** -self.beta + # min_probability = self._min_tree.min() / self._sum_tree.sum() + # max_weight = (min_probability * self.size()) ** (-self.beta) + # weights = (probabilities * self.size()) ** (-self.beta) / max_weight + weights = (probabilities * self.size()) ** (-self.beta) weights = weights / weights.max() # TODO: add proper support for multi env # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) env_indices = np.zeros(batch_size, dtype=np.uint32) - - if self.optimize_memory_usage: - next_obs = self._normalize_obs(self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env) - else: - next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) + next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) batch = ( self._normalize_obs(self.observations[sample_indices, env_indices, :], env), @@ -300,22 +344,39 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] - def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: np.ndarray, progress_remaining: float) -> None: + def _sample_proportional(self, batch_size: int) -> np.ndarray: + """ + Sample a batch of leaf nodes indices using the proportional prioritization strategy. + In other words, the probability of sampling a transition is proportional to its priority. + + :param batch_size: Number of element to sample + :return: Indices of the sampled leaf nodes + """ + # TODO: double check if this is correct + total = self._sum_tree.sum(0, self.size() - 1) + priorities_sum = np.random.random(size=batch_size) * total + return self._sum_tree.find_prefixsum_idx(priorities_sum) + + # def update_priorities(self, indices: np.ndarray, priorities: np.ndarray) -> None: + def update_priorities(self, indices: np.ndarray, priorities: np.ndarray, progress_remaining: float) -> None: """ - Update transition priorities. + Update priorities of sampled transitions. - :param leaf_nodes_indices: Indices for the leaf nodes to update - (correponding to the transitions) + :param leaf_nodes_indices: Indices of the sampled transitions. :param td_errors: New priorities, td error in the case of proportional prioritized replay buffer. - :param progress_remaining: Current progress remaining (starts from 1 and ends to 0) - to linearly anneal beta from its start value to 1.0 at the end of training """ + # TODO: move beta to the DQN algorithm # Update beta schedule self._current_progress_remaining = progress_remaining - # Batch update - priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha - self.tree.batch_update(leaf_nodes_indices, priorities) + # assert len(indices) == len(priorities) + assert np.min(priorities) > 0 + assert np.min(indices) >= 0 + assert np.max(indices) < self.buffer_size + # TODO: check if we need to add the min_priority here + # priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha + self._sum_tree[indices] = priorities**self._alpha + self._min_tree[indices] = priorities**self._alpha # Update max priority for new samples - self.max_priority = max(self.max_priority, priorities.max()) + self._max_priority = max(self._max_priority, np.max(priorities)) From 5f3cba9819c447cf1c752f2ef855666f44204dbe Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 15 Jul 2024 23:20:53 +0200 Subject: [PATCH 07/15] Add min priorities --- sbx/common/prioritized_replay_buffer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index ab81231..05de787 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -90,7 +90,7 @@ def reduce(self, start: int = 0, end: Optional[int] = None) -> float: end -= 1 return self._reduce_helper(start, end, 1, 0, self._capacity - 1) - def __setitem__(self, idx: int, val: float) -> None: + def __setitem__(self, idx: np.ndarray, val: np.ndarray) -> None: """ Set the value at index `idx` to `val` @@ -244,7 +244,7 @@ def __init__( power_of_two = int(np.ceil(np.log2(buffer_size))) tree_capacity = 2**power_of_two - self.min_priority = min_priority + self._min_priority = min_priority self._max_priority = 1.0 self._alpha = alpha @@ -370,6 +370,9 @@ def update_priorities(self, indices: np.ndarray, priorities: np.ndarray, progres # Update beta schedule self._current_progress_remaining = progress_remaining + # TODO: double check that all samples are updated + # priorities = np.abs(td_errors) + self.min_priority + priorities += self._min_priority # assert len(indices) == len(priorities) assert np.min(priorities) > 0 assert np.min(indices) >= 0 From b5ce09105aa8df2701e85d30f0584380b2ee19f8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 09:21:33 +0200 Subject: [PATCH 08/15] Add multi-env support --- sbx/common/prioritized_replay_buffer.py | 89 +++++++++++-------------- sbx/per_dqn/per_dqn.py | 36 +++++++--- 2 files changed, 64 insertions(+), 61 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index 05de787..6216aa3 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -14,7 +14,6 @@ from gymnasium import spaces from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples -from stable_baselines3.common.utils import get_linear_fn from stable_baselines3.common.vec_env.vec_normalize import VecNormalize @@ -41,6 +40,8 @@ def __init__(self, capacity: int, reduce_op: Callable, neutral_element: float) - """ assert capacity > 0 and capacity & (capacity - 1) == 0, f"Capacity must be positive and a power of 2, not {capacity}" self._capacity = capacity + # First index is the root, leaf nodes are in [capacity, 2 * capacity - 1]. + # For each parent node i, left child has index [2 * i], right child [2 * i + 1] self._values = np.full(2 * capacity, neutral_element) self._reduce_op = reduce_op self.neutral_element = neutral_element @@ -97,8 +98,10 @@ def __setitem__(self, idx: np.ndarray, val: np.ndarray) -> None: :param idx: index of the value to be updated :param val: new value """ + # assert np.all(0 <= idx < self._capacity), f"Trying to set item outside capacity: {idx}" # Indices of the leafs indices = idx + self._capacity + # Update the leaf nodes and then the related nodes self._values[indices] = val if isinstance(indices, int): indices = np.array([indices]) @@ -153,8 +156,7 @@ def find_prefixsum_idx(self, prefixsum: np.ndarray) -> np.ndarray: if isinstance(prefixsum, float): prefixsum = np.array([prefixsum]) assert 0 <= np.min(prefixsum) - assert np.max(prefixsum) <= self.sum() + 1e-5 - assert isinstance(prefixsum[0], float) + # assert np.max(prefixsum) <= self.sum() + 1e-5 indices = np.ones(len(prefixsum), dtype=int) should_continue = np.ones(len(prefixsum), dtype=bool) @@ -227,8 +229,6 @@ def __init__( device: Union[th.device, str] = "auto", n_envs: int = 1, alpha: float = 0.5, - beta: float = 0.4, - final_beta: float = 1.0, optimize_memory_usage: bool = False, min_priority: float = 1e-6, ): @@ -238,7 +238,7 @@ def __init__( assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" # TODO: add support for multi env - assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1" + # assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1" # Find the next power of 2 for the buffer size power_of_two = int(np.ceil(np.log2(buffer_size))) @@ -249,26 +249,12 @@ def __init__( self._alpha = alpha - # Track the training progress remaining (from 1 to 0) - # this is used to update beta - self._current_progress_remaining = 1.0 - - # TODO: move beta schedule to the DQN algorithm - self._inital_beta = beta - self._final_beta = final_beta - self.beta_schedule = get_linear_fn( - self._inital_beta, - self._final_beta, - end_fraction=1.0, - ) - self._sum_tree = SumSegmentTree(tree_capacity) self._min_tree = MinSegmentTree(tree_capacity) - - @property - def beta(self) -> float: - # Linear schedule - return self.beta_schedule(self._current_progress_remaining) + # Flatten the indices from the buffer to store them in the sum tree + # Replay buffer: (idx, env_idx) + # Sum tree: idx * self.n_envs + env_idx + self.env_offsets = np.arange(self.n_envs) def add( self, @@ -289,14 +275,14 @@ def add( :param done: Whether the episode was finished after the transition to be stored. :param infos: Eventual information given by the environment. """ - # store transition index with maximum priority in sum tree - self._sum_tree[self.pos] = self._max_priority**self._alpha - self._min_tree[self.pos] = self._max_priority**self._alpha + # Store transition index with maximum priority in sum tree + self._sum_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha + self._min_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha - # store transition in the buffer + # Store transition in the buffer super().add(obs, next_obs, action, reward, done, infos) - def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: """ Sample elements from the prioritized replay buffer. @@ -305,20 +291,24 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB to normalize the observations/rewards when sampling :return: a batch of sampled experiences from the buffer. """ - assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." + # assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." # priorities = np.zeros((batch_size, 1)) # sample_indices = np.zeros(batch_size, dtype=np.uint32) # TODO: check how things are sampled in the original implementation - sample_indices = self._sample_proportional(batch_size) - leaf_nodes_indices = sample_indices + leaf_nodes_indices = self._sample_proportional(batch_size) + # Convert the leaf nodes indices to buffer indices + # Replay buffer: (idx, env_idx) + # Sum tree: idx * self.n_envs + env_idx + buffer_indices = leaf_nodes_indices // self.n_envs + env_indices = leaf_nodes_indices % self.n_envs # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. # probs = priorities / self.tree.total_sum - probabilities = self._sum_tree[sample_indices] / self._sum_tree.sum() + probabilities = self._sum_tree[leaf_nodes_indices] / self._sum_tree.sum() # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. @@ -326,20 +316,21 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB # min_probability = self._min_tree.min() / self._sum_tree.sum() # max_weight = (min_probability * self.size()) ** (-self.beta) # weights = (probabilities * self.size()) ** (-self.beta) / max_weight - weights = (probabilities * self.size()) ** (-self.beta) + weights = (probabilities * self.size()) ** (-beta) weights = weights / weights.max() - # TODO: add proper support for multi env # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) - env_indices = np.zeros(batch_size, dtype=np.uint32) - next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) + # env_indices = np.zeros(batch_size, dtype=np.uint32) + next_obs = self._normalize_obs(self.next_observations[buffer_indices, env_indices, :], env) batch = ( - self._normalize_obs(self.observations[sample_indices, env_indices, :], env), - self.actions[sample_indices, env_indices, :], + self._normalize_obs(self.observations[buffer_indices, env_indices, :], env), + self.actions[buffer_indices, env_indices, :], next_obs, - self.dones[sample_indices], - self.rewards[sample_indices], + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + (self.dones[buffer_indices, env_indices] * (1 - self.timeouts[buffer_indices, env_indices])).reshape(-1, 1), + self._normalize_reward(self.rewards[buffer_indices, env_indices].reshape(-1, 1), env), weights, ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] @@ -358,28 +349,24 @@ def _sample_proportional(self, batch_size: int) -> np.ndarray: return self._sum_tree.find_prefixsum_idx(priorities_sum) # def update_priorities(self, indices: np.ndarray, priorities: np.ndarray) -> None: - def update_priorities(self, indices: np.ndarray, priorities: np.ndarray, progress_remaining: float) -> None: + def update_priorities(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None: """ Update priorities of sampled transitions. :param leaf_nodes_indices: Indices of the sampled transitions. - :param td_errors: New priorities, td error in the case of + :param priorities: New priorities, td error in the case of proportional prioritized replay buffer. """ - # TODO: move beta to the DQN algorithm - # Update beta schedule - self._current_progress_remaining = progress_remaining - # TODO: double check that all samples are updated # priorities = np.abs(td_errors) + self.min_priority priorities += self._min_priority # assert len(indices) == len(priorities) assert np.min(priorities) > 0 - assert np.min(indices) >= 0 - assert np.max(indices) < self.buffer_size + assert np.min(leaf_nodes_indices) >= 0 + assert np.max(leaf_nodes_indices) < self.buffer_size # TODO: check if we need to add the min_priority here # priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha - self._sum_tree[indices] = priorities**self._alpha - self._min_tree[indices] = priorities**self._alpha + self._sum_tree[leaf_nodes_indices] = priorities**self._alpha + self._min_tree[leaf_nodes_indices] = priorities**self._alpha # Update max priority for new samples self._max_priority = max(self._max_priority, np.max(priorities)) diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py index 7f5ae9a..9aa63a4 100644 --- a/sbx/per_dqn/per_dqn.py +++ b/sbx/per_dqn/per_dqn.py @@ -5,6 +5,7 @@ import numpy as np import optax from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_linear_fn from sbx.common.prioritized_replay_buffer import PrioritizedReplayBuffer from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState @@ -39,6 +40,8 @@ def __init__( exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, + initial_beta: float = 0.4, + final_beta: float = 1.0, optimize_memory_usage: bool = False, # Note: unused but to match SB3 API # max_grad_norm: float = 10, train_freq: Union[int, Tuple[int, str]] = 4, @@ -77,6 +80,19 @@ def __init__( _init_setup_model=_init_setup_model, ) + self._inital_beta = initial_beta + self._final_beta = final_beta + self.beta_schedule = get_linear_fn( + self._inital_beta, + self._final_beta, + end_fraction=1.0, + ) + + @property + def beta(self) -> float: + # Linear schedule + return self.beta_schedule(self._current_progress_remaining) + def learn( self, total_timesteps: int, @@ -97,7 +113,7 @@ def learn( def train(self, batch_size: int, gradient_steps: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) - data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + data = self.replay_buffer.sample(batch_size * gradient_steps, self.beta, env=self._vec_normalize_env) # Convert to numpy data = ReplayBufferSamplesNp( data.observations.numpy(), @@ -121,7 +137,7 @@ def train(self, batch_size: int, gradient_steps: int) -> None: "info": { "critic_loss": jnp.array([0.0]), "qf_mean_value": jnp.array([0.0]), - "td_error": jnp.zeros_like(data.rewards), + "priorities": jnp.zeros_like(data.rewards), }, } @@ -137,12 +153,12 @@ def train(self, batch_size: int, gradient_steps: int) -> None: self.policy.qf_state = update_carry["qf_state"] qf_loss_value = update_carry["info"]["critic_loss"] qf_mean_value = update_carry["info"]["qf_mean_value"] / gradient_steps - td_error = update_carry["info"]["td_error"] + priorities = update_carry["info"]["priorities"] # Update priorities, they will be proportional to the td error # Note: compared to the original implementation, we update # the priorities after all the gradient steps - self.replay_buffer.update_priorities(data.leaf_nodes_indices, td_error, self._current_progress_remaining) + self.replay_buffer.update_priorities(data.leaf_nodes_indices, priorities) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") @@ -179,24 +195,24 @@ def weighted_huber_loss(params): # Retrieve the q-values for the actions from the replay buffer current_q_values = jnp.take_along_axis(current_q_values, replay_actions, axis=1) # TD error in absolute value, to update priorities - td_error = jnp.abs(current_q_values - target_q_values) + priorities = jnp.abs(current_q_values - target_q_values) # Weighted Huber loss using importance sampling weights loss = (sampling_weights * optax.huber_loss(current_q_values, target_q_values)).mean() - return loss, (current_q_values.mean(), td_error.flatten()) + return loss, (current_q_values.mean(), priorities.flatten()) - (qf_loss_value, (qf_mean_value, td_error)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)( + (qf_loss_value, (qf_mean_value, priorities)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)( qf_state.params ) qf_state = qf_state.apply_gradients(grads=grads) - return qf_state, (qf_loss_value, qf_mean_value, td_error) + return qf_state, (qf_loss_value, qf_mean_value, priorities) @staticmethod @jax.jit def _train(carry, indices): data = carry["data"] - qf_state, (qf_loss_value, qf_mean_value, td_error) = PERDQN.update_qnetwork( + qf_state, (qf_loss_value, qf_mean_value, priorities) = PERDQN.update_qnetwork( carry["gamma"], carry["qf_state"], observations=data.observations[indices], @@ -210,6 +226,6 @@ def _train(carry, indices): carry["qf_state"] = qf_state carry["info"]["critic_loss"] += qf_loss_value carry["info"]["qf_mean_value"] += qf_mean_value - carry["info"]["td_error"] = carry["info"]["td_error"].at[indices].set(td_error) + carry["info"]["priorities"] = carry["info"]["priorities"].at[indices].set(priorities) return carry, None From 77c7ebdb30ce3a34bfdd551948ecd1de58ff18a3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 10:57:23 +0200 Subject: [PATCH 09/15] Fix for update interval with multi-env --- sbx/dqn/dqn.py | 4 +++- sbx/dqn/policies.py | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index cbc5fea..afcd5ef 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -227,7 +227,9 @@ def _on_step(self) -> None: This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 - if self._n_calls % self.target_update_interval == 0: + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: self.policy.qf_state = DQN.soft_update(self.tau, self.policy.qf_state) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index d8b19ba..e1089be 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -69,6 +69,7 @@ def __init__( normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, + max_grad_norm: float = 10.0, ): super().__init__( observation_space, @@ -85,6 +86,7 @@ def __init__( else: self.n_units = 256 self.activation_fn = activation_fn + self.max_grad_norm = max_grad_norm def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: key, qf_key = jax.random.split(key, 2) @@ -101,13 +103,15 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: apply_fn=self.qf.apply, params=self.qf.init({"params": qf_key}, obs), target_params=self.qf.init({"params": qf_key}, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, + tx=optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), ), ) - # TODO: jit qf.apply_fn too? self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign] return key @@ -141,9 +145,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: apply_fn=self.qf.apply, params=self.qf.init({"params": qf_key}, obs), target_params=self.qf.init({"params": qf_key}, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, + tx=optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), ), ) self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign] From 59b72d03b5d2c9b725b8c5e1f8092df07147c570 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 11:06:52 +0200 Subject: [PATCH 10/15] Fix type check and signatures --- sbx/common/prioritized_replay_buffer.py | 2 +- sbx/dqn/dqn.py | 5 +++-- sbx/per_dqn/per_dqn.py | 25 ++++++++++++++++--------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index 6216aa3..816bccc 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -282,7 +282,7 @@ def add( # Store transition in the buffer super().add(obs, next_obs, action, reward, done, infos) - def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: # type: ignore[override] """ Sample elements from the prioritized replay buffer. diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index afcd5ef..cdae123 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -135,11 +135,12 @@ def learn( progress_bar=progress_bar, ) - def train(self, batch_size: int, gradient_steps: int) -> None: + def train(self, gradient_steps: int, batch_size: int) -> None: + assert self.replay_buffer is not None # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # Convert to numpy - data = ReplayBufferSamplesNp( + data = ReplayBufferSamplesNp( # type: ignore[assignment] data.observations.numpy(), # Convert to int64 data.actions.long().numpy(), diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py index 9aa63a4..ecd4d37 100644 --- a/sbx/per_dqn/per_dqn.py +++ b/sbx/per_dqn/per_dqn.py @@ -25,6 +25,7 @@ class PERDQN(DQN): # Linear schedule will be defined in `_setup_model()` exploration_schedule: Schedule policy: DQNPolicy + replay_buffer: PrioritizedReplayBuffer def __init__( self, @@ -111,19 +112,20 @@ def learn( progress_bar=progress_bar, ) - def train(self, batch_size: int, gradient_steps: int) -> None: + def train(self, gradient_steps: int, batch_size: int) -> None: + assert self.replay_buffer is not None # Sample all at once for efficiency (so we can jit the for loop) - data = self.replay_buffer.sample(batch_size * gradient_steps, self.beta, env=self._vec_normalize_env) + th_data = self.replay_buffer.sample(batch_size * gradient_steps, self.beta, env=self._vec_normalize_env) # Convert to numpy data = ReplayBufferSamplesNp( - data.observations.numpy(), + th_data.observations.numpy(), # Convert to int64 - data.actions.long().numpy(), - data.next_observations.numpy(), - data.dones.numpy().flatten(), - data.rewards.numpy().flatten(), - data.weights.numpy().flatten(), - data.leaf_nodes_indices, + th_data.actions.long().numpy(), + th_data.next_observations.numpy(), + th_data.dones.numpy().flatten(), + th_data.rewards.numpy().flatten(), + th_data.weights.numpy().flatten(), # type: ignore[union-attr] + th_data.leaf_nodes_indices, ) # Pre compute the slice indices # otherwise jax will complain @@ -158,12 +160,17 @@ def train(self, batch_size: int, gradient_steps: int) -> None: # Update priorities, they will be proportional to the td error # Note: compared to the original implementation, we update # the priorities after all the gradient steps + assert data.leaf_nodes_indices is not None self.replay_buffer.update_priorities(data.leaf_nodes_indices, priorities) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/qf_mean_value", qf_mean_value.item()) + self.logger.record("train/beta", self.beta) + self.logger.record("train/min_priority", priorities.min().item()) + self.logger.record("train/max_priority", priorities.max().item()) + self.logger.record("train/mean_priority", priorities.mean().item()) @staticmethod @jax.jit From 5e3d182556c533d806ed2e3651f4f00fdc94e1be Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 11:40:01 +0200 Subject: [PATCH 11/15] Normalize by true min proba --- sbx/common/prioritized_replay_buffer.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index 816bccc..a0968f5 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -293,9 +293,6 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non """ # assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." - # priorities = np.zeros((batch_size, 1)) - # sample_indices = np.zeros(batch_size, dtype=np.uint32) - # TODO: check how things are sampled in the original implementation leaf_nodes_indices = self._sample_proportional(batch_size) @@ -307,17 +304,16 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. - # probs = priorities / self.tree.total_sum - probabilities = self._sum_tree[leaf_nodes_indices] / self._sum_tree.sum() + total_priorities = self._sum_tree.sum() + probabilities = self._sum_tree[leaf_nodes_indices] / total_priorities # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. - # weights = (self.size() * probs + 1e-7) ** -self.beta - # min_probability = self._min_tree.min() / self._sum_tree.sum() - # max_weight = (min_probability * self.size()) ** (-self.beta) - # weights = (probabilities * self.size()) ** (-self.beta) / max_weight - weights = (probabilities * self.size()) ** (-beta) - weights = weights / weights.max() + min_probability = self._min_tree.min() / total_priorities + max_weight = (min_probability * self.size()) ** (-beta) + weights = (probabilities * self.size()) ** (-beta) / max_weight + # weights = (probabilities * self.size()) ** (-beta) + # weights = weights / weights.max() # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) # env_indices = np.zeros(batch_size, dtype=np.uint32) From e80d9054f4ccaf5baa4875cfbd53f39b71be7caf Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 12:30:14 +0200 Subject: [PATCH 12/15] Deactivate beta schedule --- sbx/per_dqn/per_dqn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py index ecd4d37..9aab4cc 100644 --- a/sbx/per_dqn/per_dqn.py +++ b/sbx/per_dqn/per_dqn.py @@ -91,8 +91,9 @@ def __init__( @property def beta(self) -> float: + return 0.5 # same as Dopamine RL # Linear schedule - return self.beta_schedule(self._current_progress_remaining) + # return self.beta_schedule(self._current_progress_remaining) def learn( self, From 538a9085e3c0a7ba5c1f561ae6c2d833545b12e0 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 12:32:53 +0200 Subject: [PATCH 13/15] Fix stratified sampling and normalization --- sbx/common/prioritized_replay_buffer.py | 52 ++++++++++++++++++------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index a0968f5..c467356 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -191,7 +191,7 @@ class MinSegmentTree(SegmentTree): def __init__(self, capacity: int) -> None: super().__init__(capacity=capacity, reduce_op=np.minimum, neutral_element=float("inf")) - def min(self, start=0, end=None): + def min(self, start: int = 0, end: Optional[int] = None) -> float: """ Returns min(arr[start], ..., arr[end]) @@ -294,24 +294,35 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non # assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." # TODO: check how things are sampled in the original implementation + # Note: should be the same as self._sum_tree.sum(0, buffer_size - 1) + total_priorities = self._sum_tree.sum() + buffer_size = self.size() * self.n_envs + + # leaf_nodes_indices = self._sample_proportional(batch_size, total_priorities) + # TODO: check that the sampled indices are valid: + # - they should be in the range [0, buffer_size) + leaf_nodes_indices = self._stratified_sampling(batch_size, total_priorities) + # debug: uniform sampling + # leaf_nodes_indices = np.random.randint(0, buffer_size, size=batch_size) - leaf_nodes_indices = self._sample_proportional(batch_size) # Convert the leaf nodes indices to buffer indices # Replay buffer: (idx, env_idx) # Sum tree: idx * self.n_envs + env_idx buffer_indices = leaf_nodes_indices // self.n_envs env_indices = leaf_nodes_indices % self.n_envs + # assert np.all(buffer_indices < self.size()), f"Invalid indices: {buffer_indices} >= {self.size()}" + # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha # where p_i > 0 is the priority of transition i. - total_priorities = self._sum_tree.sum() probabilities = self._sum_tree[leaf_nodes_indices] / total_priorities # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. min_probability = self._min_tree.min() / total_priorities - max_weight = (min_probability * self.size()) ** (-beta) - weights = (probabilities * self.size()) ** (-beta) / max_weight + # FIXME: self.size() doesn't take into account the number of envs + max_weight = (min_probability * buffer_size) ** (-beta) + weights = (probabilities * buffer_size) ** (-beta) / max_weight # weights = (probabilities * self.size()) ** (-beta) # weights = weights / weights.max() @@ -331,18 +342,33 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] - def _sample_proportional(self, batch_size: int) -> np.ndarray: + # def _sample_proportional(self, batch_size: int, total_priorities: float) -> np.ndarray: + # """ + # Sample a batch of leaf nodes indices using the proportional prioritization strategy. + # In other words, the probability of sampling a transition is proportional to its priority. + + # :param batch_size: Number of element to sample + # :return: Indices of the sampled leaf nodes + # """ + # # TODO: double check if this is correct + # # total = self._sum_tree.sum(0, self.size() - 1) + # # priorities_sum = np.random.random(size=batch_size) * total_priorities + # priorities_sum = np.random.uniform(0, total_priorities, size=batch_size) + # return self._sum_tree.find_prefixsum_idx(priorities_sum) + + def _stratified_sampling(self, batch_size: int, total_priorities: float) -> np.ndarray: """ - Sample a batch of leaf nodes indices using the proportional prioritization strategy. - In other words, the probability of sampling a transition is proportional to its priority. + To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. + Next, a value is uniformly sampled from each range. Finally the transitions that correspond + to each of these sampled values are retrieved from the tree. :param batch_size: Number of element to sample + :param total_priorities: Sum of all priorities in the sum tree. :return: Indices of the sampled leaf nodes """ - # TODO: double check if this is correct - total = self._sum_tree.sum(0, self.size() - 1) - priorities_sum = np.random.random(size=batch_size) * total - return self._sum_tree.find_prefixsum_idx(priorities_sum) + segments = np.linspace(0, total_priorities, num=batch_size + 1) + desired_priorities = np.random.uniform(segments[:-1], segments[1:], size=batch_size) + return self._sum_tree.find_prefixsum_idx(desired_priorities) # def update_priorities(self, indices: np.ndarray, priorities: np.ndarray) -> None: def update_priorities(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None: @@ -359,7 +385,7 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarr # assert len(indices) == len(priorities) assert np.min(priorities) > 0 assert np.min(leaf_nodes_indices) >= 0 - assert np.max(leaf_nodes_indices) < self.buffer_size + # assert np.max(leaf_nodes_indices) < self.buffer_size * self.n_envs # TODO: check if we need to add the min_priority here # priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha self._sum_tree[leaf_nodes_indices] = priorities**self._alpha From 94ce7c66ea3b23f3ee71dd199cabec90bba5639c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 17 Jul 2024 12:43:30 +0200 Subject: [PATCH 14/15] Try to sample only valid transitions --- sbx/common/prioritized_replay_buffer.py | 39 +++++++++++++------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py index c467356..f8105ef 100644 --- a/sbx/common/prioritized_replay_buffer.py +++ b/sbx/common/prioritized_replay_buffer.py @@ -296,12 +296,13 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non # TODO: check how things are sampled in the original implementation # Note: should be the same as self._sum_tree.sum(0, buffer_size - 1) total_priorities = self._sum_tree.sum() + min_priority = self._min_tree.min() buffer_size = self.size() * self.n_envs - # leaf_nodes_indices = self._sample_proportional(batch_size, total_priorities) + # leaf_nodes_indices = self._sample_proportional(batch_size, min_priority, total_priorities) # TODO: check that the sampled indices are valid: # - they should be in the range [0, buffer_size) - leaf_nodes_indices = self._stratified_sampling(batch_size, total_priorities) + leaf_nodes_indices = self._stratified_sampling(batch_size, min_priority, total_priorities) # debug: uniform sampling # leaf_nodes_indices = np.random.randint(0, buffer_size, size=batch_size) @@ -319,7 +320,7 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non # Importance sampling weights. # All weights w_i were scaled so that max_i w_i = 1. - min_probability = self._min_tree.min() / total_priorities + min_probability = min_priority / total_priorities # FIXME: self.size() doesn't take into account the number of envs max_weight = (min_probability * buffer_size) ** (-beta) weights = (probabilities * buffer_size) ** (-beta) / max_weight @@ -342,21 +343,21 @@ def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = Non ) return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] - # def _sample_proportional(self, batch_size: int, total_priorities: float) -> np.ndarray: - # """ - # Sample a batch of leaf nodes indices using the proportional prioritization strategy. - # In other words, the probability of sampling a transition is proportional to its priority. - - # :param batch_size: Number of element to sample - # :return: Indices of the sampled leaf nodes - # """ - # # TODO: double check if this is correct - # # total = self._sum_tree.sum(0, self.size() - 1) - # # priorities_sum = np.random.random(size=batch_size) * total_priorities - # priorities_sum = np.random.uniform(0, total_priorities, size=batch_size) - # return self._sum_tree.find_prefixsum_idx(priorities_sum) - - def _stratified_sampling(self, batch_size: int, total_priorities: float) -> np.ndarray: + def _sample_proportional(self, batch_size: int, min_priority: float, total_priorities: float) -> np.ndarray: + """ + Sample a batch of leaf nodes indices using the proportional prioritization strategy. + In other words, the probability of sampling a transition is proportional to its priority. + + :param batch_size: Number of element to sample + :return: Indices of the sampled leaf nodes + """ + # TODO: double check if this is correct + # total = self._sum_tree.sum(0, self.size() - 1) + # priorities_sum = np.random.random(size=batch_size) * total_priorities + priorities_sum = np.random.uniform(min_priority, total_priorities, size=batch_size) + return self._sum_tree.find_prefixsum_idx(priorities_sum) + + def _stratified_sampling(self, batch_size: int, min_priority: float, total_priorities: float) -> np.ndarray: """ To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. Next, a value is uniformly sampled from each range. Finally the transitions that correspond @@ -366,7 +367,7 @@ def _stratified_sampling(self, batch_size: int, total_priorities: float) -> np.n :param total_priorities: Sum of all priorities in the sum tree. :return: Indices of the sampled leaf nodes """ - segments = np.linspace(0, total_priorities, num=batch_size + 1) + segments = np.linspace(min_priority, total_priorities, num=batch_size + 1) desired_priorities = np.random.uniform(segments[:-1], segments[1:], size=batch_size) return self._sum_tree.find_prefixsum_idx(desired_priorities) From 5d206af4c7fa5f84db2ebc42042f69301e85b861 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 25 Aug 2024 11:46:39 +0200 Subject: [PATCH 15/15] Fix displayed ent coef value --- sbx/crossq/crossq.py | 5 ++++- sbx/sac/sac.py | 6 +++++- sbx/tqc/tqc.py | 6 +++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 94e2bbc..4a1ffeb 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.target_entropy, @@ -224,11 +224,14 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @jax.jit diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 11f8ff5..f566aea 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -213,7 +213,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.tau, @@ -227,11 +227,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) + self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @jax.jit diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 5161f4d..b65c0c5 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -216,7 +216,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value), + (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.tau, @@ -232,11 +232,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) + self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf1_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @partial(jax.jit, static_argnames=["n_target_quantiles"])