From ec13c37d278b50b83534fd5cb2494f1af01998d4 Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Mon, 26 Jan 2026 12:59:00 -0800 Subject: [PATCH] Preload image and mask in EstimateMissingFlow. This avoids repeated volume reads, which can be slow even when the underlying data is cached in memory (due to cache trashing or need to reassemble the image array out of the underlying chunks). PiperOrigin-RevId: 861319469 --- processor/flow.py | 113 +++++++++++++++++--------- processor/flow_test.py | 174 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 38 deletions(-) create mode 100644 processor/flow_test.py diff --git a/processor/flow.py b/processor/flow.py index b064d4c..2f5d369 100644 --- a/processor/flow.py +++ b/processor/flow.py @@ -15,6 +15,7 @@ """Flow field estimation from SOFIMA.""" import dataclasses +import gc import time from typing import Any, Sequence @@ -594,6 +595,7 @@ def __init__( ) self._config = config + logging.info('EstimateMissingFlow running with config: %r', config) def _build_mask( self, @@ -661,6 +663,20 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: out_box = out_box.adjusted_by(end=-offset) input_ndarray = input_ndarray[:, :, : out_box.size[1], : out_box.size[0]] + # The input flow forms the initial state of the output. We will try + # to fill-in any invalid (NaN) pixels by computing flow against + # earlier sections. + ret = np.zeros([3] + list(out_box.size[::-1])) + ret[:2, ...] = input_ndarray + ret[2, ...] = self._config.delta_z + + sel_mask = None + if self._config.selection_mask_configs: + sel_mask = self._build_mask(self._config.selection_mask_configs, out_box) + + mfc = flow_field.JAXMaskedXCorrWithStatsCalculator() + invalid = np.isnan(input_ndarray[0, ...]) + patch_size = self._config.patch_size curr_image_box = bounding_box.BoundingBox( start=( @@ -671,25 +687,55 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: size=( (out_box.size[0] - 1) * stride + patch_size, (out_box.size[1] - 1) * stride + patch_size, - 1, + invalid.shape[0], ), ) curr_image_box = image_volume.clip_box_to_volume(curr_image_box) assert curr_image_box is not None - # The input flow forms the initial state of the output. We will try - # to fill-in any invalid (NaN) pixels by computing flow against - # earlier sections. - ret = np.zeros([3] + list(out_box.size[::-1])) - ret[:2, ...] = input_ndarray - ret[2, ...] = self._config.delta_z + if self._config.delta_z > 0: + search_deltas = range( + self._config.delta_z + 1, self._config.max_delta_z + 1 + ) + load_start_z = out_box.start[2] - self._config.max_delta_z + load_end_z = out_box.end[2] + else: + search_deltas = range( + self._config.delta_z - 1, self._config.max_delta_z - 1, -1 + ) + load_start_z = out_box.start[2] + # max_delta_z is negative. + load_end_z = out_box.end[2] - self._config.max_delta_z - sel_mask = None - if self._config.selection_mask_configs: - sel_mask = self._build_mask(self._config.selection_mask_configs, out_box) + load_box = bounding_box.BoundingBox( + start=( + prev_image_box.start[0], + prev_image_box.start[1], + load_start_z, + ), + size=( + prev_image_box.size[0], + prev_image_box.size[1], + load_end_z - load_start_z, + ), + ) + load_box = image_volume.clip_box_to_volume(load_box) + + logging.info('Loading image data: %r', load_box) + full_image_stack = image_volume.asarray[load_box.to_slice4d()][0, ...] + full_mask = None + if self._config.mask_configs: + full_mask = self._build_mask(self._config.mask_configs, load_box) + logging.info('Loaaded mask: %r', full_mask.shape) + + # The 'curr' image is a subset of the loaded stack, centered within the + # 'prev' image (which includes the search radius). + curr_rel_start = curr_image_box.start - load_box.start + curr_slice = ( + slice(curr_rel_start[1], curr_rel_start[1] + curr_image_box.size[1]), + slice(curr_rel_start[0], curr_rel_start[0] + curr_image_box.size[0]), + ) - mfc = flow_field.JAXMaskedXCorrWithStatsCalculator() - invalid = np.isnan(input_ndarray[0, ...]) for z in range(0, invalid.shape[0]): z0 = box.start[2] + z logging.info('Processing rel_z=%d abs_z=%d', z, z0) @@ -698,12 +744,13 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: beam_utils.counter(namespace, 'sections-already-valid').inc() continue - image_box = curr_image_box.translate([0, 0, z]) + curr_z_idx = (out_box.start[2] + z) - load_box.start[2] + assert curr_z_idx >= 0 + assert curr_z_idx < full_image_stack.shape[0] + curr_mask = None if self._config.mask_configs: - curr_mask = self._build_mask( - self._config.mask_configs, image_box - ).squeeze() + curr_mask = full_mask[curr_z_idx, ...][curr_slice] if np.all(curr_mask): beam_utils.counter(namespace, 'sections-masked').inc() continue @@ -715,37 +762,23 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: if sel_mask is not None: mask &= sel_mask[z, ...] - curr = image_volume.asarray[image_box.to_slice4d()].squeeze() - - delta_z = self._config.delta_z - if delta_z > 0: - rng = range(delta_z + 1, self._config.max_delta_z + 1) - else: - rng = range(delta_z - 1, self._config.max_delta_z - 1, -1) + curr = full_image_stack[curr_z_idx, ...][curr_slice] - for delta_z in rng: - if ( - box.start[2] - delta_z < 0 - or box.end[2] - delta_z >= image_volume.volume_size[2] - ): + for delta_z in search_deltas: + prev_z_idx = curr_z_idx - delta_z + if prev_z_idx < 0 or prev_z_idx >= full_image_stack.shape[0]: break t_start = time.time() - prev_box = prev_image_box.translate([0, 0, z - delta_z]) - logging.info('Trying delta_z=%d (%r)', delta_z, prev_box) - prev = image_volume.asarray[prev_box.to_slice4d()].squeeze() - logging.info('.. image loaded.') + logging.info('Trying delta_z=%d', delta_z) + prev_mask = None + prev = full_image_stack[prev_z_idx, ...] t1 = time.time() if self._config.mask_configs: - prev_mask = self._build_mask( - self._config.mask_configs, prev_box - ).squeeze() + prev_mask = full_mask[prev_z_idx, ...] if np.all(prev_mask): continue - else: - prev_mask = None - logging.info('.. mask loaded.') # Limit the number of estimation attempts per voxel. Attempts # are only counted when voxels in both sections are unmasked. @@ -804,4 +837,8 @@ def process(self, subvol: Subvolume) -> SubvolumeOrMany: t5 - t4, ) + del full_image_stack + del full_mask + gc.collect() + return Subvolume(ret, out_box) diff --git a/processor/flow_test.py b/processor/flow_test.py new file mode 100644 index 0000000..7a6c499 --- /dev/null +++ b/processor/flow_test.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright 2026 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from connectomics.common import bounding_box +from connectomics.volume import subvolume +import numpy as np +from sofima.processor import flow + + +class MockVolume: + + def __init__(self, data): + self._data = data # CZYX + + def clip_box_to_volume(self, box): + vol_box = bounding_box.BoundingBox(start=(0, 0, 0), size=self.volume_size) + return box.intersection(vol_box) + + @property + def asarray(self): + return self._data + + @property + def volume_size(self): + # XYZ + return (self._data.shape[3], self._data.shape[2], self._data.shape[1]) + + def __getitem__(self, key): + return self._data[key] + + +class TestEstimateMissingFlow(flow.EstimateMissingFlow): + + def __init__(self, config, image_vol): + super().__init__(config) + self.image_vol = image_vol + + def _open_volume(self, path): + return self.image_vol + + +class EstimateMissingFlowTest(absltest.TestCase): + + def test_process(self): + config = flow.EstimateMissingFlow.Config( + patch_size=16, + stride=16, + delta_z=1, + max_delta_z=2, + max_attempts=1, + mask_configs=None, + mask_only_for_patch_selection=False, + selection_mask_configs=None, + min_peak_sharpness=0.0, + min_peak_ratio=0.0, + max_magnitude=0, + batch_size=10, # Must be > 0 for batch processing + image_volinfo="dummy_path", + image_cache_bytes=0, + mask_cache_bytes=0, + search_radius=16, + ) + + # Larger volume to avoid boundary clipping with required context size + vol_shape = (1, 10, 128, 128) + vol_data = np.random.rand(*vol_shape).astype(np.float32) + + # Create a synthetic shift between z=3 and z=5. + dx, dy = 2, 3 + prev_slice = vol_data[0, 3, :, :] + shifted_slice = np.zeros_like(prev_slice) + shifted_slice[dy:, dx:] = prev_slice[:-dy, :-dx] + shifted_slice[:dy, :] = np.random.rand(dy, 128) + shifted_slice[:, :dx] = np.random.rand(128, dx) + + vol_data[0, 5, :, :] = shifted_slice + + mock_vol = MockVolume(vol_data) + processor = TestEstimateMissingFlow(config, mock_vol) + + # Start at 2,2,5 (flow coords) corresponds to 32,32,5 (image coords). + box = bounding_box.BoundingBox((2, 2, 5), (2, 2, 1)) + + # No pre-existing flow data. + input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32) + subvol = subvolume.Subvolume(input_data, box) + + result_subvol = processor.process(subvol) + + self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2)) + self.assertFalse( + np.any(np.isnan(result_subvol.data)), "Result contains NaNs" + ) + + np.testing.assert_allclose( + result_subvol.data[2, ...], 2, err_msg="delta_z incorrect" + ) + np.testing.assert_allclose( + result_subvol.data[0, 0, 0, 0], + -dx, + atol=0.5, + err_msg="Flow X incorrect", + ) + np.testing.assert_allclose( + result_subvol.data[1, 0, 0, 0], + -dy, + atol=0.5, + err_msg="Flow Y incorrect", + ) + + def test_process_clipped_context(self): + config = flow.EstimateMissingFlow.Config( + patch_size=16, + stride=16, + delta_z=1, + max_delta_z=5, # Large lookback + max_attempts=1, + mask_configs=None, + mask_only_for_patch_selection=False, + selection_mask_configs=None, + min_peak_sharpness=0.0, + min_peak_ratio=0.0, + max_magnitude=0, + batch_size=10, + image_volinfo="dummy_path", + image_cache_bytes=0, + mask_cache_bytes=0, + search_radius=16, + ) + + vol_shape = (1, 10, 128, 128) + vol_data = np.random.rand(*vol_shape).astype(np.float32) + + mock_vol = MockVolume(vol_data) + processor = TestEstimateMissingFlow(config, mock_vol) + + box = bounding_box.BoundingBox(start=(2, 2, 1), size=(2, 2, 1)) + + # No pre-existing flow data. + input_data = np.full((2, 1, 2, 2), np.nan, dtype=np.float32) + subvol = subvolume.Subvolume(input_data, box) + + result_subvol = processor.process(subvol) + + self.assertEqual(result_subvol.data.shape, (3, 1, 2, 2)) + + # Result should be NaNs because z=1 only has z=0 as valid prev. + # delta_z=1 (matching z=0) was not calculated (assumed missing). + # delta_z=2,3,4,5 look at z < 0, which is out of bounds. + self.assertTrue( + np.all(np.isnan(result_subvol.data[0, ...])), "Result X should be NaN" + ) + self.assertTrue( + np.all(np.isnan(result_subvol.data[1, ...])), "Result Y should be NaN" + ) + # Channel 2 is initialized to delta_z (1). + self.assertEqual(result_subvol.data[2, 0, 0, 0], 1) + + +if __name__ == "__main__": + absltest.main()