Skip to content

Commit 3cde02e

Browse files
Merge branch 'main' into better-community
2 parents 433b45d + f8428ce commit 3cde02e

File tree

18 files changed

+137
-126
lines changed

18 files changed

+137
-126
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ jobs:
170170
name: "TypeChecking: pixi run typing"
171171
runs-on: ubuntu-latest
172172
needs: [cache-pixi-lock]
173-
# TODO v4: Enable typechecking again
174-
if: false
175173
steps:
176174
- name: Checkout
177175
uses: actions/checkout@v5
@@ -185,14 +183,8 @@ jobs:
185183
cache: true
186184
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
187185
- name: Typechecking
188-
run: |
189-
pixi run typing --non-interactive --html-report mypy-report
190-
- name: Upload test results
191-
if: ${{ always() }} # Upload even on mypy error
192-
uses: actions/upload-artifact@v7
193-
with:
194-
name: Mypy report
195-
path: mypy-report
186+
run: | # TODO: Remove `|| true` once typechecking is stable
187+
pixi run typing --output-format github || true
196188
build-and-upload-nightly-parcels: # for alpha testing
197189
needs: [cache-pixi-lock]
198190
permissions:
@@ -220,7 +212,7 @@ jobs:
220212
run: |
221213
for pkg in $(find output -type f \( -name "*.conda" -o -name "*.tar.bz2" \) ); do
222214
echo "Uploading ${pkg}"
223-
rattler-build upload prefix -c parcels "${pkg}"
215+
pixi run -e rattler-build rattler-build upload prefix -c parcels "${pkg}"
224216
done
225217
env:
226218
PREFIX_API_KEY: ${{ secrets.PREFIX_API_KEY }}

docs/development/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ See below for more Pixi commands relevant to development.
132132

133133
**Code quality**
134134

135-
- `pixi run lint` - Run pre-commit hooks on all files (includes formatting, linting, and other code quality checks)
136-
- `pixi run typing` - Run mypy type checking on the codebase
135+
- `pixi run lint` - Run [pre-commit](https://pre-commit.com/) hooks on all files (includes formatting, linting, and other code quality checks)
136+
- `pixi run typing` - Run [ty](https://docs.astral.sh/ty/) type checking on the codebase
137137

138138
**Different environments**
139139

pixi.toml

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@ trajan = "*"
8282
matplotlib-base = ">=2.0.2"
8383
gsw = "*"
8484

85+
[feature.devtools.dependencies]
86+
pdbpp = "*"
87+
line_profiler = "*"
88+
memory_profiler = "*"
89+
snakeviz = "*"
90+
icecream = "*"
91+
ipykernel = "*"
92+
snoop = "*"
93+
pyinstrument = "*"
94+
95+
[feature.devtools.target.linux-64.dependencies]
96+
memray = "*"
97+
98+
[feature.devtools.target.osx-64.dependencies]
99+
memray = "*"
100+
101+
[feature.devtools.target.osx-arm64.dependencies]
102+
memray = "*"
103+
85104
[feature.docs.dependencies]
86105
parcels = { path = "." }
87106
numpydoc = "*"
@@ -113,12 +132,10 @@ numpydoc = "*"
113132
numpydoc-lint = { cmd = "python tools/numpydoc-public-api.py", description = "Lint public API docstrings with numpydoc." }
114133

115134
[feature.typing.dependencies]
116-
mypy = "*"
117-
lxml = "*" # in CI
118-
types-tqdm = "*"
135+
ty = "*"
119136

120137
[feature.typing.tasks]
121-
typing = { cmd = "mypy src/parcels --install-types", description = "Run static type checking with mypy." }
138+
typing = { cmd = "ty check", description = "Run static type checking with ty." }
122139

123140

124141
[environments]
@@ -128,6 +145,7 @@ default = { features = [
128145
"typing",
129146
"pre-commit",
130147
"numpydoc",
148+
"devtools",
131149
], solve-group = "main" }
132150
test = { features = ["test"], solve-group = "main" }
133151
test-minimum = { features = ["test", "minimum"] }

pyproject.toml

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -157,25 +157,8 @@ convention = "numpy"
157157
[tool.ruff.lint.isort]
158158
known-first-party = ["parcels"]
159159

160-
[tool.mypy]
161-
files = [
162-
"parcels/_typing.py",
163-
"parcels/tools/*.py",
164-
"parcels/grid.py",
165-
"parcels/field.py",
166-
"parcels/fieldset.py",
167-
]
168-
169-
[[tool.mypy.overrides]]
170-
module = [
171-
"parcels._version_setup",
172-
"mpi4py",
173-
"scipy.spatial",
174-
"sklearn.cluster",
175-
"zarr",
176-
"cftime",
177-
"pykdtree.kdtree",
178-
"netCDF4",
179-
"pooch",
160+
[tool.ty.src]
161+
include = ["./src/"]
162+
exclude = [
163+
"./src/parcels/interpolators/", # ignore for now
180164
]
181-
ignore_missing_imports = true

src/parcels/_compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
KMeans: Any | None = None
99

1010
try:
11-
from mpi4py import MPI # type: ignore[no-redef]
11+
from mpi4py import MPI # type: ignore[import-untyped,no-redef]
1212
except ModuleNotFoundError:
1313
pass
1414

1515
# KMeans is used in MPI. sklearn not installed by default
1616
try:
17-
from sklearn.cluster import KMeans # type: ignore[no-redef]
17+
from sklearn.cluster import KMeans
1818
except ModuleNotFoundError:
1919
pass
2020

src/parcels/_core/field.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from collections.abc import Callable
4+
from collections.abc import Callable, Sequence
55
from datetime import datetime
66

77
import numpy as np
@@ -428,7 +428,7 @@ def _assert_valid_uxdataarray(data: ux.UxDataArray):
428428
)
429429

430430

431-
def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: ux.Grid | XGrid):
431+
def _assert_compatible_combination(data: xr.DataArray | ux.UxDataArray, grid: UxGrid | XGrid):
432432
if isinstance(data, ux.UxDataArray):
433433
if not isinstance(grid, UxGrid):
434434
raise ValueError(
@@ -448,7 +448,7 @@ def _get_time_interval(data: xr.DataArray | ux.UxDataArray) -> TimeInterval | No
448448
return TimeInterval(data.time.values[0], data.time.values[-1])
449449

450450

451-
def _assert_same_time_interval(fields: list[Field]) -> None:
451+
def _assert_same_time_interval(fields: Sequence[Field]) -> None:
452452
if len(fields) == 0:
453453
return
454454

src/parcels/_core/fieldset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def from_ugrid_conventions(cls, ds: ux.UxDataset, mesh: str = "spherical"):
223223
)
224224

225225
for varname in set(ds.data_vars) - set(fields.keys()):
226-
fields[varname] = Field(varname, ds[varname], grid, _select_uxinterpolator(ds[varname]))
226+
fields[varname] = Field(str(varname), ds[varname], grid, _select_uxinterpolator(ds[varname]))
227227

228228
return cls(list(fields.values()))
229229

@@ -319,7 +319,7 @@ def from_sgrid_conventions(
319319
)
320320

321321
for varname in set(ds.data_vars) - set(fields.keys()) - skip_vars:
322-
fields[varname] = Field(varname, ds[varname], grid, XLinear)
322+
fields[varname] = Field(str(varname), ds[varname], grid, XLinear)
323323

324324
return cls(list(fields.values()))
325325

@@ -353,7 +353,7 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str:
353353
return msg
354354

355355

356-
def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str:
356+
def _format_calendar_error_message(field: Field | VectorField, reference_datetime: TimeLike) -> str:
357357
return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?"
358358

359359

src/parcels/_core/index_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from parcels._core.field import Field
13-
from parcels.xgrid import XGrid
13+
from parcels._core.xgrid import XGrid
1414

1515

1616
GRID_SEARCH_ERROR = -3
@@ -19,7 +19,7 @@
1919

2020

2121
def _search_1d_array(
22-
arr: np.array,
22+
arr: np.ndarray,
2323
x: float,
2424
) -> tuple[int, int]:
2525
"""

src/parcels/_core/particle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import operator
4-
from typing import Literal
4+
from typing import Any, Literal
55

66
import numpy as np
77

@@ -37,7 +37,7 @@ class Variable:
3737
def __init__(
3838
self,
3939
name,
40-
dtype: np.dtype = np.float32,
40+
dtype: np.dtype[Any] | type[np.generic] = np.float32,
4141
initial=0,
4242
to_write: bool | Literal["once"] = True,
4343
attrs: dict | None = None,
@@ -122,7 +122,7 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va
122122
raise ValueError(f"Variable name already exists: {var.name}")
123123

124124

125-
def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass:
125+
def get_default_particle(spatial_dtype: type[np.float32] | type[np.float64]) -> ParticleClass:
126126
if spatial_dtype not in [np.float32, np.float64]:
127127
raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}")
128128

@@ -177,7 +177,7 @@ def create_particle_data(
177177
nparticles: int,
178178
ngrids: int,
179179
time_interval: TimeInterval,
180-
initial: dict[str, np.array] | None = None,
180+
initial: dict[str, np.ndarray] | None = None,
181181
):
182182
if initial is None:
183183
initial = {}

src/parcels/_core/particlefile.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
196196
if self.create_new_zarrfile:
197197
if self.chunks is None:
198198
self._chunks = (nparticles, 1)
199-
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]): # type: ignore[index]
200-
arrsize = (self._maxids, self.chunks[1]) # type: ignore[index]
199+
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
200+
arrsize = (self._maxids, self.chunks[1])
201201
else:
202-
arrsize = (len(ids), self.chunks[1]) # type: ignore[index]
202+
arrsize = (len(ids), self.chunks[1])
203203
ds = xr.Dataset(
204204
attrs=self.metadata,
205205
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
@@ -221,7 +221,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
221221
data[ids, 0] = particle_data[var.name][indices_to_write]
222222
dims = ["trajectory", "obs"]
223223
ds[var.name] = xr.DataArray(data=data, dims=dims, attrs=attrs[var.name])
224-
ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks # type: ignore[index]
224+
ds[var.name].encoding["chunks"] = self.chunks[0] if var.to_write == "once" else self.chunks
225225
ds.to_zarr(store, mode="w")
226226
self._create_new_zarrfile = False
227227
else:
@@ -234,7 +234,7 @@ def _write_particle_data(self, *, particle_data, pclass, time_interval, time, in
234234
if len(once_ids) > 0:
235235
Z[var.name].vindex[ids_once] = particle_data[var.name][indices_to_write_once]
236236
else:
237-
if max(obs) >= Z[var.name].shape[1]: # type: ignore[type-var]
237+
if max(obs) >= Z[var.name].shape[1]:
238238
self._extend_zarr_dims(Z[var.name], store, dtype=var.dtype, axis=1)
239239
Z[var.name].vindex[ids, obs] = particle_data[var.name][indices_to_write]
240240

0 commit comments

Comments
 (0)