Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from scipy import interpolate
from scipy import ndimage
from scipy import spatial
import concurrent.futures


StrideZYX = float | Sequence[float]
Expand Down Expand Up @@ -389,11 +390,37 @@ def inner_box(
)


def _invert_2d_slice(z, map_slice, src_coords, query_points, u_shape, v_shape):
"""
Helper function to invert a single Z-slice of the coordinate map.
"""
# Check if this slice has any valid data
valid = np.all(np.isfinite(map_slice), axis=0)
if not np.any(valid):
return z, None, None

# Extract source (u,v) points from the map
src_points = tuple([c[valid] for c in map_slice])

# Extract target (x,y) values corresponding to those points
# src_coords is [y, x], so we reverse to get [x, y] values
values = [s[valid] for s in src_coords[::-1]]

try:
# Perform interpolation
u, v = _interpolate_points(src_points, query_points, *values)
return z, u.reshape(u_shape), v.reshape(v_shape)
except spatial.qhull.QhullError:
return z, None, None


def invert_map(
coord_map: np.ndarray,
src_box: bounding_box.BoundingBox,
dst_box: bounding_box.BoundingBox,
stride: StrideZYX,
parallelism: int = 1,
verbose: bool = False,
) -> np.ndarray:
"""Inverts a coordinate map.

Expand All @@ -406,6 +433,7 @@ def invert_map(
dst_box: uv coordinate box for which to compute output
stride: distance between nearest neighbors of the coordinate map ([z]yx
sequence or a single float)
parallelism: number of threads to use to invert maps (for 2D only)

Returns:
inverted coordinate map in relative format
Expand Down Expand Up @@ -445,20 +473,30 @@ def _sel_size(box):
dtype=coord_map.dtype,
)

for z in range(coord_map.shape[1]):
valid = np.all(np.isfinite(coord_map[:, z, ...]), axis=0)
if not np.any(valid):
continue

src_points = tuple([c[z][valid] for c in coord_map])
try:
u, v = _interpolate_points(
src_points, query_points, *[s[valid] for s in src_coords[::-1]]
)
ret_uv[0, z, ...] = u.reshape(query_coords[1].shape)
ret_uv[1, z, ...] = v.reshape(query_coords[0].shape)
except spatial.qhull.QhullError:
pass
with concurrent.futures.ProcessPoolExecutor(max_workers=parallelism) as executor:
futures = []
for z in range(coord_map.shape[1]):
# Pass the specific slice to the worker
map_slice = coord_map[:, z, ...]
futures.append(
executor.submit(
_invert_2d_slice,
z,
map_slice,
src_coords,
query_points,
query_coords[1].shape,
query_coords[0].shape
)
)

# Collect results as they finish
for future in concurrent.futures.as_completed(futures):
if verbose: print('z =', z)
z, u_res, v_res = future.result()
if u_res is not None:
ret_uv[0, z, ...] = u_res
ret_uv[1, z, ...] = v_res

return to_relative(ret_uv, stride, dst_box)

Expand Down