Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ body:
description: Which version or versions of Python are you seeing the problem on?
multiple: true
options:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
fail-fast: false
matrix:
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ repos:
rev: v0.13.0
hooks:
# Run the linter.
- id: ruff
- id: ruff-check
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this has changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'ruff-check' is the preferred id as of v0.11.10 of the pre-commit hook (change was made in astral-sh/ruff-pre-commit@39f54b7).

args: [ --fix ]
# Abort if ruff linter fails as there is some duplication of functionality with
# the slow pylint hook
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ persistent=yes

# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.9
py-version=3.10

# Discover python modules and packages in the file system subtree.
recursive=no
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed

- Internal `SilentTQDM` class has been removed; all user-facing functionality has been preserved.
- Dropped support for python 3.9 which is near end-of-life.

### Deprecated

Expand Down
5 changes: 2 additions & 3 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import json
import os
import time
from typing import Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -96,8 +95,8 @@ def setup_stein_kernel(

# Define the score function as the gradient of log density given by the KDE
def score_function(
x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]:
x: Shaped[Array, " n d"] | Shaped[Array, ""] | float | int,
) -> Shaped[Array, " n d"] | Shaped[Array, " 1 1"]:
"""
Compute the score function (gradient of log density) for a single point.

Expand Down
3 changes: 1 addition & 2 deletions benchmark/david_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import math
import time
from pathlib import Path
from typing import Optional

import equinox as eqx
import jax.numpy as jnp
Expand All @@ -60,7 +59,7 @@
# (disable line too long and too many statements ruff)
def benchmark_coreset_algorithms(
in_path: Path = Path("../examples/data/david_orig.png"),
out_path: Optional[Path] = Path(
out_path: Path | None = Path(
"../examples/benchmarking_images/david_benchmark_results.png"
),
downsampling_factor: int = 1,
Expand Down
6 changes: 3 additions & 3 deletions benchmark/mnist_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import json
import os
import time
from typing import Any, NamedTuple, Optional, Union
from typing import Any, NamedTuple

import equinox as eqx
import jax
Expand Down Expand Up @@ -143,7 +143,7 @@ def __call__(self, x: jnp.ndarray, training: bool = True) -> jnp.ndarray:
class TrainState(train_state.TrainState):
"""Custom train state with batch statistics and dropout RNG."""

batch_stats: Optional[dict[str, jnp.ndarray]]
batch_stats: dict[str, jnp.ndarray] | None
dropout_rng: KeyArrayLike


Expand Down Expand Up @@ -416,7 +416,7 @@ def prepare_datasets() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarr
def train_model(
data_bundle: dict[str, jnp.ndarray],
key: KeyArrayLike,
config: dict[str, Union[int, float]],
config: dict[str, int | float],
) -> dict[str, float]:
"""
Train the model and return the results.
Expand Down
46 changes: 20 additions & 26 deletions coreax/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@

from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Literal

import jax
import jax.numpy as jnp
import jax.random as jr
from jax import Array
from jaxtyping import Shaped
from typing_extensions import Literal, override
from typing_extensions import override

from coreax.data import Data, _atleast_2d_consistent
from coreax.kernels import UniCompositeKernel
Expand Down Expand Up @@ -177,14 +177,12 @@ class MonteCarloApproximateKernel(RandomRegressionKernel):

def gramian_row_mean(
self,
x: Union[
Shaped[Array, " n d"],
Shaped[Array, " d"],
Shaped[Array, ""],
float,
int,
Data,
],
x: Shaped[Array, " n d"]
| Shaped[Array, " d"]
| Shaped[Array, ""]
| float
| int
| Data,
**kwargs: Any,
) -> Shaped[Array, " n"]:
r"""
Expand Down Expand Up @@ -232,14 +230,12 @@ class ANNchorApproximateKernel(RandomRegressionKernel):

def gramian_row_mean(
self,
x: Union[
Shaped[Array, " n d"],
Shaped[Array, " d"],
Shaped[Array, ""],
float,
int,
Data,
],
x: Shaped[Array, " n d"]
| Shaped[Array, " d"]
| Shaped[Array, ""]
| float
| int
| Data,
**kwargs: Any,
) -> Shaped[Array, " n"]:
r"""
Expand Down Expand Up @@ -305,14 +301,12 @@ class NystromApproximateKernel(RandomRegressionKernel):

def gramian_row_mean(
self,
x: Union[
Shaped[Array, " n d"],
Shaped[Array, " d"],
Shaped[Array, ""],
float,
int,
Data,
],
x: Shaped[Array, " n d"]
| Shaped[Array, " d"]
| Shaped[Array, ""]
| float
| int
| Data,
**kwargs: Any,
) -> Shaped[Array, " n"]:
r"""
Expand Down
24 changes: 12 additions & 12 deletions coreax/benchmark_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

from collections.abc import Callable
from typing import Optional, TypeVar, Union
from typing import TypeVar

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -61,12 +61,12 @@ class IterativeKernelHerding(KernelHerding[_Data]): # pylint: disable=too-many-
"""

num_iterations: int = 1
t_schedule: Optional[Array] = None
t_schedule: Array | None = None

def reduce(
self,
dataset: _Data,
solver_state: Optional[HerdingState] = None,
solver_state: HerdingState | None = None,
) -> tuple[Coresubset[_Data], HerdingState]:
"""
Perform Kernel Herding reduction followed by additional refinement iterations.
Expand Down Expand Up @@ -118,7 +118,7 @@ def initialise_solvers( # noqa: C901
train_data_umap: Data,
key: KeyArrayLike,
cpp_oversampling_factor: int,
leaf_size: Optional[int] = None,
leaf_size: int | None = None,
) -> dict[str, Callable[[int], Solver]]:
"""
Initialise and return a list of solvers for various coreset algorithms.
Expand Down Expand Up @@ -147,7 +147,7 @@ def initialise_solvers( # noqa: C901
kernel = SquaredExponentialKernel(length_scale=length_scale)
sqrt_kernel = kernel.get_sqrt_kernel(16)

def _get_thinning_solver(_size: int) -> Union[KernelThinning, MapReduce]:
def _get_thinning_solver(_size: int) -> KernelThinning | MapReduce:
"""
Set up kernel thinning solver.

Expand All @@ -169,7 +169,7 @@ def _get_thinning_solver(_size: int) -> Union[KernelThinning, MapReduce]:
return thinning_solver
return MapReduce(thinning_solver, leaf_size=leaf_size)

def _get_herding_solver(_size: int) -> Union[KernelHerding, MapReduce]:
def _get_herding_solver(_size: int) -> KernelHerding | MapReduce:
"""
Set up kernel herding solver.

Expand All @@ -185,7 +185,7 @@ def _get_herding_solver(_size: int) -> Union[KernelHerding, MapReduce]:
return herding_solver
return MapReduce(herding_solver, leaf_size=leaf_size)

def _get_stein_solver(_size: int) -> Union[SteinThinning, MapReduce]:
def _get_stein_solver(_size: int) -> SteinThinning | MapReduce:
"""
Set up Stein thinning solver.

Expand All @@ -200,8 +200,8 @@ def _get_stein_solver(_size: int) -> Union[SteinThinning, MapReduce]:

# Define the score function as the gradient of log density given by the KDE
def score_function(
x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]:
x: Shaped[Array, " n d"] | Shaped[Array, ""] | float | int,
) -> Shaped[Array, " n d"] | Shaped[Array, " 1 1"]:
"""
Compute the score function (gradient of log density) for a single point.

Expand Down Expand Up @@ -264,7 +264,7 @@ def _get_compress_solver(_size: int) -> CompressPlusPlus:

def _get_probabilistic_herding_solver(
_size: int,
) -> Union[IterativeKernelHerding, MapReduce]:
) -> IterativeKernelHerding | MapReduce:
"""
Set up KernelHerding with probabilistic selection.

Expand All @@ -289,7 +289,7 @@ def _get_probabilistic_herding_solver(

def _get_iterative_herding_solver(
_size: int,
) -> Union[IterativeKernelHerding, MapReduce]:
) -> IterativeKernelHerding | MapReduce:
"""
Set up KernelHerding with probabilistic selection.

Expand All @@ -314,7 +314,7 @@ def _get_iterative_herding_solver(

def _get_cubic_iterative_herding_solver(
_size: int,
) -> Union[IterativeKernelHerding, MapReduce]:
) -> IterativeKernelHerding | MapReduce:
"""
Set up KernelHerding with probabilistic selection.

Expand Down
21 changes: 10 additions & 11 deletions coreax/coreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Final,
Generic,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -131,30 +130,30 @@ def __init__(self, nodes: Data, pre_coreset_data: _TOriginalData_co) -> None:
@classmethod
@overload
def build(
cls, nodes: Union[Data, Array], pre_coreset_data: Array
cls, nodes: Data | Array, pre_coreset_data: Array
) -> "PseudoCoreset[Data]": ...

@classmethod
@overload
def build(
cls,
nodes: Union[Data, Array],
nodes: Data | Array,
pre_coreset_data: tuple[Array, Array],
) -> "PseudoCoreset[SupervisedData]": ...

@classmethod
@overload
def build(
cls,
nodes: Union[Data, Array],
nodes: Data | Array,
pre_coreset_data: _TOriginalData,
) -> "PseudoCoreset[_TOriginalData]": ...

@classmethod
def build(
cls,
nodes: Union[Data, Array],
pre_coreset_data: Union[_TOriginalData, Array, tuple[Array, Array]],
nodes: Data | Array,
pre_coreset_data: _TOriginalData | Array | tuple[Array, Array],
) -> "PseudoCoreset[Data]\
| PseudoCoreset[SupervisedData]\
| PseudoCoreset[_TOriginalData]\
Expand Down Expand Up @@ -253,30 +252,30 @@ def __init__(self, indices: Data, pre_coreset_data: _TOriginalData_co) -> None:
@classmethod
@overload
def build(
cls, indices: Union[Data, Array], pre_coreset_data: Array
cls, indices: Data | Array, pre_coreset_data: Array
) -> "Coresubset[Data]": ...

@classmethod
@overload
def build(
cls,
indices: Union[Data, Array],
indices: Data | Array,
pre_coreset_data: tuple[Array, Array],
) -> "Coresubset[SupervisedData]": ...

@classmethod
@overload
def build(
cls,
indices: Union[Data, Array],
indices: Data | Array,
pre_coreset_data: _TOriginalData,
) -> "Coresubset[_TOriginalData]": ...

@classmethod
def build(
cls,
indices: Union[Data, Array],
pre_coreset_data: Union[_TOriginalData, Array, tuple[Array, Array]],
indices: Data | Array,
pre_coreset_data: _TOriginalData | Array | tuple[Array, Array],
) -> "Coresubset[Data] | Coresubset[SupervisedData] | Coresubset[_TOriginalData]":
"""
Construct a Coresubset from Data or raw Arrays.
Expand Down
Loading
Loading