From ad399c0d449ff62dbe99dcb4f105a93e2d5ec0da Mon Sep 17 00:00:00 2001 From: "paulo10.1977" Date: Tue, 23 Sep 2025 16:19:14 -0300 Subject: [PATCH 1/9] Add CnnPolicy to PPO --- sbx/ppo/policies.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index f4a80e4..c56e342 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -272,3 +272,40 @@ def _predict_all(actor_state, vf_state, observations, key): log_probs = dist.log_prob(actions) values = vf_state.apply_fn(vf_state.params, observations).flatten() return actions, log_probs, values + +class ActorCriticCNN(nn.Module): + n_actions: int + n_units: int = 512 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray): + x = jnp.transpose(x, (0, 2, 3, 1)) + x = x.astype(jnp.float32) / 255.0 + + x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) + x = self.activation_fn(x) + x = x.reshape((x.shape[0], -1)) + + shared_net = nn.Dense(self.n_units)(x) + shared_net = self.activation_fn(shared_net) + + action_logits = nn.Dense(self.n_actions)(shared_net) + + state_value = nn.Dense(1)(shared_net) + + return action_logits, state_value + + +class CnnPolicy(PPOPolicy): + def build_network(self) -> None: + self.network = ActorCriticCNN( + n_actions=self.action_space.n, + n_units=self.net_arch[0] if self.net_arch else 512, + activation_fn=self.activation_fn, + ) + From 12f7e313b1a931a44d233adc1cdb6f91cd9f12ab Mon Sep 17 00:00:00 2001 From: "paulo10.1977" Date: Tue, 23 Sep 2025 17:14:17 -0300 Subject: [PATCH 2/9] After run tests --- sbx/crossq/crossq.py | 2 +- sbx/ppo/policies.py | 10 +++++----- sbx/ppo/ppo.py | 4 ++-- sbx/tqc/tqc.py | 2 +- tests/test_cnn.py | 37 +++++++++++++++++++++++++++++++++++-- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 78a27e5..49cd317 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -261,7 +261,7 @@ def update_critic( discounts: jax.Array, key: jax.Array, ): - key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4) + key, noise_key, _dropout_key_target, dropout_key_current = jax.random.split(key, 4) # sample action from the actor dist = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index c56e342..95a65d6 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -273,6 +273,7 @@ def _predict_all(actor_state, vf_state, observations, key): values = vf_state.apply_fn(vf_state.params, observations).flatten() return actions, log_probs, values + class ActorCriticCNN(nn.Module): n_actions: int n_units: int = 512 @@ -282,7 +283,7 @@ class ActorCriticCNN(nn.Module): def __call__(self, x: jnp.ndarray): x = jnp.transpose(x, (0, 2, 3, 1)) x = x.astype(jnp.float32) / 255.0 - + x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) x = self.activation_fn(x) x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) @@ -290,14 +291,14 @@ def __call__(self, x: jnp.ndarray): x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) x = self.activation_fn(x) x = x.reshape((x.shape[0], -1)) - + shared_net = nn.Dense(self.n_units)(x) shared_net = self.activation_fn(shared_net) action_logits = nn.Dense(self.n_actions)(shared_net) - + state_value = nn.Dense(1)(shared_net) - + return action_logits, state_value @@ -308,4 +309,3 @@ def build_network(self) -> None: n_units=self.net_arch[0] if self.net_arch else 512, activation_fn=self.activation_fn, ) - diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 46eb5d7..c835d1d 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -12,7 +12,7 @@ from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax from sbx.common.utils import KLAdaptiveLR -from sbx.ppo.policies import PPOPolicy +from sbx.ppo.policies import CnnPolicy, PPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -70,7 +70,7 @@ class PPO(OnPolicyAlgorithmJax): policy_aliases: ClassVar[dict[str, type[PPOPolicy]]] = { # type: ignore[assignment] "MlpPolicy": PPOPolicy, - # "CnnPolicy": ActorCriticCnnPolicy, + "CnnPolicy": CnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, } policy: PPOPolicy # type: ignore[assignment] diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 7ec382c..15d6183 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -244,7 +244,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.actor_state, self.ent_coef_state, self.key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value), + (qf1_loss_value, _qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value), ) = self._train( self.tau, self.target_entropy, diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 78b209b..e36c59e 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -2,11 +2,11 @@ import pytest from stable_baselines3.common.envs import FakeImageEnv -from sbx import DQN +from sbx import DQN, PPO @pytest.mark.parametrize("model_class", [DQN]) -def test_cnn(tmp_path, model_class): +def test_cnn_dqn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution @@ -42,3 +42,36 @@ def test_cnn(tmp_path, model_class): assert np.allclose(action, model.predict(obs, deterministic=True)[0]) (tmp_path / SAVE_NAME).unlink() + + +@pytest.mark.parametrize("model_class", [PPO]) +def test_cnn_ppo(tmp_path, model_class): + SAVE_NAME = "cnn_model.zip" + # Fake grayscale with frameskip + # Atari after preprocessing: 84x84x1, here we are using lower resolution + # to check that the network handle it automatically + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1) + model = model_class( + "CnnPolicy", + env, + policy_kwargs=dict(net_arch=[64]), + verbose=1, + ) + model.learn(total_timesteps=250) + + obs, _ = env.reset() + + for _ in range(10): + model.predict(obs, deterministic=False) + + action, _ = model.predict(obs, deterministic=True) + + model.save(tmp_path / SAVE_NAME) + del model + + model = model_class.load(tmp_path / SAVE_NAME) + + # Check that the prediction is the same + assert np.allclose(action, model.predict(obs, deterministic=True)[0]) + + (tmp_path / SAVE_NAME).unlink() From 9fcbd7da6ac5390e56afe29a3ab4d79db772e417 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 26 Sep 2025 12:07:53 +0200 Subject: [PATCH 3/9] Remove unused key in CrossQ --- sbx/crossq/crossq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 49cd317..7c2b667 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -261,7 +261,7 @@ def update_critic( discounts: jax.Array, key: jax.Array, ): - key, noise_key, _dropout_key_target, dropout_key_current = jax.random.split(key, 4) + key, noise_key, dropout_key_current = jax.random.split(key, 3) # sample action from the actor dist = actor_state.apply_fn( {"params": actor_state.params, "batch_stats": actor_state.batch_stats}, From 6103b1e1e66eed0c95bb03e6837fa00358c5d721 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 26 Sep 2025 12:28:37 +0200 Subject: [PATCH 4/9] Move NatureCNN to Jax layers --- sbx/common/jax_layers.py | 23 +++++++++++++++++++++++ sbx/dqn/policies.py | 21 ++++----------------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index 01880aa..a77c816 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -225,3 +225,26 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = self.activation_fn(x) x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x) return residual + x + + +# CNN policy from DQN paper +class NatureCNN(nn.Module): + n_features: int = 512 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # Convert from channel-first (PyTorch) to channel-last (Jax) + x = jnp.transpose(x, (0, 2, 3, 1)) + # Convert to float and normalize the image + x = x.astype(jnp.float32) / 255.0 + x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) + x = self.activation_fn(x) + x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) + x = self.activation_fn(x) + # Flatten + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(self.n_features)(x) + return self.activation_fn(x) diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index c03be4d..8d25383 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -8,6 +8,7 @@ from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule +from sbx.common.jax_layers import NatureCNN from sbx.common.policies import BaseJaxPolicy, Flatten from sbx.common.type_aliases import RLTrainState @@ -28,28 +29,14 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return x -# Add CNN policy from DQN paper -class NatureCNN(nn.Module): +class CnnQNetwork(nn.Module): n_actions: int n_units: int = 512 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - # Convert from channel-first (PyTorch) to channel-last (Jax) - x = jnp.transpose(x, (0, 2, 3, 1)) - # Convert to float and normalize the image - x = x.astype(jnp.float32) / 255.0 - x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) - x = self.activation_fn(x) - x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) - x = self.activation_fn(x) - x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) - x = self.activation_fn(x) - # Flatten - x = x.reshape((x.shape[0], -1)) - x = nn.Dense(self.n_units)(x) - x = self.activation_fn(x) + x = NatureCNN(self.n_units, self.activation_fn)(x) x = nn.Dense(self.n_actions)(x) return x @@ -130,7 +117,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array: obs = jnp.array([self.observation_space.sample()]) - self.qf = NatureCNN( + self.qf = CnnQNetwork( n_actions=int(self.action_space.n), n_units=self.n_units, activation_fn=self.activation_fn, From 12f5922ee4035ec4b57cf652cdc9eefd55783119 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 26 Sep 2025 12:29:09 +0200 Subject: [PATCH 5/9] Refactor CNN implementation for PPO to actually use the CNN --- sbx/common/on_policy_algorithm.py | 2 +- sbx/ppo/policies.py | 106 ++++++++++++++++++++---------- tests/test_cnn.py | 54 ++++----------- 3 files changed, 83 insertions(+), 79 deletions(-) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index d0c7d97..13c4843 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -158,7 +158,7 @@ def collect_rollouts( # Always sample new stochastic action self.policy.reset_noise() - obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type] + obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type, arg-type] actions, log_probs, values = self.policy.predict_all(obs_tensor, self.policy.noise_key) actions = np.array(actions) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 95a65d6..db6abaf 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -14,6 +14,7 @@ from jax.nn.initializers import constant from stable_baselines3.common.type_aliases import Schedule +from sbx.common.jax_layers import NatureCNN from sbx.common.policies import BaseJaxPolicy, Flatten tfd = tfp.distributions @@ -22,9 +23,15 @@ class Critic(nn.Module): net_arch: Sequence[int] activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh + features_extractor: Optional[type[NatureCNN]] = None + features_dim: int = 512 @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # Note: we are using separate CNN for actor and critic + if self.features_extractor is not None: + x = self.features_extractor(self.features_dim, self.activation_fn)(x) + x = Flatten()(x) for n_units in self.net_arch: x = nn.Dense(n_units)(x) @@ -46,6 +53,8 @@ class Actor(nn.Module): split_indices: np.ndarray = field(default_factory=lambda: np.array([])) # Last layer with small scale ortho_init: bool = False + features_extractor: Optional[type[NatureCNN]] = None + features_dim: int = 512 def get_std(self) -> jnp.ndarray: # Make it work with gSDE @@ -61,6 +70,10 @@ def __post_init__(self) -> None: @nn.compact def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + # Note: we are using separate CNN for actor and critic + if self.features_extractor is not None: + x = self.features_extractor(self.features_dim, self.activation_fn)(x) + x = Flatten()(x) for n_units in self.net_arch: @@ -122,7 +135,7 @@ def __init__( # this is to keep API consistent with SB3 use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class=None, + features_extractor_class: Optional[type[NatureCNN]] = None, features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, @@ -156,7 +169,12 @@ def __init__( self.net_arch_pi = net_arch["pi"] self.net_arch_vf = net_arch["vf"] else: - self.net_arch_pi = self.net_arch_vf = [64, 64] + if features_extractor_class == NatureCNN: + # Just a linear layer after the CNN + net_arch = [] + else: + net_arch = [64, 64] + self.net_arch_pi = self.net_arch_vf = net_arch self.use_sde = use_sde self.ortho_init = ortho_init self.actor_class = actor_class @@ -204,11 +222,14 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> else: raise NotImplementedError(f"{self.action_space}") + actor_kwargs.update(self.features_extractor_kwargs) + self.actor = self.actor_class( net_arch=self.net_arch_pi, log_std_init=self.log_std_init, activation_fn=self.activation_fn, ortho_init=self.ortho_init, + features_extractor=self.features_extractor_class, **actor_kwargs, # type: ignore[arg-type] ) # Hack to make gSDE work without modifying internal SB3 code @@ -228,7 +249,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) -> ), ) - self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn) + self.vf = self.critic_class( + net_arch=self.net_arch_vf, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor_class, + **self.features_extractor_kwargs, + ) self.vf_state = TrainState.create( apply_fn=self.vf.apply, @@ -274,38 +300,46 @@ def _predict_all(actor_state, vf_state, observations, key): return actions, log_probs, values -class ActorCriticCNN(nn.Module): - n_actions: int - n_units: int = 512 - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - - @nn.compact - def __call__(self, x: jnp.ndarray): - x = jnp.transpose(x, (0, 2, 3, 1)) - x = x.astype(jnp.float32) / 255.0 - - x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) - x = self.activation_fn(x) - x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) - x = self.activation_fn(x) - x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) - x = self.activation_fn(x) - x = x.reshape((x.shape[0], -1)) - - shared_net = nn.Dense(self.n_units)(x) - shared_net = self.activation_fn(shared_net) - - action_logits = nn.Dense(self.n_actions)(shared_net) - - state_value = nn.Dense(1)(shared_net) - - return action_logits, state_value - - class CnnPolicy(PPOPolicy): - def build_network(self) -> None: - self.network = ActorCriticCNN( - n_actions=self.action_space.n, - n_units=self.net_arch[0] if self.net_arch else 512, - activation_fn=self.activation_fn, + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + ortho_init: bool = False, + log_std_init: float = 0, + # ReLU for NatureCNN + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + use_sde: bool = False, + use_expln: bool = False, + clip_mean: float = 2, + features_extractor_class: Optional[type[NatureCNN]] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = False, + actor_class: type[nn.Module] = Actor, + critic_class: type[nn.Module] = Critic, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + ortho_init, + log_std_init, + activation_fn, + use_sde, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + share_features_extractor, + actor_class, + critic_class, ) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e36c59e..3cda093 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -5,21 +5,24 @@ from sbx import DQN, PPO -@pytest.mark.parametrize("model_class", [DQN]) +@pytest.mark.parametrize("model_class", [DQN, PPO]) def test_cnn_dqn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1) - model = model_class( - "CnnPolicy", - env, - buffer_size=250, - policy_kwargs=dict(net_arch=[64]), - learning_starts=100, - verbose=1, - ) + kwargs = {} + if model_class == DQN: + kwargs = {"buffer_size": 250, "learning_starts": 100} + elif model_class == PPO: + kwargs = { + "n_steps": 128, + "batch_size": 64, + "n_epochs": 2, + } + + model = model_class("CnnPolicy", env, policy_kwargs=dict(net_arch=[64]), verbose=1, **kwargs) model.learn(total_timesteps=250) obs, _ = env.reset() @@ -42,36 +45,3 @@ def test_cnn_dqn(tmp_path, model_class): assert np.allclose(action, model.predict(obs, deterministic=True)[0]) (tmp_path / SAVE_NAME).unlink() - - -@pytest.mark.parametrize("model_class", [PPO]) -def test_cnn_ppo(tmp_path, model_class): - SAVE_NAME = "cnn_model.zip" - # Fake grayscale with frameskip - # Atari after preprocessing: 84x84x1, here we are using lower resolution - # to check that the network handle it automatically - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1) - model = model_class( - "CnnPolicy", - env, - policy_kwargs=dict(net_arch=[64]), - verbose=1, - ) - model.learn(total_timesteps=250) - - obs, _ = env.reset() - - for _ in range(10): - model.predict(obs, deterministic=False) - - action, _ = model.predict(obs, deterministic=True) - - model.save(tmp_path / SAVE_NAME) - del model - - model = model_class.load(tmp_path / SAVE_NAME) - - # Check that the prediction is the same - assert np.allclose(action, model.predict(obs, deterministic=True)[0]) - - (tmp_path / SAVE_NAME).unlink() From dfa91f2e4680cee692f996ae1ff97aa95dfdd7ff Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 29 Sep 2025 15:32:02 +0200 Subject: [PATCH 6/9] Share features extractor with selective copy --- sbx/common/utils.py | 65 +++++++++++++++++++++++++++++++++++++++++++++ sbx/ppo/ppo.py | 13 +++++++-- tests/test_cnn.py | 11 +++++++- 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/sbx/common/utils.py b/sbx/common/utils.py index 447c4f3..4f8b2de 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,6 +1,11 @@ from dataclasses import dataclass +from typing import Union +import jax +import jax.numpy as jnp import numpy as np +from flax.core import FrozenDict +from flax.training.train_state import TrainState @dataclass @@ -24,3 +29,63 @@ def update(self, kl_div: float) -> None: self.current_adaptive_lr *= self.adaptive_lr_factor self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate) + + +def mask_from_prefix(params: FrozenDict, prefix: str = "NatureCNN_") -> dict: + """ + Build a pytree mask (same structure as `params`) where a leaf is True + if the top-level module name starts with `prefix`. + """ + + def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool]: + if isinstance(tree, dict): + return {key: _traverse(value, (*path, key)) for key, value in tree.items()} + # leaf + return path[1].startswith(prefix) if len(path) > 1 else False + + return _traverse(params) # type: ignore[return-value] + + +def align_pytree(params1: FrozenDict, params2: FrozenDict) -> dict: + """ + Return a pytree with the *exact* structure of `params2`. For every leaf in `params2`, + use the corresponding leaf from `params1` if it exists; otherwise use `params2`'s leaf. + This guarantees the two pytrees have identical structure for tree_map. + """ + if isinstance(params2, dict): + out = {} + for key, params2_sub in params2.items(): + params1_sub = params1[key] if (isinstance(params1, dict) and key in params1) else None + out[key] = align_pytree(params1_sub, params2_sub) # type: ignore[arg-type] + return out + # leaf-case: if params1 value exists (not None and same shape) use it, else use params2 leaf + return params1 if (params1 is not None and params1.shape == params2.shape) else params2 # type: ignore[attr-defined, return-value] + + +@jax.jit +def masked_copy(params1: FrozenDict, params2: FrozenDict, mask: dict) -> FrozenDict: + """ + Leafwise selection: wherever mask is True we take params1 value, + otherwise params2 value. + """ + return jax.tree_util.tree_map( + lambda val1, val2, mask_value: jnp.where(mask_value, val1, val2), + params1, + params2, + mask, + ) + + +@jax.jit +def copy_naturecnn_params(state1: TrainState, state2: TrainState) -> TrainState: + """ + Copy all top-level modules whose names start with "NatureCNN_" from + state1.params into state2.params. + It is useful when sharing features extractor parameters between actor and critic. + """ + # Ensure same structure + aligned_params = align_pytree(state1.params, state2.params) + mask = mask_from_prefix(state2.params, prefix="NatureCNN_") + new_params = masked_copy(aligned_params, state2.params, mask) + + return state2.replace(params=new_params) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index c835d1d..09cde90 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -11,7 +11,7 @@ from stable_baselines3.common.utils import FloatSchedule, explained_variance from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax -from sbx.common.utils import KLAdaptiveLR +from sbx.common.utils import KLAdaptiveLR, copy_naturecnn_params from sbx.ppo.policies import CnnPolicy, PPOPolicy PPOSelf = TypeVar("PPOSelf", bound="PPO") @@ -196,7 +196,7 @@ def _setup_model(self) -> None: # self.clip_range_vf = FloatSchedule(self.clip_range_vf) @staticmethod - @partial(jax.jit, static_argnames=["normalize_advantage"]) + @partial(jax.jit, static_argnames=["normalize_advantage", "share_features_extractor"]) def _one_update( actor_state: TrainState, vf_state: TrainState, @@ -209,6 +209,7 @@ def _one_update( ent_coef: float, vf_coef: float, normalize_advantage: bool = True, + share_features_extractor: bool = False, ): # Normalize advantage # Normalization does not make sense if mini batchsize == 1, see GH issue #325 @@ -241,6 +242,10 @@ def actor_loss(params): ) actor_state = actor_state.apply_gradients(grads=grads) + if share_features_extractor: + # Hack: selective copy to share features extractor when using CNN + vf_state = copy_naturecnn_params(actor_state, vf_state) + def critic_loss(params): # Value loss using the TD(gae_lambda) target vf_values = vf_state.apply_fn(params, observations).flatten() @@ -249,6 +254,9 @@ def critic_loss(params): vf_loss_value, grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params) vf_state = vf_state.apply_gradients(grads=grads) + if share_features_extractor: + actor_state = copy_naturecnn_params(vf_state, actor_state) + # loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss return (actor_state, vf_state), (pg_loss_value, policy_loss, entropy_loss, vf_loss_value, ratio) @@ -292,6 +300,7 @@ def train(self) -> None: ent_coef=self.ent_coef, vf_coef=self.vf_coef, normalize_advantage=self.normalize_advantage, + share_features_extractor=isinstance(self.policy, CnnPolicy), ) ) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 3cda093..19a709f 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -22,7 +22,16 @@ def test_cnn_dqn(tmp_path, model_class): "n_epochs": 2, } - model = model_class("CnnPolicy", env, policy_kwargs=dict(net_arch=[64]), verbose=1, **kwargs) + model = model_class( + "CnnPolicy", + env, + policy_kwargs=dict( + net_arch=[64], + features_extractor_kwargs=dict(features_dim=64), + ), + verbose=1, + **kwargs + ) model.learn(total_timesteps=250) obs, _ = env.reset() From 8086149194343d50ffc645e78e0a987e5a86a97c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 29 Sep 2025 17:11:56 +0200 Subject: [PATCH 7/9] Rename method --- sbx/common/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sbx/common/utils.py b/sbx/common/utils.py index 4f8b2de..afea621 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -46,17 +46,17 @@ def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool] return _traverse(params) # type: ignore[return-value] -def align_pytree(params1: FrozenDict, params2: FrozenDict) -> dict: +def align_params(params1: FrozenDict, params2: FrozenDict) -> dict: """ - Return a pytree with the *exact* structure of `params2`. For every leaf in `params2`, + Return a dict with the *exact* structure of `params2`. For every leaf in `params2`, use the corresponding leaf from `params1` if it exists; otherwise use `params2`'s leaf. - This guarantees the two pytrees have identical structure for tree_map. + This guarantees the two dict have identical structure for tree_map. """ if isinstance(params2, dict): out = {} for key, params2_sub in params2.items(): params1_sub = params1[key] if (isinstance(params1, dict) and key in params1) else None - out[key] = align_pytree(params1_sub, params2_sub) # type: ignore[arg-type] + out[key] = align_params(params1_sub, params2_sub) # type: ignore[arg-type] return out # leaf-case: if params1 value exists (not None and same shape) use it, else use params2 leaf return params1 if (params1 is not None and params1.shape == params2.shape) else params2 # type: ignore[attr-defined, return-value] @@ -84,7 +84,7 @@ def copy_naturecnn_params(state1: TrainState, state2: TrainState) -> TrainState: It is useful when sharing features extractor parameters between actor and critic. """ # Ensure same structure - aligned_params = align_pytree(state1.params, state2.params) + aligned_params = align_params(state1.params, state2.params) mask = mask_from_prefix(state2.params, prefix="NatureCNN_") new_params = masked_copy(aligned_params, state2.params, mask) From 9ba65710f4b58d5713ac1d442944a82c7fe15c13 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 29 Sep 2025 19:17:55 +0200 Subject: [PATCH 8/9] Update comments --- sbx/ppo/policies.py | 2 -- sbx/ppo/ppo.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index db6abaf..7ffa7c6 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -28,7 +28,6 @@ class Critic(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - # Note: we are using separate CNN for actor and critic if self.features_extractor is not None: x = self.features_extractor(self.features_dim, self.activation_fn)(x) @@ -70,7 +69,6 @@ def __post_init__(self) -> None: @nn.compact def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] - # Note: we are using separate CNN for actor and critic if self.features_extractor is not None: x = self.features_extractor(self.features_dim, self.activation_fn)(x) diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 09cde90..a718972 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -300,6 +300,8 @@ def train(self) -> None: ent_coef=self.ent_coef, vf_coef=self.vf_coef, normalize_advantage=self.normalize_advantage, + # Sharing the CNN between actor and critic has a great impact on performance + # for Atari games share_features_extractor=isinstance(self.policy, CnnPolicy), ) ) From ce2ba79bc389d3afa135a648665b30d2241a980c Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 29 Sep 2025 19:18:47 +0200 Subject: [PATCH 9/9] Update name --- tests/test_cnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 19a709f..bd544fb 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("model_class", [DQN, PPO]) -def test_cnn_dqn(tmp_path, model_class): +def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution