From b7347d38bc94a1f84f258faa09bebf0e3a743bbd Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Thu, 10 Jul 2025 17:54:01 +0200 Subject: [PATCH 1/2] First attempt at implementing XGrid Interpolators --- parcels/application_kernels/interpolation.py | 46 ++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py index 1622ffcac7..7332389863 100644 --- a/parcels/application_kernels/interpolation.py +++ b/parcels/application_kernels/interpolation.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from parcels.uxgrid import _UXGRID_AXES + from parcels.xgrid import _XGRID_AXES __all__ = [ "UXPiecewiseConstantFace", @@ -17,6 +18,51 @@ ] +def XTriCurviLinear( + field: Field, + ti: int, + position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]], + tau: np.float32 | np.float64, + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, +): + """Trilinear interpolation on a curvilinear grid.""" + xi, xsi = position["X"] + yi, eta = position["Y"] + zi, zeta = position["Z"] + data = field.data + + return ( + ( + (1 - xsi) * (1 - eta) * data.isel(YG=yi, XG=xi) + + xsi * (1 - eta) * data.isel(YG=yi, XG=xi + 1) + + xsi * eta * data.isel(YG=yi + 1, XG=xi + 1) + + (1 - xsi) * eta * data.isel(YG=yi + 1, XG=xi) + ) + .interp(time=t, ZG=zi + zeta) + .values + ) + + +def XTriRectiLinear( + field: Field, + ti: int, + position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]], + tau: np.float32 | np.float64, + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, +): + """Trilinear interpolation on a rectilinear grid.""" + xi, xsi = position["X"] + yi, eta = position["Y"] + zi, zeta = position["Z"] + return field.data.interp(time=t, ZG=zi + zeta, YG=yi + eta, XG=xi + xsi).values + + def UXPiecewiseConstantFace( field: Field, ti: int, From aeada5dd9ffdad66a6dae64d6f377387da9f4ad7 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 15 Jul 2025 19:00:37 +0200 Subject: [PATCH 2/2] Add XGrid.get_axis_dim_mapping --- parcels/application_kernels/interpolation.py | 16 +-- parcels/xgrid.py | 32 ++++++ tests/v4/test_interpolation.py | 102 +++++++++++++++++++ 3 files changed, 144 insertions(+), 6 deletions(-) create mode 100644 tests/v4/test_interpolation.py diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py index 7332389863..03d1c587ed 100644 --- a/parcels/application_kernels/interpolation.py +++ b/parcels/application_kernels/interpolation.py @@ -33,15 +33,16 @@ def XTriCurviLinear( yi, eta = position["Y"] zi, zeta = position["Z"] data = field.data + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) return ( ( - (1 - xsi) * (1 - eta) * data.isel(YG=yi, XG=xi) - + xsi * (1 - eta) * data.isel(YG=yi, XG=xi + 1) - + xsi * eta * data.isel(YG=yi + 1, XG=xi + 1) - + (1 - xsi) * eta * data.isel(YG=yi + 1, XG=xi) + (1 - xsi) * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi}) + + xsi * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi + 1}) + + xsi * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi + 1}) + + (1 - xsi) * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi}) ) - .interp(time=t, ZG=zi + zeta) + .interp(time=t, **{axis_dim["Z"]: zi + zeta}) .values ) @@ -57,10 +58,13 @@ def XTriRectiLinear( x: np.float32 | np.float64, ): """Trilinear interpolation on a rectilinear grid.""" + axis_dim = field.grid.get_axis_dim_mapping(field.data.dims) + xi, xsi = position["X"] yi, eta = position["Y"] zi, zeta = position["Z"] - return field.data.interp(time=t, ZG=zi + zeta, YG=yi + eta, XG=xi + xsi).values + kwargs = {axis_dim["X"]: xi + xsi, axis_dim["Y"]: yi + eta, axis_dim["Z"]: zi + zeta} + return field.data.interp(time=t, **kwargs).values def UXPiecewiseConstantFace( diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 81391ee818..63a50c0f47 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -304,6 +304,38 @@ def _fpoint_info(self): return axis_position_mapping + def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]: + """ + Maps xarray dimension names to their corresponding axis (X, Y, Z). + + WARNING: This API is unstable and subject to change in future versions. + + Parameters + ---------- + dims : list[str] + List of xarray dimension names + + Returns + ------- + dict[_XGRID_AXES, str] + Dictionary mapping axes (X, Y, Z) to their corresponding dimension names + + Examples + -------- + >>> grid.get_axis_dim_mapping(['time', 'lat', 'lon']) + {'Y': 'lat', 'X': 'lon'} + + Notes + ----- + Only returns mappings for spatial axes (X, Y, Z) that are present in the grid. + """ + result = {} + for dim in dims: + axis = get_axis_from_dim_name(self.xgcm_grid.axes, dim) + if axis in self.axes: # Only include spatial axes (X, Y, Z) + result[cast(_XGRID_AXES, axis)] = dim + return result + def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: """For a given dimension name in a grid, returns the direction axis it is on.""" diff --git a/tests/v4/test_interpolation.py b/tests/v4/test_interpolation.py new file mode 100644 index 0000000000..f385bbdb11 --- /dev/null +++ b/tests/v4/test_interpolation.py @@ -0,0 +1,102 @@ +import itertools + +import numpy as np +import xarray as xr + +from parcels.application_kernels.interpolation import XTriCurviLinear +from parcels.field import Field +from parcels.xgcm import Grid +from parcels.xgrid import XGrid + + +def get_unit_square_ds(): + T, Z, Y, X = 2, 2, 2, 2 + TIME = xr.date_range("2000", "2001", T) + + _, data_z, data_y, data_x = np.meshgrid( + np.zeros(T), + np.linspace(0, 1, Z), + np.linspace(0, 1, Y), + np.linspace(0, 1, X), + indexing="ij", + ) + + return xr.Dataset( + { + "0 to 1 in X": (["time", "ZG", "YG", "XG"], data_x), + "0 to 1 in Y": (["time", "ZG", "YG", "XG"], data_y), + "0 to 1 in Z": (["time", "ZG", "YG", "XG"], data_z), + "0 to 1 in X (T-points)": (["time", "ZC", "YC", "XC"], data_x + 0.5), + "0 to 1 in Y (T-points)": (["time", "ZC", "YC", "XC"], data_y + 0.5), + "0 to 1 in Z (T-points)": (["time", "ZC", "YC", "XC"], data_z + 0.5), + "0 to 1 in X (U velocity C-grid points)": (["time", "ZC", "YC", "XG"], data_x), + "0 to 1 in Y (V velocity C-grid points)": (["time", "ZC", "YG", "XC"], data_y), + }, + coords={ + "XG": ( + ["XG"], + np.arange(0, X), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], np.arange(0, X) + 0.5, {"axis": "X"}), + "YG": ( + ["YG"], + np.arange(0, Y), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "YC": ( + ["YC"], + np.arange(0, Y) + 0.5, + {"axis": "Y"}, + ), + "ZG": ( + ["ZG"], + np.arange(Z), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(Z) + 0.5, + {"axis": "Z"}, + ), + "lon": (["XG"], np.arange(0, X)), + "lat": (["YG"], np.arange(0, Y)), + "depth": (["ZG"], np.arange(Z)), + "time": (["time"], TIME, {"axis": "T"}), + }, + ) + + +def test_XTriRectiLinear_interpolation(): + ds = get_unit_square_ds() + grid = XGrid(Grid(ds)) + field = Field("test", ds["0 to 1 in X"], grid=grid, interp_method=XTriCurviLinear) + left = field.time_interval.left + + epsilon = 1e-6 + N = 4 + + # Interpolate wrt. items on f-points + for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3): + assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" + + field = Field("test", ds["0 to 1 in Y"], grid=grid, interp_method=XTriCurviLinear) + for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3): + assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" + + field = Field("test", ds["0 to 1 in Z"], grid=grid, interp_method=XTriCurviLinear) + for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3): + assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" + + # Interpolate wrt. items on T-points + field = Field("test", ds["0 to 1 in X (T-points)"], grid=grid, interp_method=XTriCurviLinear) + for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3): + assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" + + field = Field("test", ds["0 to 1 in Y (T-points)"], grid=grid, interp_method=XTriCurviLinear) + for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3): + assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}" + + field = Field("test", ds["0 to 1 in Z (T-points)"], grid=grid, interp_method=XTriCurviLinear) + for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3): + assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"