Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,26 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = self.activation_fn(x)
x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x)
return residual + x


# CNN policy from DQN paper
class NatureCNN(nn.Module):
n_features: int = 512
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# Convert from channel-first (PyTorch) to channel-last (Jax)
x = jnp.transpose(x, (0, 2, 3, 1))
# Convert to float and normalize the image
x = x.astype(jnp.float32) / 255.0
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = self.activation_fn(x)
# Flatten
x = x.reshape((x.shape[0], -1))
x = nn.Dense(self.n_features)(x)
return self.activation_fn(x)
2 changes: 1 addition & 1 deletion sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def collect_rollouts(
# Always sample new stochastic action
self.policy.reset_noise()

obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type]
obs_tensor, _ = self.policy.prepare_obs(self._last_obs) # type: ignore[has-type, arg-type]
actions, log_probs, values = self.policy.predict_all(obs_tensor, self.policy.noise_key)

actions = np.array(actions)
Expand Down
65 changes: 65 additions & 0 deletions sbx/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from dataclasses import dataclass
from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
from flax.core import FrozenDict
from flax.training.train_state import TrainState


@dataclass
Expand All @@ -24,3 +29,63 @@ def update(self, kl_div: float) -> None:
self.current_adaptive_lr *= self.adaptive_lr_factor

self.current_adaptive_lr = np.clip(self.current_adaptive_lr, self.min_learning_rate, self.max_learning_rate)


def mask_from_prefix(params: FrozenDict, prefix: str = "NatureCNN_") -> dict:
"""
Build a pytree mask (same structure as `params`) where a leaf is True
if the top-level module name starts with `prefix`.
"""

def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool]:
if isinstance(tree, dict):
return {key: _traverse(value, (*path, key)) for key, value in tree.items()}
# leaf
return path[1].startswith(prefix) if len(path) > 1 else False

return _traverse(params) # type: ignore[return-value]


def align_params(params1: FrozenDict, params2: FrozenDict) -> dict:
"""
Return a dict with the *exact* structure of `params2`. For every leaf in `params2`,
use the corresponding leaf from `params1` if it exists; otherwise use `params2`'s leaf.
This guarantees the two dict have identical structure for tree_map.
"""
if isinstance(params2, dict):
out = {}
for key, params2_sub in params2.items():
params1_sub = params1[key] if (isinstance(params1, dict) and key in params1) else None
out[key] = align_params(params1_sub, params2_sub) # type: ignore[arg-type]
return out
# leaf-case: if params1 value exists (not None and same shape) use it, else use params2 leaf
return params1 if (params1 is not None and params1.shape == params2.shape) else params2 # type: ignore[attr-defined, return-value]


@jax.jit
def masked_copy(params1: FrozenDict, params2: FrozenDict, mask: dict) -> FrozenDict:
"""
Leafwise selection: wherever mask is True we take params1 value,
otherwise params2 value.
"""
return jax.tree_util.tree_map(
lambda val1, val2, mask_value: jnp.where(mask_value, val1, val2),
params1,
params2,
mask,
)


@jax.jit
def copy_naturecnn_params(state1: TrainState, state2: TrainState) -> TrainState:
"""
Copy all top-level modules whose names start with "NatureCNN_" from
state1.params into state2.params.
It is useful when sharing features extractor parameters between actor and critic.
"""
# Ensure same structure
aligned_params = align_params(state1.params, state2.params)
mask = mask_from_prefix(state2.params, prefix="NatureCNN_")
new_params = masked_copy(aligned_params, state2.params, mask)

return state2.replace(params=new_params)
2 changes: 1 addition & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def update_critic(
discounts: jax.Array,
key: jax.Array,
):
key, noise_key, dropout_key_target, dropout_key_current = jax.random.split(key, 4)
key, noise_key, dropout_key_current = jax.random.split(key, 3)
# sample action from the actor
dist = actor_state.apply_fn(
{"params": actor_state.params, "batch_stats": actor_state.batch_stats},
Expand Down
21 changes: 4 additions & 17 deletions sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gymnasium import spaces
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.jax_layers import NatureCNN
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import RLTrainState

Expand All @@ -28,28 +29,14 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return x


# Add CNN policy from DQN paper
class NatureCNN(nn.Module):
class CnnQNetwork(nn.Module):
n_actions: int
n_units: int = 512
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# Convert from channel-first (PyTorch) to channel-last (Jax)
x = jnp.transpose(x, (0, 2, 3, 1))
# Convert to float and normalize the image
x = x.astype(jnp.float32) / 255.0
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = self.activation_fn(x)
# Flatten
x = x.reshape((x.shape[0], -1))
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = NatureCNN(self.n_units, self.activation_fn)(x)
x = nn.Dense(self.n_actions)(x)
return x

Expand Down Expand Up @@ -130,7 +117,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:

obs = jnp.array([self.observation_space.sample()])

self.qf = NatureCNN(
self.qf = CnnQNetwork(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
Expand Down
75 changes: 72 additions & 3 deletions sbx/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from jax.nn.initializers import constant
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.jax_layers import NatureCNN
from sbx.common.policies import BaseJaxPolicy, Flatten

tfd = tfp.distributions
Expand All @@ -22,9 +23,14 @@
class Critic(nn.Module):
net_arch: Sequence[int]
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
features_extractor: Optional[type[NatureCNN]] = None
features_dim: int = 512

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
if self.features_extractor is not None:
x = self.features_extractor(self.features_dim, self.activation_fn)(x)

x = Flatten()(x)
for n_units in self.net_arch:
x = nn.Dense(n_units)(x)
Expand All @@ -46,6 +52,8 @@ class Actor(nn.Module):
split_indices: np.ndarray = field(default_factory=lambda: np.array([]))
# Last layer with small scale
ortho_init: bool = False
features_extractor: Optional[type[NatureCNN]] = None
features_dim: int = 512

def get_std(self) -> jnp.ndarray:
# Make it work with gSDE
Expand All @@ -61,6 +69,9 @@ def __post_init__(self) -> None:

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
if self.features_extractor is not None:
x = self.features_extractor(self.features_dim, self.activation_fn)(x)

x = Flatten()(x)

for n_units in self.net_arch:
Expand Down Expand Up @@ -122,7 +133,7 @@ def __init__(
# this is to keep API consistent with SB3
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class=None,
features_extractor_class: Optional[type[NatureCNN]] = None,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
Expand Down Expand Up @@ -156,7 +167,12 @@ def __init__(
self.net_arch_pi = net_arch["pi"]
self.net_arch_vf = net_arch["vf"]
else:
self.net_arch_pi = self.net_arch_vf = [64, 64]
if features_extractor_class == NatureCNN:
# Just a linear layer after the CNN
net_arch = []
else:
net_arch = [64, 64]
self.net_arch_pi = self.net_arch_vf = net_arch
self.use_sde = use_sde
self.ortho_init = ortho_init
self.actor_class = actor_class
Expand Down Expand Up @@ -204,11 +220,14 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
else:
raise NotImplementedError(f"{self.action_space}")

actor_kwargs.update(self.features_extractor_kwargs)

self.actor = self.actor_class(
net_arch=self.net_arch_pi,
log_std_init=self.log_std_init,
activation_fn=self.activation_fn,
ortho_init=self.ortho_init,
features_extractor=self.features_extractor_class,
**actor_kwargs, # type: ignore[arg-type]
)
# Hack to make gSDE work without modifying internal SB3 code
Expand All @@ -228,7 +247,12 @@ def build(self, key: jax.Array, lr_schedule: Schedule, max_grad_norm: float) ->
),
)

self.vf = self.critic_class(net_arch=self.net_arch_vf, activation_fn=self.activation_fn)
self.vf = self.critic_class(
net_arch=self.net_arch_vf,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor_class,
**self.features_extractor_kwargs,
)

self.vf_state = TrainState.create(
apply_fn=self.vf.apply,
Expand Down Expand Up @@ -272,3 +296,48 @@ def _predict_all(actor_state, vf_state, observations, key):
log_probs = dist.log_prob(actions)
values = vf_state.apply_fn(vf_state.params, observations).flatten()
return actions, log_probs, values


class CnnPolicy(PPOPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
ortho_init: bool = False,
log_std_init: float = 0,
# ReLU for NatureCNN
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
use_sde: bool = False,
use_expln: bool = False,
clip_mean: float = 2,
features_extractor_class: Optional[type[NatureCNN]] = NatureCNN,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
share_features_extractor: bool = False,
actor_class: type[nn.Module] = Actor,
critic_class: type[nn.Module] = Critic,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
ortho_init,
log_std_init,
activation_fn,
use_sde,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
share_features_extractor,
actor_class,
critic_class,
)
19 changes: 15 additions & 4 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from stable_baselines3.common.utils import FloatSchedule, explained_variance

from sbx.common.on_policy_algorithm import OnPolicyAlgorithmJax
from sbx.common.utils import KLAdaptiveLR
from sbx.ppo.policies import PPOPolicy
from sbx.common.utils import KLAdaptiveLR, copy_naturecnn_params
from sbx.ppo.policies import CnnPolicy, PPOPolicy

PPOSelf = TypeVar("PPOSelf", bound="PPO")

Expand Down Expand Up @@ -70,7 +70,7 @@ class PPO(OnPolicyAlgorithmJax):

policy_aliases: ClassVar[dict[str, type[PPOPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": PPOPolicy,
# "CnnPolicy": ActorCriticCnnPolicy,
"CnnPolicy": CnnPolicy,
# "MultiInputPolicy": MultiInputActorCriticPolicy,
}
policy: PPOPolicy # type: ignore[assignment]
Expand Down Expand Up @@ -196,7 +196,7 @@ def _setup_model(self) -> None:
# self.clip_range_vf = FloatSchedule(self.clip_range_vf)

@staticmethod
@partial(jax.jit, static_argnames=["normalize_advantage"])
@partial(jax.jit, static_argnames=["normalize_advantage", "share_features_extractor"])
def _one_update(
actor_state: TrainState,
vf_state: TrainState,
Expand All @@ -209,6 +209,7 @@ def _one_update(
ent_coef: float,
vf_coef: float,
normalize_advantage: bool = True,
share_features_extractor: bool = False,
):
# Normalize advantage
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
Expand Down Expand Up @@ -241,6 +242,10 @@ def actor_loss(params):
)
actor_state = actor_state.apply_gradients(grads=grads)

if share_features_extractor:
# Hack: selective copy to share features extractor when using CNN
vf_state = copy_naturecnn_params(actor_state, vf_state)

def critic_loss(params):
# Value loss using the TD(gae_lambda) target
vf_values = vf_state.apply_fn(params, observations).flatten()
Expand All @@ -249,6 +254,9 @@ def critic_loss(params):
vf_loss_value, grads = jax.value_and_grad(critic_loss, has_aux=False)(vf_state.params)
vf_state = vf_state.apply_gradients(grads=grads)

if share_features_extractor:
actor_state = copy_naturecnn_params(vf_state, actor_state)

# loss = policy_loss + ent_coef * entropy_loss + vf_coef * value_loss
return (actor_state, vf_state), (pg_loss_value, policy_loss, entropy_loss, vf_loss_value, ratio)

Expand Down Expand Up @@ -292,6 +300,9 @@ def train(self) -> None:
ent_coef=self.ent_coef,
vf_coef=self.vf_coef,
normalize_advantage=self.normalize_advantage,
# Sharing the CNN between actor and critic has a great impact on performance
# for Atari games
share_features_extractor=isinstance(self.policy, CnnPolicy),
)
)

Expand Down
2 changes: 1 addition & 1 deletion sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.actor_state,
self.ent_coef_state,
self.key,
(qf1_loss_value, qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value),
(qf1_loss_value, _qf2_loss_value, actor_loss_value, ent_coef_loss_value, ent_coef_value),
) = self._train(
self.tau,
self.target_entropy,
Expand Down
Loading