diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index 01880aa..a77c816 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -225,3 +225,26 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = self.activation_fn(x) x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x) return residual + x + + +# CNN policy from DQN paper +class NatureCNN(nn.Module): + n_features: int = 512 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # Convert from channel-first (PyTorch) to channel-last (Jax) + x = jnp.transpose(x, (0, 2, 3, 1)) + # Convert to float and normalize the image + x = x.astype(jnp.float32) / 255.0 + x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) + x = self.activation_fn(x) + # Flatten + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(self.n_features)(x) + return self.activation_fn(x) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index d0c7d97..13c4843 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -158,7 +158,7 @@ def collect_rollouts( # Always sample new stochastic action self.policy.reset_noise() - obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type] + obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type, arg-type] actions, log_probs, values = self.policy.predict_all(obs_tensor, self.policy.noise_key) actions = np.array(actions) diff --git a/sbx/common/utils.py b/sbx/common/utils.py index 447c4f3..afea621 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,6 +1,11 @@ from dataclasses import dataclass +from typing import Union +import jax +import jax.numpy as jnp import numpy as np +from flax.core import FrozenDict +from flax.training.train_state import TrainState @dataclass @@ -24,3 +29,63 @@ def update(self, kl_div: float) -> None: self.current_adaptive_lr *= self.adaptive_lr_factor self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) + + +def mask_from_prefix(params: FrozenDict, prefix: str = "NatureCNN_") -> dict: + """ + Build a pytree mask (same structure as `params`) where a leaf is True + if the top-level module name starts with `prefix`. + """ + + def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool]: + if isinstance(tree, dict): + return {key: _traverse(value, (*path, key)) for key, value in tree.items()} + # leaf + return path[1].startswith(prefix) if len(path) > 1 else False + + return _traverse(params) # type: ignore[return-value] + + +def align_params(params1: FrozenDict, params2: FrozenDict) -> dict: + """ + Return a dict with the *exact* structure of `params2`. For every leaf in `params2`, + use the corresponding leaf from `params1` if it exists; otherwise use `params2`'s leaf. + This guarantees the two dict have identical structure for tree_map. + """ + if isinstance(params2, dict): + out = {} + for key, params2_sub in params2.items(): + params1_sub = params1[key] if (isinstance(params1, dict) and key in params1) else None + out[key] = align_params(params1_sub, params2_sub) # type: ignore[arg-type] + return out + # leaf-case: if params1 value exists (not None and same shape) use it, else use params2 leaf + return params1 if (params1 is not None and params1.shape == params2.shape) else params2 # type: ignore[attr-defined, return-value] + + +@jax.jit +def masked_copy(params1: FrozenDict, params2: FrozenDict, mask: dict) -> FrozenDict: + """ + Leafwise selection: wherever mask is True we take params1 value, + otherwise params2 value. + """ + return jax.tree_util.tree_map( + lambda val1, val2, mask_value: jnp.where(mask_value, val1, val2), + params1, + params2, + mask, + ) + + +@jax.jit +def copy_naturecnn_params(state1: TrainState, state2: TrainState) -> TrainState: + """ + Copy all top-level modules whose names start with "NatureCNN_" from + state1.params into state2.params. + It is useful when sharing features extractor parameters between actor and critic. + """ + # Ensure same structure + aligned_params = align_params(state1.params, state2.params) + mask = mask_from_prefix(state2.params, prefix="NatureCNN_") + new_params = masked_copy(aligned_params, state2.params, mask) + + return state2.replace(params=new_params) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 78a27e5..7c2b667 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -261,7 +261,7 @@ def update_critic( discounts: jax.Array, key: jax.Array, ): - key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + key, noise_key, dropout_key_current = jax.random.split(key, 3) # sample action from the actor dist = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index c03be4d..8d25383 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -8,6 +8,7 @@ from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule +from sbx.common.jax_layers import NatureCNN from sbx.common.policies import BaseJaxPolicy, Flatten from sbx.common.type_aliases import RLTrainState @@ -28,28 +29,14 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return x -# Add CNN policy from DQN paper -class NatureCNN(nn.Module): +class CnnQNetwork(nn.Module): n_actions: int n_units: int = 512 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - # Convert from channel-first (PyTorch) to channel-last (Jax) - x = jnp.transpose(x, (0, 2, 3, 1)) - # Convert to float and normalize the image - x = x.astype(jnp.float32) / 255.0 - x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) - x = self.activation_fn(x) - x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) - x = self.activation_fn(x) - x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) - x = self.activation_fn(x) - # Flatten - x = x.reshape((x.shape[0], -1)) - x = nn.Dense(self.n_units)(x) - x = self.activation_fn(x) + x = NatureCNN(self.n_units, self.activation_fn)(x) x = nn.Dense(self.n_actions)(x) return x @@ -130,7 +117,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: obs = jnp.array([self.observation_space.sample()]) - self.qf = NatureCNN( + self.qf = CnnQNetwork( n_actions=int(self.action_space.n), n_units=self.n_units, activation_fn=self.activation_fn, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index f4a80e4..7ffa7c6 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -14,6 +14,7 @@ from jax.nn.initializers import constant from stable_baselines3.common.type_aliases import Schedule +from sbx.common.jax_layers import NatureCNN from sbx.common.policies import BaseJaxPolicy, Flatten tfd = tfp.distributions @@ -22,9 +23,14 @@ class Critic(nn.Module): net_arch: Sequence[int] activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh + features_extractor: Optional[type[NatureCNN]] = None + features_dim: int = 512 @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + if self.features_extractor is not None: + x = self.features_extractor(self.features_dim, self.activation_fn)(x) + x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) @@ -46,6 +52,8 @@ class Actor(nn.Module): split_indices: np.ndarray = field(default_factory=lambda: np.array([])) # Last layer with small scale ortho_init: bool = False + features_extractor: Optional[type[NatureCNN]] = None + features_dim: int = 512 def get_std(self) -> jnp.ndarray: # Make it work with gSDE @@ -61,6 +69,9 @@ def __post_init__(self) -> None: @nn.compact def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + if self.features_extractor is not None: + x = self.features_extractor(self.features_dim, self.activation_fn)(x) + x = Flatten()(x) for n_units in self.net_arch: @@ -122,7 +133,7 @@ def __init__( # this is to keep API consistent with SB3 use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class=None, + features_extractor_class: Optional[type[NatureCNN]] = None, features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, @@ -156,7 +167,12 @@ def __init__( self.net_arch_pi = net_arch["pi"] self.net_arch_vf = net_arch["vf"] else: - self.net_arch_pi = self.net_arch_vf = [64, 64] + if features_extractor_class == NatureCNN: + # Just a linear layer after the CNN + net_arch = [] + else: + net_arch = [64, 64] + self.net_arch_pi = self.net_arch_vf = net_arch self.use_sde = use_sde self.ortho_init = ortho_init self.actor_class = actor_class @@ -204,11 +220,14 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> else: raise NotImplementedError(f"{self.action_space}") + actor_kwargs.update(self.features_extractor_kwargs) + self.actor = self.actor_class( net_arch=self.net_arch_pi, log_std_init=self.log_std_init, activation_fn=self.activation_fn, ortho_init=self.ortho_init, + features_extractor=self.features_extractor_class, **actor_kwargs, # type: ignore[arg-type] ) # Hack to make gSDE work without modifying internal SB3 code @@ -228,7 +247,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> ), ) - self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn) + self.vf = self.critic_class( + net_arch=self.net_arch_vf, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor_class, + **self.features_extractor_kwargs, + ) self.vf_state = TrainState.create( apply_fn=self.vf.apply, @@ -272,3 +296,48 @@ def _predict_all(actor_state, vf_state, observations, key): log_probs = dist.log_prob(actions) values = vf_state.apply_fn(vf_state.params, observations).flatten() return actions, log_probs, values + + +class CnnPolicy(PPOPolicy): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + ortho_init: bool = False, + log_std_init: float = 0, + # ReLU for NatureCNN + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + use_sde: bool = False, + use_expln: bool = False, + clip_mean: float = 2, + features_extractor_class: Optional[type[NatureCNN]] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = False, + actor_class: type[nn.Module] = Actor, + critic_class: type[nn.Module] = Critic, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + ortho_init, + log_std_init, + activation_fn, + use_sde, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + share_features_extractor, + actor_class, + critic_class, + ) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 46eb5d7..a718972 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -11,8 +11,8 @@ from stable_baselines3.common.utils import FloatSchedule, explained_variance from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax -from sbx.common.utils import KLAdaptiveLR -from sbx.ppo.policies import PPOPolicy +from sbx.common.utils import KLAdaptiveLR, copy_naturecnn_params +from sbx.ppo.policies import CnnPolicy, PPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -70,7 +70,7 @@ class PPO(OnPolicyAlgorithmJax): policy_aliases: ClassVar[dict[str, type[PPOPolicy]]] = { # type: ignore[assignment] "MlpPolicy": PPOPolicy, - # "CnnPolicy": ActorCriticCnnPolicy, + "CnnPolicy": CnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, } policy: PPOPolicy # type: ignore[assignment] @@ -196,7 +196,7 @@ def _setup_model(self) -> None: # self.clip_range_vf = FloatSchedule(self.clip_range_vf) @staticmethod - @partial(jax.jit, static_argnames=["normalize_advantage"]) + @partial(jax.jit, static_argnames=["normalize_advantage", "share_features_extractor"]) def _one_update( actor_state: TrainState, vf_state: TrainState, @@ -209,6 +209,7 @@ def _one_update( ent_coef: float, vf_coef: float, normalize_advantage: bool = True, + share_features_extractor: bool = False, ): # Normalize advantage # Normalization does not make sense if mini batchsize == 1, see GH issue #325 @@ -241,6 +242,10 @@ def actor_loss(params): ) actor_state = actor_state.apply_gradients(grads=grads) + if share_features_extractor: + # Hack: selective copy to share features extractor when using CNN + vf_state = copy_naturecnn_params(actor_state, vf_state) + def critic_loss(params): # Value loss using the TD(gae_lambda) target vf_values = vf_state.apply_fn(params, observations).flatten() @@ -249,6 +254,9 @@ def critic_loss(params): vf_loss_value, grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) vf_state = vf_state.apply_gradients(grads=grads) + if share_features_extractor: + actor_state = copy_naturecnn_params(vf_state, actor_state) + # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss return (actor_state, vf_state), (pg_loss_value, policy_loss, entropy_loss, vf_loss_value, ratio) @@ -292,6 +300,9 @@ def train(self) -> None: ent_coef=self.ent_coef, vf_coef=self.vf_coef, normalize_advantage=self.normalize_advantage, + # Sharing the CNN between actor and critic has a great impact on performance + # for Atari games + share_features_extractor=isinstance(self.policy, CnnPolicy), ) ) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 7ec382c..15d6183 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -244,7 +244,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value), + (qf1_loss_value, _qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value), ) = self._train( self.tau, self.target_entropy, diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 78b209b..bd544fb 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -2,23 +2,35 @@ import pytest from stable_baselines3.common.envs import FakeImageEnv -from sbx import DQN +from sbx import DQN, PPO -@pytest.mark.parametrize("model_class", [DQN]) +@pytest.mark.parametrize("model_class", [DQN, PPO]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1) + kwargs = {} + if model_class == DQN: + kwargs = {"buffer_size": 250, "learning_starts": 100} + elif model_class == PPO: + kwargs = { + "n_steps": 128, + "batch_size": 64, + "n_epochs": 2, + } + model = model_class( "CnnPolicy", env, - buffer_size=250, - policy_kwargs=dict(net_arch=[64]), - learning_starts=100, + policy_kwargs=dict( + net_arch=[64], + features_extractor_kwargs=dict(features_dim=64), + ), verbose=1, + **kwargs ) model.learn(total_timesteps=250)