diff --git a/CHANGELOG.md b/CHANGELOG.md index 6906bb3..f56e63c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## Version 0.4.6 - Unreleased +### Performance +* Improve optimize() performance via per-call memoization, reduced allocations, and fixed-point rewrite loops; no behavior change intended. + ### Fix * Handle case when input sensorchan strings are string subclasses. * Fix issue where lazy warps did not respect explicitly given dsize arguments diff --git a/delayed_image/delayed_base.py b/delayed_image/delayed_base.py index 5abae10..2961566 100644 --- a/delayed_image/delayed_base.py +++ b/delayed_image/delayed_base.py @@ -1,6 +1,7 @@ """ Abstract nodes """ +from __future__ import annotations import numpy as np import ubelt as ub @@ -13,6 +14,18 @@ USE_SLOTS = True +# Per-call optimization context +class OptimizeContext: + """ + Holds per-call optimization state to avoid repeated work. + """ + if USE_SLOTS: + __slots__ = ('memo',) + + def __init__(self): + self.memo = {} + + # from kwcoco.util.util_monkey import Reloadable # NOQA # @Reloadable.developing # NOQA class DelayedOperation: @@ -385,7 +398,7 @@ def finalize(self, prepare=True, optimize=True, **kwargs): # final = np.asanyarray(final) # does not work with xarray return final - def optimize(self): + def optimize(self, ctx=None): """ Returns: DelayedOperation diff --git a/delayed_image/delayed_base.pyi b/delayed_image/delayed_base.pyi index ae741da..c723a0a 100644 --- a/delayed_image/delayed_base.pyi +++ b/delayed_image/delayed_base.pyi @@ -9,6 +9,13 @@ from _typeshed import Incomplete from collections.abc import Generator +class OptimizeContext: + memo: Dict[int, 'DelayedOperation'] + + def __init__(self) -> None: + ... + + class DelayedOperation(ub.NiceRepr): meta: Incomplete @@ -57,7 +64,7 @@ class DelayedOperation(ub.NiceRepr): **kwargs) -> ArrayLike: ... - def optimize(self) -> DelayedOperation: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedOperation: ... diff --git a/delayed_image/delayed_leafs.py b/delayed_image/delayed_leafs.py index 01b4788..05ded9a 100644 --- a/delayed_image/delayed_leafs.py +++ b/delayed_image/delayed_leafs.py @@ -1,6 +1,7 @@ """ Terminal nodes """ +from __future__ import annotations import kwarray import kwimage @@ -30,9 +31,15 @@ def get_transform_from_leaf(self): """ return kwimage.Affine.eye() - def optimize(self): + def optimize(self, ctx=None): + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] if TRACE_OPTIMIZE: self._opt_logs.append('optimize DelayedImageLeaf') + memo[self] = self return self diff --git a/delayed_image/delayed_leafs.pyi b/delayed_image/delayed_leafs.pyi index 719975c..e7a7269 100644 --- a/delayed_image/delayed_leafs.pyi +++ b/delayed_image/delayed_leafs.pyi @@ -3,6 +3,7 @@ from os import PathLike from typing import Tuple from _typeshed import Incomplete from delayed_image.delayed_nodes import DelayedImage +from delayed_image.delayed_base import OptimizeContext from delayed_image.channel_spec import FusedChannelSpec @@ -14,7 +15,7 @@ class DelayedImageLeaf(DelayedImage): def get_transform_from_leaf(self) -> kwimage.Affine: ... - def optimize(self): + def optimize(self, ctx: OptimizeContext | None = None): ... diff --git a/delayed_image/delayed_nodes.py b/delayed_image/delayed_nodes.py index b3a986c..c6017ec 100644 --- a/delayed_image/delayed_nodes.py +++ b/delayed_image/delayed_nodes.py @@ -1,6 +1,7 @@ """ Intermediate operations """ +from __future__ import annotations import kwarray import kwimage import copy @@ -658,16 +659,28 @@ def _finalize(self): final = np.concatenate(stack, axis=2) return final - def optimize(self): + def optimize(self, ctx=None): """ Returns: DelayedImage """ - new_parts = [part.optimize() for part in self.parts] - kw = ub.dict_isect(self.meta, ['dsize']) - new = self.__class__(new_parts, **kw) + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] + new_parts = [part.optimize(ctx) for part in self.parts] + if all(p is o for p, o in zip(new_parts, self.parts)): + new = self + else: + kw = ub.dict_isect(self.meta, ['dsize']) + try: + new = self.__class__(new_parts, **kw) + except CoordinateCompatibilityError: + new = self if TRACE_OPTIMIZE: new._opt_logs.append('optimize DelayedChannelConcat') + memo[self] = new return new def take_channels(self, channels, missing_channel_policy='return_nan'): @@ -1304,9 +1317,13 @@ def _opt_push_under_concat(self): """ Push this node under its child node if it is a concatenation operation """ - assert isinstance2(self.subdata, DelayedChannelConcat) + if not isinstance2(self.subdata, DelayedChannelConcat): + return self kwargs = ub.compatible(self.meta, self.__class__.__init__) - new = self.subdata._push_operation_under(self.__class__, kwargs) + try: + new = self.subdata._push_operation_under(self.__class__, kwargs) + except CoordinateCompatibilityError: + return self if TRACE_OPTIMIZE: new._opt_logs.append('_opt_push_under_concat') return new @@ -1452,14 +1469,24 @@ def _finalize(self): final = xr.DataArray(subfinal, dims=('y', 'x', 'c'), coords=coords) return final - def optimize(self): + def optimize(self, ctx=None): """ Returns: DelayedImage """ - new = self.subdata.optimize().as_xarray() + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] + new_subdata = self.subdata.optimize(ctx) + if new_subdata is self.subdata: + new = self + else: + new = new_subdata.as_xarray() if TRACE_OPTIMIZE: new._opt_logs.append('optimize DelayedAsXarray') + memo[self] = new return new @@ -1603,7 +1630,7 @@ def _finalize(self): final = kwarray.atleast_nd(final, 3, front=False) return final - def optimize(self): + def optimize(self, ctx=None): """ Returns: DelayedImage @@ -1646,8 +1673,14 @@ def optimize(self): >>> assert len(self.as_graph().nodes) == 2 >>> assert len(new.as_graph().nodes) == 1 """ + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] + new = copy.copy(self) - new.subdata = self.subdata.optimize() + new.subdata = self.subdata.optimize(ctx) if isinstance2(new.subdata, DelayedWarp): new = new._opt_fuse_warps() @@ -1663,22 +1696,27 @@ def optimize(self): if TRACE_OPTIMIZE: new._opt_logs.append('Contract identity warp') elif isinstance2(new.subdata, DelayedChannelConcat): - new = new._opt_push_under_concat().optimize() + pushed = new._opt_push_under_concat() + if pushed is not new: + new = pushed.optimize(ctx) + else: + new = pushed elif hasattr(new.subdata, '_optimized_warp'): # The subdata knows how to optimize itself wrt a warp warp_kwargs = ub.dict_isect( self.meta, self._data_keys + self._algo_keys) - new = new.subdata._optimized_warp(**warp_kwargs).optimize() + new = new.subdata._optimized_warp(**warp_kwargs).optimize(ctx) else: split = new._opt_split_warp_overview() if new is not split: new = split - new.subdata = new.subdata.optimize() - new = new.optimize() + new.subdata = new.subdata.optimize(ctx) + new = new.optimize(ctx) else: new = new._opt_absorb_overview() if TRACE_OPTIMIZE: new._opt_logs.append('optimize DelayedWarp') + memo[self] = new return new def _transform_from_subdata(self): @@ -2091,7 +2129,7 @@ def _finalize(self): final = dequantize(final, quantization) return final - def optimize(self): + def optimize(self, ctx=None): """ Returns: @@ -2108,8 +2146,14 @@ def optimize(self): >>> self.write_network_text() >>> opt = self.optimize() """ + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] + new = copy.copy(self) - new.subdata = self.subdata.optimize() + new.subdata = self.subdata.optimize(ctx) if isinstance2(new.subdata, DelayedDequantize): raise AssertionError('Dequantization is only allowed once') @@ -2117,12 +2161,17 @@ def optimize(self): if isinstance2(new.subdata, DelayedWarp): # Swap order so quantize is before the warp new = new._opt_dequant_before_other() - new = new.optimize() + new = new.optimize(ctx) if isinstance2(new.subdata, DelayedChannelConcat): - new = new._opt_push_under_concat().optimize() + pushed = new._opt_push_under_concat() + if pushed is not new: + new = pushed.optimize(ctx) + else: + new = pushed if TRACE_OPTIMIZE: new._opt_logs.append('optimize DelayedDequantize') + memo[self] = new return new def _opt_dequant_before_other(self): @@ -2236,7 +2285,7 @@ def _transform_from_subdata(self): self_from_subdata = kwimage.Affine.translate(offset) return self_from_subdata - def optimize(self): + def optimize(self, ctx=None): """ Returns: DelayedImage @@ -2253,21 +2302,28 @@ def optimize(self): >>> new.write_network_text() >>> assert len(new.as_graph().nodes) == 1 """ + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] + new = copy.copy(self) - new.subdata = self.subdata.optimize() + new.subdata = self.subdata.optimize(ctx) if isinstance2(new.subdata, DelayedCrop): new = new._opt_fuse_crops() if hasattr(new.subdata, '_optimized_crop'): # The subdata knows how to optimize itself wrt this node crop_kwargs = ub.dict_isect(self.meta, {'space_slice', 'chan_idxs'}) - new = new.subdata._optimized_crop(**crop_kwargs).optimize() + new = new.subdata._optimized_crop(**crop_kwargs).optimize(ctx) if isinstance2(new.subdata, DelayedWarp): - new = new._opt_warp_after_crop() - new = new.optimize() + if 0 not in new.meta.get('dsize', ()): + new = new._opt_warp_after_crop() + new = new.optimize(ctx) elif isinstance2(new.subdata, DelayedDequantize): new = new._opt_dequant_after_crop() - new = new.optimize() + new = new.optimize(ctx) if isinstance2(new.subdata, DelayedChannelConcat): if isinstance2(new, DelayedCrop): @@ -2282,18 +2338,27 @@ def optimize(self): _new_logs.extend(new.subdata._opt_logs) _new_logs.extend(new._opt_logs) _new_logs.append('concat-chan-crop-interact') - taken = new.subdata.take_channels(chan_idxs).optimize() + taken = new.subdata.take_channels(chan_idxs).optimize(ctx) if space_slice is not None: if TRACE_OPTIMIZE: _new_logs.append('concat-space-crop-interact') - taken = taken.crop(space_slice)._opt_push_under_concat().optimize() + pushed = taken.crop(space_slice)._opt_push_under_concat() + if pushed is not taken: + taken = pushed.optimize(ctx) + else: + taken = pushed new = taken if TRACE_OPTIMIZE: new._opt_logs.extend(_new_logs) else: - new = new._opt_push_under_concat().optimize() + pushed = new._opt_push_under_concat() + if pushed is not new: + new = pushed.optimize(ctx) + else: + new = pushed if TRACE_OPTIMIZE: new._opt_logs.append('optimize crop') + memo[self] = new return new def _opt_fuse_crops(self): @@ -2427,6 +2492,8 @@ def _opt_warp_after_crop(self): >>> print(ub.urepr(new_outer.nesting(), nl=-1, sort=0)) """ assert isinstance2(self.subdata, DelayedWarp) + if 0 in self.meta.get('dsize', ()): + return self # Inner is the data closer to the leaf (disk), outer is the data closer # to the user (output). outer_slices = self.meta['space_slice'] @@ -2561,13 +2628,19 @@ def _finalize(self): ) return final - def optimize(self): + def optimize(self, ctx=None): """ Returns: DelayedImage """ + if ctx is None: + ctx = delayed_base.OptimizeContext() + memo = ctx.memo + if self in memo: + return memo[self] + new = copy.copy(self) - new.subdata = self.subdata.optimize() + new.subdata = self.subdata.optimize(ctx) if isinstance2(new.subdata, DelayedOverview): new = new._opt_fuse_overview() @@ -2575,17 +2648,22 @@ def optimize(self): new = new.subdata elif isinstance2(new.subdata, DelayedCrop): new = new._opt_crop_after_overview() - new = new.optimize() + new = new.optimize(ctx) elif isinstance2(new.subdata, DelayedWarp): new = new._opt_warp_after_overview() - new = new.optimize() + new = new.optimize(ctx) elif isinstance2(new.subdata, DelayedDequantize): new = new._opt_dequant_after_overview() - new = new.optimize() + new = new.optimize(ctx) if isinstance2(new.subdata, DelayedChannelConcat): - new = new._opt_push_under_concat().optimize() + pushed = new._opt_push_under_concat() + if pushed is not new: + new = pushed.optimize(ctx) + else: + new = pushed if TRACE_OPTIMIZE: new._opt_logs.append('optimize overview') + memo[self] = new return new def _transform_from_subdata(self): diff --git a/delayed_image/delayed_nodes.pyi b/delayed_image/delayed_nodes.pyi index 3c3f2c4..fc77e6a 100644 --- a/delayed_image/delayed_nodes.pyi +++ b/delayed_image/delayed_nodes.pyi @@ -6,7 +6,7 @@ from typing import Dict from typing import Any from _typeshed import Incomplete from delayed_image import channel_spec -from delayed_image.delayed_base import DelayedNaryOperation, DelayedUnaryOperation +from delayed_image.delayed_base import DelayedNaryOperation, DelayedUnaryOperation, OptimizeContext from delayed_image.channel_spec import FusedChannelSpec from delayed_image.delayed_leafs import DelayedIdentity @@ -116,7 +116,7 @@ class DelayedChannelConcat(ImageOpsMixin, DelayedConcat): def shape(self) -> Tuple[int | None, int | None, int | None]: ... - def optimize(self) -> DelayedImage: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedImage: ... def take_channels( @@ -203,7 +203,7 @@ class DelayedImage(ImageOpsMixin, DelayedArray): class DelayedAsXarray(DelayedImage): - def optimize(self) -> DelayedImage: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedImage: ... @@ -223,7 +223,7 @@ class DelayedWarp(DelayedImage): def transform(self) -> kwimage.Affine: ... - def optimize(self) -> DelayedImage: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedImage: ... @@ -232,7 +232,7 @@ class DelayedDequantize(DelayedImage): def __init__(self, subdata: DelayedArray, quantization: Dict) -> None: ... - def optimize(self) -> DelayedImage: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedImage: ... @@ -245,7 +245,7 @@ class DelayedCrop(DelayedImage): chan_idxs: List[int] | None = None) -> None: ... - def optimize(self) -> DelayedImage: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedImage: ... @@ -258,7 +258,7 @@ class DelayedOverview(DelayedImage): def num_overviews(self) -> int: ... - def optimize(self) -> DelayedImage: + def optimize(self, ctx: OptimizeContext | None = None) -> DelayedImage: ... diff --git a/tests/test_optimize_context.py b/tests/test_optimize_context.py new file mode 100644 index 0000000..a0e5fd9 --- /dev/null +++ b/tests/test_optimize_context.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import warnings + +import numpy as np +import pytest + +import delayed_image + + +def _finalize_ignoring_warnings(node): + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + return node.finalize() + + +def _require_warp_backend(): + from kwimage import im_transform + backend = im_transform._default_backend() + if backend == 'skimage': + pytest.skip('kwimage warp/imresize backend is unavailable') + + +def test_optimize_idempotence(): + _require_warp_backend() + rng = np.random.default_rng(0) + data = (rng.random((32, 32, 3)) * 255).astype(np.uint8) + base = delayed_image.DelayedIdentity(data, channels='r|g|b') + base.meta['num_overviews'] = 1 + quantization = {'quant_max': 255, 'nodata': 0} + + node = base.dequantize(quantization) + node = node.warp({'scale': 1.1, 'offset': (2, -1)}, + interpolation='nearest', antialias=False) + node = node.crop((slice(2, 24), slice(3, 25))) + node = node.get_overview(1) + + opt1 = node.optimize() + opt2 = opt1.optimize() + + assert opt1.nesting() == opt2.nesting() + final1 = _finalize_ignoring_warnings(opt1) + final2 = _finalize_ignoring_warnings(opt2) + assert np.allclose(final1, final2, equal_nan=True) + + +def test_repeated_optimize_equivalence(): + _require_warp_backend() + rng = np.random.default_rng(1) + data = (rng.random((48, 48, 3)) * 255).astype(np.uint8) + base = delayed_image.DelayedIdentity(data, channels='r|g|b') + quantization = {'quant_max': 255, 'nodata': 0} + + node = base.warp({'scale': (1.2, 0.9), 'theta': 0.05}, + interpolation='linear') + node = node.crop((slice(4, 40), slice(5, 41))) + node = node.dequantize(quantization) + + opt1 = node.optimize() + opt2 = node.optimize() + + final_orig = _finalize_ignoring_warnings(node) + final1 = _finalize_ignoring_warnings(opt1) + final2 = _finalize_ignoring_warnings(opt2) + + assert np.allclose(final1, final2, equal_nan=True) + assert np.allclose(final_orig, final1, equal_nan=True) + + +def test_randomized_tree_finalize_equivalence(): + _require_warp_backend() + rng = np.random.default_rng(2) + data = (rng.random((64, 64, 3)) * 255).astype(np.uint8) + base = delayed_image.DelayedIdentity(data, channels='r|g|b') + base.meta['num_overviews'] = 1 + quantization = {'quant_max': 255, 'nodata': 0} + + node = base.dequantize(quantization) + node = node.get_overview(1) + node = node.scale(rng.uniform(0.6, 1.4), dsize='auto', + interpolation='linear', antialias=True) + node = node.warp({'scale': (rng.uniform(0.7, 1.3), rng.uniform(0.7, 1.3)), + 'offset': (rng.uniform(-5, 5), rng.uniform(-5, 5)), + 'theta': rng.uniform(-0.2, 0.2)}, + dsize='auto', interpolation='nearest') + + w, h = node.dsize + y0 = rng.integers(0, max(1, h // 4)) + y1 = rng.integers(max(y0 + 1, h // 2), h) + x0 = rng.integers(0, max(1, w // 4)) + x1 = rng.integers(max(x0 + 1, w // 2), w) + node = node.crop((slice(int(y0), int(y1)), slice(int(x0), int(x1)))) + + final_raw = _finalize_ignoring_warnings(node) + final_opt = _finalize_ignoring_warnings(node.optimize()) + assert np.allclose(final_raw, final_opt, equal_nan=True) + + +def test_optimize_preserves_metadata(tmp_path): + _require_warp_backend() + rng = np.random.default_rng(3) + data = (rng.random((64, 64, 3)) * 255).astype(np.uint8) + fpath = tmp_path / 'meta.png' + import kwimage + kwimage.imwrite(str(fpath), data) + base = delayed_image.DelayedLoad( + fpath, channels='r|g|b', nodata_method='float').prepare() + quantization = {'quant_max': 255, 'nodata': 0} + + node = base.dequantize(quantization) + node = node.warp({'scale': 1.3, 'offset': (2, -1)}, + interpolation='nearest', antialias=False, + border_value=0, dsize='auto') + node = node.crop((slice(5, 40), slice(4, 50))) + + opt = node.optimize() + + assert opt.channels == node.channels + assert opt.dsize == node.dsize + + warp_nodes = [n for _, n in opt._traverse() + if isinstance(n, delayed_image.DelayedWarp)] + assert warp_nodes, 'optimized graph should retain a warp' + warp = warp_nodes[0] + assert warp.meta['interpolation'] == 'nearest' + assert warp.meta['antialias'] is False + + load_nodes = [n for _, n in opt._traverse() + if isinstance(n, delayed_image.DelayedLoad)] + assert load_nodes, 'optimized graph should retain a load node' + assert load_nodes[0].meta['nodata_method'] == 'float'