From 66ef270603590c3c24789fe6b1acf8a2ee7ad470 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 5 Dec 2025 18:46:04 +0100 Subject: [PATCH 1/3] Drop Python 3.9 support --- .github/workflows/ci.yml | 8 ++++---- sbx/version.txt | 2 +- setup.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1c8d7fb..0bc3b8d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,12 +20,12 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -35,7 +35,7 @@ jobs: pip install uv # cpu version of pytorch # See https://github.com/astral-sh/uv/issues/1497 - uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu + uv pip install --system torch==2.9.1+cpu --index https://download.pytorch.org/whl/cpu uv pip install --system .[tests] # Use headless version diff --git a/sbx/version.txt b/sbx/version.txt index ca222b7..2094a10 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.23.0 +0.24.0 diff --git a/setup.py b/setup.py index 3655279..8f7de74 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.7.0a1,<3.0", + "stable_baselines3>=2.8.0a0,<3.0", "jax>=0.4.24,<0.7.0", # tf probability not compatible yet with latest jax version "jaxlib", "flax", @@ -75,13 +75,13 @@ long_description=long_description, long_description_content_type="text/markdown", version=__version__, - python_requires=">=3.9", + python_requires=">=3.10", # PyPI package information. classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], ) From 56e9e70937d928cadc69c5f864a5a4a823577542 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 5 Dec 2025 18:46:57 +0100 Subject: [PATCH 2/3] Apply autofixes for Python 3.10 --- pyproject.toml | 4 ++-- sbx/common/distributions.py | 4 ++-- sbx/common/jax_layers.py | 13 ++++++------ sbx/common/off_policy_algorithm.py | 32 +++++++++++++++--------------- sbx/common/on_policy_algorithm.py | 20 +++++++++---------- sbx/common/policies.py | 21 ++++++++++---------- sbx/common/utils.py | 3 +-- sbx/crossq/crossq.py | 28 +++++++++++++------------- sbx/crossq/policies.py | 23 +++++++++++---------- sbx/ddpg/ddpg.py | 22 ++++++++++---------- sbx/dqn/dqn.py | 26 ++++++++++++------------ sbx/dqn/policies.py | 9 +++++---- sbx/ppo/policies.py | 25 ++++++++++++----------- sbx/ppo/ppo.py | 20 +++++++++---------- sbx/sac/policies.py | 15 +++++++------- sbx/sac/sac.py | 28 +++++++++++++------------- sbx/td3/policies.py | 9 +++++---- sbx/td3/td3.py | 24 +++++++++++----------- sbx/tqc/policies.py | 15 +++++++------- sbx/tqc/tqc.py | 28 +++++++++++++------------- tests/test_flatten.py | 3 +-- tests/test_run.py | 3 +-- tests/test_spaces.py | 3 +-- 23 files changed, 191 insertions(+), 187 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 57514e5..52c07d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [tool.ruff] # Same as Black. line-length = 127 -# Assume Python 3.9 -target-version = "py39" +# Assume Python 3.10 +target-version = "py310" [tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ diff --git a/sbx/common/distributions.py b/sbx/common/distributions.py index 99e2a40..56ac193 100644 --- a/sbx/common/distributions.py +++ b/sbx/common/distributions.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp @@ -19,7 +19,7 @@ def mode(self) -> jnp.ndarray: return self.bijector.forward(self.distribution.mode()) @classmethod - def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): + def _parameter_properties(cls, dtype: Any | None, num_classes=None): td_properties = super()._parameter_properties(dtype, num_classes=num_classes) del td_properties["bijector"] return td_properties diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index a77c816..af5c053 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -1,5 +1,6 @@ from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Union +from collections.abc import Callable import flax.linen as nn import jax @@ -12,7 +13,7 @@ Array = Any Shape = tuple[int, ...] Dtype = Any # this could be a real type? -Axes = Union[int, Sequence[int]] +Axes = Union[int, Sequence[int]] # noqa: UP007 class BatchRenorm(Module): @@ -78,18 +79,18 @@ class BatchRenorm(Module): calculation for the variance. """ - use_running_average: Optional[bool] = None + use_running_average: bool | None = None axis: int = -1 momentum: float = 0.99 epsilon: float = 0.001 warmup_steps: int = 100_000 - dtype: Optional[Dtype] = None + dtype: Dtype | None = None param_dtype: Dtype = jnp.float32 use_bias: bool = True use_scale: bool = True bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones - axis_name: Optional[str] = None + axis_name: str | None = None axis_index_groups: Any = None # This parameter was added in flax.linen 0.7.2 (08/2023) # commented out to be compatible with a wider range of jax versions @@ -97,7 +98,7 @@ class BatchRenorm(Module): # use_fast_variance: bool = True @compact - def __call__(self, x, use_running_average: Optional[bool] = None): + def __call__(self, x, use_running_average: bool | None = None): """Normalizes the input using batch statistics. NOTE: diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index 74b6c64..f5f72a1 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -1,6 +1,6 @@ import io import pathlib -from typing import Any, Optional, Union +from typing import Any import jax import numpy as np @@ -21,35 +21,35 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm): def __init__( self, policy: type[BasePolicy], - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule], - qf_learning_rate: Optional[float] = None, + env: GymEnv | str, + learning_rate: float | Schedule, + qf_learning_rate: float | None = None, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, tuple[int, str]] = (1, "step"), + train_freq: int | tuple[int, str] = (1, "step"), gradient_steps: int = 1, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + action_noise: ActionNoise | None = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, - policy_kwargs: Optional[dict[str, Any]] = None, - tensorboard_log: Optional[str] = None, + policy_kwargs: dict[str, Any] | None = None, + tensorboard_log: str | None = None, verbose: int = 0, device: str = "auto", support_multi_env: bool = False, monitor_wrapper: bool = True, - seed: Optional[int] = None, + seed: int | None = None, use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, stats_window_size: int = 100, - param_resets: Optional[list[int]] = None, - supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, + param_resets: list[int] | None = None, + supported_action_spaces: tuple[type[spaces.Space], ...] | None = None, ): super().__init__( policy=policy, @@ -108,7 +108,7 @@ def _excluded_save_params(self) -> list[str]: def _update_learning_rate( # type: ignore[override] self, - optimizers: Union[list[optax.OptState], optax.OptState], + optimizers: list[optax.OptState] | optax.OptState, learning_rate: float, name: str = "learning_rate", ) -> None: @@ -129,7 +129,7 @@ def _update_learning_rate( # type: ignore[override] # Note: the optimizer must have been defined with inject_hyperparams optimizer.hyperparams["learning_rate"] = learning_rate - def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] + def set_random_seed(self, seed: int | None) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: # Sample random seed @@ -173,7 +173,7 @@ def _setup_model(self) -> None: def load_replay_buffer( self, - path: Union[str, pathlib.Path, io.BufferedIOBase], + path: str | pathlib.Path | io.BufferedIOBase, truncate_last_traj: bool = True, ) -> None: super().load_replay_buffer(path, truncate_last_traj) diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 13c4843..aa54b95 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar import gymnasium as gym import jax @@ -25,9 +25,9 @@ class OnPolicyAlgorithmJax(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, type[BasePolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule], + policy: str | type[BasePolicy], + env: GymEnv | str, + learning_rate: float | Schedule, n_steps: int, gamma: float, gae_lambda: float, @@ -36,14 +36,14 @@ def __init__( max_grad_norm: float, use_sde: bool, sde_sample_freq: int, - tensorboard_log: Optional[str] = None, + tensorboard_log: str | None = None, monitor_wrapper: bool = True, - policy_kwargs: Optional[dict[str, Any]] = None, + policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, + supported_action_spaces: tuple[type[spaces.Space], ...] | None = None, ): super().__init__( policy=policy, # type: ignore[arg-type] @@ -78,7 +78,7 @@ def _excluded_save_params(self) -> list[str]: def _update_learning_rate( # type: ignore[override] self, - optimizers: Union[list[optax.OptState], optax.OptState], + optimizers: list[optax.OptState] | optax.OptState, learning_rate: float, ) -> None: """ @@ -97,7 +97,7 @@ def _update_learning_rate( # type: ignore[override] # Note: the optimizer must have been defined with inject_hyperparams optimizer.hyperparams["learning_rate"] = learning_rate - def set_random_seed(self, seed: Optional[int]) -> None: # type: ignore[override] + def set_random_seed(self, seed: int | None) -> None: # type: ignore[override] super().set_random_seed(seed) if seed is None: # Sample random seed diff --git a/sbx/common/policies.py b/sbx/common/policies.py index ce23d09..d595a11 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -1,6 +1,7 @@ # import copy from collections.abc import Sequence -from typing import Callable, Optional, Union, no_type_check +from typing import no_type_check +from collections.abc import Callable import flax.linen as nn import jax @@ -50,11 +51,11 @@ def select_action(actor_state, observations): @no_type_check def predict( self, - observation: Union[np.ndarray, dict[str, np.ndarray]], - state: Optional[tuple[np.ndarray, ...]] = None, - episode_start: Optional[np.ndarray] = None, + observation: np.ndarray | dict[str, np.ndarray], + state: tuple[np.ndarray, ...] | None = None, + episode_start: np.ndarray | None = None, deterministic: bool = False, - ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: # self.set_training_mode(False) observation, vectorized_env = self.prepare_obs(observation) @@ -81,7 +82,7 @@ def predict( return actions, state - def prepare_obs(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[np.ndarray, bool]: + def prepare_obs(self, observation: np.ndarray | dict[str, np.ndarray]) -> tuple[np.ndarray, bool]: vectorized_env = False if isinstance(observation, dict): assert isinstance(self.observation_space, spaces.Dict) @@ -132,7 +133,7 @@ def set_training_mode(self, mode: bool) -> None: class ContinuousCritic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False - dropout_rate: Optional[float] = None + dropout_rate: float | None = None activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu output_dim: int = 1 @@ -154,7 +155,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: class SimbaContinuousCritic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False # for consistency, not used - dropout_rate: Optional[float] = None + dropout_rate: float | None = None activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu output_dim: int = 1 scale_factor: int = 4 @@ -179,7 +180,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: class VectorCritic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False - dropout_rate: Optional[float] = None + dropout_rate: float | None = None n_critics: int = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu output_dim: int = 1 @@ -210,7 +211,7 @@ class SimbaVectorCritic(nn.Module): net_arch: Sequence[int] # Note: we have use_layer_norm for consistency but it is not used (always on) use_layer_norm: bool = True - dropout_rate: Optional[float] = None + dropout_rate: float | None = None n_critics: int = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu output_dim: int = 1 diff --git a/sbx/common/utils.py b/sbx/common/utils.py index afea621..36d3596 100644 --- a/sbx/common/utils.py +++ b/sbx/common/utils.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Union import jax import jax.numpy as jnp @@ -37,7 +36,7 @@ def mask_from_prefix(params: FrozenDict, prefix: str = "NatureCNN_") -> dict: if the top-level module name starts with `prefix`. """ - def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> Union[dict, bool]: + def _traverse(tree: FrozenDict, path: tuple[str, ...] = ()) -> dict | bool: if isinstance(tree, dict): return {key: _traverse(value, (*path, key)) for key, value in tree.items()} # leaf diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 7c2b667..70f1f15 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, ClassVar, Literal import flax import flax.linen as nn @@ -53,31 +53,31 @@ class CrossQ(OffPolicyAlgorithmJax): def __init__( self, policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 1e-3, - qf_learning_rate: Optional[float] = None, + env: GymEnv | str, + learning_rate: float | Schedule = 1e-3, + qf_learning_rate: float | None = None, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, gamma: float = 0.99, - train_freq: Union[int, tuple[int, str]] = 1, + train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, policy_delay: int = 3, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + action_noise: ActionNoise | None = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, n_steps: int = 1, - ent_coef: Union[str, float] = "auto", - target_entropy: Union[Literal["auto"], float] = "auto", + ent_coef: str | float = "auto", + target_entropy: Literal["auto"] | float = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[dict[str, Any]] = None, - param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params + tensorboard_log: str | None = None, + policy_kwargs: dict[str, Any] | None = None, + param_resets: list[int] | None = None, # List of timesteps after which to reset the params verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ) -> None: diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index c4e2986..35da296 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any +from collections.abc import Callable import flax.linen as nn import jax @@ -23,7 +24,7 @@ class Critic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False use_batch_norm: bool = True - dropout_rate: Optional[float] = None + dropout_rate: float | None = None batch_norm_momentum: float = 0.99 renorm_warmup_steps: int = 100_000 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -62,7 +63,7 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> class SimbaCritic(nn.Module): net_arch: Sequence[int] - dropout_rate: Optional[float] = None + dropout_rate: float | None = None batch_norm_momentum: float = 0.99 renorm_warmup_steps: int = 100_000 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -102,7 +103,7 @@ class VectorCritic(nn.Module): use_batch_norm: bool = True batch_norm_momentum: float = 0.99 renorm_warmup_steps: int = 100_000 - dropout_rate: Optional[float] = None + dropout_rate: float | None = None n_critics: int = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu @@ -136,7 +137,7 @@ class SimbaVectorCritic(nn.Module): use_batch_norm: bool = True batch_norm_momentum: float = 0.99 renorm_warmup_steps: int = 100_000 - dropout_rate: Optional[float] = None + dropout_rate: float | None = None n_critics: int = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu scale_factor: int = 4 @@ -265,7 +266,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0.0, layer_norm: bool = False, batch_norm: bool = True, # for critic @@ -280,10 +281,10 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, actor_class: type[nn.Module] = Actor, @@ -458,7 +459,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0, layer_norm: bool = False, batch_norm: bool = True, @@ -471,10 +472,10 @@ def __init__( use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, actor_class: type[nn.Module] = SimbaActor, diff --git a/sbx/ddpg/ddpg.py b/sbx/ddpg/ddpg.py index 2f32d73..02ea1a4 100644 --- a/sbx/ddpg/ddpg.py +++ b/sbx/ddpg/ddpg.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise @@ -16,24 +16,24 @@ class DDPG(TD3): def __init__( self, policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, - qf_learning_rate: Optional[float] = 1e-3, + env: GymEnv | str, + learning_rate: float | Schedule = 3e-4, + qf_learning_rate: float | None = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, tuple[int, str]] = 1, + train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + action_noise: ActionNoise | None = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, n_steps: int = 1, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[dict[str, Any]] = None, + tensorboard_log: str | None = None, + policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ) -> None: diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 9a72dcb..a20f5d5 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar import gymnasium as gym import jax @@ -27,8 +27,8 @@ class DQN(OffPolicyAlgorithmJax): def __init__( self, policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 1e-4, + env: GymEnv | str, + learning_rate: float | Schedule = 1e-4, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 32, @@ -38,17 +38,17 @@ def __init__( exploration_fraction: float = 0.1, exploration_initial_eps: float = 1.0, exploration_final_eps: float = 0.05, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, optimize_memory_usage: bool = False, n_steps: int = 1, # max_grad_norm: float = 10, - train_freq: Union[int, tuple[int, str]] = 4, + train_freq: int | tuple[int, str] = 4, gradient_steps: int = 1, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[dict[str, Any]] = None, + tensorboard_log: str | None = None, + policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ) -> None: @@ -244,11 +244,11 @@ def _on_step(self) -> None: def predict( self, - observation: Union[np.ndarray, dict[str, np.ndarray]], - state: Optional[tuple[np.ndarray, ...]] = None, - episode_start: Optional[np.ndarray] = None, + observation: np.ndarray | dict[str, np.ndarray], + state: tuple[np.ndarray, ...] | None = None, + episode_start: np.ndarray | None = None, deterministic: bool = False, - ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, tuple[np.ndarray, ...] | None]: """ Overrides the base_class predict function to include epsilon-greedy exploration. diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index 8d25383..5dbc9b5 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional, Union +from typing import Any +from collections.abc import Callable import flax.linen as nn import jax @@ -49,13 +50,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, ): super().__init__( observation_space, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 7ffa7c6..d0506a2 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from dataclasses import field -from typing import Any, Callable, Optional, Union +from typing import Any +from collections.abc import Callable import flax.linen as nn import gymnasium as gym @@ -23,7 +24,7 @@ class Critic(nn.Module): net_arch: Sequence[int] activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh - features_extractor: Optional[type[NatureCNN]] = None + features_extractor: type[NatureCNN] | None = None features_dim: int = 512 @nn.compact @@ -46,13 +47,13 @@ class Actor(nn.Module): log_std_init: float = 0.0 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh # For Discrete, MultiDiscrete and MultiBinary actions - num_discrete_choices: Optional[Union[int, Sequence[int]]] = None + num_discrete_choices: int | Sequence[int] | None = None # For MultiDiscrete max_num_choices: int = 0 split_indices: np.ndarray = field(default_factory=lambda: np.array([])) # Last layer with small scale ortho_init: bool = False - features_extractor: Optional[type[NatureCNN]] = None + features_extractor: type[NatureCNN] | None = None features_dim: int = 512 def get_std(self) -> jnp.ndarray: @@ -124,7 +125,7 @@ def __init__( observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, ortho_init: bool = False, log_std_init: float = 0.0, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, @@ -133,11 +134,11 @@ def __init__( # this is to keep API consistent with SB3 use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Optional[type[NatureCNN]] = None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_class: type[NatureCNN] | None = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, share_features_extractor: bool = False, actor_class: type[nn.Module] = Actor, critic_class: type[nn.Module] = Critic, @@ -304,7 +305,7 @@ def __init__( observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, ortho_init: bool = False, log_std_init: float = 0, # ReLU for NatureCNN @@ -312,11 +313,11 @@ def __init__( use_sde: bool = False, use_expln: bool = False, clip_mean: float = 2, - features_extractor_class: Optional[type[NatureCNN]] = NatureCNN, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_class: type[NatureCNN] | None = NatureCNN, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, share_features_extractor: bool = False, actor_class: type[nn.Module] = Actor, critic_class: type[nn.Module] = Critic, diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index a718972..0a27fc4 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Any, ClassVar, Optional, TypeVar, Union +from typing import Any, ClassVar, TypeVar import jax import jax.numpy as jnp @@ -78,27 +78,27 @@ class PPO(OnPolicyAlgorithmJax): def __init__( self, - policy: Union[str, type[PPOPolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, + policy: str | type[PPOPolicy], + env: GymEnv | str, + learning_rate: float | Schedule = 3e-4, n_steps: int = 2048, batch_size: int = 64, n_epochs: int = 10, gamma: float = 0.99, gae_lambda: float = 0.95, - clip_range: Union[float, Schedule] = 0.2, - clip_range_vf: Union[None, float, Schedule] = None, + clip_range: float | Schedule = 0.2, + clip_range_vf: None | float | Schedule = None, normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, - target_kl: Optional[float] = None, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[dict[str, Any]] = None, + target_kl: float | None = None, + tensorboard_log: str | None = None, + policy_kwargs: dict[str, Any] | None = None, verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ): diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index e77e57d..ea27052 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional, Union +from typing import Any +from collections.abc import Callable import flax.linen as nn import jax @@ -27,7 +28,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0.0, layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, @@ -38,10 +39,10 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, actor_class: type[nn.Module] = SquashedGaussianActor, @@ -164,7 +165,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0, layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, @@ -173,11 +174,11 @@ def __init__( use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, # AdamW for simba optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, actor_class: type[nn.Module] = SimbaSquashedGaussianActor, diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 1379c0d..9c532b2 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, ClassVar, Literal import flax import flax.linen as nn @@ -54,32 +54,32 @@ class SAC(OffPolicyAlgorithmJax): def __init__( self, policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, - qf_learning_rate: Optional[float] = None, + env: GymEnv | str, + learning_rate: float | Schedule = 3e-4, + qf_learning_rate: float | None = None, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, tuple[int, str]] = 1, + train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, policy_delay: int = 1, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + action_noise: ActionNoise | None = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, n_steps: int = 1, - ent_coef: Union[str, float] = "auto", - target_entropy: Union[Literal["auto"], float] = "auto", + ent_coef: str | float = "auto", + target_entropy: Literal["auto"] | float = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[dict[str, Any]] = None, - param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params + tensorboard_log: str | None = None, + policy_kwargs: dict[str, Any] | None = None, + param_resets: list[int] | None = None, # List of timesteps after which to reset the params verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ) -> None: diff --git a/sbx/td3/policies.py b/sbx/td3/policies.py index d32d7f4..7b3bc9f 100644 --- a/sbx/td3/policies.py +++ b/sbx/td3/policies.py @@ -1,5 +1,6 @@ from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any +from collections.abc import Callable import flax.linen as nn import jax @@ -35,16 +36,16 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0.0, layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, ): diff --git a/sbx/td3/td3.py b/sbx/td3/td3.py index b3e8e04..0ca4be1 100644 --- a/sbx/td3/td3.py +++ b/sbx/td3/td3.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar import flax import jax @@ -29,29 +29,29 @@ class TD3(OffPolicyAlgorithmJax): def __init__( self, policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, - qf_learning_rate: Optional[float] = None, + env: GymEnv | str, + learning_rate: float | Schedule = 3e-4, + qf_learning_rate: float | None = None, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, tuple[int, str]] = 1, + train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, policy_delay: int = 2, target_policy_noise: float = 0.2, target_noise_clip: float = 0.5, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + action_noise: ActionNoise | None = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, n_steps: int = 1, - tensorboard_log: Optional[str] = None, + tensorboard_log: str | None = None, stats_window_size: int = 100, - policy_kwargs: Optional[dict[str, Any]] = None, - param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params + policy_kwargs: dict[str, Any] | None = None, + param_resets: list[int] | None = None, # List of timesteps after which to reset the params verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ) -> None: diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 4783e91..5d314c8 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional, Union +from typing import Any +from collections.abc import Callable import flax.linen as nn import jax @@ -27,7 +28,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0.0, layer_norm: bool = False, top_quantiles_to_drop_per_net: int = 2, @@ -40,10 +41,10 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, actor_class: type[nn.Module] = SquashedGaussianActor, @@ -186,7 +187,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + net_arch: list[int] | dict[str, list[int]] | None = None, dropout_rate: float = 0, layer_norm: bool = False, top_quantiles_to_drop_per_net: int = 2, @@ -197,10 +198,10 @@ def __init__( use_expln: bool = False, clip_mean: float = 2, features_extractor_class=None, - features_extractor_kwargs: Optional[dict[str, Any]] = None, + features_extractor_kwargs: dict[str, Any] | None = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, - optimizer_kwargs: Optional[dict[str, Any]] = None, + optimizer_kwargs: dict[str, Any] | None = None, n_critics: int = 2, share_features_extractor: bool = False, actor_class: type[nn.Module] = SimbaSquashedGaussianActor, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index 15d6183..8a549dd 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, ClassVar, Literal import flax import flax.linen as nn @@ -53,33 +53,33 @@ class TQC(OffPolicyAlgorithmJax): def __init__( self, policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, - qf_learning_rate: Optional[float] = None, + env: GymEnv | str, + learning_rate: float | Schedule = 3e-4, + qf_learning_rate: float | None = None, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, tuple[int, str]] = 1, + train_freq: int | tuple[int, str] = 1, gradient_steps: int = 1, policy_delay: int = 1, top_quantiles_to_drop_per_net: int = 2, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[dict[str, Any]] = None, + action_noise: ActionNoise | None = None, + replay_buffer_class: type[ReplayBuffer] | None = None, + replay_buffer_kwargs: dict[str, Any] | None = None, n_steps: int = 1, - ent_coef: Union[str, float] = "auto", - target_entropy: Union[Literal["auto"], float] = "auto", + ent_coef: str | float = "auto", + target_entropy: Literal["auto"] | float = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, stats_window_size: int = 100, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[dict[str, Any]] = None, - param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params + tensorboard_log: str | None = None, + policy_kwargs: dict[str, Any] | None = None, + param_resets: list[int] | None = None, # List of timesteps after which to reset the params verbose: int = 0, - seed: Optional[int] = None, + seed: int | None = None, device: str = "auto", _init_setup_model: bool = True, ) -> None: diff --git a/tests/test_flatten.py b/tests/test_flatten.py index c1307f2..beab0f5 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import gymnasium as gym import numpy as np @@ -17,7 +16,7 @@ class DummyEnv(gym.Env): def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} diff --git a/tests/test_run.py b/tests/test_run.py index 7c6ab59..c23b48f 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,3 @@ -from typing import Optional import flax.linen as nn import numpy as np @@ -176,7 +175,7 @@ def test_dqn(tmp_path) -> None: @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer]) -def test_dict(replay_buffer_class: Optional[type[HerReplayBuffer]]) -> None: +def test_dict(replay_buffer_class: type[HerReplayBuffer] | None) -> None: env = BitFlippingEnv(n_bits=2, continuous=True) model = SAC("MultiInputPolicy", env, replay_buffer_class=replay_buffer_class) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index c9c9595..3466c04 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import gymnasium as gym import numpy as np @@ -19,7 +18,7 @@ def step(self, action): assert action in self.action_space return self.observation_space.sample(), 0.0, False, False, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + def reset(self, *, seed: int | None = None, options: dict | None = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} From 2ff659aa29990090ca7ea7cdceb061c87e475730 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 5 Dec 2025 18:47:18 +0100 Subject: [PATCH 3/3] Reformat --- sbx/common/jax_layers.py | 3 +-- sbx/common/policies.py | 3 +-- sbx/crossq/policies.py | 3 +-- sbx/dqn/policies.py | 2 +- sbx/ppo/policies.py | 3 +-- sbx/sac/policies.py | 2 +- sbx/td3/policies.py | 3 +-- sbx/tqc/policies.py | 2 +- tests/test_run.py | 1 - 9 files changed, 8 insertions(+), 14 deletions(-) diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index af5c053..34f7e66 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -1,6 +1,5 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any, Union -from collections.abc import Callable import flax.linen as nn import jax diff --git a/sbx/common/policies.py b/sbx/common/policies.py index d595a11..f897abf 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -1,7 +1,6 @@ # import copy -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import no_type_check -from collections.abc import Callable import flax.linen as nn import jax diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 35da296..5b4465f 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -1,7 +1,6 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial from typing import Any -from collections.abc import Callable import flax.linen as nn import jax diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index 5dbc9b5..0cfd149 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -1,5 +1,5 @@ -from typing import Any from collections.abc import Callable +from typing import Any import flax.linen as nn import jax diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index d0506a2..6e8bbe6 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -1,7 +1,6 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import field from typing import Any -from collections.abc import Callable import flax.linen as nn import gymnasium as gym diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index ea27052..2ed054e 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -1,5 +1,5 @@ -from typing import Any from collections.abc import Callable +from typing import Any import flax.linen as nn import jax diff --git a/sbx/td3/policies.py b/sbx/td3/policies.py index 7b3bc9f..29e2adb 100644 --- a/sbx/td3/policies.py +++ b/sbx/td3/policies.py @@ -1,6 +1,5 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any -from collections.abc import Callable import flax.linen as nn import jax diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index 5d314c8..8ec1e0f 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -1,5 +1,5 @@ -from typing import Any from collections.abc import Callable +from typing import Any import flax.linen as nn import jax diff --git a/tests/test_run.py b/tests/test_run.py index c23b48f..6ac397e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,3 @@ - import flax.linen as nn import numpy as np import pytest