From acacb96d476a3a0b64b53fa8634549434284677f Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:52:21 +0200 Subject: [PATCH 1/6] Add grid localization to array of interest --- parcels/xgrid.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index c3c110485..86cff982a 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -129,6 +129,23 @@ def get_axis_dim(self, axis: _XGRID_AXES) -> int: return get_cell_count_along_dim(self.xgcm_grid.axes[axis]) + def localize(self, position: dict[_XGRID_AXES, tuple[int, float]], dims: list[str]) -> dict[str, tuple[int, float]]: + """ + Uses the grid context (i.e., the staggering of the grid) to convert a position relative + to the F-points in the grid to a position relative to the dimensions of the array + of interest. + """ + axis_to_var = {get_axis_from_dim_name(self.xgcm_grid.axes, dim): dim for dim in dims} + var_positions = { + axis: get_xgcm_position_from_dim_name(self.xgcm_grid.axes, dim) for axis, dim in axis_to_var.items() + } + return { + axis_to_var[axis]: _convert_center_pos_to_fpoint( + index=index, bcoord=bcoord, position=var_positions[axis], f_points_position=self._fpoint_info[axis] + ) + for axis, (index, bcoord) in position.items() + } + @property def _z4d(self) -> Literal[0, 1]: """ @@ -185,6 +202,19 @@ def search(self, z, y, x, ei=None): raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") + @cached_property + def _fpoint_info(self): + xgcm_axes = self.xgcm_grid.axes + f_point_positions = ["left", "right", "inner", "outer"] + axis_position_mapping = {} + for axis in self.axes: + coords = xgcm_axes[axis].coords + edge_positions = list(filter(lambda x: x in f_point_positions, coords.keys())) + assert len(edge_positions) == 1, f"Axis {axis} has multiple edge positions: {edge_positions}" + axis_position_mapping[axis] = edge_positions[0] + + return axis_position_mapping + def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: """For a given dimension name in a grid, returns the direction axis it is on.""" @@ -337,3 +367,22 @@ def _search_1d_array( i = np.argmin(arr <= x) - 1 bcoord = (x - arr[i]) / (arr[i + 1] - arr[i]) return i, bcoord + + +def _convert_center_pos_to_fpoint( + *, index: int, bcoord: float, position: _XGCM_AXIS_POSITION, f_points_position: _XGCM_AXIS_POSITION +) -> tuple[int, float]: + """Converts a position relative to the center point along an axis to a reposition relative to the cell edges.""" + if position != "center": + return index, bcoord + + bcoord = bcoord - 0.5 + if bcoord < 0: + bcoord += 1.0 + index -= 1 + + # Correct relative to the f-point position + if f_points_position in ["inner", "right"]: + index += 1 + + return index, bcoord From 7c4e9b3bd62b5cfd1a5ad9fbef7ca38c2c6068bb Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 14 Jul 2025 11:57:42 +0200 Subject: [PATCH 2/6] Update vairable naming and docstrings --- parcels/xgrid.py | 51 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 86cff982a..9e78d9dad 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -132,8 +132,33 @@ def get_axis_dim(self, axis: _XGRID_AXES) -> int: def localize(self, position: dict[_XGRID_AXES, tuple[int, float]], dims: list[str]) -> dict[str, tuple[int, float]]: """ Uses the grid context (i.e., the staggering of the grid) to convert a position relative - to the F-points in the grid to a position relative to the dimensions of the array - of interest. + to the F-points in the grid to a position relative to the staggered grid the array + of interest is defined on. + + Uses dimensions of the DataArray to determine the staggered grid. + + Parameters + ---------- + position : dict + A mapping of the axis to a tuple of (index, barycentric coordinate) for the + F-points in the grid. + dims : list[str] + A list of dimension names that the DataArray is defined on. This is used to determine + the staggering of the grid and which axis each dimension corresponds to. + + Returns + ------- + dict[str, tuple[int, float]] + A mapping of the dimension names to a tuple of (index, barycentric coordinate) for + the staggered grid the DataArray is defined on. + + Example + ------- + >>> position = {'X': (5, 0.51), 'Y': ( + 10, 0.25), 'Z': (3, 0.75)} + >>> dims = ['time', 'depth', 'YC', 'XC'] + >>> grid.localize(position, dims) + {'depth': (3, 0.75), 'YC': (9, 0.75), 'XC': (5, 0.01)} """ axis_to_var = {get_axis_from_dim_name(self.xgcm_grid.axes, dim): dim for dim in dims} var_positions = { @@ -141,7 +166,10 @@ def localize(self, position: dict[_XGRID_AXES, tuple[int, float]], dims: list[st } return { axis_to_var[axis]: _convert_center_pos_to_fpoint( - index=index, bcoord=bcoord, position=var_positions[axis], f_points_position=self._fpoint_info[axis] + index=index, + bcoord=bcoord, + xgcm_position=var_positions[axis], + f_points_xgcm_position=self._fpoint_info[axis], ) for axis, (index, bcoord) in position.items() } @@ -204,12 +232,13 @@ def search(self, z, y, x, ei=None): @cached_property def _fpoint_info(self): + """Returns a mapping of the spatial axes in the Grid to their XGCM positions.""" xgcm_axes = self.xgcm_grid.axes f_point_positions = ["left", "right", "inner", "outer"] axis_position_mapping = {} for axis in self.axes: coords = xgcm_axes[axis].coords - edge_positions = list(filter(lambda x: x in f_point_positions, coords.keys())) + edge_positions = [pos for pos in coords.keys() if pos in f_point_positions] assert len(edge_positions) == 1, f"Axis {axis} has multiple edge positions: {edge_positions}" axis_position_mapping[axis] = edge_positions[0] @@ -370,10 +399,16 @@ def _search_1d_array( def _convert_center_pos_to_fpoint( - *, index: int, bcoord: float, position: _XGCM_AXIS_POSITION, f_points_position: _XGCM_AXIS_POSITION + *, index: int, bcoord: float, xgcm_position: _XGCM_AXIS_POSITION, f_points_xgcm_position: _XGCM_AXIS_POSITION ) -> tuple[int, float]: - """Converts a position relative to the center point along an axis to a reposition relative to the cell edges.""" - if position != "center": + """Converts a physical position relative to the cell edges defined in the grid to be relative to the center point. + + This is used to "localize" a position to be relative to the staggered grid at which the field is defined, so that + it can be easily interpolated. + + This also handles different model input cell edges and centers are staggered in different directions (e.g., with NEMO and MITgcm). + """ + if xgcm_position != "center": # Data is already defined on the F points return index, bcoord bcoord = bcoord - 0.5 @@ -382,7 +417,7 @@ def _convert_center_pos_to_fpoint( index -= 1 # Correct relative to the f-point position - if f_points_position in ["inner", "right"]: + if f_points_xgcm_position in ["inner", "right"]: index += 1 return index, bcoord From 89dcfd7adce4dde5dfb6988d2a146633865e5294 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 14 Jul 2025 12:09:42 +0200 Subject: [PATCH 3/6] Add ds_2d_right dataset to structured.generic With tests --- parcels/_datasets/structured/generic.py | 42 +++++++++++++++++++++++++ tests/v4/test_datasets.py | 21 +++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 tests/v4/test_datasets.py diff --git a/parcels/_datasets/structured/generic.py b/parcels/_datasets/structured/generic.py index 2432f079c..0ad09093d 100644 --- a/parcels/_datasets/structured/generic.py +++ b/parcels/_datasets/structured/generic.py @@ -178,5 +178,47 @@ def _unrolled_cone_curvilinear_grid(): "time": (["time"], TIME, {"axis": "T"}), }, ), + "ds_2d_right": xr.Dataset( + { + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)), + "U (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), + "V (A grid)": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), + "U (C grid)": (["time", "ZG", "YC", "XG"], np.random.rand(T, Z, Y, X)), + "V (C grid)": (["time", "ZG", "YG", "XC"], np.random.rand(T, Z, Y, X)), + }, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / X * np.arange(0, X), + {"axis": "X", "c_grid_axis_shift": 0.5}, + ), + "XC": (["XC"], 2 * np.pi / X * (np.arange(0, X) - 0.5), {"axis": "X"}), + "YG": ( + ["YG"], + 2 * np.pi / (Y) * np.arange(0, Y), + {"axis": "Y", "c_grid_axis_shift": 0.5}, + ), + "YC": ( + ["YC"], + 2 * np.pi / (Y) * (np.arange(0, Y) - 0.5), + {"axis": "Y"}, + ), + "ZG": ( + ["ZG"], + np.arange(Z), + {"axis": "Z", "c_grid_axis_shift": 0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(Z) - 0.5, + {"axis": "Z"}, + ), + "lon": (["XG"], 2 * np.pi / X * np.arange(0, X)), + "lat": (["YG"], 2 * np.pi / (Y) * np.arange(0, Y)), + "depth": (["ZG"], np.arange(Z)), + "time": (["time"], TIME, {"axis": "T"}), + }, + ), "2d_left_unrolled_cone": _unrolled_cone_curvilinear_grid(), } diff --git a/tests/v4/test_datasets.py b/tests/v4/test_datasets.py new file mode 100644 index 000000000..e5201e58b --- /dev/null +++ b/tests/v4/test_datasets.py @@ -0,0 +1,21 @@ +from parcels._datasets.structured.generic import datasets +from parcels.xgcm import Grid + + +def test_left_indexed_dataset(): + """Checks that 'ds_2d_left' is right indexed on all variables.""" + ds = datasets["ds_2d_left"] + grid = Grid(ds) + + for _axis_name, axis in grid.axes.items(): + for pos, _dim_name in axis.coords.items(): + assert pos in ["left", "center"] + + +def test_right_indexed_dataset(): + """Checks that 'ds_2d_right' is right indexed on all variables.""" + ds = datasets["ds_2d_right"] + grid = Grid(ds) + for _axis_name, axis in grid.axes.items(): + for pos, _dim_name in axis.coords.items(): + assert pos in ["center", "right"] From e2e549cc26e67168cfab0fb36c681093412e11d5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 14 Jul 2025 13:30:10 +0200 Subject: [PATCH 4/6] Add test_xgrid_localize --- parcels/_datasets/structured/generic.py | 4 +- tests/v4/test_xgrid.py | 54 +++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/parcels/_datasets/structured/generic.py b/parcels/_datasets/structured/generic.py index 0ad09093d..1b1cd7d81 100644 --- a/parcels/_datasets/structured/generic.py +++ b/parcels/_datasets/structured/generic.py @@ -136,7 +136,7 @@ def _unrolled_cone_curvilinear_grid(): datasets = { "2d_left_rotated": _rotated_curvilinear_grid(), - "ds_2d_left": xr.Dataset( + "ds_2d_left": xr.Dataset( # MITgcm indexing style { "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)), @@ -178,7 +178,7 @@ def _unrolled_cone_curvilinear_grid(): "time": (["time"], TIME, {"axis": "T"}), }, ), - "ds_2d_right": xr.Dataset( + "ds_2d_right": xr.Dataset( # NEMO indexing style { "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, Z, Y, X)), "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, Z, Y, X)), diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 2af1e1f01..210ca6b74 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -145,3 +145,57 @@ def test_search_1d_array(array, x, expected_xi, expected_xsi): xi, xsi = _search_1d_array(array, x) assert xi == expected_xi assert np.isclose(xsi, expected_xsi) + + +@pytest.mark.parametrize( + "grid, da_name, expected", + [ + pytest.param( + XGrid(xgcm.Grid(datasets["ds_2d_left"], periodic=False)), + "U (C grid)", + { + "XG": (np.int64(0), np.float64(0.0)), + "YC": (np.int64(-1), np.float64(0.5)), + "ZG": (np.int64(0), np.float64(0.0)), + }, + id="MITgcm indexing style U (C grid)", + ), + pytest.param( + XGrid(xgcm.Grid(datasets["ds_2d_left"], periodic=False)), + "V (C grid)", + { + "XC": (np.int64(-1), np.float64(0.5)), + "YG": (np.int64(0), np.float64(0.0)), + "ZG": (np.int64(0), np.float64(0.0)), + }, + id="MITgcm indexing style V (C grid)", + ), + pytest.param( + XGrid(xgcm.Grid(datasets["ds_2d_right"], periodic=False)), + "U (C grid)", + { + "XG": (np.int64(0), np.float64(0.0)), + "YC": (np.int64(0), np.float64(0.5)), + "ZG": (np.int64(0), np.float64(0.0)), + }, + id="NEMO indexing style U (C grid)", + ), + pytest.param( + XGrid(xgcm.Grid(datasets["ds_2d_right"], periodic=False)), + "V (C grid)", + { + "XC": (np.int64(0), np.float64(0.5)), + "YG": (np.int64(0), np.float64(0.0)), + "ZG": (np.int64(0), np.float64(0.0)), + }, + id="NEMO indexing style V (C grid)", + ), + ], +) +def test_xgrid_localize_zero_position(grid, da_name, expected): + """Test localize function using left and right datasets.""" + position = grid.search(0, 0, 0) + da = grid.xgcm_grid._ds[da_name] + + local_position = grid.localize(position, da.dims) + assert local_position == expected, f"Expected {expected}, got {local_position}" From 551db5ee8aea40877ce0b0a9d30f8e3d367b4e09 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 16 Jul 2025 11:38:39 +0200 Subject: [PATCH 5/6] Disallow numpydoc 1.9.0 https://github.com/numpy/numpydoc/issues/638 --- environment.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 4b8694f0c..2a4704054 100644 --- a/environment.yml +++ b/environment.yml @@ -41,7 +41,7 @@ dependencies: #! Keep in sync with [tool.pixi.dependencies] in pyproject.toml # Docs - ipython - - numpydoc + - numpydoc!=1.9.0 - nbsphinx - sphinx - pandoc diff --git a/pyproject.toml b/pyproject.toml index 8fde067a2..318f5851f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ pre_commit = "*" # Docs ipython = "*" -numpydoc = "*" +numpydoc = "!=1.9.0" nbsphinx = "*" sphinx = "*" pandoc = "*" From a9f2b601a1651b6ac45a61fbfcbddf110ccc8297 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 16 Jul 2025 11:42:59 +0200 Subject: [PATCH 6/6] Update XGrid localize docstring --- parcels/xgrid.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 9e78d9dad..2cfda507e 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -137,6 +137,8 @@ def localize(self, position: dict[_XGRID_AXES, tuple[int, float]], dims: list[st Uses dimensions of the DataArray to determine the staggered grid. + WARNING: This API is unstable and subject to change in future versions. + Parameters ---------- position : dict