Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import optax
from gymnasium import spaces
from stable_baselines3 import HerReplayBuffer
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
Expand Down Expand Up @@ -35,6 +35,7 @@ def __init__(
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
policy_kwargs: Optional[dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
Expand Down Expand Up @@ -63,6 +64,8 @@ def __init__(
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
n_steps=n_steps,
action_noise=action_noise,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
Expand Down Expand Up @@ -137,6 +140,11 @@ def _setup_model(self) -> None:
if self.replay_buffer_class is None: # type: ignore[has-type]
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet."
elif self.n_steps > 1:
self.replay_buffer_class = NStepReplayBuffer
# Add required arguments for computing n-step returns
self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma})
else:
self.replay_buffer_class = ReplayBuffer

Expand Down
1 change: 1 addition & 0 deletions sbx/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ class ReplayBufferSamplesNp(NamedTuple):
next_observations: np.ndarray
dones: np.ndarray
rewards: np.ndarray
discounts: np.ndarray
32 changes: 20 additions & 12 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
n_steps: int = 1,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
use_sde: bool = False,
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
n_steps=n_steps,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
Expand Down Expand Up @@ -205,13 +207,20 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
obs = data.observations.numpy()
next_obs = data.next_observations.numpy()

if data.discounts is None:
discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32)
else:
# For bootstrapping with n-step returns
discounts = data.discounts.numpy().flatten()

# Convert to numpy
data = ReplayBufferSamplesNp( # type: ignore[assignment]
obs,
data.actions.numpy(),
next_obs,
data.dones.numpy().flatten(),
data.rewards.numpy().flatten(),
discounts,
)

(
Expand All @@ -221,7 +230,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.key,
(actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.gamma,
self.target_entropy,
gradient_steps,
data,
Expand All @@ -242,7 +250,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
@staticmethod
@jax.jit
def update_critic(
gamma: float,
actor_state: BatchNormTrainState,
qf_state: BatchNormTrainState,
ent_coef_state: TrainState,
Expand All @@ -251,6 +258,7 @@ def update_critic(
next_observations: jax.Array,
rewards: jax.Array,
dones: jax.Array,
discounts: jax.Array,
key: jax.Array,
):
key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4)
Expand Down Expand Up @@ -298,9 +306,9 @@ def mse_loss(
# Compute target q_values
next_q_values = jnp.min(qf_next_values, axis=0)
# td error + entropy term
next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1)
next_q_values = next_q_values - ent_coef_value * next_log_prob[:, None]
# shape is (batch_size, 1)
target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values
target_q_values = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_q_values

return 0.5 * ((jax.lax.stop_gradient(target_q_values) - current_q_values) ** 2).mean(axis=1).sum(), state_updates

Expand Down Expand Up @@ -399,7 +407,6 @@ def update_actor_and_temperature(
@partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"])
def _train(
cls,
gamma: float,
target_entropy: ArrayLike,
gradient_steps: int,
data: ReplayBufferSamplesNp,
Expand Down Expand Up @@ -435,24 +442,25 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]:
key = carry["key"]
info = carry["info"]
batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size)
batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size)
batch_actions = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size)
batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size)
batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size)
batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size)
batch_rewards = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size)
batch_dones = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size)
batch_discounts = jax.lax.dynamic_slice_in_dim(data.discounts, i * batch_size, batch_size)
(
qf_state,
(qf_loss_value, ent_coef_value),
key,
) = cls.update_critic(
gamma,
actor_state,
qf_state,
ent_coef_state,
batch_obs,
batch_act,
batch_actions,
batch_next_obs,
batch_rew,
batch_done,
batch_rewards,
batch_dones,
batch_discounts,
key,
)
# No target q values with CrossQ
Expand Down
2 changes: 2 additions & 0 deletions sbx/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
n_steps: int = 1,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
target_noise_clip=0.0,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
n_steps=n_steps,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
Expand Down
26 changes: 20 additions & 6 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import numpy as np
import optax
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import LinearSchedule

Expand Down Expand Up @@ -37,7 +38,10 @@ def __init__(
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
optimize_memory_usage: bool = False, # Note: unused but to match SB3 API
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
# max_grad_norm: float = 10,
train_freq: Union[int, tuple[int, str]] = 4,
gradient_steps: int = 1,
Expand All @@ -59,6 +63,10 @@ def __init__(
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
n_steps=n_steps,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
Expand Down Expand Up @@ -133,6 +141,12 @@ def learn(
def train(self, batch_size, gradient_steps):
# Sample all at once for efficiency (so we can jit the for loop)
data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env)

if data.discounts is None:
discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32)
else:
# For bootstrapping with n-step returns
discounts = data.discounts.numpy().flatten()
# Convert to numpy
data = ReplayBufferSamplesNp(
data.observations.numpy(),
Expand All @@ -141,14 +155,14 @@ def train(self, batch_size, gradient_steps):
data.next_observations.numpy(),
data.dones.numpy().flatten(),
data.rewards.numpy().flatten(),
discounts,
)
# Pre compute the slice indices
# otherwise jax will complain
indices = jnp.arange(len(data.dones)).reshape(gradient_steps, batch_size)

update_carry = {
"qf_state": self.policy.qf_state,
"gamma": self.gamma,
"data": data,
"indices": indices,
"info": {
Expand Down Expand Up @@ -178,24 +192,24 @@ def train(self, batch_size, gradient_steps):
@staticmethod
@jax.jit
def update_qnetwork(
gamma: float,
qf_state: RLTrainState,
observations: np.ndarray,
replay_actions: np.ndarray,
next_observations: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
discounts: np.ndarray,
):
# Compute the next Q-values using the target network
qf_next_values = qf_state.apply_fn(qf_state.target_params, next_observations)

# Follow greedy policy: use the one with the highest value
next_q_values = qf_next_values.max(axis=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
next_q_values = next_q_values[:, None]

# shape is (batch_size, 1)
target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values
target_q_values = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_q_values

def huber_loss(params):
# Get current Q-values estimates
Expand Down Expand Up @@ -264,13 +278,13 @@ def _train(carry, indices):
data = carry["data"]

qf_state, (qf_loss_value, qf_mean_value) = DQN.update_qnetwork(
carry["gamma"],
carry["qf_state"],
observations=data.observations[indices],
replay_actions=data.actions[indices],
next_observations=data.next_observations[indices],
rewards=data.rewards[indices],
dones=data.dones[indices],
discounts=data.discounts[indices],
)

carry["qf_state"] = qf_state
Expand Down
8 changes: 4 additions & 4 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import numpy as np
import optax
import tensorflow_probability.substrates.jax as tfp
from flax.linen.initializers import constant
from flax.training.train_state import TrainState
from gymnasium import spaces
from jax.nn.initializers import constant
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.policies import BaseJaxPolicy, Flatten
Expand Down Expand Up @@ -174,7 +174,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
obs = jnp.array([self.observation_space.sample()])

if isinstance(self.action_space, spaces.Box):
actor_kwargs = {
actor_kwargs: dict[str, Any] = {
"action_dim": int(np.prod(self.action_space.shape)),
}
elif isinstance(self.action_space, spaces.Discrete):
Expand All @@ -184,7 +184,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
}
elif isinstance(self.action_space, spaces.MultiDiscrete):
assert self.action_space.nvec.ndim == 1, (
f"Only one-dimensional MultiDiscrete action spaces are supported, "
"Only one-dimensional MultiDiscrete action spaces are supported, "
f"but found MultiDiscrete({(self.action_space.nvec).tolist()})."
)
actor_kwargs = {
Expand Down Expand Up @@ -232,7 +232,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->

self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
params=self.vf.init({"params": vf_key}, obs),
params=self.vf.init(vf_key, obs),
tx=optax.chain(
optax.clip_by_global_norm(max_grad_norm),
optimizer_class,
Expand Down
42 changes: 22 additions & 20 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,6 @@ def _setup_model(self) -> None:

self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm)

self.key, ent_key = jax.random.split(self.key, 2)

self.actor = self.policy.actor # type: ignore[assignment]
self.vf = self.policy.vf # type: ignore[assignment]

Expand Down Expand Up @@ -236,21 +234,23 @@ def actor_loss(params):
entropy_loss = -jnp.mean(entropy)

total_policy_loss = policy_loss + ent_coef * entropy_loss
return total_policy_loss, ratio
return total_policy_loss, (ratio, policy_loss, entropy_loss)

(pg_loss_value, ratio), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params)
(pg_loss_value, (ratio, policy_loss, entropy_loss)), grads = jax.value_and_grad(actor_loss, has_aux=True)(
actor_state.params
)
actor_state = actor_state.apply_gradients(grads=grads)

def critic_loss(params):
# Value loss using the TD(gae_lambda) target
vf_values = vf_state.apply_fn(params, observations).flatten()
return ((returns - vf_values) ** 2).mean()
return vf_coef * ((returns - vf_values) ** 2).mean()

vf_loss_value, grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params)
vf_state = vf_state.apply_gradients(grads=grads)

# loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss
return (actor_state, vf_state), (pg_loss_value, vf_loss_value, ratio)
return (actor_state, vf_state), (pg_loss_value, policy_loss, entropy_loss, vf_loss_value, ratio)

def train(self) -> None:
"""
Expand Down Expand Up @@ -279,18 +279,20 @@ def train(self) -> None:
else:
actions = rollout_data.actions.numpy()

(self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss, ratio) = self._one_update(
actor_state=self.policy.actor_state,
vf_state=self.policy.vf_state,
observations=rollout_data.observations.numpy(),
actions=actions,
advantages=rollout_data.advantages.numpy(),
returns=rollout_data.returns.numpy(),
old_log_prob=rollout_data.old_log_prob.numpy(),
clip_range=clip_range,
ent_coef=self.ent_coef,
vf_coef=self.vf_coef,
normalize_advantage=self.normalize_advantage,
(self.policy.actor_state, self.policy.vf_state), (pg_loss, policy_loss, entropy_loss, value_loss, ratio) = (
self._one_update(
actor_state=self.policy.actor_state,
vf_state=self.policy.vf_state,
observations=rollout_data.observations.numpy(),
actions=actions,
advantages=rollout_data.advantages.numpy(),
returns=rollout_data.returns.numpy(),
old_log_prob=rollout_data.old_log_prob.numpy(),
clip_range=clip_range,
ent_coef=self.ent_coef,
vf_coef=self.vf_coef,
normalize_advantage=self.normalize_advantage,
)
)

# Calculate approximate form of reverse KL Divergence for adaptive lr
Expand Down Expand Up @@ -319,9 +321,9 @@ def train(self) -> None:
)

# Logs
# self.logger.record("train/entropy_loss", np.mean(entropy_losses))
# self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
# TODO: use mean instead of one point
self.logger.record("train/entropy_loss", entropy_loss.item())
self.logger.record("train/policy_gradient_loss", policy_loss.item())
self.logger.record("train/value_loss", value_loss.item())
self.logger.record("train/approx_kl", mean_kl_div)
self.logger.record("train/clip_fraction", mean_clip_fraction)
Expand Down
Loading