Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,63 @@

if TYPE_CHECKING:
from parcels.uxgrid import _UXGRID_AXES
from parcels.xgrid import _XGRID_AXES

__all__ = [
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
]


def XTriCurviLinear(
field: Field,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""Trilinear interpolation on a curvilinear grid."""
xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]
data = field.data
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)

return (
(
(1 - xsi) * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi})
+ xsi * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi + 1})
+ xsi * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi + 1})
+ (1 - xsi) * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi})
)
.interp(time=t, **{axis_dim["Z"]: zi + zeta})
.values
)


def XTriRectiLinear(
field: Field,
ti: int,
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
y: np.float32 | np.float64,
x: np.float32 | np.float64,
):
"""Trilinear interpolation on a rectilinear grid."""
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)

xi, xsi = position["X"]
yi, eta = position["Y"]
zi, zeta = position["Z"]
kwargs = {axis_dim["X"]: xi + xsi, axis_dim["Y"]: yi + eta, axis_dim["Z"]: zi + zeta}
return field.data.interp(time=t, **kwargs).values


def UXPiecewiseConstantFace(
field: Field,
ti: int,
Expand Down
32 changes: 32 additions & 0 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,38 @@ def _fpoint_info(self):

return axis_position_mapping

def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
"""
Maps xarray dimension names to their corresponding axis (X, Y, Z).

WARNING: This API is unstable and subject to change in future versions.

Parameters
----------
dims : list[str]
List of xarray dimension names

Returns
-------
dict[_XGRID_AXES, str]
Dictionary mapping axes (X, Y, Z) to their corresponding dimension names

Examples
--------
>>> grid.get_axis_dim_mapping(['time', 'lat', 'lon'])
{'Y': 'lat', 'X': 'lon'}

Notes
-----
Only returns mappings for spatial axes (X, Y, Z) that are present in the grid.
"""
result = {}
for dim in dims:
axis = get_axis_from_dim_name(self.xgcm_grid.axes, dim)
if axis in self.axes: # Only include spatial axes (X, Y, Z)
result[cast(_XGRID_AXES, axis)] = dim
return result


def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
"""For a given dimension name in a grid, returns the direction axis it is on."""
Expand Down
102 changes: 102 additions & 0 deletions tests/v4/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import itertools

import numpy as np
import xarray as xr

from parcels.application_kernels.interpolation import XTriCurviLinear
from parcels.field import Field
from parcels.xgcm import Grid
from parcels.xgrid import XGrid


def get_unit_square_ds():
T, Z, Y, X = 2, 2, 2, 2
TIME = xr.date_range("2000", "2001", T)

_, data_z, data_y, data_x = np.meshgrid(
np.zeros(T),
np.linspace(0, 1, Z),
np.linspace(0, 1, Y),
np.linspace(0, 1, X),
indexing="ij",
)

return xr.Dataset(
{
"0 to 1 in X": (["time", "ZG", "YG", "XG"], data_x),
"0 to 1 in Y": (["time", "ZG", "YG", "XG"], data_y),
"0 to 1 in Z": (["time", "ZG", "YG", "XG"], data_z),
"0 to 1 in X (T-points)": (["time", "ZC", "YC", "XC"], data_x + 0.5),
"0 to 1 in Y (T-points)": (["time", "ZC", "YC", "XC"], data_y + 0.5),
"0 to 1 in Z (T-points)": (["time", "ZC", "YC", "XC"], data_z + 0.5),
"0 to 1 in X (U velocity C-grid points)": (["time", "ZC", "YC", "XG"], data_x),
"0 to 1 in Y (V velocity C-grid points)": (["time", "ZC", "YG", "XC"], data_y),
},
coords={
"XG": (
["XG"],
np.arange(0, X),
{"axis": "X", "c_grid_axis_shift": -0.5},
),
"XC": (["XC"], np.arange(0, X) + 0.5, {"axis": "X"}),
"YG": (
["YG"],
np.arange(0, Y),
{"axis": "Y", "c_grid_axis_shift": -0.5},
),
"YC": (
["YC"],
np.arange(0, Y) + 0.5,
{"axis": "Y"},
),
"ZG": (
["ZG"],
np.arange(Z),
{"axis": "Z", "c_grid_axis_shift": -0.5},
),
"ZC": (
["ZC"],
np.arange(Z) + 0.5,
{"axis": "Z"},
),
"lon": (["XG"], np.arange(0, X)),
"lat": (["YG"], np.arange(0, Y)),
"depth": (["ZG"], np.arange(Z)),
"time": (["time"], TIME, {"axis": "T"}),
},
)


def test_XTriRectiLinear_interpolation():
ds = get_unit_square_ds()
grid = XGrid(Grid(ds))
field = Field("test", ds["0 to 1 in X"], grid=grid, interp_method=XTriCurviLinear)
left = field.time_interval.left

epsilon = 1e-6
N = 4

# Interpolate wrt. items on f-points
for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3):
assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"

field = Field("test", ds["0 to 1 in Y"], grid=grid, interp_method=XTriCurviLinear)
for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3):
assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"

field = Field("test", ds["0 to 1 in Z"], grid=grid, interp_method=XTriCurviLinear)
for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3):
assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"

# Interpolate wrt. items on T-points
field = Field("test", ds["0 to 1 in X (T-points)"], grid=grid, interp_method=XTriCurviLinear)
for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3):
assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"

field = Field("test", ds["0 to 1 in Y (T-points)"], grid=grid, interp_method=XTriCurviLinear)
for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3):
assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"

field = Field("test", ds["0 to 1 in Z (T-points)"], grid=grid, interp_method=XTriCurviLinear)
for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3):
assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"
Loading