diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 6c7f8a26c..beb77c351 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -5,11 +5,7 @@ import numpy as np -from parcels._typing import Mesh -from parcels.tools.statuscodes import ( - _raise_grid_searching_error, - _raise_time_extrapolation_error, -) +from parcels.tools.statuscodes import _raise_time_extrapolation_error if TYPE_CHECKING: from parcels.xgrid import XGrid @@ -17,6 +13,9 @@ from .field import Field +GRID_SEARCH_ERROR = -3 + + def _search_time_index(field: Field, time: datetime): """Find and return the index and relative coordinate in the time array associated with a given time. @@ -40,13 +39,7 @@ def _search_time_index(field: Field, time: datetime): return np.atleast_1d(tau), np.atleast_1d(ti) -def _search_indices_curvilinear_2d( - grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None -): # TODO fix typing instructions to make clear that y, x etc need to be ndarrays - yi, xi = yi_guess, xi_guess - if yi is None or xi is None: - yi, xi = grid.get_spatial_hash().query(y, x) - +def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray): xsi = eta = -1.0 * np.ones(len(x), dtype=float) invA = np.array( [ @@ -56,67 +49,60 @@ def _search_indices_curvilinear_2d( [1, -1, 1, -1], ] ) - maxIterSearch = 1e6 - it = 0 - tol = 1.0e-10 - - # # ! Error handling for out of bounds - # TODO: Re-enable in some capacity - # if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: - # if grid.lon[0, 0] < grid.lon[0, -1]: - # _raise_grid_searching_error(y, x) - # elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] - # _raise_grid_searching_error(z, y, x) - - # if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: - # _raise_grid_searching_error(z, y, x) - - while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol): - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - a = np.dot(invA, px) - b = np.dot(invA, py) - - aa = a[3] * b[2] - a[2] * b[3] - bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] - cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] - - det2 = bb * bb - 4 * aa * cc - with np.errstate(divide="ignore", invalid="ignore"): - det = np.where(det2 > 0, np.sqrt(det2), eta) - - eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta)) - - xsi = np.where( - abs(a[1] + a[3] * eta) < 1e-12, - ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5, - (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta), - ) - - xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi)) - yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi)) - - (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid._mesh) - it += 1 - if it > maxIterSearch: - print(f"Correct cell not found after {maxIterSearch} iterations") - _raise_grid_searching_error(0, y, x) - xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi)) - eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta)) - - if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)): - _raise_grid_searching_error(y, x) - return (yi, eta, xi, xsi) + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) + + a, b = np.dot(invA, px), np.dot(invA, py) + aa = a[3] * b[2] - a[2] * b[3] + bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] + cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] + det2 = bb * bb - 4 * aa * cc -def _reconnect_bnd_indices(yi: int, xi: int, ydim: int, xdim: int, mesh: Mesh): - xi = np.where(xi < 0, (xdim - 2) if mesh == "spherical" else 0, xi) - xi = np.where(xi > xdim - 2, 0 if mesh == "spherical" else (xdim - 2), xi) + with np.errstate(divide="ignore", invalid="ignore"): + det = np.where(det2 > 0, np.sqrt(det2), eta) + eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta)) - xi = np.where(yi > ydim - 2, xdim - xi if mesh == "spherical" else xi, xi) + xsi = np.where( + abs(a[1] + a[3] * eta) < 1e-12, + ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5, + (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta), + ) - yi = np.where(yi < 0, 0, yi) - yi = np.where(yi > ydim - 2, ydim - 2, yi) + is_in_cell = np.where((xsi >= 0) & (xsi <= 1) & (eta >= 0) & (eta <= 1), 1, 0) - return yi, xi + return is_in_cell, np.column_stack((xsi, eta)) + + +def _search_indices_curvilinear_2d( + grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None +): + yi_guess = np.array(yi_guess) + xi_guess = np.array(xi_guess) + xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32) + yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32) + if np.any(xi_guess): + # If an initial guess is provided, we first perform a point in cell check for all guessed indices + is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi_guess, xi_guess) + y_check = y[is_in_cell == 0] + x_check = x[is_in_cell == 0] + zero_indices = np.where(is_in_cell == 0)[0] + else: + # Otherwise, we need to check all points + y_check = y + x_check = x + coords = -1.0 * np.ones((len(y), 2), dtype=np.float32) + zero_indices = np.arange(len(y)) + + # If there are any points that were not found in the first step, we query the spatial hash for those points + if len(zero_indices) > 0: + yi_q, xi_q, coords_q = grid.get_spatial_hash().query(y_check, x_check) + # Only those points that were not found in the first step are updated + coords[zero_indices, :] = coords_q + yi[zero_indices] = yi_q + xi[zero_indices] = xi_q + + xsi = coords[:, 0] + eta = coords[:, 1] + + return (yi, eta, xi, xsi) diff --git a/parcels/field.py b/parcels/field.py index ce437c471..4d612ea15 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -30,7 +30,7 @@ from parcels.uxgrid import UxGrid from parcels.xgrid import LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, XGrid, _transpose_xfield_data_to_tzyx -from ._index_search import _search_time_index +from ._index_search import GRID_SEARCH_ERROR, _search_time_index __all__ = ["Field", "VectorField"] @@ -341,21 +341,25 @@ def __getitem__(self, key): def _update_particle_states_position(particle, position): """Update the particle states based on the position dictionary.""" - if particle and "X" in position: # TODO also support uxgrid search - particle.state = np.maximum( - np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state - ) - particle.state = np.maximum( - np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state - ) - particle.state = np.maximum( - np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state), - particle.state, - ) - particle.state = np.maximum( - np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state), - particle.state, - ) + if particle: # TODO also support uxgrid search + for dim in ["X", "Y"]: + if dim in position: + particle.state = np.maximum( + np.where(position[dim][0] == -1, StatusCode.ErrorOutOfBounds, particle.state), particle.state + ) + particle.state = np.maximum( + np.where(position[dim][0] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particle.state), + particle.state, + ) + if "Z" in position: + particle.state = np.maximum( + np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state), + particle.state, + ) + particle.state = np.maximum( + np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state), + particle.state, + ) def _update_particle_states_interp_value(particle, value): diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 76447bc23..8f5e3ffc6 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -1,5 +1,8 @@ import numpy as np +import parcels +from parcels._index_search import GRID_SEARCH_ERROR, curvilinear_point_in_cell + class SpatialHash: """Custom data structure that is used for performing grid searches using Spatial Hashing. This class constructs an overlying @@ -11,8 +14,6 @@ class SpatialHash: ---------- grid : parcels.xgrid.XGrid Source grid used to construct the hash grid and hash table - reconstruct : bool, default=False - If true, reconstructs the spatial hash Note ---- @@ -22,13 +23,15 @@ class SpatialHash: def __init__( self, grid, - reconstruct=False, + bitwidth=1023, ): - # TODO : Enforce grid to be an instance of parcels.xgrid.XGrid - # Currently, this is not done due to circular import with parcels.xgrid + if isinstance(grid, parcels.xgrid.XGrid): + self._point_in_cell = curvilinear_point_in_cell + else: + raise NotImplementedError("SpatialHash only supports parcels.xgrid.XGrid grids at this time.") self._source_grid = grid - self.reconstruct = reconstruct + self._bitwidth = bitwidth # Max integer to use per coordinate in quantization (10 bits = 0..1023) if self._source_grid._mesh == "spherical": # Boundaries of the hash grid are the unit cube @@ -69,9 +72,12 @@ def __init__( axis=-1, ) # Compute centroid locations of each cells - self._xc = np.mean(_xbound, axis=-1) - self._yc = np.mean(_ybound, axis=-1) - self._zc = np.mean(_zbound, axis=-1) + self._xlow = np.min(_xbound, axis=-1) + self._xhigh = np.max(_xbound, axis=-1) + self._ylow = np.min(_ybound, axis=-1) + self._yhigh = np.max(_ybound, axis=-1) + self._zlow = np.min(_zbound, axis=-1) + self._zhigh = np.max(_zbound, axis=-1) else: # Boundaries of the hash grid are the bounding box of the source grid @@ -104,47 +110,134 @@ def __init__( axis=-1, ) # Compute centroid locations of each cells - self._xc = np.mean(_xbound, axis=-1) - self._yc = np.mean(_ybound, axis=-1) - self._zc = np.zeros_like(self._xc) + self._xlow = np.min(_xbound, axis=-1) + self._xhigh = np.max(_xbound, axis=-1) + self._ylow = np.min(_ybound, axis=-1) + self._yhigh = np.max(_ybound, axis=-1) + self._zlow = np.zeros_like(self._xlow) + self._zhigh = np.zeros_like(self._xlow) # Generate the mapping from the hash indices to unstructured grid elements - self._hash_table = None self._hash_table = self._initialize_hash_table() def _initialize_hash_table(self): """Create a mapping that relates unstructured grid faces to hash indices by determining which faces overlap with which hash cells """ - if self._hash_table is None or self.reconstruct: - j, i = np.indices(self._xc.shape) # Get the indices of the curvilinear grid - - morton_codes = _encode_morton3d( - self._xc, self._yc, self._zc, self._xmin, self._xmax, self._ymin, self._ymax, self._zmin, self._zmax - ) - ## Prepare quick lookup (hash) table for relating i,j indices to morton codes - # Sort i,j indices by morton code - order = np.argsort(morton_codes.ravel()) - morton_codes_sorted = morton_codes.ravel()[order] - i_sorted = i.ravel()[order] - j_sorted = j.ravel()[order] - - # Get a list of unique morton codes and their corresponding starts and counts (CSR format) - keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True) - hash_table = { - "keys": keys, - "starts": starts, - "counts": counts, - "i": i_sorted, - "j": j_sorted, - } - return hash_table - - def query( - self, - y, - x, - ): + # Quantize the bounding box in each direction + xqlow, yqlow, zqlow = quantize_coordinates( + self._xlow, + self._ylow, + self._zlow, + self._xmin, + self._xmax, + self._ymin, + self._ymax, + self._zmin, + self._zmax, + self._bitwidth, + ) + + xqhigh, yqhigh, zqhigh = quantize_coordinates( + self._xhigh, + self._yhigh, + self._zhigh, + self._xmin, + self._xmax, + self._ymin, + self._ymax, + self._zmin, + self._zmax, + self._bitwidth, + ) + xqlow = xqlow.ravel() + yqlow = yqlow.ravel() + zqlow = zqlow.ravel() + xqhigh = xqhigh.ravel() + yqhigh = yqhigh.ravel() + zqhigh = zqhigh.ravel() + nx = xqhigh - xqlow + 1 + ny = yqhigh - yqlow + 1 + nz = zqhigh - zqlow + 1 + num_hash_per_face = nx * ny * nz + total_hash_entries = np.sum(num_hash_per_face) + morton_codes = np.zeros(total_hash_entries, dtype=np.uint32) + + # Compute the j, i indices corresponding to each hash entry + nface = np.size(self._xlow) + face_ids = np.repeat(np.arange(nface, dtype=np.int32), num_hash_per_face) + offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int32)[:-1] + + valid = num_hash_per_face != 0 + if not np.any(valid): + # nothing to do + pass + else: + # Grab only valid faces to avoid empty arrays + nx_v = np.asarray(nx[valid], dtype=np.int32) + ny_v = np.asarray(ny[valid], dtype=np.int32) + nz_v = np.asarray(nz[valid], dtype=np.int32) + xlow_v = np.asarray(xqlow[valid], dtype=np.int32) + ylow_v = np.asarray(yqlow[valid], dtype=np.int32) + zlow_v = np.asarray(zqlow[valid], dtype=np.int32) + starts_v = np.asarray(offsets[valid], dtype=np.int32) + + # Count of elements per valid face (should match num_hash_per_face[valid]) + counts = (nx_v * ny_v * nz_v).astype(np.int32) + total = int(counts.sum()) + + # Map each global element to its face and output position + start_for_elem = np.repeat(starts_v, counts) # shape (total,) + + # Intra-face linear index for each element (0..counts_i-1) + # Offsets per face within the concatenation of valid faces: + face_starts_local = np.cumsum(np.r_[0, counts[:-1]]) + intra = np.arange(total, dtype=np.int32) - np.repeat(face_starts_local, counts) + + # Derive (zi, yi, xi) from intra using per-face sizes + ny_nz = np.repeat(ny_v * nz_v, counts) + nz_rep = np.repeat(nz_v, counts) + + xi = intra // ny_nz + rem = intra % ny_nz + yi = rem // nz_rep + zi = rem % nz_rep + + # Add per-face lows + x0 = np.repeat(xlow_v, counts) + y0 = np.repeat(ylow_v, counts) + z0 = np.repeat(zlow_v, counts) + + xq = x0 + xi + yq = y0 + yi + zq = z0 + zi + + # Vectorized morton encode for all elements at once + codes_all = _encode_quantized_morton3d(xq, yq, zq) + + # Scatter into the preallocated output using computed absolute indices + out_idx = start_for_elem + intra + morton_codes[out_idx] = codes_all + + # Sort face indices by morton code + order = np.argsort(morton_codes) + morton_codes_sorted = morton_codes[order] + face_sorted = face_ids[order] + j_sorted, i_sorted = np.unravel_index(face_sorted, self._xlow.shape) + + # Get a list of unique morton codes and their corresponding starts and counts (CSR format) + keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True) + + hash_table = { + "keys": keys, + "starts": starts, + "counts": counts, + "i": i_sorted, + "j": j_sorted, + } + return hash_table + + def query(self, y, x): """ Queries the hash table and finds the closes face in the source grid for each coordinate pair. @@ -157,9 +250,13 @@ def query( Returns ------- - faces : ndarray of shape (N,2), dtype=np.int32 - For each coordinate pair, returns the (j,i) indices of the closest face in the hash grid. - If no face is found, returns (-1,-1) for that query. + j : ndarray, shape (N,) + j-indices of the located face in the source grid for each query. If no face was found, GRID_SEARCH_ERROR is returned. + i : ndarray, shape (N,) + i-indices of the located face in the source grid for each query. If no face was found, GRID_SEARCH_ERROR is returned. + coords : ndarray, shape (N, 2) + The local coordinates (xsi, eta) of the located face in the source grid for each query. + If no face was found, (-1.0, -1.0) """ keys = self._hash_table["keys"] starts = self._hash_table["starts"] @@ -167,10 +264,6 @@ def query( i = self._hash_table["i"] j = self._hash_table["j"] - xc = self._xc - yc = self._yc - zc = self._zc - y = np.asarray(y) x = np.asarray(x) if self._source_grid._mesh == "spherical": @@ -190,127 +283,94 @@ def query( num_queries = query_codes.size # Locate each query in the unique key array - pos = np.searchsorted(keys, query_codes) # pos is shape (N,) + pos = np.searchsorted(keys, query_codes) # pos is shape (num_queries,) - # Valid hits: inside range with finite query coordinates - valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y) + # Valid hits: inside range with finite query coordinates and query codes give exact morton code match. + valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y) & (query_codes == keys[pos]) # Pre-allocate i and j indices of the best match for each query # Default values to -1 (no match case) - j_best = np.full(num_queries, -1, dtype=np.int64) - i_best = np.full(num_queries, -1, dtype=np.int64) + j_best = np.full(num_queries, GRID_SEARCH_ERROR, dtype=np.int32) + i_best = np.full(num_queries, GRID_SEARCH_ERROR, dtype=np.int32) # How many matches each query has; hit_counts[i] is the number of hits for query i - hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (N,) + hit_counts = np.where(valid, counts[pos], 0).astype(np.int32) # has shape (num_queries,) if hit_counts.sum() == 0: - return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) - - # CSR-style offsets (prefix sum), total number of hits - offsets = np.empty(hit_counts.size + 1, dtype=np.int64) - offsets[0] = 0 - np.cumsum(hit_counts, out=offsets[1:]) - total = int(offsets[-1]) - - # Now, we need to create some quick lookup arrays that give us the list of positions in the hash table - # that correspond to each query. - # Create a quick lookup array that maps each element of all the valid queries (with repeats) to its query index - q_index_for_elem = np.repeat(np.arange(num_queries, dtype=np.int64), hit_counts) # This has shape (total,) - - # For each element, compute its "intra-group" offset (0..hits_i-1). - intra = np.arange(total, dtype=np.int64) - np.repeat(offsets[:-1], hit_counts) - - # starts[pos[q_index_for_elem]] + intra gives a list of positions in the hash table that we can - # use to quickly gather the (i,j) pairs for each query - source_idx = starts[pos[q_index_for_elem]].astype(np.int64) + intra - - # Gather all (j,i) pairs in one shot - j_all = j[source_idx] - i_all = i[source_idx] + return ( + j_best.reshape(query_codes.shape), + i_best.reshape(query_codes.shape), + np.full((num_queries, 2), -1.0, dtype=np.float32), + ) - # Segment-wise minima per query using reduceat - # For each query, we need to find the minimum distance. - if total == 1: - # Build absolute source index for the winning candidate in each query - start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries - src_best = (start_for_q).astype(np.int64) - else: - # Gather centroid coordinates at those (j,i) - xc_all = xc[j_all, i_all] - yc_all = yc[j_all, i_all] - zc_all = zc[j_all, i_all] + # Now, for each query, we need to gather the candidate (j,i) indices from the hash table + # Each j,i pair needs to be repeated hit_counts[i] times, only when there are hits. - # Broadcast to flat (same as q_flat), then repeat per candidate - # This makes it easy to compute distances from the query points - # to the candidate points for minimization. - qx_all = np.repeat(qx.ravel(), hit_counts) - qy_all = np.repeat(qy.ravel(), hit_counts) - qz_all = np.repeat(qz.ravel(), hit_counts) + # Boolean array for keeping track of which queries have candidates + has_hits = hit_counts > 0 # shape (num_queries,), True for queries that had candidates - # Squared distances for all candidates - dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2 + # A quick lookup array that maps all candindates back to its query index + q_index_for_candidate = np.repeat( + np.arange(num_queries, dtype=np.int32), hit_counts + ) # shape (hit_counts.sum(),) + # Map all candidates to positions in the hash table + hash_positions = pos[q_index_for_candidate] # shape (hit_counts.sum(),) - dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1]) - # To get argmin indices per query (without loops): - # Build a masked "within-index" array that is large unless it equals the segment-min. - big = np.iinfo(np.int64).max - within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big) - argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties + # Now that we have the positions in the hash table for each table, we can gather the (j,i) pairs for each candidate + # We do this in a vectorized way by using a CSR-like approach + # starts[pos[q_index_for_candidate]] gives the starting point in the hash table for each candidate + # hit_counts gives the number of candidates for each query - # Build absolute source index for the winning candidate in each query - start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries - src_best = (start_for_q + argmin_within).astype(np.int64) + # We need to build an array that gives the offset within each query's candidates + offsets = np.concatenate(([0], np.cumsum(hit_counts))).astype(np.int32) # shape (num_queries+1,) + total = int(offsets[-1]) # total number of candidates across all queries - # Write outputs only for queries that had candidates - has_hits = hit_counts > 0 - j_best[has_hits] = j[src_best[has_hits]] - i_best[has_hits] = i[src_best[has_hits]] + # Now, for each candidate, we need a simple array that tells us its "local candidate id" within its query + # This way, we can easily take the starts[pos[q_index_for_candidate]] and add this local id to get the absolute index + # We calculate this by computing the "global candidate number" (0..total-1) and subtracting the offsets of the corresponding query + # This gives us an array that goes from 0..hit_counts[i]-1 for each query i + intra = np.arange(total, dtype=np.int32) - np.repeat(offsets[:-1], hit_counts) # shape (hit_counts.sum(),) - return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) + # starts[pos[q_index_for_candidate]] + intra gives a list of positions in the hash table that we can + # use to quickly gather the (i,j) pairs for each query + source_idx = starts[hash_positions].astype(np.int32) + intra + # Gather all candidate (j,i) pairs in one shot + j_all = j[source_idx] + i_all = i[source_idx] -def _triangle_area(A, B, C): - """Compute the area of a triangle given by three points.""" - d1 = B - A - d2 = C - A - d3 = np.cross(d1, d2) - return 0.5 * np.linalg.norm(d3) + # Now we need to construct arrays that repeats the y and x coordinates for each candidate + # to enable vectorized point-in-cell checks + y_rep = np.repeat(y, hit_counts) # shape (hit_counts.sum(),) + x_rep = np.repeat(x, hit_counts) # shape (hit_counts.sum(),) + # For each query we perform a point in cell check. + is_in_face, coordinates = self._point_in_cell(self._source_grid, y_rep, x_rep, j_all, i_all) -def _barycentric_coordinates(nodes, point, min_area=1e-8): - """ - Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. - So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized - barycentric coordinates, which is only valid for convex polygons. + coords_best = np.full((num_queries, coordinates.shape[1]), -1.0, dtype=np.float32) - Parameters - ---------- - nodes : numpy.ndarray - Spherical coordinates (lat,lon) of each corner node of a face - point : numpy.ndarray - Spherical coordinates (lat,lon) of the point + # For each query that has hits, we need to find the first candidate that was inside the face + f_indices = np.flatnonzero(is_in_face) # Indices of all faces that contained the point + # For each true position, find which query it belongs to by searching offsets + # Query index q satisfies offsets[q] <= pos < offsets[q+1]. + q = np.searchsorted(offsets[1:], f_indices, side="right") - Returns - ------- - numpy.ndarray - Barycentric coordinates corresponding to each vertex. + uniq_q, q_idx = np.unique(q, return_index=True) + keep = has_hits[uniq_q] - """ - n = len(nodes) - sum_wi = 0 - w = [] + if keep.any(): + uniq_q = uniq_q[keep] + pos_first = f_indices[q_idx[keep]] - for i in range(0, n): - vim1 = nodes[i - 1] - vi = nodes[i] - vi1 = nodes[(i + 1) % n] - a0 = _triangle_area(vim1, vi, vi1) - a1 = max(_triangle_area(point, vim1, vi), min_area) - a2 = max(_triangle_area(point, vi, vi1), min_area) - sum_wi += a0 / (a1 * a2) - w.append(a0 / (a1 * a2)) - barycentric_coords = [w_i / sum_wi for w_i in w] + # Directly scatter: the code wants the first True inside each slice + j_best[uniq_q] = j_all[pos_first] + i_best[uniq_q] = i_all[pos_first] + coords_best[uniq_q] = coordinates[pos_first] - return barycentric_coords + return ( + j_best.reshape(query_codes.shape), + i_best.reshape(query_codes.shape), + coords_best.reshape((num_queries, coordinates.shape[1])), + ) def _latlon_rad_to_xyz(lat, lon): @@ -370,15 +430,14 @@ def _dilate_bits(n): return n -def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax): +def quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1023): """ - Quantize (x, y, z) to 10 bits each (0..1023), dilate the bits so there are - two zeros between successive bits, and interleave them into a 3D Morton code. + Normalize (x, y, z) to [0, 1] over their bounding box, then quantize to 10 bits each (0..1023). Parameters ---------- x, y, z : array_like - Input coordinates to encode. Can be scalars or arrays (broadcasting applies). + Input coordinates to quantize. Can be scalars or arrays (broadcasting applies). xmin, xmax : float Minimum and maximum bounds for x coordinate. ymin, ymax : float @@ -388,13 +447,8 @@ def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax): Returns ------- - code : ndarray, dtype=uint32 - The resulting Morton codes, same shape as the broadcasted input coordinates. - - Notes - ----- - - Works with scalars or NumPy arrays (broadcasting applies). - - Output is up to 30 bits returned as uint32. + xq, yq, zq : ndarray, dtype=uint32 + The quantized coordinates, each in range [0, 1023], same shape as the broadcasted input coordinates. """ # Convert inputs to ndarray for consistent dtype/ufunc behavior. x = np.asarray(x) @@ -413,17 +467,74 @@ def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax): yn = np.where(dy != 0, (y - ymin) / dy, 0.0) zn = np.where(dz != 0, (z - zmin) / dz, 0.0) - # --- 2) Quantize to 10 bits (0..1023). --- - # Multiply by 1023, round down, and clip to be safe against slight overshoot. - xq = np.clip((xn * 1023.0).astype(np.uint32), 0, 1023) - yq = np.clip((yn * 1023.0).astype(np.uint32), 0, 1023) - zq = np.clip((zn * 1023.0).astype(np.uint32), 0, 1023) + # --- 2) Quantize to (0..bitwidth). --- + # Multiply by bitwidth, round down, and clip to be safe against slight overshoot. + xq = np.clip((xn * bitwidth).astype(np.uint32), 0, bitwidth) + yq = np.clip((yn * bitwidth).astype(np.uint32), 0, bitwidth) + zq = np.clip((zn * bitwidth).astype(np.uint32), 0, bitwidth) + + return xq, yq, zq + + +def _encode_quantized_morton3d(xq, yq, zq): + xq = np.asarray(xq) + yq = np.asarray(yq) + zq = np.asarray(zq) + + # --- 3) Bit-dilate each 10-bit number so each bit is separated by two zeros. --- + # _dilate_bits maps: b9..b0 -> b9 0 0 b8 0 0 ... b0 0 0 + dx3 = _dilate_bits(xq).astype(np.uint32) + dy3 = _dilate_bits(yq).astype(np.uint32) + dz3 = _dilate_bits(zq).astype(np.uint32) + + # --- 4) Interleave the dilated bits into a single Morton code. --- + # Bit layout (from LSB upward): x0,y0,z0, x1,y1,z1, ..., x9,y9,z9 + # We shift z's bits by 2, y's by 1, x stays at 0, then OR them together. + # Cast to a wide type before shifting/OR to be safe when arrays are used. + code = (dz3 << 2) | (dy3 << 1) | dx3 + + # Since our compact type fits in 30 bits, uint32 is enough. + return code.astype(np.uint32) + + +def _encode_morton3d(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1023): + """ + Quantize (x, y, z) to 10 bits each (0..1023), dilate the bits so there are + two zeros between successive bits, and interleave them into a 3D Morton code. + + Parameters + ---------- + x, y, z : array_like + Input coordinates to encode. Can be scalars or arrays (broadcasting applies). + xmin, xmax : float + Minimum and maximum bounds for x coordinate. + ymin, ymax : float + Minimum and maximum bounds for y coordinate. + zmin, zmax : float + Minimum and maximum bounds for z coordinate. + + Returns + ------- + code : ndarray, dtype=uint32 + The resulting Morton codes, same shape as the broadcasted input coordinates. + + Notes + ----- + - Works with scalars or NumPy arrays (broadcasting applies). + - Output is up to 30 bits returned as uint32. + """ + # Convert inputs to ndarray for consistent dtype/ufunc behavior. + x = np.asarray(x) + y = np.asarray(y) + z = np.asarray(z) + + xq, yq, zq = quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth) # --- 3) Bit-dilate each 10-bit number so each bit is separated by two zeros. --- # _dilate_bits maps: b9..b0 -> b9 0 0 b8 0 0 ... b0 0 0 - dx3 = _dilate_bits(xq).astype(np.uint64) - dy3 = _dilate_bits(yq).astype(np.uint64) - dz3 = _dilate_bits(zq).astype(np.uint64) + dx3 = _dilate_bits(xq).astype(np.uint32) + dy3 = _dilate_bits(yq).astype(np.uint32) + dz3 = _dilate_bits(zq).astype(np.uint32) # --- 4) Interleave the dilated bits into a single Morton code. --- # Bit layout (from LSB upward): x0,y0,z0, x1,y1,z1, ..., x9,y9,z9 diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index bc8014f76..a82037ef3 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -6,7 +6,6 @@ import uxarray as ux from parcels._typing import assert_valid_mesh -from parcels.spatialhash import _barycentric_coordinates from parcels.tools.statuscodes import FieldOutOfBoundError from parcels.xgrid import _search_1d_array @@ -164,3 +163,48 @@ def _lonlat_rad_to_xyz( z = np.sin(lat) return x, y, z + + +def _triangle_area(A, B, C): + """Compute the area of a triangle given by three points.""" + d1 = B - A + d2 = C - A + d3 = np.cross(d1, d2) + return 0.5 * np.linalg.norm(d3) + + +def _barycentric_coordinates(nodes, point, min_area=1e-8): + """ + Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. + So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized + barycentric coordinates, which is only valid for convex polygons. + + Parameters + ---------- + nodes : numpy.ndarray + Spherical coordinates (lat,lon) of each corner node of a face + point : numpy.ndarray + Spherical coordinates (lat,lon) of the point + + Returns + ------- + numpy.ndarray + Barycentric coordinates corresponding to each vertex. + + """ + n = len(nodes) + sum_wi = 0 + w = [] + + for i in range(0, n): + vim1 = nodes[i - 1] + vi = nodes[i] + vi1 = nodes[(i + 1) % n] + a0 = _triangle_area(vim1, vi, vi1) + a1 = max(_triangle_area(point, vim1, vi), min_area) + a2 = max(_triangle_area(point, vi, vi1), min_area) + sum_wi += a0 / (a1 * a2) + w.append(a0 / (a1 * a2)) + barycentric_coords = [w_i / sum_wi for w_i in w] + + return barycentric_coords diff --git a/parcels/xgrid.py b/parcels/xgrid.py index ac6dd89e2..0ab5feec1 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -375,7 +375,7 @@ def get_spatial_hash( """ if self._spatialhash is None or reconstruct: - self._spatialhash = SpatialHash(self, reconstruct) + self._spatialhash = SpatialHash(self) return self._spatialhash diff --git a/tests/v4/test_index_search.py b/tests/v4/test_index_search.py index 5606913fd..994fe22dc 100644 --- a/tests/v4/test_index_search.py +++ b/tests/v4/test_index_search.py @@ -63,7 +63,7 @@ def test_indexing_nemo_curvilinear(): {"glamf": "lon", "gphif": "lat", "z": "depth"} ) xgcm_grid = Grid(ds, coords={"X": {"left": "x"}, "Y": {"left": "y"}}, periodic=False) - grid = XGrid(xgcm_grid) + grid = XGrid(xgcm_grid, mesh="spherical") # Test points on the NEMO 1/4 degree curvilinear grid lats = np.array([-30, 0, 88]) diff --git a/tests/v4/test_spatialhash.py b/tests/v4/test_spatialhash.py index c9026e7af..bdec91173 100644 --- a/tests/v4/test_spatialhash.py +++ b/tests/v4/test_spatialhash.py @@ -15,9 +15,9 @@ def test_invalid_positions(): ds = datasets["2d_left_rotated"] grid = XGrid.from_dataset(ds) - j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf]) - assert np.all(j == -1) - assert np.all(i == -1) + j, i, coords = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf]) + assert np.all(j == -3) + assert np.all(i == -3) def test_mixed_positions(): @@ -27,8 +27,8 @@ def test_mixed_positions(): lon = grid.lon.mean() y = [lat, np.nan] x = [lon, np.nan] - j, i = grid.get_spatial_hash().query(y, x) + j, i, coords = grid.get_spatial_hash().query(y, x) assert j[0] == 29 # Actual value for 2d_left_rotated center assert i[0] == 14 # Actual value for 2d_left_rotated center - assert j[1] == -1 - assert i[1] == -1 + assert j[1] == -3 + assert i[1] == -3