diff --git a/map_utils.py b/map_utils.py index bfb081d..a3b6535 100644 --- a/map_utils.py +++ b/map_utils.py @@ -61,6 +61,7 @@ from scipy import interpolate from scipy import ndimage from scipy import spatial +import concurrent.futures StrideZYX = float | Sequence[float] @@ -389,11 +390,37 @@ def inner_box( ) +def _invert_2d_slice(z, map_slice, src_coords, query_points, u_shape, v_shape): + """ + Helper function to invert a single Z-slice of the coordinate map. + """ + # Check if this slice has any valid data + valid = np.all(np.isfinite(map_slice), axis=0) + if not np.any(valid): + return z, None, None + + # Extract source (u,v) points from the map + src_points = tuple([c[valid] for c in map_slice]) + + # Extract target (x,y) values corresponding to those points + # src_coords is [y, x], so we reverse to get [x, y] values + values = [s[valid] for s in src_coords[::-1]] + + try: + # Perform interpolation + u, v = _interpolate_points(src_points, query_points, *values) + return z, u.reshape(u_shape), v.reshape(v_shape) + except spatial.qhull.QhullError: + return z, None, None + + def invert_map( coord_map: np.ndarray, src_box: bounding_box.BoundingBox, dst_box: bounding_box.BoundingBox, stride: StrideZYX, + parallelism: int = 1, + verbose: bool = False, ) -> np.ndarray: """Inverts a coordinate map. @@ -406,6 +433,7 @@ def invert_map( dst_box: uv coordinate box for which to compute output stride: distance between nearest neighbors of the coordinate map ([z]yx sequence or a single float) + parallelism: number of threads to use to invert maps (for 2D only) Returns: inverted coordinate map in relative format @@ -445,20 +473,30 @@ def _sel_size(box): dtype=coord_map.dtype, ) - for z in range(coord_map.shape[1]): - valid = np.all(np.isfinite(coord_map[:, z, ...]), axis=0) - if not np.any(valid): - continue - - src_points = tuple([c[z][valid] for c in coord_map]) - try: - u, v = _interpolate_points( - src_points, query_points, *[s[valid] for s in src_coords[::-1]] - ) - ret_uv[0, z, ...] = u.reshape(query_coords[1].shape) - ret_uv[1, z, ...] = v.reshape(query_coords[0].shape) - except spatial.qhull.QhullError: - pass + with concurrent.futures.ProcessPoolExecutor(max_workers=parallelism) as executor: + futures = [] + for z in range(coord_map.shape[1]): + # Pass the specific slice to the worker + map_slice = coord_map[:, z, ...] + futures.append( + executor.submit( + _invert_2d_slice, + z, + map_slice, + src_coords, + query_points, + query_coords[1].shape, + query_coords[0].shape + ) + ) + + # Collect results as they finish + for future in concurrent.futures.as_completed(futures): + if verbose: print('z =', z) + z, u_res, v_res = future.result() + if u_res is not None: + ret_uv[0, z, ...] = u_res + ret_uv[1, z, ...] = v_res return to_relative(ret_uv, stride, dst_box)