From 216a74038cd5a80673c6b04d3a7b6cfa6f6d6688 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:55:52 +0100 Subject: [PATCH 1/8] Fix type annotations --- src/parcels/_core/field.py | 4 ++-- src/parcels/_core/fieldset.py | 10 ++++++---- src/parcels/_core/kernel.py | 6 +----- src/parcels/_core/utils/interpolation.py | 2 +- src/parcels/_core/utils/time.py | 9 +++++---- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/parcels/_core/field.py b/src/parcels/_core/field.py index b8c22ac0a..e225c0133 100644 --- a/src/parcels/_core/field.py +++ b/src/parcels/_core/field.py @@ -249,9 +249,9 @@ def __init__( self.igrid = U.igrid if W is None: - _assert_same_time_interval((U, V)) + _assert_same_time_interval([U, V]) else: - _assert_same_time_interval((U, V, W)) + _assert_same_time_interval([U, V, W]) self.time_interval = U.time_interval diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 4346fb233..8ca9a30cc 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -175,12 +175,13 @@ def add_constant(self, name, value): @property def gridset(self) -> list[BaseGrid]: - grids = [] + grids: list[BaseGrid] = [] for field in self.fields.values(): if field.grid not in grids: grids.append(field.grid) return grids + @staticmethod def from_copernicusmarine(ds: xr.Dataset): """Create a FieldSet from a Copernicus Marine Service xarray.Dataset. @@ -238,7 +239,7 @@ def from_copernicusmarine(ds: xr.Dataset): mesh="spherical", ) - fields = {} + fields: dict[str, Field | VectorField] = {} if "U" in ds.data_vars and "V" in ds.data_vars: fields["U"] = Field("U", ds["U"], grid, XLinear) fields["V"] = Field("V", ds["V"], grid, XLinear) @@ -279,7 +280,7 @@ def from_fesom2(ds: ux.UxDataset): grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="spherical") ds = _discover_fesom2_U_and_V(ds) - fields = {} + fields: dict[str, Field | VectorField] = {} if "U" in ds.data_vars and "V" in ds.data_vars: fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"])) fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"])) @@ -325,7 +326,8 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str: return msg -def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str: +def _format_calendar_error_message(field: Field | VectorField, reference_datetime: TimeLike) -> str: + assert field.time_interval is not None return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?" diff --git a/src/parcels/_core/kernel.py b/src/parcels/_core/kernel.py index 3fe8b717d..30811e2d3 100644 --- a/src/parcels/_core/kernel.py +++ b/src/parcels/_core/kernel.py @@ -2,7 +2,6 @@ import types import warnings -from typing import TYPE_CHECKING import numpy as np @@ -24,9 +23,6 @@ AdvectionRK45, ) -if TYPE_CHECKING: - from collections.abc import Callable - __all__ = ["Kernel"] @@ -84,7 +80,7 @@ def __init__( # if (pyfunc is AdvectionRK4_3D) and fieldset.U.gridindexingtype == "croco": # pyfunc = AdvectionRK4_3D_CROCO - self._pyfuncs: list[Callable] = pyfuncs + self._pyfuncs: list[types.FunctionType] = pyfuncs @property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file) def funcname(self): diff --git a/src/parcels/_core/utils/interpolation.py b/src/parcels/_core/utils/interpolation.py index e4893a813..3957d5885 100644 --- a/src/parcels/_core/utils/interpolation.py +++ b/src/parcels/_core/utils/interpolation.py @@ -22,7 +22,7 @@ def phi1D_quad(xsi: float) -> list[float]: return phi -def phi2D_lin(eta: float, xsi: float) -> list[float]: +def phi2D_lin(eta: float, xsi: float) -> np.ndarray: phi = np.column_stack([(1-xsi) * (1-eta), xsi * (1-eta), xsi * eta , diff --git a/src/parcels/_core/utils/time.py b/src/parcels/_core/utils/time.py index fe62813ef..0967e87ef 100644 --- a/src/parcels/_core/utils/time.py +++ b/src/parcels/_core/utils/time.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Literal, TypeVar import cftime import numpy as np @@ -128,7 +129,7 @@ def get_datetime_type_calendar( return type(example_datetime), calendar -_TD_PRECISION_GETTER_FOR_UNIT = ( +_TD_PRECISION_GETTER_FOR_UNIT: tuple[tuple[Callable[[timedelta], int], Literal["D", "s", "us"]], ...] = ( (lambda dt: dt.days, "D"), (lambda dt: dt.seconds, "s"), (lambda dt: dt.microseconds, "us"), @@ -140,14 +141,14 @@ def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> n return dt try: - dts = [] + dts: list[np.timedelta64] = [] for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT: value = get_value_for_unit(dt) if value != 0: dts.append(np.timedelta64(value, np_unit)) if dts: - return sum(dts) + return np.sum(dts) else: return np.timedelta64(0, "s") except Exception as e: From 08c8b6eef5232793e72b6c557fb38fc587fa0039 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 13:46:30 +0100 Subject: [PATCH 2/8] Remove time extrapolation from repr --- src/parcels/_reprs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parcels/_reprs.py b/src/parcels/_reprs.py index 6ea42992a..ba708027e 100644 --- a/src/parcels/_reprs.py +++ b/src/parcels/_reprs.py @@ -14,7 +14,6 @@ def field_repr(field: Field) -> str: # TODO v4: Rework or remove entirely out = f"""<{type(field).__name__}> name : {field.name!r} data : {field.data!r} - extrapolate time: {field.allow_time_extrapolation!r} """ return textwrap.dedent(out).strip() From 761a643cce3607b11cebdcf7f2f15a4d4d3725ae Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:09:30 +0100 Subject: [PATCH 3/8] Ignore select mypy errors Allowing for gradual resolution --- pyproject.toml | 1 + src/parcels/interpolators.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca43f03ba..131eba1ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ files = [ "parcels/field.py", "parcels/fieldset.py", ] +disable_error_code = "import-untyped,attr-defined,return-value,return,arg-type,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" [[tool.mypy.overrides]] module = [ diff --git a/src/parcels/interpolators.py b/src/parcels/interpolators.py index abd707eea..d593d5929 100644 --- a/src/parcels/interpolators.py +++ b/src/parcels/interpolators.py @@ -400,8 +400,8 @@ def _Spatialslip( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], vectorfield: VectorField, - a: np.float32, - b: np.float32, + a: float, + b: float, ): """Helper function for spatial boundary condition interpolation for velocity fields.""" xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"] From 9079247b0fa6975a0bf9eed867e56abd31454464 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:28:10 +0100 Subject: [PATCH 4/8] Fix "import-untyped" errors --- pyproject.toml | 9 ++++++++- src/parcels/_core/index_search.py | 2 +- src/parcels/_core/xgrid.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 131eba1ed..29f2a7e7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,7 @@ files = [ "parcels/field.py", "parcels/fieldset.py", ] -disable_error_code = "import-untyped,attr-defined,return-value,return,arg-type,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" +disable_error_code = "attr-defined,return-value,return,arg-type,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" [[tool.mypy.overrides]] module = [ @@ -154,9 +154,16 @@ module = [ "scipy.spatial", "sklearn.cluster", "zarr", + "zarr.storage", + "uxarray", + "xgcm", "cftime", "pykdtree.kdtree", "netCDF4", "pooch", ] ignore_missing_imports = true + +[[tool.mypy.overrides]] # TODO: This module is v3 code and should eventually be removed/converted to v4 +module = "parcels.interaction.interactionkernel" +ignore_errors = true diff --git a/src/parcels/_core/index_search.py b/src/parcels/_core/index_search.py index 173d14226..35c22fda5 100644 --- a/src/parcels/_core/index_search.py +++ b/src/parcels/_core/index_search.py @@ -9,8 +9,8 @@ from parcels._core.utils.time import timedelta_to_float if TYPE_CHECKING: + from parcels import XGrid from parcels._core.field import Field - from parcels.xgrid import XGrid GRID_SEARCH_ERROR = -3 diff --git a/src/parcels/_core/xgrid.py b/src/parcels/_core/xgrid.py index 86805a60f..238faae1e 100644 --- a/src/parcels/_core/xgrid.py +++ b/src/parcels/_core/xgrid.py @@ -278,7 +278,7 @@ def _gtype(self): TODO: Remove """ - from parcels.grid import GridType + from parcels._core.basegrid import GridType if len(self.lon.shape) <= 1: if self.depth is None or len(self.depth.shape) <= 1: From 584279a74ab4e84499a3670ed17260f119bbc290 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:43:25 +0100 Subject: [PATCH 5/8] Fix type annotations --- pyproject.toml | 2 +- src/parcels/_core/index_search.py | 2 +- src/parcels/_core/uxgrid.py | 2 +- src/parcels/interpolators.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 29f2a7e7d..5e7cfdf18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,7 @@ files = [ "parcels/field.py", "parcels/fieldset.py", ] -disable_error_code = "attr-defined,return-value,return,arg-type,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" +disable_error_code = "attr-defined,arg-type,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" [[tool.mypy.overrides]] module = [ diff --git a/src/parcels/_core/index_search.py b/src/parcels/_core/index_search.py index 35c22fda5..88af63fd9 100644 --- a/src/parcels/_core/index_search.py +++ b/src/parcels/_core/index_search.py @@ -21,7 +21,7 @@ def _search_1d_array( arr: np.array, x: float, -) -> tuple[int, int]: +) -> tuple[np.array[int], np.array[float]]: """ Searches for particle locations in a 1D array and returns barycentric coordinate along dimension. diff --git a/src/parcels/_core/uxgrid.py b/src/parcels/_core/uxgrid.py index a3fc9e240..fecda2d9f 100644 --- a/src/parcels/_core/uxgrid.py +++ b/src/parcels/_core/uxgrid.py @@ -18,7 +18,7 @@ class UxGrid(BaseGrid): for interpolation on unstructured grids. """ - def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid: + def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh): """ Initializes the UxGrid with a uxarray grid and vertical coordinate array. diff --git a/src/parcels/interpolators.py b/src/parcels/interpolators.py index d593d5929..dc7b5a778 100644 --- a/src/parcels/interpolators.py +++ b/src/parcels/interpolators.py @@ -35,7 +35,7 @@ def ZeroInterpolator( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], field: Field, -) -> np.float32 | np.float64: +) -> float: """Template function used for the signature check of the lateral interpolation methods.""" return 0.0 @@ -44,7 +44,7 @@ def ZeroInterpolator_Vector( particle_positions: dict[str, float | np.ndarray], grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]], vectorfield: VectorField, -) -> np.float32 | np.float64: +) -> float: """Template function used for the signature check of the interpolation methods for velocity fields.""" return 0.0 From f30482a253e7ea3fa3b032e68c823674acb1ddf2 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:00:41 +0100 Subject: [PATCH 6/8] Make TimeInterval into Generic class --- src/parcels/_core/utils/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parcels/_core/utils/time.py b/src/parcels/_core/utils/time.py index 0967e87ef..c807b004e 100644 --- a/src/parcels/_core/utils/time.py +++ b/src/parcels/_core/utils/time.py @@ -2,7 +2,7 @@ from collections.abc import Callable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Literal, TypeVar +from typing import TYPE_CHECKING, Generic, Literal, TypeVar import cftime import numpy as np @@ -13,7 +13,7 @@ T = TypeVar("T", bound="TimeLike") -class TimeInterval: +class TimeInterval(Generic[T]): """A class representing a time interval between two datetime or np.timedelta64 objects. Parameters @@ -28,7 +28,7 @@ class TimeInterval: For the purposes of this codebase, the interval can be thought of as closed on the left and right. """ - def __init__(self, left: T, right: T) -> None: + def __init__(self, left: T, right: T): if not isinstance(left, (np.timedelta64, datetime, cftime.datetime, np.datetime64)): raise ValueError( f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(left)}." From 2414796e50967689417ce8700cb3aa367552ed8b Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:27:46 +0100 Subject: [PATCH 7/8] Fix type annotations --- pyproject.toml | 6 +++++- src/parcels/_core/fieldset.py | 14 ++++++++------ src/parcels/_core/particle.py | 4 ++-- src/parcels/_core/particleset.py | 22 ++++++++++++++-------- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e7cfdf18..6b3b7c708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,7 @@ files = [ "parcels/field.py", "parcels/fieldset.py", ] -disable_error_code = "attr-defined,arg-type,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" +disable_error_code = "attr-defined,assignment,operator,call-overload,index,valid-type,override,misc,union-attr" [[tool.mypy.overrides]] module = [ @@ -167,3 +167,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] # TODO: This module is v3 code and should eventually be removed/converted to v4 module = "parcels.interaction.interactionkernel" ignore_errors = true + +[[tool.mypy.overrides]] # TODO: This module should stabilize before release of v4 +module = "parcels.interpolators" +ignore_errors = true diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index 8ca9a30cc..c25a6d921 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -241,17 +241,19 @@ def from_copernicusmarine(ds: xr.Dataset): fields: dict[str, Field | VectorField] = {} if "U" in ds.data_vars and "V" in ds.data_vars: - fields["U"] = Field("U", ds["U"], grid, XLinear) - fields["V"] = Field("V", ds["V"], grid, XLinear) - + field_U = Field("U", ds["U"], grid, XLinear) + field_V = Field("V", ds["V"], grid, XLinear) + fields["U"] = field_U + fields["V"] = field_V if "W" in ds.data_vars: ds["W"] -= ds[ "W" ] # Negate W to convert from up positive to down positive (as that's the direction of positive z) - fields["W"] = Field("W", ds["W"], grid, XLinear) - fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"]) + field_W = Field("W", ds["W"], grid, XLinear) + fields["W"] = field_W + fields["UVW"] = VectorField("UVW", field_U, field_V, field_W) else: - fields["UV"] = VectorField("UV", fields["U"], fields["V"]) + fields["UV"] = VectorField("UV", field_U, field_V) for varname in set(ds.data_vars) - set(fields.keys()): fields[varname] = Field(varname, ds[varname], grid, XLinear) diff --git a/src/parcels/_core/particle.py b/src/parcels/_core/particle.py index 69f37d68f..c1d81fd01 100644 --- a/src/parcels/_core/particle.py +++ b/src/parcels/_core/particle.py @@ -37,7 +37,7 @@ class Variable: def __init__( self, name, - dtype: np.dtype = np.float32, + dtype: type[np.float32 | np.float64 | np.int32 | np.int64] | None = np.float32, initial=0, to_write: bool | Literal["once"] = True, attrs: dict | None = None, @@ -147,7 +147,7 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va raise ValueError(f"Variable name already exists: {var.name}") -def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass: +def get_default_particle(spatial_dtype: type[np.float32 | np.float64]) -> ParticleClass: if spatial_dtype not in [np.float32, np.float64]: raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}") diff --git a/src/parcels/_core/particleset.py b/src/parcels/_core/particleset.py index fe23db61a..4ece9f1b5 100644 --- a/src/parcels/_core/particleset.py +++ b/src/parcels/_core/particleset.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import datetime import sys +import types import warnings from collections.abc import Iterable -from typing import Literal +from typing import TYPE_CHECKING, Literal import numpy as np import xarray as xr @@ -22,6 +25,9 @@ from parcels._logger import logger from parcels._reprs import particleset_repr +if TYPE_CHECKING: + from parcels import FieldSet, ParticleClass, ParticleFile + __all__ = ["ParticleSet"] @@ -62,10 +68,10 @@ class ParticleSet: def __init__( self, - fieldset, - pclass=Particle, - lon=None, - lat=None, + fieldset: FieldSet, + pclass: ParticleClass = Particle, + lon: np.array[float] = None, + lat: np.array[float] = None, z=None, time=None, trajectory_ids=None, @@ -429,12 +435,12 @@ def set_variable_write_status(self, var, write_status): def execute( self, - pyfunc, + pyfunc: types.FunctionType | Kernel, dt: datetime.timedelta | np.timedelta64 | float, endtime: np.timedelta64 | np.datetime64 | None = None, runtime: datetime.timedelta | np.timedelta64 | float | None = None, - output_file=None, - verbose_progress=True, + output_file: ParticleFile = None, + verbose_progress: bool = True, ): """Execute a given kernel function over the particle set for multiple timesteps. From 829feaa21df47c3d6f0d2bcd47ed7eee2ac41831 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 5 Dec 2025 16:51:38 +0100 Subject: [PATCH 8/8] Fix type annotations --- src/parcels/_core/fieldset.py | 13 ++++++++----- src/parcels/_core/index_search.py | 7 ++++--- src/parcels/_core/xgrid.py | 2 +- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/parcels/_core/fieldset.py b/src/parcels/_core/fieldset.py index c25a6d921..547711c89 100644 --- a/src/parcels/_core/fieldset.py +++ b/src/parcels/_core/fieldset.py @@ -284,14 +284,17 @@ def from_fesom2(ds: ux.UxDataset): fields: dict[str, Field | VectorField] = {} if "U" in ds.data_vars and "V" in ds.data_vars: - fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"])) - fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"])) + field_U = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"])) + field_V = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"])) + fields["U"] = field_U + fields["V"] = field_V if "W" in ds.data_vars: - fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"])) - fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"]) + field_W = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"])) + fields["W"] = field_W + fields["UVW"] = VectorField("UVW", field_U, field_V, field_W) else: - fields["UV"] = VectorField("UV", fields["U"], fields["V"]) + fields["UV"] = VectorField("UV", field_U, field_V) for varname in set(ds.data_vars) - set(fields.keys()): fields[varname] = Field(varname, ds[varname], grid, _select_uxinterpolator(ds[varname])) diff --git a/src/parcels/_core/index_search.py b/src/parcels/_core/index_search.py index 88af63fd9..2198f7b58 100644 --- a/src/parcels/_core/index_search.py +++ b/src/parcels/_core/index_search.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import datetime from typing import TYPE_CHECKING import numpy as np @@ -63,14 +62,14 @@ def _search_1d_array( return np.atleast_1d(index), np.atleast_1d(bcoord) -def _search_time_index(field: Field, time: datetime): +def _search_time_index(field: Field, time: float): """Find and return the index and relative coordinate in the time array associated with a given time. Parameters ---------- field: Field - time: datetime + time: float This is the amount of time, in seconds (time_delta), in unix epoch Note that we normalize to either the first or the last index if the sampled value is outside the time value range. @@ -172,6 +171,8 @@ def _search_indices_curvilinear_2d( """ if np.any(xi): # If an initial guess is provided, we first perform a point in cell check for all guessed indices + assert xi is not None + assert yi is not None is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi) y_check = y[is_in_cell == 0] x_check = x[is_in_cell == 0] diff --git a/src/parcels/_core/xgrid.py b/src/parcels/_core/xgrid.py index 238faae1e..a072029ba 100644 --- a/src/parcels/_core/xgrid.py +++ b/src/parcels/_core/xgrid.py @@ -380,7 +380,7 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]: return result -def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: +def get_axis_from_dim_name(axes: _XGCM_AXES, dim: Hashable) -> _XGCM_AXIS_DIRECTION | None: """For a given dimension name in a grid, returns the direction axis it is on.""" for axis_name, axis in axes.items(): if dim in axis.coords.values():