Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ files = [
"parcels/field.py",
"parcels/fieldset.py",
]
disable_error_code = "attr-defined,assignment,operator,call-overload,index,valid-type,override,misc,union-attr"

[[tool.mypy.overrides]]
module = [
Expand All @@ -153,9 +154,20 @@ module = [
"scipy.spatial",
"sklearn.cluster",
"zarr",
"zarr.storage",
"uxarray",
"xgcm",
"cftime",
"pykdtree.kdtree",
"netCDF4",
"pooch",
]
ignore_missing_imports = true

[[tool.mypy.overrides]] # TODO: This module is v3 code and should eventually be removed/converted to v4
module = "parcels.interaction.interactionkernel"
ignore_errors = true

[[tool.mypy.overrides]] # TODO: This module should stabilize before release of v4
module = "parcels.interpolators"
ignore_errors = true
4 changes: 2 additions & 2 deletions src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def __init__(
self.igrid = U.igrid

if W is None:
_assert_same_time_interval((U, V))
_assert_same_time_interval([U, V])
else:
_assert_same_time_interval((U, V, W))
_assert_same_time_interval([U, V, W])

self.time_interval = U.time_interval

Expand Down
37 changes: 22 additions & 15 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,13 @@ def add_constant(self, name, value):

@property
def gridset(self) -> list[BaseGrid]:
grids = []
grids: list[BaseGrid] = []
for field in self.fields.values():
if field.grid not in grids:
grids.append(field.grid)
return grids

@staticmethod
def from_copernicusmarine(ds: xr.Dataset):
"""Create a FieldSet from a Copernicus Marine Service xarray.Dataset.

Expand Down Expand Up @@ -238,19 +239,21 @@ def from_copernicusmarine(ds: xr.Dataset):
mesh="spherical",
)

fields = {}
fields: dict[str, Field | VectorField] = {}
if "U" in ds.data_vars and "V" in ds.data_vars:
fields["U"] = Field("U", ds["U"], grid, XLinear)
fields["V"] = Field("V", ds["V"], grid, XLinear)

field_U = Field("U", ds["U"], grid, XLinear)
field_V = Field("V", ds["V"], grid, XLinear)
fields["U"] = field_U
fields["V"] = field_V
if "W" in ds.data_vars:
ds["W"] -= ds[
"W"
] # Negate W to convert from up positive to down positive (as that's the direction of positive z)
fields["W"] = Field("W", ds["W"], grid, XLinear)
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
field_W = Field("W", ds["W"], grid, XLinear)
fields["W"] = field_W
fields["UVW"] = VectorField("UVW", field_U, field_V, field_W)
else:
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
fields["UV"] = VectorField("UV", field_U, field_V)

for varname in set(ds.data_vars) - set(fields.keys()):
fields[varname] = Field(varname, ds[varname], grid, XLinear)
Expand Down Expand Up @@ -279,16 +282,19 @@ def from_fesom2(ds: ux.UxDataset):
grid = UxGrid(ds.uxgrid, z=ds.coords["nz"], mesh="spherical")
ds = _discover_fesom2_U_and_V(ds)

fields = {}
fields: dict[str, Field | VectorField] = {}
if "U" in ds.data_vars and "V" in ds.data_vars:
fields["U"] = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
fields["V"] = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
field_U = Field("U", ds["U"], grid, _select_uxinterpolator(ds["U"]))
field_V = Field("V", ds["V"], grid, _select_uxinterpolator(ds["U"]))
fields["U"] = field_U
fields["V"] = field_V

if "W" in ds.data_vars:
fields["W"] = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))
fields["UVW"] = VectorField("UVW", fields["U"], fields["V"], fields["W"])
field_W = Field("W", ds["W"], grid, _select_uxinterpolator(ds["U"]))
fields["W"] = field_W
fields["UVW"] = VectorField("UVW", field_U, field_V, field_W)
else:
fields["UV"] = VectorField("UV", fields["U"], fields["V"])
fields["UV"] = VectorField("UV", field_U, field_V)

for varname in set(ds.data_vars) - set(fields.keys()):
fields[varname] = Field(varname, ds[varname], grid, _select_uxinterpolator(ds[varname]))
Expand Down Expand Up @@ -325,7 +331,8 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str:
return msg


def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str:
def _format_calendar_error_message(field: Field | VectorField, reference_datetime: TimeLike) -> str:
assert field.time_interval is not None
return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?"


Expand Down
11 changes: 6 additions & 5 deletions src/parcels/_core/index_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -9,8 +8,8 @@
from parcels._core.utils.time import timedelta_to_float

if TYPE_CHECKING:
from parcels import XGrid
from parcels._core.field import Field
from parcels.xgrid import XGrid


GRID_SEARCH_ERROR = -3
Expand All @@ -21,7 +20,7 @@
def _search_1d_array(
arr: np.array,
x: float,
) -> tuple[int, int]:
) -> tuple[np.array[int], np.array[float]]:
"""
Searches for particle locations in a 1D array and returns barycentric coordinate along dimension.

Expand Down Expand Up @@ -63,14 +62,14 @@ def _search_1d_array(
return np.atleast_1d(index), np.atleast_1d(bcoord)


def _search_time_index(field: Field, time: datetime):
def _search_time_index(field: Field, time: float):
"""Find and return the index and relative coordinate in the time array associated with a given time.

Parameters
----------
field: Field

time: datetime
time: float
This is the amount of time, in seconds (time_delta), in unix epoch
Note that we normalize to either the first or the last index
if the sampled value is outside the time value range.
Expand Down Expand Up @@ -172,6 +171,8 @@ def _search_indices_curvilinear_2d(
"""
if np.any(xi):
# If an initial guess is provided, we first perform a point in cell check for all guessed indices
assert xi is not None
assert yi is not None
is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi)
y_check = y[is_in_cell == 0]
x_check = x[is_in_cell == 0]
Expand Down
6 changes: 1 addition & 5 deletions src/parcels/_core/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import types
import warnings
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -24,9 +23,6 @@
AdvectionRK45,
)

if TYPE_CHECKING:
from collections.abc import Callable

__all__ = ["Kernel"]


Expand Down Expand Up @@ -84,7 +80,7 @@ def __init__(
# if (pyfunc is AdvectionRK4_3D) and fieldset.U.gridindexingtype == "croco":
# pyfunc = AdvectionRK4_3D_CROCO

self._pyfuncs: list[Callable] = pyfuncs
self._pyfuncs: list[types.FunctionType] = pyfuncs

@property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file)
def funcname(self):
Expand Down
4 changes: 2 additions & 2 deletions src/parcels/_core/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Variable:
def __init__(
self,
name,
dtype: np.dtype = np.float32,
dtype: type[np.float32 | np.float64 | np.int32 | np.int64] | None = np.float32,
initial=0,
to_write: bool | Literal["once"] = True,
attrs: dict | None = None,
Expand Down Expand Up @@ -147,7 +147,7 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va
raise ValueError(f"Variable name already exists: {var.name}")


def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass:
def get_default_particle(spatial_dtype: type[np.float32 | np.float64]) -> ParticleClass:
if spatial_dtype not in [np.float32, np.float64]:
raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}")

Expand Down
22 changes: 14 additions & 8 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import datetime
import sys
import types
import warnings
from collections.abc import Iterable
from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np
import xarray as xr
Expand All @@ -22,6 +25,9 @@
from parcels._logger import logger
from parcels._reprs import particleset_repr

if TYPE_CHECKING:
from parcels import FieldSet, ParticleClass, ParticleFile

__all__ = ["ParticleSet"]


Expand Down Expand Up @@ -62,10 +68,10 @@ class ParticleSet:

def __init__(
self,
fieldset,
pclass=Particle,
lon=None,
lat=None,
fieldset: FieldSet,
pclass: ParticleClass = Particle,
lon: np.array[float] = None,
lat: np.array[float] = None,
z=None,
time=None,
trajectory_ids=None,
Expand Down Expand Up @@ -429,12 +435,12 @@ def set_variable_write_status(self, var, write_status):

def execute(
self,
pyfunc,
pyfunc: types.FunctionType | Kernel,
dt: datetime.timedelta | np.timedelta64 | float,
endtime: np.timedelta64 | np.datetime64 | None = None,
runtime: datetime.timedelta | np.timedelta64 | float | None = None,
output_file=None,
verbose_progress=True,
output_file: ParticleFile = None,
verbose_progress: bool = True,
):
"""Execute a given kernel function over the particle set for multiple timesteps.

Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_core/utils/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def phi1D_quad(xsi: float) -> list[float]:
return phi


def phi2D_lin(eta: float, xsi: float) -> list[float]:
def phi2D_lin(eta: float, xsi: float) -> np.ndarray:
phi = np.column_stack([(1-xsi) * (1-eta),
xsi * (1-eta),
xsi * eta ,
Expand Down
13 changes: 7 additions & 6 deletions src/parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

import cftime
import numpy as np
Expand All @@ -12,7 +13,7 @@
T = TypeVar("T", bound="TimeLike")


class TimeInterval:
class TimeInterval(Generic[T]):
"""A class representing a time interval between two datetime or np.timedelta64 objects.

Parameters
Expand All @@ -27,7 +28,7 @@ class TimeInterval:
For the purposes of this codebase, the interval can be thought of as closed on the left and right.
"""

def __init__(self, left: T, right: T) -> None:
def __init__(self, left: T, right: T):
if not isinstance(left, (np.timedelta64, datetime, cftime.datetime, np.datetime64)):
raise ValueError(
f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(left)}."
Expand Down Expand Up @@ -128,7 +129,7 @@ def get_datetime_type_calendar(
return type(example_datetime), calendar


_TD_PRECISION_GETTER_FOR_UNIT = (
_TD_PRECISION_GETTER_FOR_UNIT: tuple[tuple[Callable[[timedelta], int], Literal["D", "s", "us"]], ...] = (
(lambda dt: dt.days, "D"),
(lambda dt: dt.seconds, "s"),
(lambda dt: dt.microseconds, "us"),
Expand All @@ -140,14 +141,14 @@ def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> n
return dt

try:
dts = []
dts: list[np.timedelta64] = []
for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT:
value = get_value_for_unit(dt)
if value != 0:
dts.append(np.timedelta64(value, np_unit))

if dts:
return sum(dts)
return np.sum(dts)
else:
return np.timedelta64(0, "s")
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_core/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class UxGrid(BaseGrid):
for interpolation on unstructured grids.
"""

def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid:
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh):
"""
Initializes the UxGrid with a uxarray grid and vertical coordinate array.

Expand Down
4 changes: 2 additions & 2 deletions src/parcels/_core/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _gtype(self):

TODO: Remove
"""
from parcels.grid import GridType
from parcels._core.basegrid import GridType

if len(self.lon.shape) <= 1:
if self.depth is None or len(self.depth.shape) <= 1:
Expand Down Expand Up @@ -380,7 +380,7 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
return result


def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: Hashable) -> _XGCM_AXIS_DIRECTION | None:
"""For a given dimension name in a grid, returns the direction axis it is on."""
for axis_name, axis in axes.items():
if dim in axis.coords.values():
Expand Down
1 change: 0 additions & 1 deletion src/parcels/_reprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def field_repr(field: Field) -> str: # TODO v4: Rework or remove entirely
out = f"""<{type(field).__name__}>
name : {field.name!r}
data : {field.data!r}
extrapolate time: {field.allow_time_extrapolation!r}
"""
return textwrap.dedent(out).strip()

Expand Down
8 changes: 4 additions & 4 deletions src/parcels/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ZeroInterpolator(
particle_positions: dict[str, float | np.ndarray],
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
field: Field,
) -> np.float32 | np.float64:
) -> float:
"""Template function used for the signature check of the lateral interpolation methods."""
return 0.0

Expand All @@ -44,7 +44,7 @@ def ZeroInterpolator_Vector(
particle_positions: dict[str, float | np.ndarray],
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
vectorfield: VectorField,
) -> np.float32 | np.float64:
) -> float:
"""Template function used for the signature check of the interpolation methods for velocity fields."""
return 0.0

Expand Down Expand Up @@ -400,8 +400,8 @@ def _Spatialslip(
particle_positions: dict[str, float | np.ndarray],
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
vectorfield: VectorField,
a: np.float32,
b: np.float32,
a: float,
b: float,
):
"""Helper function for spatial boundary condition interpolation for velocity fields."""
xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"]
Expand Down