Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ lint:
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
Expand Down
2 changes: 2 additions & 0 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sbx.crossq import CrossQ
from sbx.ddpg import DDPG
from sbx.dqn import DQN
from sbx.per_dqn import PERDQN
from sbx.ppo import PPO
from sbx.sac import SAC
from sbx.td3 import TD3
Expand All @@ -26,6 +27,7 @@ def DroQ(*args, **kwargs):
"CrossQ",
"DDPG",
"DQN",
"PERDQN",
"PPO",
"SAC",
"TD3",
Expand Down
395 changes: 395 additions & 0 deletions sbx/common/prioritized_replay_buffer.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion sbx/common/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import NamedTuple, Optional, Union

import flax
import numpy as np
Expand All @@ -19,3 +19,5 @@ class ReplayBufferSamplesNp(NamedTuple):
next_observations: np.ndarray
dones: np.ndarray
rewards: np.ndarray
weights: Union[np.ndarray, float] = 1.0
leaf_nodes_indices: Optional[np.ndarray] = None
5 changes: 4 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(actor_loss_value, qf_loss_value, ent_coef_value),
(actor_loss_value, qf_loss_value, ent_coef_loss),
) = self._train(
self.gamma,
self.target_entropy,
Expand All @@ -224,11 +224,14 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.ent_coef_state,
self.key,
)
ent_coef_value = self.ent_coef_state.apply_fn({"params": self.ent_coef_state.params})
self._n_updates += gradient_steps
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/actor_loss", actor_loss_value.item())
self.logger.record("train/critic_loss", qf_loss_value.item())
self.logger.record("train/ent_coef", ent_coef_value.item())
if isinstance(self.ent_coef, EntropyCoef):
self.logger.record("train/ent_coef_loss", ent_coef_loss.item())

@staticmethod
@jax.jit
Expand Down
14 changes: 11 additions & 3 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import numpy as np
import optax
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn

Expand Down Expand Up @@ -41,6 +42,8 @@ def __init__(
# max_grad_norm: float = 10,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
Expand All @@ -59,6 +62,8 @@ def __init__(
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
Expand Down Expand Up @@ -130,11 +135,12 @@ def learn(
progress_bar=progress_bar,
)

def train(self, batch_size, gradient_steps):
def train(self, gradient_steps: int, batch_size: int) -> None:
assert self.replay_buffer is not None
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)
# Convert to numpy
data = ReplayBufferSamplesNp(
data = ReplayBufferSamplesNp( # type: ignore[assignment]
data.observations.numpy(),
# Convert to int64
data.actions.long().numpy(),
Expand Down Expand Up @@ -222,7 +228,9 @@ def _on_step(self) -> None:
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
# Account for multiple environments
# each call to step() corresponds to n_envs transitions
if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
self.policy.qf_state = DQN.soft_update(self.tau, self.policy.qf_state)

self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
Expand Down
20 changes: 14 additions & 6 deletions sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
max_grad_norm: float = 10.0,
):
super().__init__(
observation_space,
Expand All @@ -85,6 +86,7 @@ def __init__(
else:
self.n_units = 256
self.activation_fn = activation_fn
self.max_grad_norm = max_grad_norm

def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)
Expand All @@ -101,9 +103,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
apply_fn=self.qf.apply,
params=self.qf.init({"params": qf_key}, obs),
target_params=self.qf.init({"params": qf_key}, obs),
tx=self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
tx=optax.chain(
optax.clip_by_global_norm(self.max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
),
),
)

Expand Down Expand Up @@ -140,9 +145,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
apply_fn=self.qf.apply,
params=self.qf.init({"params": qf_key}, obs),
target_params=self.qf.init({"params": qf_key}, obs),
tx=self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
tx=optax.chain(
optax.clip_by_global_norm(self.max_grad_norm),
self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
),
),
)
self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign]
Expand Down
3 changes: 3 additions & 0 deletions sbx/per_dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbx.per_dqn.per_dqn import PERDQN

__all__ = ["PERDQN"]
Loading