Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies: #! Keep in sync with [tool.pixi.dependencies] in pyproject.toml

# Docs
- ipython
- numpydoc
- numpydoc!=1.9.0
- nbsphinx
- sphinx
- pandoc
Expand Down
44 changes: 43 additions & 1 deletion parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _unrolled_cone_curvilinear_grid():

datasets = {
"2d_left_rotated": _rotated_curvilinear_grid(),
"ds_2d_left": xr.Dataset(
"ds_2d_left": xr.Dataset( # MITgcm indexing style
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
Expand Down Expand Up @@ -178,5 +178,47 @@ def _unrolled_cone_curvilinear_grid():
"time": (["time"], TIME, {"axis": "T"}),
},
),
"ds_2d_right": xr.Dataset( # NEMO indexing style
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)),
"U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
"V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)),
"U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)),
"V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)),
},
coords={
"XG": (
["XG"],
2 * np.pi / X * np.arange(0, X),
{"axis": "X", "c_grid_axis_shift": 0.5},
),
"XC": (["XC"], 2 * np.pi / X * (np.arange(0, X) - 0.5), {"axis": "X"}),
"YG": (
["YG"],
2 * np.pi / (Y) * np.arange(0, Y),
{"axis": "Y", "c_grid_axis_shift": 0.5},
),
"YC": (
["YC"],
2 * np.pi / (Y) * (np.arange(0, Y) - 0.5),
{"axis": "Y"},
),
"ZG": (
["ZG"],
np.arange(Z),
{"axis": "Z", "c_grid_axis_shift": 0.5},
),
"ZC": (
["ZC"],
np.arange(Z) - 0.5,
{"axis": "Z"},
),
"lon": (["XG"], 2 * np.pi / X * np.arange(0, X)),
"lat": (["YG"], 2 * np.pi / (Y) * np.arange(0, Y)),
"depth": (["ZG"], np.arange(Z)),
"time": (["time"], TIME, {"axis": "T"}),
},
),
"2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(),
}
86 changes: 86 additions & 0 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,53 @@ def get_axis_dim(self, axis: _XGRID_AXES) -> int:

return get_cell_count_along_dim(self.xgcm_grid.axes[axis])

def localize(self, position: dict[_XGRID_AXES, tuple[int, float]], dims: list[str]) -> dict[str, tuple[int, float]]:
"""
Uses the grid context (i.e., the staggering of the grid) to convert a position relative
to the F-points in the grid to a position relative to the staggered grid the array
of interest is defined on.

Uses dimensions of the DataArray to determine the staggered grid.

WARNING: This API is unstable and subject to change in future versions.

Parameters
----------
position : dict
A mapping of the axis to a tuple of (index, barycentric coordinate) for the
F-points in the grid.
dims : list[str]
A list of dimension names that the DataArray is defined on. This is used to determine
the staggering of the grid and which axis each dimension corresponds to.

Returns
-------
dict[str, tuple[int, float]]
A mapping of the dimension names to a tuple of (index, barycentric coordinate) for
the staggered grid the DataArray is defined on.

Example
-------
>>> position = {'X': (5, 0.51), 'Y': (
10, 0.25), 'Z': (3, 0.75)}
>>> dims = ['time', 'depth', 'YC', 'XC']
>>> grid.localize(position, dims)
{'depth': (3, 0.75), 'YC': (9, 0.75), 'XC': (5, 0.01)}
"""
axis_to_var = {get_axis_from_dim_name(self.xgcm_grid.axes, dim): dim for dim in dims}
var_positions = {
axis: get_xgcm_position_from_dim_name(self.xgcm_grid.axes, dim) for axis, dim in axis_to_var.items()
}
return {
axis_to_var[axis]: _convert_center_pos_to_fpoint(
index=index,
bcoord=bcoord,
xgcm_position=var_positions[axis],
f_points_xgcm_position=self._fpoint_info[axis],
)
for axis, (index, bcoord) in position.items()
}

@property
def _z4d(self) -> Literal[0, 1]:
"""
Expand Down Expand Up @@ -185,6 +232,20 @@ def search(self, z, y, x, ei=None):

raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")

@cached_property
def _fpoint_info(self):
"""Returns a mapping of the spatial axes in the Grid to their XGCM positions."""
xgcm_axes = self.xgcm_grid.axes
f_point_positions = ["left", "right", "inner", "outer"]
axis_position_mapping = {}
for axis in self.axes:
coords = xgcm_axes[axis].coords
edge_positions = [pos for pos in coords.keys() if pos in f_point_positions]
assert len(edge_positions) == 1, f"Axis {axis} has multiple edge positions: {edge_positions}"
axis_position_mapping[axis] = edge_positions[0]

return axis_position_mapping


def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
"""For a given dimension name in a grid, returns the direction axis it is on."""
Expand Down Expand Up @@ -337,3 +398,28 @@ def _search_1d_array(
i = np.argmin(arr <= x) - 1
bcoord = (x - arr[i]) / (arr[i + 1] - arr[i])
return i, bcoord


def _convert_center_pos_to_fpoint(
*, index: int, bcoord: float, xgcm_position: _XGCM_AXIS_POSITION, f_points_xgcm_position: _XGCM_AXIS_POSITION
) -> tuple[int, float]:
"""Converts a physical position relative to the cell edges defined in the grid to be relative to the center point.

This is used to "localize" a position to be relative to the staggered grid at which the field is defined, so that
it can be easily interpolated.

This also handles different model input cell edges and centers are staggered in different directions (e.g., with NEMO and MITgcm).
"""
if xgcm_position != "center": # Data is already defined on the F points
return index, bcoord

bcoord = bcoord - 0.5
if bcoord < 0:
bcoord += 1.0
index -= 1

# Correct relative to the f-point position
if f_points_xgcm_position in ["inner", "right"]:
index += 1

return index, bcoord
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pre_commit = "*"

# Docs
ipython = "*"
numpydoc = "*"
numpydoc = "!=1.9.0"
nbsphinx = "*"
sphinx = "*"
pandoc = "*"
Expand Down
21 changes: 21 additions & 0 deletions tests/v4/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from parcels._datasets.structured.generic import datasets
from parcels.xgcm import Grid


def test_left_indexed_dataset():
"""Checks that 'ds_2d_left' is right indexed on all variables."""
ds = datasets["ds_2d_left"]
grid = Grid(ds)

for _axis_name, axis in grid.axes.items():
for pos, _dim_name in axis.coords.items():
assert pos in ["left", "center"]


def test_right_indexed_dataset():
"""Checks that 'ds_2d_right' is right indexed on all variables."""
ds = datasets["ds_2d_right"]
grid = Grid(ds)
for _axis_name, axis in grid.axes.items():
for pos, _dim_name in axis.coords.items():
assert pos in ["center", "right"]
54 changes: 54 additions & 0 deletions tests/v4/test_xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,57 @@ def test_search_1d_array(array, x, expected_xi, expected_xsi):
xi, xsi = _search_1d_array(array, x)
assert xi == expected_xi
assert np.isclose(xsi, expected_xsi)


@pytest.mark.parametrize(
"grid, da_name, expected",
[
pytest.param(
XGrid(xgcm.Grid(datasets["ds_2d_left"], periodic=False)),
"U (C grid)",
{
"XG": (np.int64(0), np.float64(0.0)),
"YC": (np.int64(-1), np.float64(0.5)),
"ZG": (np.int64(0), np.float64(0.0)),
},
id="MITgcm indexing style U (C grid)",
),
pytest.param(
XGrid(xgcm.Grid(datasets["ds_2d_left"], periodic=False)),
"V (C grid)",
{
"XC": (np.int64(-1), np.float64(0.5)),
"YG": (np.int64(0), np.float64(0.0)),
"ZG": (np.int64(0), np.float64(0.0)),
},
id="MITgcm indexing style V (C grid)",
),
pytest.param(
XGrid(xgcm.Grid(datasets["ds_2d_right"], periodic=False)),
"U (C grid)",
{
"XG": (np.int64(0), np.float64(0.0)),
"YC": (np.int64(0), np.float64(0.5)),
"ZG": (np.int64(0), np.float64(0.0)),
},
id="NEMO indexing style U (C grid)",
),
pytest.param(
XGrid(xgcm.Grid(datasets["ds_2d_right"], periodic=False)),
"V (C grid)",
{
"XC": (np.int64(0), np.float64(0.5)),
"YG": (np.int64(0), np.float64(0.0)),
"ZG": (np.int64(0), np.float64(0.0)),
},
id="NEMO indexing style V (C grid)",
),
],
)
def test_xgrid_localize_zero_position(grid, da_name, expected):
"""Test localize function using left and right datasets."""
position = grid.search(0, 0, 0)
da = grid.xgcm_grid._ds[da_name]

local_position = grid.localize(position, da.dims)
assert local_position == expected, f"Expected {expected}, got {local_position}"
Loading