Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b428cd2
Further cleaning the handling of grid searching errors
erikvansebille Sep 4, 2025
0bce72e
Removing while-loop in _search_indices_curvilinear_2d
erikvansebille Sep 4, 2025
9035be9
Switch hash table build to use source grid overlaps; add point-in-cel…
fluidnumerics-joe Sep 7, 2025
f9e044b
Set nemo mesh type to spherical
fluidnumerics-joe Sep 7, 2025
0cbbbc9
Vectorize hash table initialization
fluidnumerics-joe Sep 9, 2025
59eecf7
Add point_in_cell method to query call
fluidnumerics-joe Sep 9, 2025
4c9661c
Vectorize final query loop in j,i array construction
fluidnumerics-joe Sep 9, 2025
9a3707f
Move _barycentric_coordinates function to uxgrid
fluidnumerics-joe Sep 9, 2025
3c7507c
Drop uint64 to uint32 to save memory (and dilated bits don't need uin…
fluidnumerics-joe Sep 9, 2025
0be54db
Drop temp variables to int32
fluidnumerics-joe Sep 9, 2025
dfbb991
Return coordinates from PIC search and remove redundant coordinate ca…
fluidnumerics-joe Sep 10, 2025
95cb593
Add logic for using provided guess
fluidnumerics-joe Sep 10, 2025
46c988e
Set failed query indices to GRID_SEARCH_ERROR
fluidnumerics-joe Sep 10, 2025
8d62e20
Remove "reconstruct"
fluidnumerics-joe Sep 10, 2025
dc06f63
Add _point_in_cell private method to spatialhash
fluidnumerics-joe Sep 10, 2025
9645ea3
Merge remote-tracking branch 'origin/v4-dev' into curvilinear_index_s…
fluidnumerics-joe Sep 10, 2025
29e5fe0
Merge remote-tracking branch 'origin/v4-dev' into curvilinear_index_s…
fluidnumerics-joe Sep 10, 2025
50f0e85
Remove initial None set for spatial hash
fluidnumerics-joe Sep 12, 2025
6df7035
Merge branch 'v4-dev' into curvilinear_index_search_without_while_loop
fluidnumerics-joe Sep 12, 2025
031eb46
Merge branch 'v4-dev' into curvilinear_index_search_without_while_loop
fluidnumerics-joe Sep 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 57 additions & 71 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@

import numpy as np

from parcels._typing import Mesh
from parcels.tools.statuscodes import (
_raise_grid_searching_error,
_raise_time_extrapolation_error,
)
from parcels.tools.statuscodes import _raise_time_extrapolation_error

if TYPE_CHECKING:
from parcels.xgrid import XGrid

from .field import Field


GRID_SEARCH_ERROR = -3


def _search_time_index(field: Field, time: datetime):
"""Find and return the index and relative coordinate in the time array associated with a given time.

Expand All @@ -40,13 +39,7 @@ def _search_time_index(field: Field, time: datetime):
return np.atleast_1d(tau), np.atleast_1d(ti)


def _search_indices_curvilinear_2d(
grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None
): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays
yi, xi = yi_guess, xi_guess
if yi is None or xi is None:
yi, xi = grid.get_spatial_hash().query(y, x)

def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):
xsi = eta = -1.0 * np.ones(len(x), dtype=float)
invA = np.array(
[
Expand All @@ -56,67 +49,60 @@ def _search_indices_curvilinear_2d(
[1, -1, 1, -1],
]
)
maxIterSearch = 1e6
it = 0
tol = 1.0e-10

# # ! Error handling for out of bounds
# TODO: Re-enable in some capacity
# if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
# if grid.lon[0, 0] < grid.lon[0, -1]:
# _raise_grid_searching_error(y, x)
# elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
# _raise_grid_searching_error(z, y, x)

# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
# _raise_grid_searching_error(z, y, x)

while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])

py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
a = np.dot(invA, px)
b = np.dot(invA, py)

aa = a[3] * b[2] - a[2] * b[3]
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]

det2 = bb * bb - 4 * aa * cc
with np.errstate(divide="ignore", invalid="ignore"):
det = np.where(det2 > 0, np.sqrt(det2), eta)

eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))

xsi = np.where(
abs(a[1] + a[3] * eta) < 1e-12,
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
)

xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))

(yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh)
it += 1
if it > maxIterSearch:
print(f"Correct cell not found after {maxIterSearch} iterations")
_raise_grid_searching_error(0, y, x)
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))

if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
_raise_grid_searching_error(y, x)
return (yi, eta, xi, xsi)

px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])

a, b = np.dot(invA, px), np.dot(invA, py)
aa = a[3] * b[2] - a[2] * b[3]
bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3]
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
det2 = bb * bb - 4 * aa * cc

def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh):
xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi)
xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi)
with np.errstate(divide="ignore", invalid="ignore"):
det = np.where(det2 > 0, np.sqrt(det2), eta)
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))

xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi)
xsi = np.where(
abs(a[1] + a[3] * eta) < 1e-12,
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
)

yi = np.where(yi < 0, 0, yi)
yi = np.where(yi > ydim - 2, ydim - 2, yi)
is_in_cell = np.where((xsi >= 0) & (xsi <= 1) & (eta >= 0) & (eta <= 1), 1, 0)

return yi, xi
return is_in_cell, np.column_stack((xsi, eta))


def _search_indices_curvilinear_2d(
grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None
):
yi_guess = np.array(yi_guess)
xi_guess = np.array(xi_guess)
xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32)
yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32)
if np.any(xi_guess):
# If an initial guess is provided, we first perform a point in cell check for all guessed indices
is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi_guess, xi_guess)
y_check = y[is_in_cell == 0]
x_check = x[is_in_cell == 0]
zero_indices = np.where(is_in_cell == 0)[0]
else:
# Otherwise, we need to check all points
y_check = y
x_check = x
coords = -1.0 * np.ones((len(y), 2), dtype=np.float32)
zero_indices = np.arange(len(y))

# If there are any points that were not found in the first step, we query the spatial hash for those points
if len(zero_indices) > 0:
yi_q, xi_q, coords_q = grid.get_spatial_hash().query(y_check, x_check)
# Only those points that were not found in the first step are updated
coords[zero_indices, :] = coords_q
yi[zero_indices] = yi_q
xi[zero_indices] = xi_q

xsi = coords[:, 0]
eta = coords[:, 1]

return (yi, eta, xi, xsi)
36 changes: 20 additions & 16 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from parcels.uxgrid import UxGrid
from parcels.xgrid import LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, XGrid, _transpose_xfield_data_to_tzyx

from ._index_search import _search_time_index
from ._index_search import GRID_SEARCH_ERROR, _search_time_index

__all__ = ["Field", "VectorField"]

Expand Down Expand Up @@ -341,21 +341,25 @@ def __getitem__(self, key):

def _update_particle_states_position(particle, position):
"""Update the particle states based on the position dictionary."""
if particle and "X" in position: # TODO also support uxgrid search
particle.state = np.maximum(
np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
)
particle.state = np.maximum(
np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
)
particle.state = np.maximum(
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
particle.state,
)
particle.state = np.maximum(
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
particle.state,
)
if particle: # TODO also support uxgrid search
for dim in ["X", "Y"]:
if dim in position:
particle.state = np.maximum(
np.where(position[dim][0] == -1, StatusCode.ErrorOutOfBounds, particle.state), particle.state
)
particle.state = np.maximum(
np.where(position[dim][0] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particle.state),
particle.state,
)
if "Z" in position:
particle.state = np.maximum(
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
particle.state,
)
particle.state = np.maximum(
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
particle.state,
)


def _update_particle_states_interp_value(particle, value):
Expand Down
Loading
Loading