Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand All @@ -35,7 +35,7 @@ jobs:
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
uv pip install --system torch==2.9.1+cpu --index https://download.pytorch.org/whl/cpu

uv pip install --system .[tests]
# Use headless version
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.9
target-version = "py39"
# Assume Python 3.10
target-version = "py310"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
Expand Down
4 changes: 2 additions & 2 deletions sbx/common/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any

import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
Expand All @@ -19,7 +19,7 @@ def mode(self) -> jnp.ndarray:
return self.bijector.forward(self.distribution.mode())

@classmethod
def _parameter_properties(cls, dtype: Optional[Any], num_classes=None):
def _parameter_properties(cls, dtype: Any | None, num_classes=None):
td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
del td_properties["bijector"]
return td_properties
14 changes: 7 additions & 7 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from collections.abc import Callable, Sequence
from typing import Any, Union

import flax.linen as nn
import jax
Expand All @@ -12,7 +12,7 @@
Array = Any
Shape = tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]
Axes = Union[int, Sequence[int]] # noqa: UP007


class BatchRenorm(Module):
Expand Down Expand Up @@ -78,26 +78,26 @@ class BatchRenorm(Module):
calculation for the variance.
"""

use_running_average: Optional[bool] = None
use_running_average: bool | None = None
axis: int = -1
momentum: float = 0.99
epsilon: float = 0.001
warmup_steps: int = 100_000
dtype: Optional[Dtype] = None
dtype: Dtype | None = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_name: str | None = None
axis_index_groups: Any = None
# This parameter was added in flax.linen 0.7.2 (08/2023)
# commented out to be compatible with a wider range of jax versions
# TODO: re-activate in some months (04/2024)
# use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None):
def __call__(self, x, use_running_average: bool | None = None):
"""Normalizes the input using batch statistics.

NOTE:
Expand Down
32 changes: 16 additions & 16 deletions sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import pathlib
from typing import Any, Optional, Union
from typing import Any

import jax
import numpy as np
Expand All @@ -21,35 +21,35 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
def __init__(
self,
policy: type[BasePolicy],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
qf_learning_rate: Optional[float] = None,
env: GymEnv | str,
learning_rate: float | Schedule,
qf_learning_rate: float | None = None,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, tuple[int, str]] = (1, "step"),
train_freq: int | tuple[int, str] = (1, "step"),
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
action_noise: ActionNoise | None = None,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
optimize_memory_usage: bool = False,
n_steps: int = 1,
policy_kwargs: Optional[dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
policy_kwargs: dict[str, Any] | None = None,
tensorboard_log: str | None = None,
verbose: int = 0,
device: str = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
seed: int | None = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
stats_window_size: int = 100,
param_resets: Optional[list[int]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
param_resets: list[int] | None = None,
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
):
super().__init__(
policy=policy,
Expand Down Expand Up @@ -108,7 +108,7 @@ def _excluded_save_params(self) -> list[str]:

def _update_learning_rate( # type: ignore[override]
self,
optimizers: Union[list[optax.OptState], optax.OptState],
optimizers: list[optax.OptState] | optax.OptState,
learning_rate: float,
name: str = "learning_rate",
) -> None:
Expand All @@ -129,7 +129,7 @@ def _update_learning_rate( # type: ignore[override]
# Note: the optimizer must have been defined with inject_hyperparams
optimizer.hyperparams["learning_rate"] = learning_rate

def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
def set_random_seed(self, seed: int | None) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
# Sample random seed
Expand Down Expand Up @@ -173,7 +173,7 @@ def _setup_model(self) -> None:

def load_replay_buffer(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
path: str | pathlib.Path | io.BufferedIOBase,
truncate_last_traj: bool = True,
) -> None:
super().load_replay_buffer(path, truncate_last_traj)
Expand Down
20 changes: 10 additions & 10 deletions sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, TypeVar, Union
from typing import Any, TypeVar

import gymnasium as gym
import jax
Expand All @@ -25,9 +25,9 @@ class OnPolicyAlgorithmJax(OnPolicyAlgorithm):

def __init__(
self,
policy: Union[str, type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
policy: str | type[BasePolicy],
env: GymEnv | str,
learning_rate: float | Schedule,
n_steps: int,
gamma: float,
gae_lambda: float,
Expand All @@ -36,14 +36,14 @@ def __init__(
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
tensorboard_log: Optional[str] = None,
tensorboard_log: str | None = None,
monitor_wrapper: bool = True,
policy_kwargs: Optional[dict[str, Any]] = None,
policy_kwargs: dict[str, Any] | None = None,
verbose: int = 0,
seed: Optional[int] = None,
seed: int | None = None,
device: str = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
supported_action_spaces: tuple[type[spaces.Space], ...] | None = None,
):
super().__init__(
policy=policy, # type: ignore[arg-type]
Expand Down Expand Up @@ -78,7 +78,7 @@ def _excluded_save_params(self) -> list[str]:

def _update_learning_rate( # type: ignore[override]
self,
optimizers: Union[list[optax.OptState], optax.OptState],
optimizers: list[optax.OptState] | optax.OptState,
learning_rate: float,
) -> None:
"""
Expand All @@ -97,7 +97,7 @@ def _update_learning_rate( # type: ignore[override]
# Note: the optimizer must have been defined with inject_hyperparams
optimizer.hyperparams["learning_rate"] = learning_rate

def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override]
def set_random_seed(self, seed: int | None) -> None: # type: ignore[override]
super().set_random_seed(seed)
if seed is None:
# Sample random seed
Expand Down
22 changes: 11 additions & 11 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# import copy
from collections.abc import Sequence
from typing import Callable, Optional, Union, no_type_check
from collections.abc import Callable, Sequence
from typing import no_type_check

import flax.linen as nn
import jax
Expand Down Expand Up @@ -50,11 +50,11 @@ def select_action(actor_state, observations):
@no_type_check
def predict(
self,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
observation: np.ndarray | dict[str, np.ndarray],
state: tuple[np.ndarray, ...] | None = None,
episode_start: np.ndarray | None = None,
deterministic: bool = False,
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]:
# self.set_training_mode(False)

observation, vectorized_env = self.prepare_obs(observation)
Expand All @@ -81,7 +81,7 @@ def predict(

return actions, state

def prepare_obs(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[np.ndarray, bool]:
def prepare_obs(self, observation: np.ndarray | dict[str, np.ndarray]) -> tuple[np.ndarray, bool]:
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(self.observation_space, spaces.Dict)
Expand Down Expand Up @@ -132,7 +132,7 @@ def set_training_mode(self, mode: bool) -> None:
class ContinuousCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
dropout_rate: float | None = None
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
output_dim: int = 1

Expand All @@ -154,7 +154,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
class SimbaContinuousCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False # for consistency, not used
dropout_rate: Optional[float] = None
dropout_rate: float | None = None
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
output_dim: int = 1
scale_factor: int = 4
Expand All @@ -179,7 +179,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
dropout_rate: Optional[float] = None
dropout_rate: float | None = None
n_critics: int = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
output_dim: int = 1
Expand Down Expand Up @@ -210,7 +210,7 @@ class SimbaVectorCritic(nn.Module):
net_arch: Sequence[int]
# Note: we have use_layer_norm for consistency but it is not used (always on)
use_layer_norm: bool = True
dropout_rate: Optional[float] = None
dropout_rate: float | None = None
n_critics: int = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
output_dim: int = 1
Expand Down
3 changes: 1 addition & 2 deletions sbx/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -37,7 +36,7 @@ def mask_from_prefix(params: FrozenDict, prefix: str = "NatureCNN_") -> dict:
if the top-level module name starts with `prefix`.
"""

def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool]:
def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> dict | bool:
if isinstance(tree, dict):
return {key: _traverse(value, (*path, key)) for key, value in tree.items()}
# leaf
Expand Down
28 changes: 14 additions & 14 deletions sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, ClassVar, Literal, Optional, Union
from typing import Any, ClassVar, Literal

import flax
import flax.linen as nn
Expand Down Expand Up @@ -53,31 +53,31 @@ class CrossQ(OffPolicyAlgorithmJax):
def __init__(
self,
policy,
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
qf_learning_rate: Optional[float] = None,
env: GymEnv | str,
learning_rate: float | Schedule = 1e-3,
qf_learning_rate: float | None = None,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
gamma: float = 0.99,
train_freq: Union[int, tuple[int, str]] = 1,
train_freq: int | tuple[int, str] = 1,
gradient_steps: int = 1,
policy_delay: int = 3,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
action_noise: ActionNoise | None = None,
replay_buffer_class: type[ReplayBuffer] | None = None,
replay_buffer_kwargs: dict[str, Any] | None = None,
n_steps: int = 1,
ent_coef: Union[str, float] = "auto",
target_entropy: Union[Literal["auto"], float] = "auto",
ent_coef: str | float = "auto",
target_entropy: Literal["auto"] | float = "auto",
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params
tensorboard_log: str | None = None,
policy_kwargs: dict[str, Any] | None = None,
param_resets: list[int] | None = None, # List of timesteps after which to reset the params
verbose: int = 0,
seed: Optional[int] = None,
seed: int | None = None,
device: str = "auto",
_init_setup_model: bool = True,
) -> None:
Expand Down
Loading