diff --git a/Makefile b/Makefile index 0177d5a..a8bdbf4 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ lint: # see https://www.flake8rules.com/ ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff check ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..62dd9be 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -3,6 +3,7 @@ from sbx.crossq import CrossQ from sbx.ddpg import DDPG from sbx.dqn import DQN +from sbx.per_dqn import PERDQN from sbx.ppo import PPO from sbx.sac import SAC from sbx.td3 import TD3 @@ -26,6 +27,7 @@ def DroQ(*args, **kwargs): "CrossQ", "DDPG", "DQN", + "PERDQN", "PPO", "SAC", "TD3", diff --git a/sbx/common/prioritized_replay_buffer.py b/sbx/common/prioritized_replay_buffer.py new file mode 100644 index 0000000..f8105ef --- /dev/null +++ b/sbx/common/prioritized_replay_buffer.py @@ -0,0 +1,395 @@ +""" +Segment tree implementation taken from Stable Baselines 2: +https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/segment_tree.py + +Notable differences: +- This implementation uses numpy arrays to store the values (faster initialization) +- We don't use a special function to have unique indices (no significant performance difference found) +""" + +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch as th +from gymnasium import spaces +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import ReplayBufferSamples +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize + + +class SegmentTree: + def __init__(self, capacity: int, reduce_op: Callable, neutral_element: float) -> None: + """ + Build a Segment Tree data structure. + + https://en.wikipedia.org/wiki/Segment_tree + + Can be used as regular array that supports Index arrays, but with two + important differences: + + a) setting item's value is slightly slower. + It is O(log capacity) instead of O(1). + b) user has access to an efficient ( O(log segment size) ) + `reduce` operation which reduces `operation` over + a contiguous subsequence of items in the array. + + :param capacity: Total size of the array - must be a power of two. + :param reduce_op: Operation for combining elements (eg. sum, max) must form a + mathematical group together with the set of possible values for array elements (i.e. be associative) + :param neutral_element: Neutral element for the operation above. eg. float('-inf') for max and 0 for sum. + """ + assert capacity > 0 and capacity & (capacity - 1) == 0, f"Capacity must be positive and a power of 2, not {capacity}" + self._capacity = capacity + # First index is the root, leaf nodes are in [capacity, 2 * capacity - 1]. + # For each parent node i, left child has index [2 * i], right child [2 * i + 1] + self._values = np.full(2 * capacity, neutral_element) + self._reduce_op = reduce_op + self.neutral_element = neutral_element + + def _reduce_helper(self, start: int, end: int, node: int, node_start: int, node_end: int) -> float: + """ + Query the value of the segment tree for the given range + + :param start: start of the range + :param end: end of the range + :param node: current node in the segment tree + :param node_start: start of the range represented by the current node + :param node_end: end of the range represented by the current node + :return: result of reducing ``self.reduce_op`` over the specified range of array elements. + """ + if start == node_start and end == node_end: + return self._values[node] + mid = (node_start + node_end) // 2 + if end <= mid: + return self._reduce_helper(start, end, 2 * node, node_start, mid) + else: + if mid + 1 <= start: + return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) + else: + return self._reduce_op( + self._reduce_helper(start, mid, 2 * node, node_start, mid), + self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end), + ) + + def reduce(self, start: int = 0, end: Optional[int] = None) -> float: + """ + Returns result of applying ``self.reduce_op`` + to a contiguous subsequence of the array. + + .. code-block:: python + + self.reduce_op(arr[start], operation(arr[start+1], operation(... arr[end]))) + + :param start: beginning of the subsequence + :param end: end of the subsequences + :return: result of reducing ``self.reduce_op`` over the specified range of array elements. + """ + if end is None: + end = self._capacity + if end < 0: + end += self._capacity + end -= 1 + return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + + def __setitem__(self, idx: np.ndarray, val: np.ndarray) -> None: + """ + Set the value at index `idx` to `val` + + :param idx: index of the value to be updated + :param val: new value + """ + # assert np.all(0 <= idx < self._capacity), f"Trying to set item outside capacity: {idx}" + # Indices of the leafs + indices = idx + self._capacity + # Update the leaf nodes and then the related nodes + self._values[indices] = val + if isinstance(indices, int): + indices = np.array([indices]) + # Go up one level in the tree and remove duplicate indices + indices = np.unique(indices // 2) + while len(indices) > 1 or indices[0] > 0: + # As long as there are non-zero indices, update the corresponding values + self._values[indices] = self._reduce_op(self._values[2 * indices], self._values[2 * indices + 1]) + # Go up one level in the tree and remove duplicate indices + indices = np.unique(indices // 2) + + def __getitem__(self, idx: np.ndarray) -> np.ndarray: + """ + Get the value(s) at index `idx` + """ + assert np.max(idx) < self._capacity, f"Index must be less than capacity, got {np.max(idx)} >= {self._capacity}" + assert 0 <= np.min(idx) + return self._values[self._capacity + idx] + + +class SumSegmentTree(SegmentTree): + """ + A Segment Tree data structure where each node contains the sum of the + values in its leaf nodes. Can be used as a Sum Tree for priorities. + """ + + def __init__(self, capacity: int) -> None: + super().__init__(capacity=capacity, reduce_op=np.add, neutral_element=0.0) + + def sum(self, start: int = 0, end: Optional[int] = None) -> float: + """ + Returns arr[start] + ... + arr[end] + + :param start: start position of the reduction (must be >= 0) + :param end: end position of the reduction (must be < len(arr), can be None for len(arr) - 1) + :return: reduction of SumSegmentTree + """ + return super().reduce(start, end) + + def find_prefixsum_idx(self, prefixsum: np.ndarray) -> np.ndarray: + """ + Find the highest index `i` in the array such that + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum + + if array values are probabilities, this function + allows to sample indices according to the discrete + probability efficiently. + + :param prefixsum: float upper bounds on the sum of array prefix + :return: highest indices satisfying the prefixsum constraint + """ + if isinstance(prefixsum, float): + prefixsum = np.array([prefixsum]) + assert 0 <= np.min(prefixsum) + # assert np.max(prefixsum) <= self.sum() + 1e-5 + + indices = np.ones(len(prefixsum), dtype=int) + should_continue = np.ones(len(prefixsum), dtype=bool) + + while np.any(should_continue): # while not all nodes are leafs + indices[should_continue] = 2 * indices[should_continue] + prefixsum_new = np.where( + self._values[indices] <= prefixsum, + prefixsum - self._values[indices], + prefixsum, + ) + # Prepare update of prefixsum for all right children + indices = np.where( + np.logical_or(self._values[indices] > prefixsum, np.logical_not(should_continue)), + indices, + indices + 1, + ) + # Select child node for non-leaf nodes + prefixsum = prefixsum_new + # Update prefixsum + should_continue = indices < self._capacity + # Collect leafs + return indices - self._capacity + + +class MinSegmentTree(SegmentTree): + """ + A Segment Tree data structure where each node contains the minimum of the + values in its leaf nodes. Can be used as a Min Tree for priorities. + """ + + def __init__(self, capacity: int) -> None: + super().__init__(capacity=capacity, reduce_op=np.minimum, neutral_element=float("inf")) + + def min(self, start: int = 0, end: Optional[int] = None) -> float: + """ + Returns min(arr[start], ..., arr[end]) + + :param start: start position of the reduction (must be >= 0) + :param end: end position of the reduction (must be < len(arr), can be None for len(arr) - 1) + :return: reduction of MinSegmentTree + """ + return super().reduce(start, end) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized Replay Buffer (proportional priorities version). + Paper: https://arxiv.org/abs/1511.05952 + This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param n_envs: Number of parallel environments + :param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization) + :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) + :param final_beta: Value of beta at the end of training. + Linear annealing is used to interpolate between initial value of beta and final beta. + :param min_priority: Minimum priority, prevents zero probabilities, so that all samples + always have a non-zero probability to be sampled. + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + alpha: float = 0.5, + optimize_memory_usage: bool = False, + min_priority: float = 1e-6, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs) + + # TODO: check if we can support optimize_memory_usage + assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" + + # TODO: add support for multi env + # assert n_envs == 1, "PrioritizedReplayBuffer doesn't support n_envs > 1" + + # Find the next power of 2 for the buffer size + power_of_two = int(np.ceil(np.log2(buffer_size))) + tree_capacity = 2**power_of_two + + self._min_priority = min_priority + self._max_priority = 1.0 + + self._alpha = alpha + + self._sum_tree = SumSegmentTree(tree_capacity) + self._min_tree = MinSegmentTree(tree_capacity) + # Flatten the indices from the buffer to store them in the sum tree + # Replay buffer: (idx, env_idx) + # Sum tree: idx * self.n_envs + env_idx + self.env_offsets = np.arange(self.n_envs) + + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + """ + Add a new transition to the buffer. + + :param obs: Starting observation of the transition to be stored. + :param next_obs: Destination observation of the transition to be stored. + :param action: Action performed in the transition to be stored. + :param reward: Reward received in the transition to be stored. + :param done: Whether the episode was finished after the transition to be stored. + :param infos: Eventual information given by the environment. + """ + # Store transition index with maximum priority in sum tree + self._sum_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha + self._min_tree[self.pos * self.n_envs + self.env_offsets] = self._max_priority**self._alpha + + # Store transition in the buffer + super().add(obs, next_obs, action, reward, done, infos) + + def sample(self, batch_size: int, beta: float, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: # type: ignore[override] + """ + Sample elements from the prioritized replay buffer. + + :param batch_size: Number of element to sample + :param env:associated gym VecEnv + to normalize the observations/rewards when sampling + :return: a batch of sampled experiences from the buffer. + """ + # assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." + + # TODO: check how things are sampled in the original implementation + # Note: should be the same as self._sum_tree.sum(0, buffer_size - 1) + total_priorities = self._sum_tree.sum() + min_priority = self._min_tree.min() + buffer_size = self.size() * self.n_envs + + # leaf_nodes_indices = self._sample_proportional(batch_size, min_priority, total_priorities) + # TODO: check that the sampled indices are valid: + # - they should be in the range [0, buffer_size) + leaf_nodes_indices = self._stratified_sampling(batch_size, min_priority, total_priorities) + # debug: uniform sampling + # leaf_nodes_indices = np.random.randint(0, buffer_size, size=batch_size) + + # Convert the leaf nodes indices to buffer indices + # Replay buffer: (idx, env_idx) + # Sum tree: idx * self.n_envs + env_idx + buffer_indices = leaf_nodes_indices // self.n_envs + env_indices = leaf_nodes_indices % self.n_envs + + # assert np.all(buffer_indices < self.size()), f"Invalid indices: {buffer_indices} >= {self.size()}" + + # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha + # where p_i > 0 is the priority of transition i. + probabilities = self._sum_tree[leaf_nodes_indices] / total_priorities + + # Importance sampling weights. + # All weights w_i were scaled so that max_i w_i = 1. + min_probability = min_priority / total_priorities + # FIXME: self.size() doesn't take into account the number of envs + max_weight = (min_probability * buffer_size) ** (-beta) + weights = (probabilities * buffer_size) ** (-beta) / max_weight + # weights = (probabilities * self.size()) ** (-beta) + # weights = weights / weights.max() + + # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) + # env_indices = np.zeros(batch_size, dtype=np.uint32) + next_obs = self._normalize_obs(self.next_observations[buffer_indices, env_indices, :], env) + + batch = ( + self._normalize_obs(self.observations[buffer_indices, env_indices, :], env), + self.actions[buffer_indices, env_indices, :], + next_obs, + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + (self.dones[buffer_indices, env_indices] * (1 - self.timeouts[buffer_indices, env_indices])).reshape(-1, 1), + self._normalize_reward(self.rewards[buffer_indices, env_indices].reshape(-1, 1), env), + weights, + ) + return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] + + def _sample_proportional(self, batch_size: int, min_priority: float, total_priorities: float) -> np.ndarray: + """ + Sample a batch of leaf nodes indices using the proportional prioritization strategy. + In other words, the probability of sampling a transition is proportional to its priority. + + :param batch_size: Number of element to sample + :return: Indices of the sampled leaf nodes + """ + # TODO: double check if this is correct + # total = self._sum_tree.sum(0, self.size() - 1) + # priorities_sum = np.random.random(size=batch_size) * total_priorities + priorities_sum = np.random.uniform(min_priority, total_priorities, size=batch_size) + return self._sum_tree.find_prefixsum_idx(priorities_sum) + + def _stratified_sampling(self, batch_size: int, min_priority: float, total_priorities: float) -> np.ndarray: + """ + To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. + Next, a value is uniformly sampled from each range. Finally the transitions that correspond + to each of these sampled values are retrieved from the tree. + + :param batch_size: Number of element to sample + :param total_priorities: Sum of all priorities in the sum tree. + :return: Indices of the sampled leaf nodes + """ + segments = np.linspace(min_priority, total_priorities, num=batch_size + 1) + desired_priorities = np.random.uniform(segments[:-1], segments[1:], size=batch_size) + return self._sum_tree.find_prefixsum_idx(desired_priorities) + + # def update_priorities(self, indices: np.ndarray, priorities: np.ndarray) -> None: + def update_priorities(self, leaf_nodes_indices: np.ndarray, priorities: np.ndarray) -> None: + """ + Update priorities of sampled transitions. + + :param leaf_nodes_indices: Indices of the sampled transitions. + :param priorities: New priorities, td error in the case of + proportional prioritized replay buffer. + """ + # TODO: double check that all samples are updated + # priorities = np.abs(td_errors) + self.min_priority + priorities += self._min_priority + # assert len(indices) == len(priorities) + assert np.min(priorities) > 0 + assert np.min(leaf_nodes_indices) >= 0 + # assert np.max(leaf_nodes_indices) < self.buffer_size * self.n_envs + # TODO: check if we need to add the min_priority here + # priorities = (np.abs(td_errors) + self.min_priority) ** self.alpha + self._sum_tree[leaf_nodes_indices] = priorities**self._alpha + self._min_tree[leaf_nodes_indices] = priorities**self._alpha + # Update max priority for new samples + self._max_priority = max(self._max_priority, np.max(priorities)) diff --git a/sbx/common/type_aliases.py b/sbx/common/type_aliases.py index f014a1f..bf23356 100644 --- a/sbx/common/type_aliases.py +++ b/sbx/common/type_aliases.py @@ -1,4 +1,4 @@ -from typing import NamedTuple +from typing import NamedTuple, Optional, Union import flax import numpy as np @@ -19,3 +19,5 @@ class ReplayBufferSamplesNp(NamedTuple): next_observations: np.ndarray dones: np.ndarray rewards: np.ndarray + weights: Union[np.ndarray, float] = 1.0 + leaf_nodes_indices: Optional[np.ndarray] = None diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index f888672..f8a0c20 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.target_entropy, @@ -224,11 +224,14 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @jax.jit diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 852b823..cdae123 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import numpy as np import optax +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn @@ -41,6 +42,8 @@ def __init__( # max_grad_norm: float = 10, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, + replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, @@ -59,6 +62,8 @@ def __init__( gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, @@ -130,11 +135,12 @@ def learn( progress_bar=progress_bar, ) - def train(self, batch_size, gradient_steps): + def train(self, gradient_steps: int, batch_size: int) -> None: + assert self.replay_buffer is not None # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) # Convert to numpy - data = ReplayBufferSamplesNp( + data = ReplayBufferSamplesNp( # type: ignore[assignment] data.observations.numpy(), # Convert to int64 data.actions.long().numpy(), @@ -222,7 +228,9 @@ def _on_step(self) -> None: This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 - if self._n_calls % self.target_update_interval == 0: + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: self.policy.qf_state = DQN.soft_update(self.tau, self.policy.qf_state) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index 4cff77b..e1089be 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -69,6 +69,7 @@ def __init__( normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, + max_grad_norm: float = 10.0, ): super().__init__( observation_space, @@ -85,6 +86,7 @@ def __init__( else: self.n_units = 256 self.activation_fn = activation_fn + self.max_grad_norm = max_grad_norm def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: key, qf_key = jax.random.split(key, 2) @@ -101,9 +103,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: apply_fn=self.qf.apply, params=self.qf.init({"params": qf_key}, obs), target_params=self.qf.init({"params": qf_key}, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, + tx=optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), ), ) @@ -140,9 +145,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: apply_fn=self.qf.apply, params=self.qf.init({"params": qf_key}, obs), target_params=self.qf.init({"params": qf_key}, obs), - tx=self.optimizer_class( - learning_rate=lr_schedule(1), # type: ignore[call-arg] - **self.optimizer_kwargs, + tx=optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + self.optimizer_class( + learning_rate=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ), ), ) self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign] diff --git a/sbx/per_dqn/__init__.py b/sbx/per_dqn/__init__.py new file mode 100644 index 0000000..188179e --- /dev/null +++ b/sbx/per_dqn/__init__.py @@ -0,0 +1,3 @@ +from sbx.per_dqn.per_dqn import PERDQN + +__all__ = ["PERDQN"] diff --git a/sbx/per_dqn/per_dqn.py b/sbx/per_dqn/per_dqn.py new file mode 100644 index 0000000..9aab4cc --- /dev/null +++ b/sbx/per_dqn/per_dqn.py @@ -0,0 +1,239 @@ +from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import get_linear_fn + +from sbx.common.prioritized_replay_buffer import PrioritizedReplayBuffer +from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState +from sbx.dqn import DQN +from sbx.dqn.policies import CNNPolicy, DQNPolicy + + +class PERDQN(DQN): + """ + DQN with Prioritized Experience Replay (PER). + """ + + policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment] + "MlpPolicy": DQNPolicy, + "CnnPolicy": CNNPolicy, + } + # Linear schedule will be defined in `_setup_model()` + exploration_schedule: Schedule + policy: DQNPolicy + replay_buffer: PrioritizedReplayBuffer + + def __init__( + self, + policy, + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 1e-4, + buffer_size: int = 1_000_000, # 1e6 + learning_starts: int = 100, + batch_size: int = 32, + tau: float = 1.0, + gamma: float = 0.99, + target_update_interval: int = 1000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.05, + initial_beta: float = 0.4, + final_beta: float = 1.0, + optimize_memory_usage: bool = False, # Note: unused but to match SB3 API + # max_grad_norm: float = 10, + train_freq: Union[int, Tuple[int, str]] = 4, + gradient_steps: int = 1, + # replay_buffer_class: Optional[Type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: str = "auto", + _init_setup_model: bool = True, + ) -> None: + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + target_update_interval=target_update_interval, + exploration_fraction=exploration_fraction, + exploration_initial_eps=exploration_initial_eps, + exploration_final_eps=exploration_final_eps, + optimize_memory_usage=optimize_memory_usage, + train_freq=train_freq, + gradient_steps=gradient_steps, + replay_buffer_class=PrioritizedReplayBuffer, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + seed=seed, + _init_setup_model=_init_setup_model, + ) + + self._inital_beta = initial_beta + self._final_beta = final_beta + self.beta_schedule = get_linear_fn( + self._inital_beta, + self._final_beta, + end_fraction=1.0, + ) + + @property + def beta(self) -> float: + return 0.5 # same as Dopamine RL + # Linear schedule + # return self.beta_schedule(self._current_progress_remaining) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + tb_log_name: str = "PERDQN", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ): + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + ) + + def train(self, gradient_steps: int, batch_size: int) -> None: + assert self.replay_buffer is not None + # Sample all at once for efficiency (so we can jit the for loop) + th_data = self.replay_buffer.sample(batch_size * gradient_steps, self.beta, env=self._vec_normalize_env) + # Convert to numpy + data = ReplayBufferSamplesNp( + th_data.observations.numpy(), + # Convert to int64 + th_data.actions.long().numpy(), + th_data.next_observations.numpy(), + th_data.dones.numpy().flatten(), + th_data.rewards.numpy().flatten(), + th_data.weights.numpy().flatten(), # type: ignore[union-attr] + th_data.leaf_nodes_indices, + ) + # Pre compute the slice indices + # otherwise jax will complain + indices = jnp.arange(len(data.dones)).reshape(gradient_steps, batch_size) + + update_carry = { + "qf_state": self.policy.qf_state, + "gamma": self.gamma, + "data": data, + "indices": indices, + "info": { + "critic_loss": jnp.array([0.0]), + "qf_mean_value": jnp.array([0.0]), + "priorities": jnp.zeros_like(data.rewards), + }, + } + + # jit the loop similar to https://github.com/Howuhh/sac-n-jax + # we use scan to be able to play with unroll parameter + update_carry, _ = jax.lax.scan( + self._train, + update_carry, + indices, + unroll=1, + ) + + self.policy.qf_state = update_carry["qf_state"] + qf_loss_value = update_carry["info"]["critic_loss"] + qf_mean_value = update_carry["info"]["qf_mean_value"] / gradient_steps + priorities = update_carry["info"]["priorities"] + + # Update priorities, they will be proportional to the td error + # Note: compared to the original implementation, we update + # the priorities after all the gradient steps + assert data.leaf_nodes_indices is not None + self.replay_buffer.update_priorities(data.leaf_nodes_indices, priorities) + + self._n_updates += gradient_steps + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/critic_loss", qf_loss_value.item()) + self.logger.record("train/qf_mean_value", qf_mean_value.item()) + self.logger.record("train/beta", self.beta) + self.logger.record("train/min_priority", priorities.min().item()) + self.logger.record("train/max_priority", priorities.max().item()) + self.logger.record("train/mean_priority", priorities.mean().item()) + + @staticmethod + @jax.jit + def update_qnetwork( + gamma: float, + qf_state: RLTrainState, + observations: np.ndarray, + replay_actions: np.ndarray, + next_observations: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + sampling_weights: np.ndarray, + ): + # Compute the next Q-values using the target network + qf_next_values = qf_state.apply_fn(qf_state.target_params, next_observations) + + # Follow greedy policy: use the one with the highest value + next_q_values = qf_next_values.max(axis=1) + # Avoid potential broadcast issue + next_q_values = next_q_values.reshape(-1, 1) + + # shape is (batch_size, 1) + target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + + # Special case when using PrioritizedReplayBuffer (PER) + def weighted_huber_loss(params): + # Get current Q-values estimates + current_q_values = qf_state.apply_fn(params, observations) + # Retrieve the q-values for the actions from the replay buffer + current_q_values = jnp.take_along_axis(current_q_values, replay_actions, axis=1) + # TD error in absolute value, to update priorities + priorities = jnp.abs(current_q_values - target_q_values) + # Weighted Huber loss using importance sampling weights + loss = (sampling_weights * optax.huber_loss(current_q_values, target_q_values)).mean() + return loss, (current_q_values.mean(), priorities.flatten()) + + (qf_loss_value, (qf_mean_value, priorities)), grads = jax.value_and_grad(weighted_huber_loss, has_aux=True)( + qf_state.params + ) + qf_state = qf_state.apply_gradients(grads=grads) + + return qf_state, (qf_loss_value, qf_mean_value, priorities) + + @staticmethod + @jax.jit + def _train(carry, indices): + data = carry["data"] + + qf_state, (qf_loss_value, qf_mean_value, priorities) = PERDQN.update_qnetwork( + carry["gamma"], + carry["qf_state"], + observations=data.observations[indices], + replay_actions=data.actions[indices], + next_observations=data.next_observations[indices], + rewards=data.rewards[indices], + dones=data.dones[indices], + sampling_weights=data.weights[indices], + ) + + carry["qf_state"] = qf_state + carry["info"]["critic_loss"] += qf_loss_value + carry["info"]["qf_mean_value"] += qf_mean_value + carry["info"]["priorities"] = carry["info"]["priorities"].at[indices].set(priorities) + + return carry, None diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index e3795cd..1618371 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (actor_loss_value, qf_loss_value, ent_coef_value), + (actor_loss_value, qf_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.tau, @@ -225,11 +225,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) + self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @jax.jit diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index f723c31..9594cad 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -216,7 +216,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_value), + (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss), ) = self._train( self.gamma, self.tau, @@ -232,11 +232,15 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.ent_coef_state, self.key, ) + ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params}) + self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/actor_loss", actor_loss_value.item()) self.logger.record("train/critic_loss", qf1_loss_value.item()) self.logger.record("train/ent_coef", ent_coef_value.item()) + if isinstance(self.ent_coef, EntropyCoef): + self.logger.record("train/ent_coef_loss", ent_coef_loss.item()) @staticmethod @partial(jax.jit, static_argnames=["n_target_quantiles"]) diff --git a/tests/test_run.py b/tests/test_run.py index 18d6dec..b7272c7 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,7 +8,7 @@ from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.evaluation import evaluate_policy -from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ, DroQ +from sbx import DDPG, DQN, PERDQN, PPO, SAC, TD3, TQC, CrossQ, DroQ def check_save_load(model, model_class, tmp_path): @@ -116,9 +116,9 @@ def test_dropout(model_class): model.learn(110) -@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, CrossQ]) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN, PERDQN, CrossQ]) def test_policy_kwargs(model_class) -> None: - env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1" + env_id = "CartPole-v1" if model_class in [DQN, PERDQN] else "Pendulum-v1" model = model_class( "MlpPolicy", @@ -147,8 +147,9 @@ def test_ppo(tmp_path, env_id: str) -> None: check_save_load(model, PPO, tmp_path) -def test_dqn(tmp_path) -> None: - model = DQN( +@pytest.mark.parametrize("model_class", [DQN, PERDQN]) +def test_dqn(tmp_path, model_class) -> None: + model = model_class( "MlpPolicy", "CartPole-v1", verbose=1, @@ -156,7 +157,7 @@ def test_dqn(tmp_path) -> None: target_update_interval=10, ) model.learn(128) - check_save_load(model, DQN, tmp_path) + check_save_load(model, model_class, tmp_path) @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer])