diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index ae205a3..74b6c64 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -7,7 +7,7 @@ import optax from gymnasium import spaces from stable_baselines3 import HerReplayBuffer -from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer +from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy @@ -35,6 +35,7 @@ def __init__( replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, + n_steps: int = 1, policy_kwargs: Optional[dict[str, Any]] = None, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -63,6 +64,8 @@ def __init__( gradient_steps=gradient_steps, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, + optimize_memory_usage=optimize_memory_usage, + n_steps=n_steps, action_noise=action_noise, use_sde=use_sde, sde_sample_freq=sde_sample_freq, @@ -137,6 +140,11 @@ def _setup_model(self) -> None: if self.replay_buffer_class is None: # type: ignore[has-type] if isinstance(self.observation_space, spaces.Dict): self.replay_buffer_class = DictReplayBuffer + assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet." + elif self.n_steps > 1: + self.replay_buffer_class = NStepReplayBuffer + # Add required arguments for computing n-step returns + self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma}) else: self.replay_buffer_class = ReplayBuffer diff --git a/sbx/common/type_aliases.py b/sbx/common/type_aliases.py index f014a1f..c8cb4e3 100644 --- a/sbx/common/type_aliases.py +++ b/sbx/common/type_aliases.py @@ -19,3 +19,4 @@ class ReplayBufferSamplesNp(NamedTuple): next_observations: np.ndarray dones: np.ndarray rewards: np.ndarray + discounts: np.ndarray diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index d099faa..78a27e5 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -66,6 +66,7 @@ def __init__( action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, + n_steps: int = 1, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, @@ -94,6 +95,7 @@ def __init__( action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, + n_steps=n_steps, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, @@ -205,6 +207,12 @@ def train(self, gradient_steps: int, batch_size: int) -> None: obs = data.observations.numpy() next_obs = data.next_observations.numpy() + if data.discounts is None: + discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32) + else: + # For bootstrapping with n-step returns + discounts = data.discounts.numpy().flatten() + # Convert to numpy data = ReplayBufferSamplesNp( # type: ignore[assignment] obs, @@ -212,6 +220,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: next_obs, data.dones.numpy().flatten(), data.rewards.numpy().flatten(), + discounts, ) ( @@ -221,7 +230,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, (actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value), ) = self._train( - self.gamma, self.target_entropy, gradient_steps, data, @@ -242,7 +250,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: @staticmethod @jax.jit def update_critic( - gamma: float, actor_state: BatchNormTrainState, qf_state: BatchNormTrainState, ent_coef_state: TrainState, @@ -251,6 +258,7 @@ def update_critic( next_observations: jax.Array, rewards: jax.Array, dones: jax.Array, + discounts: jax.Array, key: jax.Array, ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) @@ -298,9 +306,9 @@ def mse_loss( # Compute target q_values next_q_values = jnp.min(qf_next_values, axis=0) # td error + entropy term - next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + next_q_values = next_q_values - ent_coef_value * next_log_prob[:, None] # shape is (batch_size, 1) - target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + target_q_values = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_q_values return 0.5 * ((jax.lax.stop_gradient(target_q_values) - current_q_values) ** 2).mean(axis=1).sum(), state_updates @@ -399,7 +407,6 @@ def update_actor_and_temperature( @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"]) def _train( cls, - gamma: float, target_entropy: ArrayLike, gradient_steps: int, data: ReplayBufferSamplesNp, @@ -435,24 +442,25 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: key = carry["key"] info = carry["info"] batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) - batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_actions = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) - batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) - batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_rewards = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_dones = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_discounts = jax.lax.dynamic_slice_in_dim(data.discounts, i * batch_size, batch_size) ( qf_state, (qf_loss_value, ent_coef_value), key, ) = cls.update_critic( - gamma, actor_state, qf_state, ent_coef_state, batch_obs, - batch_act, + batch_actions, batch_next_obs, - batch_rew, - batch_done, + batch_rewards, + batch_dones, + batch_discounts, key, ) # No target q values with CrossQ diff --git a/sbx/ddpg/ddpg.py b/sbx/ddpg/ddpg.py index 7ee5728..2f32d73 100644 --- a/sbx/ddpg/ddpg.py +++ b/sbx/ddpg/ddpg.py @@ -29,6 +29,7 @@ def __init__( action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, + n_steps: int = 1, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, @@ -56,6 +57,7 @@ def __init__( target_noise_clip=0.0, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, + n_steps=n_steps, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 8d4338e..9a72dcb 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import numpy as np import optax +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import LinearSchedule @@ -37,7 +38,10 @@ def __init__( exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, - optimize_memory_usage: bool = False, # Note: unused but to match SB3 API + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, + optimize_memory_usage: bool = False, + n_steps: int = 1, # max_grad_norm: float = 10, train_freq: Union[int, tuple[int, str]] = 4, gradient_steps: int = 1, @@ -59,6 +63,10 @@ def __init__( gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + optimize_memory_usage=optimize_memory_usage, + n_steps=n_steps, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, @@ -133,6 +141,12 @@ def learn( def train(self, batch_size, gradient_steps): # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + + if data.discounts is None: + discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32) + else: + # For bootstrapping with n-step returns + discounts = data.discounts.numpy().flatten() # Convert to numpy data = ReplayBufferSamplesNp( data.observations.numpy(), @@ -141,6 +155,7 @@ def train(self, batch_size, gradient_steps): data.next_observations.numpy(), data.dones.numpy().flatten(), data.rewards.numpy().flatten(), + discounts, ) # Pre compute the slice indices # otherwise jax will complain @@ -148,7 +163,6 @@ def train(self, batch_size, gradient_steps): update_carry = { "qf_state": self.policy.qf_state, - "gamma": self.gamma, "data": data, "indices": indices, "info": { @@ -178,13 +192,13 @@ def train(self, batch_size, gradient_steps): @staticmethod @jax.jit def update_qnetwork( - gamma: float, qf_state: RLTrainState, observations: np.ndarray, replay_actions: np.ndarray, next_observations: np.ndarray, rewards: np.ndarray, dones: np.ndarray, + discounts: np.ndarray, ): # Compute the next Q-values using the target network qf_next_values = qf_state.apply_fn(qf_state.target_params, next_observations) @@ -192,10 +206,10 @@ def update_qnetwork( # Follow greedy policy: use the one with the highest value next_q_values = qf_next_values.max(axis=1) # Avoid potential broadcast issue - next_q_values = next_q_values.reshape(-1, 1) + next_q_values = next_q_values[:, None] # shape is (batch_size, 1) - target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + target_q_values = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_q_values def huber_loss(params): # Get current Q-values estimates @@ -264,13 +278,13 @@ def _train(carry, indices): data = carry["data"] qf_state, (qf_loss_value, qf_mean_value) = DQN.update_qnetwork( - carry["gamma"], carry["qf_state"], observations=data.observations[indices], replay_actions=data.actions[indices], next_observations=data.next_observations[indices], rewards=data.rewards[indices], dones=data.dones[indices], + discounts=data.discounts[indices], ) carry["qf_state"] = qf_state diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 4cb73b8..f4a80e4 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -9,9 +9,9 @@ import numpy as np import optax import tensorflow_probability.substrates.jax as tfp -from flax.linen.initializers import constant from flax.training.train_state import TrainState from gymnasium import spaces +from jax.nn.initializers import constant from stable_baselines3.common.type_aliases import Schedule from sbx.common.policies import BaseJaxPolicy, Flatten @@ -174,7 +174,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> obs = jnp.array([self.observation_space.sample()]) if isinstance(self.action_space, spaces.Box): - actor_kwargs = { + actor_kwargs: dict[str, Any] = { "action_dim": int(np.prod(self.action_space.shape)), } elif isinstance(self.action_space, spaces.Discrete): @@ -184,7 +184,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> } elif isinstance(self.action_space, spaces.MultiDiscrete): assert self.action_space.nvec.ndim == 1, ( - f"Only one-dimensional MultiDiscrete action spaces are supported, " + "Only one-dimensional MultiDiscrete action spaces are supported, " f"but found MultiDiscrete({(self.action_space.nvec).tolist()})." ) actor_kwargs = { @@ -232,7 +232,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> self.vf_state = TrainState.create( apply_fn=self.vf.apply, - params=self.vf.init({"params": vf_key}, obs), + params=self.vf.init(vf_key, obs), tx=optax.chain( optax.clip_by_global_norm(max_grad_norm), optimizer_class, diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index b9d5f32..46eb5d7 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -184,8 +184,6 @@ def _setup_model(self) -> None: self.key = self.policy.build(self.key, self.lr_schedule, self.max_grad_norm) - self.key, ent_key = jax.random.split(self.key, 2) - self.actor = self.policy.actor # type: ignore[assignment] self.vf = self.policy.vf # type: ignore[assignment] @@ -236,21 +234,23 @@ def actor_loss(params): entropy_loss = -jnp.mean(entropy) total_policy_loss = policy_loss + ent_coef * entropy_loss - return total_policy_loss, ratio + return total_policy_loss, (ratio, policy_loss, entropy_loss) - (pg_loss_value, ratio), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) + (pg_loss_value, (ratio, policy_loss, entropy_loss)), grads = jax.value_and_grad(actor_loss, has_aux=True)( + actor_state.params + ) actor_state = actor_state.apply_gradients(grads=grads) def critic_loss(params): # Value loss using the TD(gae_lambda) target vf_values = vf_state.apply_fn(params, observations).flatten() - return ((returns - vf_values) ** 2).mean() + return vf_coef * ((returns - vf_values) ** 2).mean() vf_loss_value, grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) vf_state = vf_state.apply_gradients(grads=grads) # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss - return (actor_state, vf_state), (pg_loss_value, vf_loss_value, ratio) + return (actor_state, vf_state), (pg_loss_value, policy_loss, entropy_loss, vf_loss_value, ratio) def train(self) -> None: """ @@ -279,18 +279,20 @@ def train(self) -> None: else: actions = rollout_data.actions.numpy() - (self.policy.actor_state, self.policy.vf_state), (pg_loss, value_loss, ratio) = self._one_update( - actor_state=self.policy.actor_state, - vf_state=self.policy.vf_state, - observations=rollout_data.observations.numpy(), - actions=actions, - advantages=rollout_data.advantages.numpy(), - returns=rollout_data.returns.numpy(), - old_log_prob=rollout_data.old_log_prob.numpy(), - clip_range=clip_range, - ent_coef=self.ent_coef, - vf_coef=self.vf_coef, - normalize_advantage=self.normalize_advantage, + (self.policy.actor_state, self.policy.vf_state), (pg_loss, policy_loss, entropy_loss, value_loss, ratio) = ( + self._one_update( + actor_state=self.policy.actor_state, + vf_state=self.policy.vf_state, + observations=rollout_data.observations.numpy(), + actions=actions, + advantages=rollout_data.advantages.numpy(), + returns=rollout_data.returns.numpy(), + old_log_prob=rollout_data.old_log_prob.numpy(), + clip_range=clip_range, + ent_coef=self.ent_coef, + vf_coef=self.vf_coef, + normalize_advantage=self.normalize_advantage, + ) ) # Calculate approximate form of reverse KL Divergence for adaptive lr @@ -319,9 +321,9 @@ def train(self) -> None: ) # Logs - # self.logger.record("train/entropy_loss", np.mean(entropy_losses)) - # self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) # TODO: use mean instead of one point + self.logger.record("train/entropy_loss", entropy_loss.item()) + self.logger.record("train/policy_gradient_loss", policy_loss.item()) self.logger.record("train/value_loss", value_loss.item()) self.logger.record("train/approx_kl", mean_kl_div) self.logger.record("train/clip_fraction", mean_clip_fraction) diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 73ff2f8..1379c0d 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -68,6 +68,7 @@ def __init__( action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, + n_steps: int = 1, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, @@ -97,6 +98,7 @@ def __init__( action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, + n_steps=n_steps, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, @@ -218,6 +220,12 @@ def train(self, gradient_steps: int, batch_size: int) -> None: obs = data.observations.numpy() next_obs = data.next_observations.numpy() + if data.discounts is None: + discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32) + else: + # For bootstrapping with n-step returns + discounts = data.discounts.numpy().flatten() + # Convert to numpy data = ReplayBufferSamplesNp( # type: ignore[assignment] obs, @@ -225,6 +233,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: next_obs, data.dones.numpy().flatten(), data.rewards.numpy().flatten(), + discounts, ) ( @@ -234,7 +243,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, (actor_loss_value, qf_loss_value, ent_coef_loss_value, ent_coef_value), ) = self._train( - self.gamma, self.tau, self.target_entropy, gradient_steps, @@ -256,7 +264,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: @staticmethod @jax.jit def update_critic( - gamma: float, actor_state: TrainState, qf_state: RLTrainState, ent_coef_state: TrainState, @@ -265,6 +272,7 @@ def update_critic( next_observations: jax.Array, rewards: jax.Array, dones: jax.Array, + discounts: jax.Array, key: jax.Array, ): key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) @@ -284,9 +292,9 @@ def update_critic( next_q_values = jnp.min(qf_next_values, axis=0) # td error + entropy term - next_q_values = next_q_values - ent_coef_value * next_log_prob.reshape(-1, 1) + next_q_values = next_q_values - ent_coef_value * next_log_prob[:, None] # shape is (batch_size, 1) - target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + target_q_values = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_q_values def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) @@ -380,7 +388,6 @@ def update_actor_and_temperature( @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"]) def _train( cls, - gamma: float, tau: float, target_entropy: ArrayLike, gradient_steps: int, @@ -417,24 +424,25 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: key = carry["key"] info = carry["info"] batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) - batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_actions = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) - batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) - batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_rewards = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_dones = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_discounts = jax.lax.dynamic_slice_in_dim(data.discounts, i * batch_size, batch_size) ( qf_state, (qf_loss_value, ent_coef_value), key, ) = cls.update_critic( - gamma, actor_state, qf_state, ent_coef_state, batch_obs, - batch_act, + batch_actions, batch_next_obs, - batch_rew, - batch_done, + batch_rewards, + batch_dones, + batch_discounts, key, ) qf_state = cls.soft_update(tau, qf_state) diff --git a/sbx/td3/td3.py b/sbx/td3/td3.py index 304952a..b3e8e04 100644 --- a/sbx/td3/td3.py +++ b/sbx/td3/td3.py @@ -45,6 +45,7 @@ def __init__( action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, + n_steps: int = 1, tensorboard_log: Optional[str] = None, stats_window_size: int = 100, policy_kwargs: Optional[dict[str, Any]] = None, @@ -69,6 +70,7 @@ def __init__( action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, + n_steps=n_steps, use_sde=False, stats_window_size=stats_window_size, policy_kwargs=policy_kwargs, @@ -139,6 +141,12 @@ def train(self, gradient_steps: int, batch_size: int) -> None: obs = data.observations.numpy() next_obs = data.next_observations.numpy() + if data.discounts is None: + discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32) + else: + # For bootstrapping with n-step returns + discounts = data.discounts.numpy().flatten() + # Convert to numpy data = ReplayBufferSamplesNp( # type: ignore[assignment] obs, @@ -146,6 +154,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: next_obs, data.dones.numpy().flatten(), data.rewards.numpy().flatten(), + discounts, ) ( @@ -154,7 +163,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, (actor_loss_value, qf_loss_value), ) = self._train( - self.gamma, self.tau, gradient_steps, data, @@ -174,7 +182,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: @staticmethod @jax.jit def update_critic( - gamma: float, actor_state: RLTrainState, qf_state: RLTrainState, observations: jax.Array, @@ -182,6 +189,7 @@ def update_critic( next_observations: jax.Array, rewards: jax.Array, dones: jax.Array, + discounts: jax.Array, target_policy_noise: float, target_noise_clip: float, key: jax.Array, @@ -203,7 +211,7 @@ def update_critic( next_q_values = jnp.min(qf_next_values, axis=0) # shape is (batch_size, 1) - target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values + target_q_values = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_q_values def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) @@ -261,7 +269,6 @@ def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState) - @partial(jax.jit, static_argnames=["cls", "gradient_steps", "policy_delay", "policy_delay_offset"]) def _train( cls, - gamma: float, tau: float, gradient_steps: int, data: ReplayBufferSamplesNp, @@ -294,23 +301,24 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: key = carry["key"] info = carry["info"] batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) - batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_actions = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) - batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) - batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_rewards = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_dones = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_discounts = jax.lax.dynamic_slice_in_dim(data.discounts, i * batch_size, batch_size) ( qf_state, qf_loss_value, key, ) = cls.update_critic( - gamma, actor_state, qf_state, batch_obs, - batch_act, + batch_actions, batch_next_obs, - batch_rew, - batch_done, + batch_rewards, + batch_dones, + batch_discounts, target_policy_noise, target_noise_clip, key, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 9593535..7ec382c 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -68,6 +68,7 @@ def __init__( action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, + n_steps: int = 1, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, @@ -97,6 +98,7 @@ def __init__( action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, + n_steps=n_steps, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, @@ -221,6 +223,12 @@ def train(self, gradient_steps: int, batch_size: int) -> None: obs = data.observations.numpy() next_obs = data.next_observations.numpy() + if data.discounts is None: + discounts = np.full((batch_size * gradient_steps,), self.gamma, dtype=np.float32) + else: + # For bootstrapping with n-step returns + discounts = data.discounts.numpy().flatten() + # Convert to numpy data = ReplayBufferSamplesNp( # type: ignore[assignment] obs, @@ -228,6 +236,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: next_obs, data.dones.numpy().flatten(), data.rewards.numpy().flatten(), + discounts, ) ( self.policy.qf1_state, @@ -237,7 +246,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.key, (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value), ) = self._train( - self.gamma, self.tau, self.target_entropy, gradient_steps, @@ -261,7 +269,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: @staticmethod @partial(jax.jit, static_argnames=["n_target_quantiles"]) def update_critic( - gamma: float, n_target_quantiles: int, actor_state: TrainState, qf1_state: RLTrainState, @@ -272,6 +279,7 @@ def update_critic( next_observations: jax.Array, rewards: jax.Array, dones: jax.Array, + discounts: jax.Array, key: jax.Array, ): key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) @@ -306,8 +314,8 @@ def update_critic( next_target_quantiles = next_quantiles[:, :n_target_quantiles] # td error + entropy term - next_target_quantiles = next_target_quantiles - ent_coef_value * next_log_prob.reshape(-1, 1) - target_quantiles = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_target_quantiles + next_target_quantiles = next_target_quantiles - ent_coef_value * next_log_prob[:, None] + target_quantiles = rewards[:, None] + (1 - dones[:, None]) * discounts[:, None] * next_target_quantiles # Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles). target_quantiles = jnp.expand_dims(target_quantiles, axis=1) @@ -439,7 +447,6 @@ def update_actor_and_temperature( ) def _train( cls, - gamma: float, tau: float, target_entropy: ArrayLike, gradient_steps: int, @@ -481,26 +488,27 @@ def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: key = carry["key"] info = carry["info"] batch_obs = jax.lax.dynamic_slice_in_dim(data.observations, i * batch_size, batch_size) - batch_act = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) + batch_actions = jax.lax.dynamic_slice_in_dim(data.actions, i * batch_size, batch_size) batch_next_obs = jax.lax.dynamic_slice_in_dim(data.next_observations, i * batch_size, batch_size) - batch_rew = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) - batch_done = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_rewards = jax.lax.dynamic_slice_in_dim(data.rewards, i * batch_size, batch_size) + batch_dones = jax.lax.dynamic_slice_in_dim(data.dones, i * batch_size, batch_size) + batch_discounts = jax.lax.dynamic_slice_in_dim(data.discounts, i * batch_size, batch_size) ( (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value, ent_coef_value), key, ) = cls.update_critic( - gamma, n_target_quantiles, actor_state, qf1_state, qf2_state, ent_coef_state, batch_obs, - batch_act, + batch_actions, batch_next_obs, - batch_rew, - batch_done, + batch_rewards, + batch_dones, + batch_discounts, key, ) qf1_state, qf2_state = cls.soft_update(tau, qf1_state, qf2_state) diff --git a/sbx/version.txt b/sbx/version.txt index 8854156..2157409 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.21.0 +0.22.0 diff --git a/setup.py b/setup.py index a8655da..3655279 100644 --- a/setup.py +++ b/setup.py @@ -41,8 +41,8 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.6.1a1,<3.0", - "jax>=0.4.24", + "stable_baselines3>=2.7.0a1,<3.0", + "jax>=0.4.24,<0.7.0", # tf probability not compatible yet with latest jax version "jaxlib", "flax", "optax", @@ -62,7 +62,7 @@ # Lint code "ruff>=0.3.1", # Reformat - "black>=24.2.0,<25", + "black>=25.1.0,<26", ], }, description="Jax version of Stable Baselines, implementations of reinforcement learning algorithms.", diff --git a/tests/test_run.py b/tests/test_run.py index c6e8ca1..7c6ab59 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -97,6 +97,7 @@ def test_sac_td3(tmp_path, model_class) -> None: gradient_steps=1, learning_rate=1e-3, policy_kwargs=net_kwargs, + n_steps=3, ) key_before_learn = model.key model.learn(110)