From b428cd2ccd5d394d9dbff58fefbd6646ec8ce991 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Thu, 4 Sep 2025 09:15:49 +0200 Subject: [PATCH 01/16] Further cleaning the handling of grid searching errors --- parcels/_index_search.py | 16 +++++++--------- parcels/field.py | 36 ++++++++++++++++++++---------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index dde987261..a141620d7 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -6,10 +6,7 @@ import numpy as np from parcels._typing import Mesh -from parcels.tools.statuscodes import ( - _raise_grid_searching_error, - _raise_time_extrapolation_error, -) +from parcels.tools.statuscodes import _raise_time_extrapolation_error if TYPE_CHECKING: from parcels.xgrid import XGrid @@ -17,6 +14,9 @@ from .field import Field +GRID_SEARCH_ERROR = -3 + + def _search_time_index(field: Field, time: datetime): """Find and return the index and relative coordinate in the time array associated with a given time. @@ -99,12 +99,10 @@ def _search_indices_curvilinear_2d( it += 1 if it > maxIterSearch: print(f"Correct cell not found after {maxIterSearch} iterations") - _raise_grid_searching_error(0, y, x) - xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi)) - eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta)) - if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)): - _raise_grid_searching_error(y, x) + # checking if xsi or eta is outside [0, 1] + xi = np.where(xsi < 0, GRID_SEARCH_ERROR, np.where(xsi > 1, GRID_SEARCH_ERROR, xi)) + yi = np.where(eta < 0, GRID_SEARCH_ERROR, np.where(eta > 1, GRID_SEARCH_ERROR, yi)) return (yi, eta, xi, xsi) diff --git a/parcels/field.py b/parcels/field.py index 360e84393..736a82033 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -25,7 +25,7 @@ from parcels.uxgrid import UxGrid from parcels.xgrid import LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, XGrid, _transpose_xfield_data_to_tzyx -from ._index_search import _search_time_index +from ._index_search import GRID_SEARCH_ERROR, _search_time_index __all__ = ["Field", "VectorField"] @@ -353,21 +353,25 @@ def __getitem__(self, key): def _update_particle_states_position(particle, position): """Update the particle states based on the position dictionary.""" - if particle and "X" in position: # TODO also support uxgrid search - particle.state = np.maximum( - np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state - ) - particle.state = np.maximum( - np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state - ) - particle.state = np.maximum( - np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state), - particle.state, - ) - particle.state = np.maximum( - np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state), - particle.state, - ) + if particle: # TODO also support uxgrid search + for dim in ["X", "Y"]: + if dim in position: + particle.state = np.maximum( + np.where(position[dim][0] == -1, StatusCode.ErrorOutOfBounds, particle.state), particle.state + ) + particle.state = np.maximum( + np.where(position[dim][0] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particle.state), + particle.state, + ) + if "Z" in position: + particle.state = np.maximum( + np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state), + particle.state, + ) + particle.state = np.maximum( + np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state), + particle.state, + ) def _update_particle_states_interp_value(particle, value): From 0bce72e2195a79ecce30ffa608a1ca6356ead845 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Thu, 4 Sep 2025 13:39:59 +0200 Subject: [PATCH 02/16] Removing while-loop in _search_indices_curvilinear_2d Note that this breaks some of the unit tests --- parcels/_index_search.py | 41 ++++++++++------------------------------ 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index a141620d7..79244ded1 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -56,33 +56,17 @@ def _search_indices_curvilinear_2d( [1, -1, 1, -1], ] ) - maxIterSearch = 1e6 - it = 0 - tol = 1.0e-10 - # # ! Error handling for out of bounds - # TODO: Re-enable in some capacity - # if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: - # if grid.lon[0, 0] < grid.lon[0, -1]: - # _raise_grid_searching_error(y, x) - # elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] - # _raise_grid_searching_error(z, y, x) + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - # if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: - # _raise_grid_searching_error(z, y, x) + a, b = np.dot(invA, px), np.dot(invA, py) + aa = a[3] * b[2] - a[2] * b[3] + bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] + cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] + det2 = bb * bb - 4 * aa * cc - while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol): - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - a = np.dot(invA, px) - b = np.dot(invA, py) - - aa = a[3] * b[2] - a[2] * b[3] - bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] - cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] - - det2 = bb * bb - 4 * aa * cc + with np.errstate(divide="ignore", invalid="ignore"): det = np.where(det2 > 0, np.sqrt(det2), eta) eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta)) @@ -92,17 +76,12 @@ def _search_indices_curvilinear_2d( (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta), ) - xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi)) - yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi)) - - (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh) - it += 1 - if it > maxIterSearch: - print(f"Correct cell not found after {maxIterSearch} iterations") + (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh) # checking if xsi or eta is outside [0, 1] xi = np.where(xsi < 0, GRID_SEARCH_ERROR, np.where(xsi > 1, GRID_SEARCH_ERROR, xi)) yi = np.where(eta < 0, GRID_SEARCH_ERROR, np.where(eta > 1, GRID_SEARCH_ERROR, yi)) + return (yi, eta, xi, xsi) From 9035be91c7a0f582bb068cab316fc491f9e40a0c Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Sun, 7 Sep 2025 18:33:27 -0400 Subject: [PATCH 03/16] Switch hash table build to use source grid overlaps; add point-in-cell check The query method now requires a point in cell check method to be sent in as an argument --- parcels/_index_search.py | 37 ++++- parcels/spatialhash.py | 288 ++++++++++++++++++++++++++------------- 2 files changed, 231 insertions(+), 94 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 79244ded1..87f12650b 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -40,12 +40,47 @@ def _search_time_index(field: Field, time: datetime): return np.atleast_1d(tau), np.atleast_1d(ti) +def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray): + xsi = eta = -1.0 * np.ones(len(x), dtype=float) + invA = np.array( + [ + [1, 0, 0, 0], + [-1, 1, 0, 0], + [-1, 0, 0, 1], + [1, -1, 1, -1], + ] + ) + + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) + + a, b = np.dot(invA, px), np.dot(invA, py) + aa = a[3] * b[2] - a[2] * b[3] + bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] + cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] + det2 = bb * bb - 4 * aa * cc + + with np.errstate(divide="ignore", invalid="ignore"): + det = np.where(det2 > 0, np.sqrt(det2), eta) + eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta)) + + xsi = np.where( + abs(a[1] + a[3] * eta) < 1e-12, + ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5, + (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta), + ) + + is_in_cell = np.where((xsi >= 0) & (xsi <= 1) & (eta >= 0) & (eta <= 1), 1, 0) + + return is_in_cell + + def _search_indices_curvilinear_2d( grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None ): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays yi, xi = yi_guess, xi_guess if yi is None or xi is None: - yi, xi = grid.get_spatial_hash().query(y, x) + yi, xi = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) xsi = eta = -1.0 * np.ones(len(x), dtype=float) invA = np.array( diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 76447bc23..0e0dcd2f2 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -23,12 +23,14 @@ def __init__( self, grid, reconstruct=False, + bitwidth=1023, ): # TODO : Enforce grid to be an instance of parcels.xgrid.XGrid # Currently, this is not done due to circular import with parcels.xgrid self._source_grid = grid self.reconstruct = reconstruct + self._bitwidth = bitwidth # Max integer to use per coordinate in quantization (10 bits = 0..1023) if self._source_grid._mesh == "spherical": # Boundaries of the hash grid are the unit cube @@ -69,9 +71,12 @@ def __init__( axis=-1, ) # Compute centroid locations of each cells - self._xc = np.mean(_xbound, axis=-1) - self._yc = np.mean(_ybound, axis=-1) - self._zc = np.mean(_zbound, axis=-1) + self._xlow = np.min(_xbound, axis=-1) + self._xhigh = np.max(_xbound, axis=-1) + self._ylow = np.min(_ybound, axis=-1) + self._yhigh = np.max(_ybound, axis=-1) + self._zlow = np.min(_zbound, axis=-1) + self._zhigh = np.max(_zbound, axis=-1) else: # Boundaries of the hash grid are the bounding box of the source grid @@ -104,9 +109,12 @@ def __init__( axis=-1, ) # Compute centroid locations of each cells - self._xc = np.mean(_xbound, axis=-1) - self._yc = np.mean(_ybound, axis=-1) - self._zc = np.zeros_like(self._xc) + self._xlow = np.min(_xbound, axis=-1) + self._xhigh = np.max(_xbound, axis=-1) + self._ylow = np.min(_ybound, axis=-1) + self._yhigh = np.max(_ybound, axis=-1) + self._zlow = np.zeros_like(self._xlow) + self._zhigh = np.zeros_like(self._xlow) # Generate the mapping from the hash indices to unstructured grid elements self._hash_table = None @@ -117,17 +125,72 @@ def _initialize_hash_table(self): which faces overlap with which hash cells """ if self._hash_table is None or self.reconstruct: - j, i = np.indices(self._xc.shape) # Get the indices of the curvilinear grid + # j, i = np.indices(self._xlow.shape) # Get the indices of the curvilinear grid + + # Quantize the bounding box in each direction + xqlow, yqlow, zqlow = quantize_coordinates( + self._xlow, + self._ylow, + self._zlow, + self._xmin, + self._xmax, + self._ymin, + self._ymax, + self._zmin, + self._zmax, + self._bitwidth, + ) - morton_codes = _encode_morton3d( - self._xc, self._yc, self._zc, self._xmin, self._xmax, self._ymin, self._ymax, self._zmin, self._zmax + xqhigh, yqhigh, zqhigh = quantize_coordinates( + self._xhigh, + self._yhigh, + self._zhigh, + self._xmin, + self._xmax, + self._ymin, + self._ymax, + self._zmin, + self._zmax, + self._bitwidth, ) - ## Prepare quick lookup (hash) table for relating i,j indices to morton codes - # Sort i,j indices by morton code - order = np.argsort(morton_codes.ravel()) - morton_codes_sorted = morton_codes.ravel()[order] - i_sorted = i.ravel()[order] - j_sorted = j.ravel()[order] + xqlow = xqlow.ravel() + yqlow = yqlow.ravel() + zqlow = zqlow.ravel() + xqhigh = xqhigh.ravel() + yqhigh = yqhigh.ravel() + zqhigh = zqhigh.ravel() + nx = xqhigh - xqlow + 1 + ny = yqhigh - yqlow + 1 + nz = zqhigh - zqlow + 1 + num_hash_per_face = nx * ny * nz + total_hash_entries = np.sum(num_hash_per_face) + + morton_codes = np.zeros(total_hash_entries, dtype=np.uint32) + + # Compute the j, i indices corresponding to each hash entry + nface = np.size(self._xlow) + face_ids = np.repeat(np.arange(nface, dtype=np.int64), num_hash_per_face) + offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int64) + + for k in range(len(num_hash_per_face)): + if num_hash_per_face[k] == 0: + continue + start, end = offsets[k], offsets[k + 1] + # Local sizes + nxk, nyk, nzk = int(nx[k]), int(ny[k]), int(nz[k]) + + # Build the Cartesian product + xq_block = xqlow[k] + np.repeat(np.arange(nxk), nyk * nzk) + yq_block = yqlow[k] + np.tile(np.repeat(np.arange(nyk), nzk), nxk) + zq_block = zqlow[k] + np.tile(np.arange(nzk), nxk * nyk) + + morton_codes[start:end] = _encode_quantized_morton3d(xq_block, yq_block, zq_block) + + # Sort face indices by morton code + order = np.argsort(morton_codes) + morton_codes_sorted = morton_codes[order] + face_sorted = face_ids[order] + j_sorted, i_sorted = np.unravel_index(face_sorted, self._xlow.shape) # Get a list of unique morton codes and their corresponding starts and counts (CSR format) keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True) @@ -140,11 +203,7 @@ def _initialize_hash_table(self): } return hash_table - def query( - self, - y, - x, - ): + def query(self, y, x, point_in_cell): """ Queries the hash table and finds the closes face in the source grid for each coordinate pair. @@ -167,10 +226,6 @@ def query( i = self._hash_table["i"] j = self._hash_table["j"] - xc = self._xc - yc = self._yc - zc = self._zc - y = np.asarray(y) x = np.asarray(x) if self._source_grid._mesh == "spherical": @@ -190,10 +245,10 @@ def query( num_queries = query_codes.size # Locate each query in the unique key array - pos = np.searchsorted(keys, query_codes) # pos is shape (N,) + pos = np.searchsorted(keys, query_codes) # pos is shape (num_queries,) - # Valid hits: inside range with finite query coordinates - valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y) + # Valid hits: inside range with finite query coordinates and query codes give exact morton code match. + valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y) & (query_codes == keys[pos]) # Pre-allocate i and j indices of the best match for each query # Default values to -1 (no match case) @@ -201,69 +256,65 @@ def query( i_best = np.full(num_queries, -1, dtype=np.int64) # How many matches each query has; hit_counts[i] is the number of hits for query i - hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (N,) + hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (num_queries,) if hit_counts.sum() == 0: return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) - # CSR-style offsets (prefix sum), total number of hits - offsets = np.empty(hit_counts.size + 1, dtype=np.int64) - offsets[0] = 0 - np.cumsum(hit_counts, out=offsets[1:]) - total = int(offsets[-1]) + # Now, for each query, we need to gather the candidate (j,i) indices from the hash table + # Each j,i pair needs to be repeated hit_counts[i] times, only when there are hits. + + # Boolean array for keeping track of which queries have candidates + has_hits = hit_counts > 0 # shape (num_queries,), True for queries that had candidates - # Now, we need to create some quick lookup arrays that give us the list of positions in the hash table - # that correspond to each query. - # Create a quick lookup array that maps each element of all the valid queries (with repeats) to its query index - q_index_for_elem = np.repeat(np.arange(num_queries, dtype=np.int64), hit_counts) # This has shape (total,) + # A quick lookup array that maps all candindates back to its query index + q_index_for_candidate = np.repeat( + np.arange(num_queries, dtype=np.int64), hit_counts + ) # shape (hit_counts.sum(),) + # Map all candidates to positions in the hash table + hash_positions = pos[q_index_for_candidate] # shape (hit_counts.sum(),) - # For each element, compute its "intra-group" offset (0..hits_i-1). - intra = np.arange(total, dtype=np.int64) - np.repeat(offsets[:-1], hit_counts) + # Now that we have the positions in the hash table for each table, we can gather the (j,i) pairs for each candidate + # We do this in a vectorized way by using a CSR-like approach + # starts[pos[q_index_for_candidate]] gives the starting point in the hash table for each candidate + # hit_counts gives the number of candidates for each query - # starts[pos[q_index_for_elem]] + intra gives a list of positions in the hash table that we can + # We need to build an array that gives the offset within each query's candidates + offsets = np.concatenate(([0], np.cumsum(hit_counts))).astype(np.int64) # shape (num_queries+1,) + total = int(offsets[-1]) # total number of candidates across all queries + + # Now, for each candidate, we need a simple array that tells us its "local candidate id" within its query + # This way, we can easily take the starts[pos[q_index_for_candidate]] and add this local id to get the absolute index + # We calculate this by computing the "global candidate number" (0..total-1) and subtracting the offsets of the corresponding query + # This gives us an array that goes from 0..hit_counts[i]-1 for each query i + intra = np.arange(total, dtype=np.int64) - np.repeat(offsets[:-1], hit_counts) # shape (hit_counts.sum(),) + + # starts[pos[q_index_for_candidate]] + intra gives a list of positions in the hash table that we can # use to quickly gather the (i,j) pairs for each query - source_idx = starts[pos[q_index_for_elem]].astype(np.int64) + intra + source_idx = starts[hash_positions].astype(np.int64) + intra - # Gather all (j,i) pairs in one shot + # Gather all candidate (j,i) pairs in one shot j_all = j[source_idx] i_all = i[source_idx] - # Segment-wise minima per query using reduceat - # For each query, we need to find the minimum distance. - if total == 1: - # Build absolute source index for the winning candidate in each query - start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries - src_best = (start_for_q).astype(np.int64) - else: - # Gather centroid coordinates at those (j,i) - xc_all = xc[j_all, i_all] - yc_all = yc[j_all, i_all] - zc_all = zc[j_all, i_all] - - # Broadcast to flat (same as q_flat), then repeat per candidate - # This makes it easy to compute distances from the query points - # to the candidate points for minimization. - qx_all = np.repeat(qx.ravel(), hit_counts) - qy_all = np.repeat(qy.ravel(), hit_counts) - qz_all = np.repeat(qz.ravel(), hit_counts) - - # Squared distances for all candidates - dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2 - - dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1]) - # To get argmin indices per query (without loops): - # Build a masked "within-index" array that is large unless it equals the segment-min. - big = np.iinfo(np.int64).max - within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big) - argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties - - # Build absolute source index for the winning candidate in each query - start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries - src_best = (start_for_q + argmin_within).astype(np.int64) - - # Write outputs only for queries that had candidates - has_hits = hit_counts > 0 - j_best[has_hits] = j[src_best[has_hits]] - i_best[has_hits] = i[src_best[has_hits]] + # Now we need to construct arrays that repeats the y and x coordinates for each candidate + # to enable vectorized point-in-cell checks + y_rep = np.repeat(y, hit_counts) # shape (hit_counts.sum(),) + x_rep = np.repeat(x, hit_counts) # shape (hit_counts.sum(),) + + # For each query we perform a point in cell check. + is_in_face = point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) + + # For each query that has hits, we need to find the first candidate that was inside the face + for q in range(num_queries): + if has_hits[q]: + # Masked array for the current query + mask = is_in_face[offsets[q] : offsets[q + 1]] + if mask.any(): + # Find the candidate face that contains the query point + argmin_within = mask.argmax() + # Store the result + j_best[q] = j_all[offsets[q] + argmin_within] + i_best[q] = i_all[offsets[q] + argmin_within] return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) @@ -370,15 +421,14 @@ def _dilate_bits(n): return n -def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax): +def quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1023): """ - Quantize (x, y, z) to 10 bits each (0..1023), dilate the bits so there are - two zeros between successive bits, and interleave them into a 3D Morton code. + Normalize (x, y, z) to [0, 1] over their bounding box, then quantize to 10 bits each (0..1023). Parameters ---------- x, y, z : array_like - Input coordinates to encode. Can be scalars or arrays (broadcasting applies). + Input coordinates to quantize. Can be scalars or arrays (broadcasting applies). xmin, xmax : float Minimum and maximum bounds for x coordinate. ymin, ymax : float @@ -388,13 +438,8 @@ def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax): Returns ------- - code : ndarray, dtype=uint32 - The resulting Morton codes, same shape as the broadcasted input coordinates. - - Notes - ----- - - Works with scalars or NumPy arrays (broadcasting applies). - - Output is up to 30 bits returned as uint32. + xq, yq, zq : ndarray, dtype=uint32 + The quantized coordinates, each in range [0, 1023], same shape as the broadcasted input coordinates. """ # Convert inputs to ndarray for consistent dtype/ufunc behavior. x = np.asarray(x) @@ -415,9 +460,66 @@ def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax): # --- 2) Quantize to 10 bits (0..1023). --- # Multiply by 1023, round down, and clip to be safe against slight overshoot. - xq = np.clip((xn * 1023.0).astype(np.uint32), 0, 1023) - yq = np.clip((yn * 1023.0).astype(np.uint32), 0, 1023) - zq = np.clip((zn * 1023.0).astype(np.uint32), 0, 1023) + xq = np.clip((xn * bitwidth).astype(np.uint32), 0, bitwidth) + yq = np.clip((yn * bitwidth).astype(np.uint32), 0, bitwidth) + zq = np.clip((zn * bitwidth).astype(np.uint32), 0, bitwidth) + + return xq, yq, zq + + +def _encode_quantized_morton3d(xq, yq, zq): + xq = np.asarray(xq) + yq = np.asarray(yq) + zq = np.asarray(zq) + + # --- 3) Bit-dilate each 10-bit number so each bit is separated by two zeros. --- + # _dilate_bits maps: b9..b0 -> b9 0 0 b8 0 0 ... b0 0 0 + dx3 = _dilate_bits(xq).astype(np.uint64) + dy3 = _dilate_bits(yq).astype(np.uint64) + dz3 = _dilate_bits(zq).astype(np.uint64) + + # --- 4) Interleave the dilated bits into a single Morton code. --- + # Bit layout (from LSB upward): x0,y0,z0, x1,y1,z1, ..., x9,y9,z9 + # We shift z's bits by 2, y's by 1, x stays at 0, then OR them together. + # Cast to a wide type before shifting/OR to be safe when arrays are used. + code = (dz3 << 2) | (dy3 << 1) | dx3 + + # Since our compact type fits in 30 bits, uint32 is enough. + return code.astype(np.uint32) + + +def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1023): + """ + Quantize (x, y, z) to 10 bits each (0..1023), dilate the bits so there are + two zeros between successive bits, and interleave them into a 3D Morton code. + + Parameters + ---------- + x, y, z : array_like + Input coordinates to encode. Can be scalars or arrays (broadcasting applies). + xmin, xmax : float + Minimum and maximum bounds for x coordinate. + ymin, ymax : float + Minimum and maximum bounds for y coordinate. + zmin, zmax : float + Minimum and maximum bounds for z coordinate. + + Returns + ------- + code : ndarray, dtype=uint32 + The resulting Morton codes, same shape as the broadcasted input coordinates. + + Notes + ----- + - Works with scalars or NumPy arrays (broadcasting applies). + - Output is up to 30 bits returned as uint32. + """ + # Convert inputs to ndarray for consistent dtype/ufunc behavior. + x = np.asarray(x) + y = np.asarray(y) + z = np.asarray(z) + + xq, yq, zq = quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth) # --- 3) Bit-dilate each 10-bit number so each bit is separated by two zeros. --- # _dilate_bits maps: b9..b0 -> b9 0 0 b8 0 0 ... b0 0 0 From f9e044b7b73b616932c22aa9ff177aec4cccdc73 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Sun, 7 Sep 2025 18:53:47 -0400 Subject: [PATCH 04/16] Set nemo mesh type to spherical This ensures overlaps are properly calculated across antimeridian --- tests/v4/test_index_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v4/test_index_search.py b/tests/v4/test_index_search.py index 5606913fd..994fe22dc 100644 --- a/tests/v4/test_index_search.py +++ b/tests/v4/test_index_search.py @@ -63,7 +63,7 @@ def test_indexing_nemo_curvilinear(): {"glamf": "lon", "gphif": "lat", "z": "depth"} ) xgcm_grid = Grid(ds, coords={"X": {"left": "x"}, "Y": {"left": "y"}}, periodic=False) - grid = XGrid(xgcm_grid) + grid = XGrid(xgcm_grid, mesh="spherical") # Test points on the NEMO 1/4 degree curvilinear grid lats = np.array([-30, 0, 88]) From 0cbbbc9f8176666881bb281e9f6999434cb25508 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Mon, 8 Sep 2025 21:00:33 -0400 Subject: [PATCH 05/16] Vectorize hash table initialization --- parcels/spatialhash.py | 73 +++++++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 0e0dcd2f2..638284436 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -164,27 +164,63 @@ def _initialize_hash_table(self): nz = zqhigh - zqlow + 1 num_hash_per_face = nx * ny * nz total_hash_entries = np.sum(num_hash_per_face) - morton_codes = np.zeros(total_hash_entries, dtype=np.uint32) # Compute the j, i indices corresponding to each hash entry nface = np.size(self._xlow) face_ids = np.repeat(np.arange(nface, dtype=np.int64), num_hash_per_face) - offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int64) - - for k in range(len(num_hash_per_face)): - if num_hash_per_face[k] == 0: - continue - start, end = offsets[k], offsets[k + 1] - # Local sizes - nxk, nyk, nzk = int(nx[k]), int(ny[k]), int(nz[k]) - - # Build the Cartesian product - xq_block = xqlow[k] + np.repeat(np.arange(nxk), nyk * nzk) - yq_block = yqlow[k] + np.tile(np.repeat(np.arange(nyk), nzk), nxk) - zq_block = zqlow[k] + np.tile(np.arange(nzk), nxk * nyk) - - morton_codes[start:end] = _encode_quantized_morton3d(xq_block, yq_block, zq_block) + offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int64)[:-1] + + valid = num_hash_per_face != 0 + if not np.any(valid): + # nothing to do + pass + else: + # Grab only valid faces to avoid empty arrays + nx_v = np.asarray(nx[valid], dtype=np.int64) + ny_v = np.asarray(ny[valid], dtype=np.int64) + nz_v = np.asarray(nz[valid], dtype=np.int64) + xlow_v = np.asarray(xqlow[valid], dtype=np.int64) + ylow_v = np.asarray(yqlow[valid], dtype=np.int64) + zlow_v = np.asarray(zqlow[valid], dtype=np.int64) + starts_v = np.asarray(offsets[valid], dtype=np.int64) + + # Count of elements per valid face (should match num_hash_per_face[valid]) + counts = (nx_v * ny_v * nz_v).astype(np.int64) + total = int(counts.sum()) + + # Map each global element to its face and output position + start_for_elem = np.repeat(starts_v, counts) # shape (total,) + + # Intra-face linear index for each element (0..counts_i-1) + # Offsets per face within the concatenation of valid faces: + face_starts_local = np.cumsum(np.r_[0, counts[:-1]]) + intra = np.arange(total, dtype=np.int64) - np.repeat(face_starts_local, counts) + + # Derive (zi, yi, xi) from intra using per-face sizes + ny_nz = np.repeat(ny_v * nz_v, counts) + nz_rep = np.repeat(nz_v, counts) + + xi = intra // ny_nz + rem = intra % ny_nz + yi = rem // nz_rep + zi = rem % nz_rep + + # Add per-face lows + x0 = np.repeat(xlow_v, counts) + y0 = np.repeat(ylow_v, counts) + z0 = np.repeat(zlow_v, counts) + + xq = x0 + xi + yq = y0 + yi + zq = z0 + zi + + # Vectorized morton encode for all elements at once + codes_all = _encode_quantized_morton3d(xq, yq, zq) + + # Scatter into the preallocated output using computed absolute indices + out_idx = start_for_elem + intra + morton_codes[out_idx] = codes_all # Sort face indices by morton code order = np.argsort(morton_codes) @@ -194,6 +230,7 @@ def _initialize_hash_table(self): # Get a list of unique morton codes and their corresponding starts and counts (CSR format) keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True) + hash_table = { "keys": keys, "starts": starts, @@ -458,8 +495,8 @@ def quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1 yn = np.where(dy != 0, (y - ymin) / dy, 0.0) zn = np.where(dz != 0, (z - zmin) / dz, 0.0) - # --- 2) Quantize to 10 bits (0..1023). --- - # Multiply by 1023, round down, and clip to be safe against slight overshoot. + # --- 2) Quantize to (0..bitwidth). --- + # Multiply by bitwidth, round down, and clip to be safe against slight overshoot. xq = np.clip((xn * bitwidth).astype(np.uint32), 0, bitwidth) yq = np.clip((yn * bitwidth).astype(np.uint32), 0, bitwidth) zq = np.clip((zn * bitwidth).astype(np.uint32), 0, bitwidth) From 59eecf75cea6acb08afbd0ea20dad8b27455b007 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Mon, 8 Sep 2025 21:08:07 -0400 Subject: [PATCH 06/16] Add point_in_cell method to query call --- tests/v4/test_spatialhash.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/v4/test_spatialhash.py b/tests/v4/test_spatialhash.py index c9026e7af..af967af46 100644 --- a/tests/v4/test_spatialhash.py +++ b/tests/v4/test_spatialhash.py @@ -1,6 +1,7 @@ import numpy as np from parcels._datasets.structured.generic import datasets +from parcels._index_search import curvilinear_point_in_cell from parcels.xgrid import XGrid @@ -15,7 +16,7 @@ def test_invalid_positions(): ds = datasets["2d_left_rotated"] grid = XGrid.from_dataset(ds) - j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf]) + j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf], curvilinear_point_in_cell) assert np.all(j == -1) assert np.all(i == -1) @@ -27,7 +28,7 @@ def test_mixed_positions(): lon = grid.lon.mean() y = [lat, np.nan] x = [lon, np.nan] - j, i = grid.get_spatial_hash().query(y, x) + j, i = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) assert j[0] == 29 # Actual value for 2d_left_rotated center assert i[0] == 14 # Actual value for 2d_left_rotated center assert j[1] == -1 From 4c9661c31c493c8fc5f95eafa8ba1f3cf22dc470 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Mon, 8 Sep 2025 21:51:07 -0400 Subject: [PATCH 07/16] Vectorize final query loop in j,i array construction --- parcels/spatialhash.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 638284436..b945f3b8d 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -342,16 +342,21 @@ def query(self, y, x, point_in_cell): is_in_face = point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) # For each query that has hits, we need to find the first candidate that was inside the face - for q in range(num_queries): - if has_hits[q]: - # Masked array for the current query - mask = is_in_face[offsets[q] : offsets[q + 1]] - if mask.any(): - # Find the candidate face that contains the query point - argmin_within = mask.argmax() - # Store the result - j_best[q] = j_all[offsets[q] + argmin_within] - i_best[q] = i_all[offsets[q] + argmin_within] + f_indices = np.flatnonzero(is_in_face) # Indices of all faces that contained the point + # For each true position, find which query it belongs to by searching offsets + # Query index q satisfies offsets[q] <= pos < offsets[q+1]. + q = np.searchsorted(offsets[1:], f_indices, side="right") + + uniq_q, q_idx = np.unique(q, return_index=True) + keep = has_hits[uniq_q] + + if keep.any(): + uniq_q = uniq_q[keep] + pos_first = f_indices[q_idx[keep]] + + # Directly scatter: the code wants the first True inside each slice + j_best[uniq_q] = j_all[pos_first] + i_best[uniq_q] = i_all[pos_first] return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) From 9a3707f60da7fcb000b75246ddfa2a05ac28b743 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 9 Sep 2025 13:35:46 -0400 Subject: [PATCH 08/16] Move _barycentric_coordinates function to uxgrid This function is now only needed within the uxgrid module --- parcels/spatialhash.py | 45 ----------------------------------------- parcels/uxgrid.py | 46 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 46 deletions(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index b945f3b8d..b45b42c3b 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -361,51 +361,6 @@ def query(self, y, x, point_in_cell): return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) -def _triangle_area(A, B, C): - """Compute the area of a triangle given by three points.""" - d1 = B - A - d2 = C - A - d3 = np.cross(d1, d2) - return 0.5 * np.linalg.norm(d3) - - -def _barycentric_coordinates(nodes, point, min_area=1e-8): - """ - Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. - So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized - barycentric coordinates, which is only valid for convex polygons. - - Parameters - ---------- - nodes : numpy.ndarray - Spherical coordinates (lat,lon) of each corner node of a face - point : numpy.ndarray - Spherical coordinates (lat,lon) of the point - - Returns - ------- - numpy.ndarray - Barycentric coordinates corresponding to each vertex. - - """ - n = len(nodes) - sum_wi = 0 - w = [] - - for i in range(0, n): - vim1 = nodes[i - 1] - vi = nodes[i] - vi1 = nodes[(i + 1) % n] - a0 = _triangle_area(vim1, vi, vi1) - a1 = max(_triangle_area(point, vim1, vi), min_area) - a2 = max(_triangle_area(point, vi, vi1), min_area) - sum_wi += a0 / (a1 * a2) - w.append(a0 / (a1 * a2)) - barycentric_coords = [w_i / sum_wi for w_i in w] - - return barycentric_coords - - def _latlon_rad_to_xyz(lat, lon): """Converts Spherical latitude and longitude coordinates into Cartesian x, y, z coordinates. diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index bc8014f76..a82037ef3 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -6,7 +6,6 @@ import uxarray as ux from parcels._typing import assert_valid_mesh -from parcels.spatialhash import _barycentric_coordinates from parcels.tools.statuscodes import FieldOutOfBoundError from parcels.xgrid import _search_1d_array @@ -164,3 +163,48 @@ def _lonlat_rad_to_xyz( z = np.sin(lat) return x, y, z + + +def _triangle_area(A, B, C): + """Compute the area of a triangle given by three points.""" + d1 = B - A + d2 = C - A + d3 = np.cross(d1, d2) + return 0.5 * np.linalg.norm(d3) + + +def _barycentric_coordinates(nodes, point, min_area=1e-8): + """ + Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. + So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized + barycentric coordinates, which is only valid for convex polygons. + + Parameters + ---------- + nodes : numpy.ndarray + Spherical coordinates (lat,lon) of each corner node of a face + point : numpy.ndarray + Spherical coordinates (lat,lon) of the point + + Returns + ------- + numpy.ndarray + Barycentric coordinates corresponding to each vertex. + + """ + n = len(nodes) + sum_wi = 0 + w = [] + + for i in range(0, n): + vim1 = nodes[i - 1] + vi = nodes[i] + vi1 = nodes[(i + 1) % n] + a0 = _triangle_area(vim1, vi, vi1) + a1 = max(_triangle_area(point, vim1, vi), min_area) + a2 = max(_triangle_area(point, vi, vi1), min_area) + sum_wi += a0 / (a1 * a2) + w.append(a0 / (a1 * a2)) + barycentric_coords = [w_i / sum_wi for w_i in w] + + return barycentric_coords From 3c7507c44d082c3ed61c71314b5a83f056a9ec02 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 9 Sep 2025 13:55:47 -0400 Subject: [PATCH 09/16] Drop uint64 to uint32 to save memory (and dilated bits don't need uint64) --- parcels/spatialhash.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index b45b42c3b..de3662882 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -471,9 +471,9 @@ def _encode_quantized_morton3d(xq, yq, zq): # --- 3) Bit-dilate each 10-bit number so each bit is separated by two zeros. --- # _dilate_bits maps: b9..b0 -> b9 0 0 b8 0 0 ... b0 0 0 - dx3 = _dilate_bits(xq).astype(np.uint64) - dy3 = _dilate_bits(yq).astype(np.uint64) - dz3 = _dilate_bits(zq).astype(np.uint64) + dx3 = _dilate_bits(xq).astype(np.uint32) + dy3 = _dilate_bits(yq).astype(np.uint32) + dz3 = _dilate_bits(zq).astype(np.uint32) # --- 4) Interleave the dilated bits into a single Morton code. --- # Bit layout (from LSB upward): x0,y0,z0, x1,y1,z1, ..., x9,y9,z9 @@ -520,9 +520,9 @@ def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1023) # --- 3) Bit-dilate each 10-bit number so each bit is separated by two zeros. --- # _dilate_bits maps: b9..b0 -> b9 0 0 b8 0 0 ... b0 0 0 - dx3 = _dilate_bits(xq).astype(np.uint64) - dy3 = _dilate_bits(yq).astype(np.uint64) - dz3 = _dilate_bits(zq).astype(np.uint64) + dx3 = _dilate_bits(xq).astype(np.uint32) + dy3 = _dilate_bits(yq).astype(np.uint32) + dz3 = _dilate_bits(zq).astype(np.uint32) # --- 4) Interleave the dilated bits into a single Morton code. --- # Bit layout (from LSB upward): x0,y0,z0, x1,y1,z1, ..., x9,y9,z9 From 0be54db304b2835a862c296b8039d98391bc1b3a Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Tue, 9 Sep 2025 13:59:31 -0400 Subject: [PATCH 10/16] Drop temp variables to int32 We're not anticipating having beyond 2^31 faces in problems for the forseeable future. --- parcels/spatialhash.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index de3662882..9b58d62b9 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -168,8 +168,8 @@ def _initialize_hash_table(self): # Compute the j, i indices corresponding to each hash entry nface = np.size(self._xlow) - face_ids = np.repeat(np.arange(nface, dtype=np.int64), num_hash_per_face) - offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int64)[:-1] + face_ids = np.repeat(np.arange(nface, dtype=np.int32), num_hash_per_face) + offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int32)[:-1] valid = num_hash_per_face != 0 if not np.any(valid): @@ -177,16 +177,16 @@ def _initialize_hash_table(self): pass else: # Grab only valid faces to avoid empty arrays - nx_v = np.asarray(nx[valid], dtype=np.int64) - ny_v = np.asarray(ny[valid], dtype=np.int64) - nz_v = np.asarray(nz[valid], dtype=np.int64) - xlow_v = np.asarray(xqlow[valid], dtype=np.int64) - ylow_v = np.asarray(yqlow[valid], dtype=np.int64) - zlow_v = np.asarray(zqlow[valid], dtype=np.int64) - starts_v = np.asarray(offsets[valid], dtype=np.int64) + nx_v = np.asarray(nx[valid], dtype=np.int32) + ny_v = np.asarray(ny[valid], dtype=np.int32) + nz_v = np.asarray(nz[valid], dtype=np.int32) + xlow_v = np.asarray(xqlow[valid], dtype=np.int32) + ylow_v = np.asarray(yqlow[valid], dtype=np.int32) + zlow_v = np.asarray(zqlow[valid], dtype=np.int32) + starts_v = np.asarray(offsets[valid], dtype=np.int32) # Count of elements per valid face (should match num_hash_per_face[valid]) - counts = (nx_v * ny_v * nz_v).astype(np.int64) + counts = (nx_v * ny_v * nz_v).astype(np.int32) total = int(counts.sum()) # Map each global element to its face and output position @@ -195,7 +195,7 @@ def _initialize_hash_table(self): # Intra-face linear index for each element (0..counts_i-1) # Offsets per face within the concatenation of valid faces: face_starts_local = np.cumsum(np.r_[0, counts[:-1]]) - intra = np.arange(total, dtype=np.int64) - np.repeat(face_starts_local, counts) + intra = np.arange(total, dtype=np.int32) - np.repeat(face_starts_local, counts) # Derive (zi, yi, xi) from intra using per-face sizes ny_nz = np.repeat(ny_v * nz_v, counts) @@ -289,11 +289,11 @@ def query(self, y, x, point_in_cell): # Pre-allocate i and j indices of the best match for each query # Default values to -1 (no match case) - j_best = np.full(num_queries, -1, dtype=np.int64) - i_best = np.full(num_queries, -1, dtype=np.int64) + j_best = np.full(num_queries, -1, dtype=np.int32) + i_best = np.full(num_queries, -1, dtype=np.int32) # How many matches each query has; hit_counts[i] is the number of hits for query i - hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (num_queries,) + hit_counts = np.where(valid, counts[pos], 0).astype(np.int32) # has shape (num_queries,) if hit_counts.sum() == 0: return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) @@ -305,7 +305,7 @@ def query(self, y, x, point_in_cell): # A quick lookup array that maps all candindates back to its query index q_index_for_candidate = np.repeat( - np.arange(num_queries, dtype=np.int64), hit_counts + np.arange(num_queries, dtype=np.int32), hit_counts ) # shape (hit_counts.sum(),) # Map all candidates to positions in the hash table hash_positions = pos[q_index_for_candidate] # shape (hit_counts.sum(),) @@ -316,18 +316,18 @@ def query(self, y, x, point_in_cell): # hit_counts gives the number of candidates for each query # We need to build an array that gives the offset within each query's candidates - offsets = np.concatenate(([0], np.cumsum(hit_counts))).astype(np.int64) # shape (num_queries+1,) + offsets = np.concatenate(([0], np.cumsum(hit_counts))).astype(np.int32) # shape (num_queries+1,) total = int(offsets[-1]) # total number of candidates across all queries # Now, for each candidate, we need a simple array that tells us its "local candidate id" within its query # This way, we can easily take the starts[pos[q_index_for_candidate]] and add this local id to get the absolute index # We calculate this by computing the "global candidate number" (0..total-1) and subtracting the offsets of the corresponding query # This gives us an array that goes from 0..hit_counts[i]-1 for each query i - intra = np.arange(total, dtype=np.int64) - np.repeat(offsets[:-1], hit_counts) # shape (hit_counts.sum(),) + intra = np.arange(total, dtype=np.int32) - np.repeat(offsets[:-1], hit_counts) # shape (hit_counts.sum(),) # starts[pos[q_index_for_candidate]] + intra gives a list of positions in the hash table that we can # use to quickly gather the (i,j) pairs for each query - source_idx = starts[hash_positions].astype(np.int64) + intra + source_idx = starts[hash_positions].astype(np.int32) + intra # Gather all candidate (j,i) pairs in one shot j_all = j[source_idx] From dfbb99122876af42803471ecdabdf8fc3c1306dd Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Wed, 10 Sep 2025 09:43:54 -0400 Subject: [PATCH 11/16] Return coordinates from PIC search and remove redundant coordinate calculation --- parcels/_index_search.py | 38 ++++++------------------------------ parcels/spatialhash.py | 17 +++++++++++++--- tests/v4/test_spatialhash.py | 4 ++-- 3 files changed, 22 insertions(+), 37 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 87f12650b..f98193579 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -72,44 +72,18 @@ def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray is_in_cell = np.where((xsi >= 0) & (xsi <= 1) & (eta >= 0) & (eta <= 1), 1, 0) - return is_in_cell + return is_in_cell, np.column_stack((xsi, eta)) def _search_indices_curvilinear_2d( - grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None -): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays + grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None +): yi, xi = yi_guess, xi_guess if yi is None or xi is None: - yi, xi = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) + yi, xi, coords = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) - xsi = eta = -1.0 * np.ones(len(x), dtype=float) - invA = np.array( - [ - [1, 0, 0, 0], - [-1, 1, 0, 0], - [-1, 0, 0, 1], - [1, -1, 1, -1], - ] - ) - - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - - a, b = np.dot(invA, px), np.dot(invA, py) - aa = a[3] * b[2] - a[2] * b[3] - bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] - cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] - det2 = bb * bb - 4 * aa * cc - - with np.errstate(divide="ignore", invalid="ignore"): - det = np.where(det2 > 0, np.sqrt(det2), eta) - eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta)) - - xsi = np.where( - abs(a[1] + a[3] * eta) < 1e-12, - ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5, - (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta), - ) + xsi = coords[:, 0] + eta = coords[:, 1] (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 9b58d62b9..36b049889 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -295,7 +295,11 @@ def query(self, y, x, point_in_cell): # How many matches each query has; hit_counts[i] is the number of hits for query i hit_counts = np.where(valid, counts[pos], 0).astype(np.int32) # has shape (num_queries,) if hit_counts.sum() == 0: - return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) + return ( + j_best.reshape(query_codes.shape), + i_best.reshape(query_codes.shape), + np.full((num_queries, 2), -1.0, dtype=np.float32), + ) # Now, for each query, we need to gather the candidate (j,i) indices from the hash table # Each j,i pair needs to be repeated hit_counts[i] times, only when there are hits. @@ -339,7 +343,9 @@ def query(self, y, x, point_in_cell): x_rep = np.repeat(x, hit_counts) # shape (hit_counts.sum(),) # For each query we perform a point in cell check. - is_in_face = point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) + is_in_face, coordinates = point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) + + coords_best = np.full((num_queries, coordinates.shape[1]), -1.0, dtype=np.float32) # For each query that has hits, we need to find the first candidate that was inside the face f_indices = np.flatnonzero(is_in_face) # Indices of all faces that contained the point @@ -357,8 +363,13 @@ def query(self, y, x, point_in_cell): # Directly scatter: the code wants the first True inside each slice j_best[uniq_q] = j_all[pos_first] i_best[uniq_q] = i_all[pos_first] + coords_best[uniq_q] = coordinates[pos_first] - return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) + return ( + j_best.reshape(query_codes.shape), + i_best.reshape(query_codes.shape), + coords_best.reshape((num_queries, coordinates.shape[1])), + ) def _latlon_rad_to_xyz(lat, lon): diff --git a/tests/v4/test_spatialhash.py b/tests/v4/test_spatialhash.py index af967af46..ea0d8ed7b 100644 --- a/tests/v4/test_spatialhash.py +++ b/tests/v4/test_spatialhash.py @@ -16,7 +16,7 @@ def test_invalid_positions(): ds = datasets["2d_left_rotated"] grid = XGrid.from_dataset(ds) - j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf], curvilinear_point_in_cell) + j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf], curvilinear_point_in_cell) assert np.all(j == -1) assert np.all(i == -1) @@ -28,7 +28,7 @@ def test_mixed_positions(): lon = grid.lon.mean() y = [lat, np.nan] x = [lon, np.nan] - j, i = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) + j, i, coords = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) assert j[0] == 29 # Actual value for 2d_left_rotated center assert i[0] == 14 # Actual value for 2d_left_rotated center assert j[1] == -1 From 95cb59318fb41ec1d03a6fd38b742e0b7d622ffb Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Wed, 10 Sep 2025 10:33:23 -0400 Subject: [PATCH 12/16] Add logic for using provided guess --- parcels/_index_search.py | 42 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index f98193579..e71414600 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -5,7 +5,6 @@ import numpy as np -from parcels._typing import Mesh from parcels.tools.statuscodes import _raise_time_extrapolation_error if TYPE_CHECKING: @@ -78,29 +77,36 @@ def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray def _search_indices_curvilinear_2d( grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None ): - yi, xi = yi_guess, xi_guess - if yi is None or xi is None: - yi, xi, coords = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) + yi_guess = np.array(yi_guess) + xi_guess = np.array(xi_guess) + xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32) + yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32) + if np.any(xi_guess): + # If an initial guess is provided, we first perform a point in cell check for all guessed indices + is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi_guess, xi_guess) + y_check = y[is_in_cell == 0] + x_check = x[is_in_cell == 0] + zero_indices = np.where(is_in_cell == 0)[0] + else: + # Otherwise, we need to check all points + y_check = y + x_check = x + coords = -1.0 * np.ones((len(y), 2), dtype=np.float32) + zero_indices = np.arange(len(y)) + + # If there are any points that were not found in the first step, we query the spatial hash for those points + if len(zero_indices) > 0: + yi_q, xi_q, coords_q = grid.get_spatial_hash().query(y_check, x_check, curvilinear_point_in_cell) + # Only those points that were not found in the first step are updated + coords[zero_indices, :] = coords_q + yi[zero_indices] = yi_q + xi[zero_indices] = xi_q xsi = coords[:, 0] eta = coords[:, 1] - (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh) - # checking if xsi or eta is outside [0, 1] xi = np.where(xsi < 0, GRID_SEARCH_ERROR, np.where(xsi > 1, GRID_SEARCH_ERROR, xi)) yi = np.where(eta < 0, GRID_SEARCH_ERROR, np.where(eta > 1, GRID_SEARCH_ERROR, yi)) return (yi, eta, xi, xsi) - - -def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh): - xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi) - xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi) - - xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi) - - yi = np.where(yi < 0, 0, yi) - yi = np.where(yi > ydim - 2, ydim - 2, yi) - - return yi, xi From 46c988e34e5b0a131de18dfb2ea12e63dfc9386f Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Wed, 10 Sep 2025 10:37:29 -0400 Subject: [PATCH 13/16] Set failed query indices to GRID_SEARCH_ERROR This eliminates the extra step needed to mask out grid search errors --- parcels/_index_search.py | 4 ---- parcels/spatialhash.py | 16 +++++++++++----- tests/v4/test_spatialhash.py | 8 ++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index e71414600..4b3c7831e 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -105,8 +105,4 @@ def _search_indices_curvilinear_2d( xsi = coords[:, 0] eta = coords[:, 1] - # checking if xsi or eta is outside [0, 1] - xi = np.where(xsi < 0, GRID_SEARCH_ERROR, np.where(xsi > 1, GRID_SEARCH_ERROR, xi)) - yi = np.where(eta < 0, GRID_SEARCH_ERROR, np.where(eta > 1, GRID_SEARCH_ERROR, yi)) - return (yi, eta, xi, xsi) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 36b049889..88c35ab94 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -1,5 +1,7 @@ import numpy as np +from parcels._index_search import GRID_SEARCH_ERROR + class SpatialHash: """Custom data structure that is used for performing grid searches using Spatial Hashing. This class constructs an overlying @@ -253,9 +255,13 @@ def query(self, y, x, point_in_cell): Returns ------- - faces : ndarray of shape (N,2), dtype=np.int32 - For each coordinate pair, returns the (j,i) indices of the closest face in the hash grid. - If no face is found, returns (-1,-1) for that query. + j : ndarray, shape (N,) + j-indices of the located face in the source grid for each query. If no face was found, GRID_SEARCH_ERROR is returned. + i : ndarray, shape (N,) + i-indices of the located face in the source grid for each query. If no face was found, GRID_SEARCH_ERROR is returned. + coords : ndarray, shape (N, 2) + The local coordinates (xsi, eta) of the located face in the source grid for each query. + If no face was found, (-1.0, -1.0) """ keys = self._hash_table["keys"] starts = self._hash_table["starts"] @@ -289,8 +295,8 @@ def query(self, y, x, point_in_cell): # Pre-allocate i and j indices of the best match for each query # Default values to -1 (no match case) - j_best = np.full(num_queries, -1, dtype=np.int32) - i_best = np.full(num_queries, -1, dtype=np.int32) + j_best = np.full(num_queries, GRID_SEARCH_ERROR, dtype=np.int32) + i_best = np.full(num_queries, GRID_SEARCH_ERROR, dtype=np.int32) # How many matches each query has; hit_counts[i] is the number of hits for query i hit_counts = np.where(valid, counts[pos], 0).astype(np.int32) # has shape (num_queries,) diff --git a/tests/v4/test_spatialhash.py b/tests/v4/test_spatialhash.py index ea0d8ed7b..61d1b6b6c 100644 --- a/tests/v4/test_spatialhash.py +++ b/tests/v4/test_spatialhash.py @@ -17,8 +17,8 @@ def test_invalid_positions(): grid = XGrid.from_dataset(ds) j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf], curvilinear_point_in_cell) - assert np.all(j == -1) - assert np.all(i == -1) + assert np.all(j == -3) + assert np.all(i == -3) def test_mixed_positions(): @@ -31,5 +31,5 @@ def test_mixed_positions(): j, i, coords = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) assert j[0] == 29 # Actual value for 2d_left_rotated center assert i[0] == 14 # Actual value for 2d_left_rotated center - assert j[1] == -1 - assert i[1] == -1 + assert j[1] == -3 + assert i[1] == -3 From 8d62e20ec43616ad138dc365b0375ab1fee6c534 Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Wed, 10 Sep 2025 10:52:11 -0400 Subject: [PATCH 14/16] Remove "reconstruct" --- parcels/spatialhash.py | 229 ++++++++++++++++++++--------------------- parcels/xgrid.py | 2 +- 2 files changed, 112 insertions(+), 119 deletions(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 88c35ab94..068f48b10 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -13,8 +13,6 @@ class SpatialHash: ---------- grid : parcels.xgrid.XGrid Source grid used to construct the hash grid and hash table - reconstruct : bool, default=False - If true, reconstructs the spatial hash Note ---- @@ -24,14 +22,12 @@ class SpatialHash: def __init__( self, grid, - reconstruct=False, bitwidth=1023, ): # TODO : Enforce grid to be an instance of parcels.xgrid.XGrid # Currently, this is not done due to circular import with parcels.xgrid self._source_grid = grid - self.reconstruct = reconstruct self._bitwidth = bitwidth # Max integer to use per coordinate in quantization (10 bits = 0..1023) if self._source_grid._mesh == "spherical": @@ -126,121 +122,118 @@ def _initialize_hash_table(self): """Create a mapping that relates unstructured grid faces to hash indices by determining which faces overlap with which hash cells """ - if self._hash_table is None or self.reconstruct: - # j, i = np.indices(self._xlow.shape) # Get the indices of the curvilinear grid - - # Quantize the bounding box in each direction - xqlow, yqlow, zqlow = quantize_coordinates( - self._xlow, - self._ylow, - self._zlow, - self._xmin, - self._xmax, - self._ymin, - self._ymax, - self._zmin, - self._zmax, - self._bitwidth, - ) + # Quantize the bounding box in each direction + xqlow, yqlow, zqlow = quantize_coordinates( + self._xlow, + self._ylow, + self._zlow, + self._xmin, + self._xmax, + self._ymin, + self._ymax, + self._zmin, + self._zmax, + self._bitwidth, + ) - xqhigh, yqhigh, zqhigh = quantize_coordinates( - self._xhigh, - self._yhigh, - self._zhigh, - self._xmin, - self._xmax, - self._ymin, - self._ymax, - self._zmin, - self._zmax, - self._bitwidth, - ) - xqlow = xqlow.ravel() - yqlow = yqlow.ravel() - zqlow = zqlow.ravel() - xqhigh = xqhigh.ravel() - yqhigh = yqhigh.ravel() - zqhigh = zqhigh.ravel() - nx = xqhigh - xqlow + 1 - ny = yqhigh - yqlow + 1 - nz = zqhigh - zqlow + 1 - num_hash_per_face = nx * ny * nz - total_hash_entries = np.sum(num_hash_per_face) - morton_codes = np.zeros(total_hash_entries, dtype=np.uint32) - - # Compute the j, i indices corresponding to each hash entry - nface = np.size(self._xlow) - face_ids = np.repeat(np.arange(nface, dtype=np.int32), num_hash_per_face) - offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int32)[:-1] - - valid = num_hash_per_face != 0 - if not np.any(valid): - # nothing to do - pass - else: - # Grab only valid faces to avoid empty arrays - nx_v = np.asarray(nx[valid], dtype=np.int32) - ny_v = np.asarray(ny[valid], dtype=np.int32) - nz_v = np.asarray(nz[valid], dtype=np.int32) - xlow_v = np.asarray(xqlow[valid], dtype=np.int32) - ylow_v = np.asarray(yqlow[valid], dtype=np.int32) - zlow_v = np.asarray(zqlow[valid], dtype=np.int32) - starts_v = np.asarray(offsets[valid], dtype=np.int32) - - # Count of elements per valid face (should match num_hash_per_face[valid]) - counts = (nx_v * ny_v * nz_v).astype(np.int32) - total = int(counts.sum()) - - # Map each global element to its face and output position - start_for_elem = np.repeat(starts_v, counts) # shape (total,) - - # Intra-face linear index for each element (0..counts_i-1) - # Offsets per face within the concatenation of valid faces: - face_starts_local = np.cumsum(np.r_[0, counts[:-1]]) - intra = np.arange(total, dtype=np.int32) - np.repeat(face_starts_local, counts) - - # Derive (zi, yi, xi) from intra using per-face sizes - ny_nz = np.repeat(ny_v * nz_v, counts) - nz_rep = np.repeat(nz_v, counts) - - xi = intra // ny_nz - rem = intra % ny_nz - yi = rem // nz_rep - zi = rem % nz_rep - - # Add per-face lows - x0 = np.repeat(xlow_v, counts) - y0 = np.repeat(ylow_v, counts) - z0 = np.repeat(zlow_v, counts) - - xq = x0 + xi - yq = y0 + yi - zq = z0 + zi - - # Vectorized morton encode for all elements at once - codes_all = _encode_quantized_morton3d(xq, yq, zq) - - # Scatter into the preallocated output using computed absolute indices - out_idx = start_for_elem + intra - morton_codes[out_idx] = codes_all - - # Sort face indices by morton code - order = np.argsort(morton_codes) - morton_codes_sorted = morton_codes[order] - face_sorted = face_ids[order] - j_sorted, i_sorted = np.unravel_index(face_sorted, self._xlow.shape) - - # Get a list of unique morton codes and their corresponding starts and counts (CSR format) - keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True) - - hash_table = { - "keys": keys, - "starts": starts, - "counts": counts, - "i": i_sorted, - "j": j_sorted, - } - return hash_table + xqhigh, yqhigh, zqhigh = quantize_coordinates( + self._xhigh, + self._yhigh, + self._zhigh, + self._xmin, + self._xmax, + self._ymin, + self._ymax, + self._zmin, + self._zmax, + self._bitwidth, + ) + xqlow = xqlow.ravel() + yqlow = yqlow.ravel() + zqlow = zqlow.ravel() + xqhigh = xqhigh.ravel() + yqhigh = yqhigh.ravel() + zqhigh = zqhigh.ravel() + nx = xqhigh - xqlow + 1 + ny = yqhigh - yqlow + 1 + nz = zqhigh - zqlow + 1 + num_hash_per_face = nx * ny * nz + total_hash_entries = np.sum(num_hash_per_face) + morton_codes = np.zeros(total_hash_entries, dtype=np.uint32) + + # Compute the j, i indices corresponding to each hash entry + nface = np.size(self._xlow) + face_ids = np.repeat(np.arange(nface, dtype=np.int32), num_hash_per_face) + offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int32)[:-1] + + valid = num_hash_per_face != 0 + if not np.any(valid): + # nothing to do + pass + else: + # Grab only valid faces to avoid empty arrays + nx_v = np.asarray(nx[valid], dtype=np.int32) + ny_v = np.asarray(ny[valid], dtype=np.int32) + nz_v = np.asarray(nz[valid], dtype=np.int32) + xlow_v = np.asarray(xqlow[valid], dtype=np.int32) + ylow_v = np.asarray(yqlow[valid], dtype=np.int32) + zlow_v = np.asarray(zqlow[valid], dtype=np.int32) + starts_v = np.asarray(offsets[valid], dtype=np.int32) + + # Count of elements per valid face (should match num_hash_per_face[valid]) + counts = (nx_v * ny_v * nz_v).astype(np.int32) + total = int(counts.sum()) + + # Map each global element to its face and output position + start_for_elem = np.repeat(starts_v, counts) # shape (total,) + + # Intra-face linear index for each element (0..counts_i-1) + # Offsets per face within the concatenation of valid faces: + face_starts_local = np.cumsum(np.r_[0, counts[:-1]]) + intra = np.arange(total, dtype=np.int32) - np.repeat(face_starts_local, counts) + + # Derive (zi, yi, xi) from intra using per-face sizes + ny_nz = np.repeat(ny_v * nz_v, counts) + nz_rep = np.repeat(nz_v, counts) + + xi = intra // ny_nz + rem = intra % ny_nz + yi = rem // nz_rep + zi = rem % nz_rep + + # Add per-face lows + x0 = np.repeat(xlow_v, counts) + y0 = np.repeat(ylow_v, counts) + z0 = np.repeat(zlow_v, counts) + + xq = x0 + xi + yq = y0 + yi + zq = z0 + zi + + # Vectorized morton encode for all elements at once + codes_all = _encode_quantized_morton3d(xq, yq, zq) + + # Scatter into the preallocated output using computed absolute indices + out_idx = start_for_elem + intra + morton_codes[out_idx] = codes_all + + # Sort face indices by morton code + order = np.argsort(morton_codes) + morton_codes_sorted = morton_codes[order] + face_sorted = face_ids[order] + j_sorted, i_sorted = np.unravel_index(face_sorted, self._xlow.shape) + + # Get a list of unique morton codes and their corresponding starts and counts (CSR format) + keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True) + + hash_table = { + "keys": keys, + "starts": starts, + "counts": counts, + "i": i_sorted, + "j": j_sorted, + } + return hash_table def query(self, y, x, point_in_cell): """ diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 6979d6f11..db94294ad 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -369,7 +369,7 @@ def get_spatial_hash( """ if self._spatialhash is None or reconstruct: - self._spatialhash = SpatialHash(self, reconstruct) + self._spatialhash = SpatialHash(self) return self._spatialhash From dc06f6345040c9e28e3bef8ab1cdab98b1679d3f Mon Sep 17 00:00:00 2001 From: Joe Schoonover Date: Wed, 10 Sep 2025 10:56:46 -0400 Subject: [PATCH 15/16] Add _point_in_cell private method to spatialhash This removes the need for passing the PIC method during the query call. At the moment we're assuming that if an XGrid type comes in for the source_grid that it is indeed curvilinear. --- parcels/_index_search.py | 2 +- parcels/spatialhash.py | 13 ++++++++----- tests/v4/test_spatialhash.py | 5 ++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 4b3c7831e..beb77c351 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -96,7 +96,7 @@ def _search_indices_curvilinear_2d( # If there are any points that were not found in the first step, we query the spatial hash for those points if len(zero_indices) > 0: - yi_q, xi_q, coords_q = grid.get_spatial_hash().query(y_check, x_check, curvilinear_point_in_cell) + yi_q, xi_q, coords_q = grid.get_spatial_hash().query(y_check, x_check) # Only those points that were not found in the first step are updated coords[zero_indices, :] = coords_q yi[zero_indices] = yi_q diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 068f48b10..5da2e673d 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -1,6 +1,7 @@ import numpy as np -from parcels._index_search import GRID_SEARCH_ERROR +import parcels +from parcels._index_search import GRID_SEARCH_ERROR, curvilinear_point_in_cell class SpatialHash: @@ -24,8 +25,10 @@ def __init__( grid, bitwidth=1023, ): - # TODO : Enforce grid to be an instance of parcels.xgrid.XGrid - # Currently, this is not done due to circular import with parcels.xgrid + if isinstance(grid, parcels.xgrid.XGrid): + self._point_in_cell = curvilinear_point_in_cell + else: + raise NotImplementedError("SpatialHash only supports parcels.xgrid.XGrid grids at this time.") self._source_grid = grid self._bitwidth = bitwidth # Max integer to use per coordinate in quantization (10 bits = 0..1023) @@ -235,7 +238,7 @@ def _initialize_hash_table(self): } return hash_table - def query(self, y, x, point_in_cell): + def query(self, y, x): """ Queries the hash table and finds the closes face in the source grid for each coordinate pair. @@ -342,7 +345,7 @@ def query(self, y, x, point_in_cell): x_rep = np.repeat(x, hit_counts) # shape (hit_counts.sum(),) # For each query we perform a point in cell check. - is_in_face, coordinates = point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) + is_in_face, coordinates = self._point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) coords_best = np.full((num_queries, coordinates.shape[1]), -1.0, dtype=np.float32) diff --git a/tests/v4/test_spatialhash.py b/tests/v4/test_spatialhash.py index 61d1b6b6c..bdec91173 100644 --- a/tests/v4/test_spatialhash.py +++ b/tests/v4/test_spatialhash.py @@ -1,7 +1,6 @@ import numpy as np from parcels._datasets.structured.generic import datasets -from parcels._index_search import curvilinear_point_in_cell from parcels.xgrid import XGrid @@ -16,7 +15,7 @@ def test_invalid_positions(): ds = datasets["2d_left_rotated"] grid = XGrid.from_dataset(ds) - j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf], curvilinear_point_in_cell) + j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf]) assert np.all(j == -3) assert np.all(i == -3) @@ -28,7 +27,7 @@ def test_mixed_positions(): lon = grid.lon.mean() y = [lat, np.nan] x = [lon, np.nan] - j, i, coords = grid.get_spatial_hash().query(y, x, curvilinear_point_in_cell) + j, i, coords = grid.get_spatial_hash().query(y, x) assert j[0] == 29 # Actual value for 2d_left_rotated center assert i[0] == 14 # Actual value for 2d_left_rotated center assert j[1] == -3 From 50f0e8500df9f6fc3762300e546e752a5c8aea85 Mon Sep 17 00:00:00 2001 From: Joe Schoonover <11430768+fluidnumerics-joe@users.noreply.github.com> Date: Fri, 12 Sep 2025 05:58:53 -0400 Subject: [PATCH 16/16] Remove initial None set for spatial hash Co-authored-by: Nick Hodgskin <36369090+VeckoTheGecko@users.noreply.github.com> --- parcels/spatialhash.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 5da2e673d..8f5e3ffc6 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -118,7 +118,6 @@ def __init__( self._zhigh = np.zeros_like(self._xlow) # Generate the mapping from the hash indices to unstructured grid elements - self._hash_table = None self._hash_table = self._initialize_hash_table() def _initialize_hash_table(self):