From 5481fa23901312526a95ae5a878de0458a3bb8b6 Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Fri, 23 Jan 2026 18:32:54 -0500 Subject: [PATCH 1/2] Add experimental AST optimizer and tests --- CHANGELOG.md | 5 + delayed_image/delayed_base.py | 13 ++ delayed_image/delayed_base.pyi | 3 + delayed_image/experimental/__init__.py | 5 + delayed_image/experimental/astopt/__init__.py | 5 + .../experimental/astopt/optimizer.py | 94 ++++++++++ delayed_image/experimental/astopt/rules.py | 176 ++++++++++++++++++ .../experimental/astopt/signature.py | 40 ++++ .../experimental/astopt/transformer.py | 40 ++++ tests/conftest.py | 15 ++ tests/test_ast_optimize_equivalence.py | 62 ++++++ tests/test_delayed_nodes.py | 8 +- tests/test_delayed_ops.py | 10 +- tests/test_find_reference_scale.py | 8 +- tests/test_huge_scale_ratio.py | 6 +- tests/test_issue_4.py | 24 +-- tests/test_itk_backend.py | 4 +- tests/test_optimize_crop.py | 12 +- tests/test_subband_select.py | 8 +- 19 files changed, 498 insertions(+), 40 deletions(-) create mode 100644 delayed_image/experimental/__init__.py create mode 100644 delayed_image/experimental/astopt/__init__.py create mode 100644 delayed_image/experimental/astopt/optimizer.py create mode 100644 delayed_image/experimental/astopt/rules.py create mode 100644 delayed_image/experimental/astopt/signature.py create mode 100644 delayed_image/experimental/astopt/transformer.py create mode 100644 tests/conftest.py create mode 100644 tests/test_ast_optimize_equivalence.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6906bb3..8ac0ff0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## Version 0.4.6 - Unreleased +### Added +* Experimental AST-based optimizer available via `optimize_ast()` and + `delayed_image.experimental.astopt.optimize()`. The existing `optimize()` + behavior is unchanged. + ### Fix * Handle case when input sensorchan strings are string subclasses. * Fix issue where lazy warps did not respect explicitly given dsize arguments diff --git a/delayed_image/delayed_base.py b/delayed_image/delayed_base.py index 5abae10..ec4ffa2 100644 --- a/delayed_image/delayed_base.py +++ b/delayed_image/delayed_base.py @@ -392,6 +392,19 @@ def optimize(self): """ raise NotImplementedError + def optimize_ast(self, **kwargs): + """ + Experimental AST-based optimizer. + + Args: + **kwargs: forwarded to the experimental optimizer. + + Returns: + DelayedOperation + """ + from delayed_image.experimental.astopt import optimize + return optimize(self, **kwargs) + def _set_nested_params(self, **kwargs): """ Hack to override nested params on all warps for things like diff --git a/delayed_image/delayed_base.pyi b/delayed_image/delayed_base.pyi index ae741da..b486c82 100644 --- a/delayed_image/delayed_base.pyi +++ b/delayed_image/delayed_base.pyi @@ -60,6 +60,9 @@ class DelayedOperation(ub.NiceRepr): def optimize(self) -> DelayedOperation: ... + def optimize_ast(self, **kwargs) -> DelayedOperation: + ... + class DelayedNaryOperation(DelayedOperation): parts: Incomplete diff --git a/delayed_image/experimental/__init__.py b/delayed_image/experimental/__init__.py new file mode 100644 index 0000000..74e2c00 --- /dev/null +++ b/delayed_image/experimental/__init__.py @@ -0,0 +1,5 @@ +"""Experimental submodules for delayed_image.""" + +from delayed_image.experimental import astopt # noqa: F401 + +__all__ = ["astopt"] diff --git a/delayed_image/experimental/astopt/__init__.py b/delayed_image/experimental/astopt/__init__.py new file mode 100644 index 0000000..d354ece --- /dev/null +++ b/delayed_image/experimental/astopt/__init__.py @@ -0,0 +1,5 @@ +"""AST-based experimental optimizer.""" + +from delayed_image.experimental.astopt.optimizer import optimize, optimize_trace # noqa: F401 + +__all__ = ["optimize", "optimize_trace"] diff --git a/delayed_image/experimental/astopt/optimizer.py b/delayed_image/experimental/astopt/optimizer.py new file mode 100644 index 0000000..17253f1 --- /dev/null +++ b/delayed_image/experimental/astopt/optimizer.py @@ -0,0 +1,94 @@ +"""Experimental AST-based optimizer driver.""" + +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass, field +from typing import Dict, Tuple + +from delayed_image.experimental.astopt import rules +from delayed_image.experimental.astopt.signature import node_signature +from delayed_image.experimental.astopt.transformer import get_children, rebuild + + +@dataclass +class OptimizeTrace: + applied: list[dict] = field(default_factory=list) + counts: Counter = field(default_factory=Counter) + + def record(self, rule_name, node, new_node): + self.applied.append({ + "rule": rule_name, + "node": node.__class__.__name__, + "new_node": new_node.__class__.__name__, + }) + self.counts[rule_name] += 1 + + +class ASTOptimizer: + """AST optimizer with rule-based local rewrites.""" + + def __init__(self, trace: OptimizeTrace | None = None, legacy_fallback: bool = True): + self.trace = trace + self.legacy_fallback = legacy_fallback + self._memo: Dict[int, Tuple[object, str]] = {} + + def optimize(self, node): + return self._optimize_node(node) + + def _optimize_node(self, node): + node_id = id(node) + if node_id in self._memo: + return self._memo[node_id][0] + + replacements = {} + child_signatures = [] + for path, child in get_children(node): + new_child = self._optimize_node(child) + if new_child is not child: + replacements[path] = new_child + child_signatures.append(node_signature(new_child, [])) + + if replacements: + node = rebuild(node, replacements) + + node = self._apply_rules(node) + + if self.legacy_fallback: + node = node.optimize() + + signature = node_signature(node, child_signatures) + self._memo[node_id] = (node, signature) + return node + + def _apply_rules(self, node): + max_iters = 20 + for _ in range(max_iters): + changed = False + for rule in rules.rules_for(node): + new_node, did_change, name = rule(node) + if did_change: + if self.trace is not None and name: + self.trace.record(name, node, new_node) + node = new_node + changed = True + break + if not changed: + break + return node + + +def optimize(node, **kwargs): + """Optimize a delayed node using the experimental AST optimizer.""" + trace = kwargs.pop("trace", None) + legacy_fallback = kwargs.pop("legacy_fallback", True) + optimizer = ASTOptimizer(trace=trace, legacy_fallback=legacy_fallback) + return optimizer.optimize(node) + + +def optimize_trace(node, **kwargs): + """Optimize and return the trace of applied rules.""" + trace = OptimizeTrace() + kwargs["trace"] = trace + result = optimize(node, **kwargs) + return result, trace diff --git a/delayed_image/experimental/astopt/rules.py b/delayed_image/experimental/astopt/rules.py new file mode 100644 index 0000000..cabe16a --- /dev/null +++ b/delayed_image/experimental/astopt/rules.py @@ -0,0 +1,176 @@ +"""Rule registry for the experimental AST optimizer.""" + +from __future__ import annotations + +from collections import defaultdict + +import ubelt as ub + +from delayed_image import delayed_nodes + + +RULES = defaultdict(list) + + +def register_rule(node_type): + def decorator(func): + RULES[node_type].append(func) + return func + return decorator + + +def rules_for(node): + rules = [] + for cls in node.__class__.__mro__: + rules.extend(RULES.get(cls, [])) + return rules + + +DelayedWarp = delayed_nodes.DelayedWarp +DelayedCrop = delayed_nodes.DelayedCrop +DelayedOverview = delayed_nodes.DelayedOverview +DelayedDequantize = delayed_nodes.DelayedDequantize +DelayedChannelConcat = delayed_nodes.DelayedChannelConcat +isinstance2 = delayed_nodes.isinstance2 + + +@register_rule(DelayedWarp) +def fuse_warps(node): + if isinstance2(node.subdata, DelayedWarp): + return node._opt_fuse_warps(), True, "fuse_warps" + return node, False, None + + +@register_rule(DelayedWarp) +def remove_identity_warp(node): + noop_eps = node.meta.get("noop_eps", 0) + is_negligible = ( + node.dsize == node.subdata.dsize + and node.transform.isclose_identity(rtol=noop_eps, atol=noop_eps) + ) + if is_negligible: + return node.subdata, True, "remove_identity_warp" + return node, False, None + + +@register_rule(DelayedWarp) +def push_warp_under_concat(node): + if isinstance2(node.subdata, DelayedChannelConcat): + return node._opt_push_under_concat(), True, "warp_under_concat" + return node, False, None + + +@register_rule(DelayedWarp) +def warp_on_optimized_subdata(node): + if hasattr(node.subdata, "_optimized_warp"): + warp_kwargs = ub.dict_isect(node.meta, node._data_keys + node._algo_keys) + return node.subdata._optimized_warp(**warp_kwargs), True, "subdata_optimized_warp" + return node, False, None + + +@register_rule(DelayedWarp) +def split_warp_overview(node): + split = node._opt_split_warp_overview() + if split is not node: + return split, True, "split_warp_overview" + return node, False, None + + +@register_rule(DelayedWarp) +def absorb_overview(node): + absorbed = node._opt_absorb_overview() + if absorbed is not node: + return absorbed, True, "absorb_overview" + return node, False, None + + +@register_rule(DelayedCrop) +def fuse_crops(node): + if isinstance2(node.subdata, DelayedCrop): + return node._opt_fuse_crops(), True, "fuse_crops" + return node, False, None + + +@register_rule(DelayedCrop) +def optimized_crop_subdata(node): + if hasattr(node.subdata, "_optimized_crop"): + crop_kwargs = ub.dict_isect(node.meta, {"space_slice", "chan_idxs"}) + return node.subdata._optimized_crop(**crop_kwargs), True, "subdata_optimized_crop" + return node, False, None + + +@register_rule(DelayedCrop) +def crop_after_warp(node): + if isinstance2(node.subdata, DelayedWarp): + return node._opt_warp_after_crop(), True, "crop_after_warp" + return node, False, None + + +@register_rule(DelayedCrop) +def dequant_after_crop(node): + if isinstance2(node.subdata, DelayedDequantize): + return node._opt_dequant_after_crop(), True, "dequant_after_crop" + return node, False, None + + +@register_rule(DelayedCrop) +def crop_under_concat(node): + if isinstance2(node.subdata, DelayedChannelConcat): + return node._opt_push_under_concat(), True, "crop_under_concat" + return node, False, None + + +@register_rule(DelayedOverview) +def fuse_overview(node): + if isinstance2(node.subdata, DelayedOverview): + return node._opt_fuse_overview(), True, "fuse_overview" + return node, False, None + + +@register_rule(DelayedOverview) +def drop_identity_overview(node): + if node.meta.get("overview", None) == 0: + return node.subdata, True, "drop_overview_0" + return node, False, None + + +@register_rule(DelayedOverview) +def crop_after_overview(node): + if isinstance2(node.subdata, DelayedCrop): + return node._opt_crop_after_overview(), True, "crop_after_overview" + return node, False, None + + +@register_rule(DelayedOverview) +def warp_after_overview(node): + if isinstance2(node.subdata, DelayedWarp): + return node._opt_warp_after_overview(), True, "warp_after_overview" + return node, False, None + + +@register_rule(DelayedOverview) +def dequant_after_overview(node): + if isinstance2(node.subdata, DelayedDequantize): + return node._opt_dequant_after_overview(), True, "dequant_after_overview" + return node, False, None + + +@register_rule(DelayedOverview) +def overview_under_concat(node): + if isinstance2(node.subdata, DelayedChannelConcat): + return node._opt_push_under_concat(), True, "overview_under_concat" + return node, False, None + + +@register_rule(DelayedDequantize) +def dequant_before_warp(node): + if isinstance2(node.subdata, DelayedWarp): + return node._opt_dequant_before_other(), True, "dequant_before_warp" + return node, False, None + + +@register_rule(DelayedDequantize) +def dequant_under_concat(node): + if isinstance2(node.subdata, DelayedChannelConcat): + return node._opt_push_under_concat(), True, "dequant_under_concat" + return node, False, None diff --git a/delayed_image/experimental/astopt/signature.py b/delayed_image/experimental/astopt/signature.py new file mode 100644 index 0000000..61d03a8 --- /dev/null +++ b/delayed_image/experimental/astopt/signature.py @@ -0,0 +1,40 @@ +"""Helpers for creating canonical signatures for delayed nodes.""" + +from __future__ import annotations + +from typing import Any + +import ubelt as ub + + +def _normalize_value(value: Any) -> Any: + if hasattr(value, "concise"): + try: + value = value.concise() + except Exception: + pass + if hasattr(value, "spec"): + try: + value = value.spec + except Exception: + pass + if isinstance(value, dict): + return {key: _normalize_value(val) for key, val in value.items()} + if isinstance(value, (list, tuple)): + return [_normalize_value(val) for val in value] + return value + + +def node_signature(node, child_signatures) -> str: + """Return a hashable signature for a delayed node.""" + meta = getattr(node, "meta", {}) + normalized_meta = {key: _normalize_value(val) for key, val in meta.items()} + payload = { + "type": node.__class__.__name__, + "meta": normalized_meta, + "children": child_signatures, + } + try: + return ub.hash_data(payload) + except Exception: + return ub.hash_data(ub.urepr(payload, sort=1)) diff --git a/delayed_image/experimental/astopt/transformer.py b/delayed_image/experimental/astopt/transformer.py new file mode 100644 index 0000000..6664d75 --- /dev/null +++ b/delayed_image/experimental/astopt/transformer.py @@ -0,0 +1,40 @@ +"""Traversal helpers for delayed node trees.""" + +from __future__ import annotations + +import copy +from typing import Dict, Iterable, List, Tuple + + +ChildPath = Tuple[str, int | None] + + +def get_children(node) -> List[Tuple[ChildPath, object]]: + """Return a list of child paths and nodes.""" + children: List[Tuple[ChildPath, object]] = [] + if hasattr(node, "subdata"): + subdata = node.subdata + if subdata is not None: + children.append((("subdata", None), subdata)) + if hasattr(node, "parts"): + parts = node.parts + if parts is not None: + for idx, part in enumerate(parts): + children.append((("parts", idx), part)) + return children + + +def rebuild(node, new_children: Dict[ChildPath, object]): + """Return a shallow copy of node with updated children.""" + if not new_children: + return node + new_node = copy.copy(node) + if hasattr(new_node, "subdata") and ("subdata", None) in new_children: + new_node.subdata = new_children[("subdata", None)] + if hasattr(new_node, "parts"): + parts = list(new_node.parts) + for (kind, idx), child in new_children.items(): + if kind == "parts" and idx is not None: + parts[idx] = child + new_node.parts = parts + return new_node diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3422043 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest + + +@pytest.fixture(params=["legacy", "ast"]) +def optimize_func(request): + if request.param == "legacy": + return lambda node, **kwargs: node.optimize(**kwargs) + return lambda node, **kwargs: node.optimize_ast(**kwargs) + + +@pytest.fixture() +def optimize_pair(): + def _pair(node, **kwargs): + return node.optimize(**kwargs), node.optimize_ast(**kwargs) + return _pair diff --git a/tests/test_ast_optimize_equivalence.py b/tests/test_ast_optimize_equivalence.py new file mode 100644 index 0000000..ba225e7 --- /dev/null +++ b/tests/test_ast_optimize_equivalence.py @@ -0,0 +1,62 @@ +import numpy as np +import pytest + +from delayed_image import DelayedIdentity + + +def _make_base(seed=0, dsize=(64, 64)): + rng = np.random.RandomState(seed) + data = rng.randint(0, 255, size=(dsize[1], dsize[0], 3), dtype=np.uint8) + return DelayedIdentity(data, channels='r|g|b') + + +def test_ast_optimize_equivalence_simple(optimize_pair): + base = _make_base() + quantization = {'quant_max': 255, 'nodata': 0} + node = base.dequantize(quantization) + node = node.crop((slice(2, 60), slice(4, 66))) + node = node.take_channels('r|b') + + legacy, ast = optimize_pair(node) + + legacy_final = legacy.finalize() + ast_final = ast.finalize() + + assert legacy.dsize == ast.dsize + np.testing.assert_allclose(legacy_final, ast_final, atol=1e-6) + + +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_ast_optimize_equivalence_random(seed, optimize_pair): + rng = np.random.RandomState(seed) + base = _make_base(seed=seed) + quantization = {'quant_max': 255, 'nodata': 0} + + node = base + used_dequant = False + for _ in range(5): + op = rng.choice(["crop", "dequant", "channels"]) + if op == "crop": + x0 = rng.randint(0, 20) + y0 = rng.randint(0, 20) + node = node.crop((slice(y0, y0 + 32), slice(x0, x0 + 32))) + elif op == "channels": + node = node.take_channels('r|g') + elif op == "dequant" and not used_dequant: + node = node.dequantize(quantization) + used_dequant = True + + legacy, ast = optimize_pair(node) + np.testing.assert_allclose(legacy.finalize(), ast.finalize(), atol=1e-6) + + +def test_optimize_idempotent(optimize_func): + base = _make_base() + quantization = {'quant_max': 255, 'nodata': 0} + node = base.dequantize(quantization).crop((slice(0, 40), slice(0, 40))) + node = node.take_channels('r|b') + + opt1 = optimize_func(node) + opt2 = optimize_func(opt1) + + np.testing.assert_allclose(opt1.finalize(), opt2.finalize(), atol=1e-6) diff --git a/tests/test_delayed_nodes.py b/tests/test_delayed_nodes.py index 33803f4..31deb11 100644 --- a/tests/test_delayed_nodes.py +++ b/tests/test_delayed_nodes.py @@ -1,4 +1,4 @@ -def test_crop_optimize_issue(): +def test_crop_optimize_issue(optimize_func): """ There was an issue in 0.2.0 where a crop would be optimized incorrectly. @@ -65,10 +65,10 @@ def demo_weird_delayed(): chan2.write_network_text() print('\n-- Chan V1 [opt] --') - chan1_opt = chan1.optimize() + chan1_opt = optimize_func(chan1) chan1_opt.write_network_text() print('\n-- Chan V2 [opt] --') - chan2_opt = chan2.optimize() + chan2_opt = optimize_func(chan2) chan2_opt.write_network_text() assert chan1_opt.dsize == chan2_opt.dsize @@ -128,4 +128,4 @@ def test_lazy_warp_with_explicit_dsize(): CommandLine: python ~/code/delayed_image/tests/test_delayed_nodes.py """ - test_crop_optimize_issue() + test_crop_optimize_issue(lambda node, **kwargs: node.optimize(**kwargs)) diff --git a/tests/test_delayed_ops.py b/tests/test_delayed_ops.py index f5097e7..9d6d394 100644 --- a/tests/test_delayed_ops.py +++ b/tests/test_delayed_ops.py @@ -9,7 +9,7 @@ # @profile -def test_shuffle_delayed_operations(): +def test_shuffle_delayed_operations(optimize_func): """ CommandLine: XDEV_PROFILE=1 xdoctest -m tests/test_delayed_ops.py test_shuffle_delayed_operations @@ -23,7 +23,7 @@ def test_shuffle_delayed_operations(): # overviews=3) base = DelayedLoad(fpath, channels='r|g|b')._load_metadata() quantization = {'quant_max': 255, 'nodata': 0} - base.get_overview(1).dequantize(quantization).optimize() + optimize_func(base.get_overview(1).dequantize(quantization)) operations = [ ('warp', {'scale': 1}), @@ -57,7 +57,7 @@ def test_shuffle_delayed_operations(): delayed = func(args) # delayed.write_network_text(with_labels="name") - opt = delayed.optimize() + opt = optimize_func(delayed) # opt.write_network_text(with_labels="name") # We always expect that we will get a sequence in the form @@ -80,7 +80,7 @@ def test_shuffle_delayed_operations(): prev_idx = this_idx -def test_static_operation_optimize_single_chain(): +def test_static_operation_optimize_single_chain(optimize_func): """ There are 4 main operations: @@ -176,7 +176,7 @@ def __getitem__(self, index): # the manipulated data will get cropped away. # The optimized tree looks like this - optimized = delayed.optimize() + optimized = optimize_func(delayed) optimized.print_graph() """ ╙── Warp dsize=(64,32),transform={offset=(-0.6713,0.1755),scale=(0.5472,0.5773),shearx=0.1653,theta=-0.3208} diff --git a/tests/test_find_reference_scale.py b/tests/test_find_reference_scale.py index a400a97..7824844 100644 --- a/tests/test_find_reference_scale.py +++ b/tests/test_find_reference_scale.py @@ -1,6 +1,6 @@ -def test_find_reference_scale(): +def test_find_reference_scale(optimize_func): try: from rich import print as rprint use_rich = 1 @@ -21,8 +21,8 @@ def test_find_reference_scale(): mod = ref.warp(transform1).warp(transform2) # [0:100, 0:100] - opt_ref = ref.optimize() - opt_mod = mod.optimize() + opt_ref = optimize_func(ref) + opt_mod = optimize_func(mod) rprint('\n-- [green] REF --') ref.write_network_text(rich=use_rich) @@ -45,7 +45,7 @@ def test_find_reference_scale(): tf_mod_from_ref = tf_mod_from_leaf @ tf_leaf_from_ref recon_mod = opt_ref.warp(tf_mod_from_ref, dsize=mod.dsize) - recon_opt_mod = recon_mod.optimize() + recon_opt_mod = optimize_func(recon_mod) rprint('\n-- [orange1] MOD (recon) --') recon_mod.write_network_text(rich=use_rich) diff --git a/tests/test_huge_scale_ratio.py b/tests/test_huge_scale_ratio.py index 41b5f28..2037a1b 100644 --- a/tests/test_huge_scale_ratio.py +++ b/tests/test_huge_scale_ratio.py @@ -5,7 +5,7 @@ """ -def test_100x_scale_difference(): +def test_100x_scale_difference(optimize_func): """ There is an issue here in that the native subdata does not seem to agree with the resampled subdata. @@ -82,7 +82,7 @@ def fancy_checkerboard(dsize, num_squares): lores_resampled_sample_linear = kwimage.fill_nans_with_checkers(lores_resampled_sample_linear) # Native Approach: - native_parts, native_warps = chip.optimize().undo_warps(remove=['scale'], return_warps=True) + native_parts, native_warps = optimize_func(chip).undo_warps(remove=['scale'], return_warps=True) native1, native2 = native_parts warp_native1_from_virtual = native_warps[0] warp_native2_from_virtual = native_warps[1] @@ -112,7 +112,7 @@ def fancy_checkerboard(dsize, num_squares): print(roi_resolution2) # Get the delayed operation tree for just the coarse image for print comparison - resampled2 = chip.optimize().undo_warps(remove=[])[1] + resampled2 = optimize_func(chip).undo_warps(remove=[])[1] print('\n[green]--- Resampled Operations For Data 2 ---') resampled2.print_graph(fields='all') diff --git a/tests/test_issue_4.py b/tests/test_issue_4.py index 51a18ec..efa16b4 100644 --- a/tests/test_issue_4.py +++ b/tests/test_issue_4.py @@ -1,4 +1,4 @@ -def test_issue4(): +def test_issue4(optimize_func): """ The symptom is given this tree: @@ -125,22 +125,22 @@ def test_issue4(): delayed_crop = delayed_crop.prepare() delayed_crop.print_graph() - optimized = delayed_crop.optimize() + optimized = optimize_func(delayed_crop) optimized.print_graph() assert optimized.dsize == (225, 225) -def test_clipped_negative_slice(): +def test_clipped_negative_slice(optimize_func): import delayed_image from delayed_image.helpers import mkslice base = delayed_image.DelayedLoad.demo(dsize=(256, 256)) slices = mkslice[-10:216, 0:256] cropped = base.crop(slices) assert cropped.dsize == (256, 0) - assert cropped.optimize().dsize == (256, 0) + assert optimize_func(cropped).dsize == (256, 0) -def test_oob_crop_after_load(): +def test_oob_crop_after_load(optimize_func): import delayed_image import ubelt as ub from delayed_image.helpers import mkslice @@ -159,7 +159,7 @@ def test_oob_crop_after_load(): print('key = {}'.format(ub.urepr(key, nl=1))) orig.print_graph() orig.prepare() - opt = orig.optimize() + opt = optimize_func(orig) opt.print_graph() outputs[key] = opt print('----------') @@ -169,7 +169,7 @@ def test_oob_crop_after_load(): assert outputs['v3'].dsize == (100, 200) -def test_oob_crop_after_warp(): +def test_oob_crop_after_warp(optimize_func): """ Like test_oob_crop_after_load, but adds in a warp before the slices that triggered errors the previous test did not. @@ -195,7 +195,7 @@ def test_oob_crop_after_warp(): orig.print_graph() orig.prepare() try: - opt = orig.optimize() + opt = optimize_func(orig) except Exception as ex: print('ex = {}'.format(ub.urepr(ex, nl=1))) errors.append((key, ex)) @@ -210,7 +210,7 @@ def test_oob_crop_after_warp(): assert outputs['v3'].dsize == (100, 200) -def test_oob_crop_after_warp_with_overviews(): +def test_oob_crop_after_warp_with_overviews(optimize_func): """ Like test_oob_crop_after_load, but adds in a warp before the slices that triggered errors the previous test did not. @@ -243,7 +243,7 @@ def test_oob_crop_after_warp_with_overviews(): orig.print_graph() orig.prepare() try: - opt = orig.optimize() + opt = optimize_func(orig) except Exception as ex: print('ex = {}'.format(ub.urepr(ex, nl=1))) errors.append((key, ex)) @@ -258,7 +258,7 @@ def test_oob_crop_after_warp_with_overviews(): assert outputs['v3'].dsize == (100, 200) -def test_both_total_negative_slice(): +def test_both_total_negative_slice(optimize_func): import delayed_image try: import osgeo @@ -271,6 +271,6 @@ def test_both_total_negative_slice(): pad = [(0, 0), (0, 0)] crop = base.crop(slices, wrap=False, clip=False, pad=pad) crop.print_graph() - opt = crop.optimize() + opt = optimize_func(crop) assert crop.dsize == opt.dsize opt.print_graph() diff --git a/tests/test_itk_backend.py b/tests/test_itk_backend.py index fa083ad..519b268 100644 --- a/tests/test_itk_backend.py +++ b/tests/test_itk_backend.py @@ -12,7 +12,7 @@ def skip_if_missing_itk(): pytest.skip('requires itk to test') -def test_itk_warp(): +def test_itk_warp(optimize_func): skip_if_missing_itk() from delayed_image import DelayedLoad @@ -23,7 +23,7 @@ def test_itk_warp(): dsize = 'auto' new = self.warp({'scale': 1 / 30}, backend=backend, dsize=dsize, antialias=0) new.print_graph(fields='all') - opt = new.optimize() + opt = optimize_func(new) opt.print_graph(fields='all') result = opt.finalize() # import kwplot diff --git a/tests/test_optimize_crop.py b/tests/test_optimize_crop.py index dd2b6fa..8b68b3a 100644 --- a/tests/test_optimize_crop.py +++ b/tests/test_optimize_crop.py @@ -1,4 +1,4 @@ -def test_optimize_crop_without_clip_reproduction(): +def test_optimize_crop_without_clip_reproduction(optimize_func): """ There was an issue where a non-clipped crop would optimize it with a clip. This reproduces the issue exactly as it was originally seen. @@ -44,7 +44,7 @@ def test_optimize_crop_without_clip_reproduction(): delayed_image.delayed_nodes.TRACE_OPTIMIZE = 1 - optimize = delayed.optimize() + optimize = optimize_func(delayed) optimize.print_graph() if delayed_image.delayed_nodes.TRACE_OPTIMIZE: @@ -57,7 +57,7 @@ def test_optimize_crop_without_clip_reproduction(): assert optimize.dsize == (416, 416), ('optimization should keep that size') -def test_optimize_crop_without_clip_simplified(): +def test_optimize_crop_without_clip_simplified(optimize_func): """ This reproduces a simplified minimal version of the issue """ @@ -75,14 +75,14 @@ def test_optimize_crop_without_clip_simplified(): delayed = delayed.warp({'scale': 0.25}, dsize=(416, 416)) delayed.print_graph() - optimize = delayed.optimize() + optimize = optimize_func(delayed) optimize.print_graph() assert delayed.dsize == (416, 416), ('original image has a specific size') assert optimize.dsize == (416, 416), ('optimization should keep that size') -def test_optimize_crop_without_clip_minimal(): +def test_optimize_crop_without_clip_minimal(optimize_func): """ Minimal operations that caused the issue """ @@ -101,7 +101,7 @@ def test_optimize_crop_without_clip_minimal(): delayed_image.delayed_nodes.TRACE_OPTIMIZE = 1 - optimize = delayed.optimize() + optimize = optimize_func(delayed) optimize.print_graph() if delayed_image.delayed_nodes.TRACE_OPTIMIZE: diff --git a/tests/test_subband_select.py b/tests/test_subband_select.py index e4e5362..f6fe910 100644 --- a/tests/test_subband_select.py +++ b/tests/test_subband_select.py @@ -1,6 +1,6 @@ -def test_subchannel_select_with_overviews_case1(): +def test_subchannel_select_with_overviews_case1(optimize_func): """ This reproduces a bug in version < 0.2.8 with the exact operation tree that caused it in production. @@ -39,7 +39,7 @@ def __getitem__(self, index): print(chr(10) + 'Before Optimization:') node.write_network_text() - optimized = node.optimize() + optimized = optimize_func(node) print(chr(10) + 'After Optimization:') optimized.write_network_text() @@ -54,7 +54,7 @@ def __getitem__(self, index): assert im2.shape[2] == 2 -def test_subchannel_select_with_overviews_case2(): +def test_subchannel_select_with_overviews_case2(optimize_func): """ This reproduces a bug in version < 0.2.8 with a minimal example """ @@ -75,7 +75,7 @@ def test_subchannel_select_with_overviews_case2(): print(chr(10) + 'Before Optimization:') node.write_network_text() - optimized = node.optimize() + optimized = optimize_func(node) print(chr(10) + 'After Optimization:') optimized.write_network_text() From bb1a77e8490189917bb7b23e4e9c1020bc31dafb Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Fri, 23 Jan 2026 18:49:21 -0500 Subject: [PATCH 2/2] Guard concat crop rule and fix __future__ import --- delayed_image/experimental/astopt/optimizer.py | 2 +- delayed_image/experimental/astopt/rules.py | 4 +++- delayed_image/experimental/astopt/signature.py | 2 +- delayed_image/experimental/astopt/transformer.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/delayed_image/experimental/astopt/optimizer.py b/delayed_image/experimental/astopt/optimizer.py index 17253f1..b5fce89 100644 --- a/delayed_image/experimental/astopt/optimizer.py +++ b/delayed_image/experimental/astopt/optimizer.py @@ -1,6 +1,6 @@ """Experimental AST-based optimizer driver.""" -from __future__ import annotations +from __future__ import __annotations__ from collections import Counter from dataclasses import dataclass, field diff --git a/delayed_image/experimental/astopt/rules.py b/delayed_image/experimental/astopt/rules.py index cabe16a..f497736 100644 --- a/delayed_image/experimental/astopt/rules.py +++ b/delayed_image/experimental/astopt/rules.py @@ -1,6 +1,6 @@ """Rule registry for the experimental AST optimizer.""" -from __future__ import annotations +from __future__ import __annotations__ from collections import defaultdict @@ -116,6 +116,8 @@ def dequant_after_crop(node): @register_rule(DelayedCrop) def crop_under_concat(node): if isinstance2(node.subdata, DelayedChannelConcat): + if node.meta.get("chan_idxs", None) is not None: + return node, False, None return node._opt_push_under_concat(), True, "crop_under_concat" return node, False, None diff --git a/delayed_image/experimental/astopt/signature.py b/delayed_image/experimental/astopt/signature.py index 61d03a8..9f2f7c1 100644 --- a/delayed_image/experimental/astopt/signature.py +++ b/delayed_image/experimental/astopt/signature.py @@ -1,6 +1,6 @@ """Helpers for creating canonical signatures for delayed nodes.""" -from __future__ import annotations +from __future__ import __annotations__ from typing import Any diff --git a/delayed_image/experimental/astopt/transformer.py b/delayed_image/experimental/astopt/transformer.py index 6664d75..25d36c7 100644 --- a/delayed_image/experimental/astopt/transformer.py +++ b/delayed_image/experimental/astopt/transformer.py @@ -1,6 +1,6 @@ """Traversal helpers for delayed node trees.""" -from __future__ import annotations +from __future__ import __annotations__ import copy from typing import Dict, Iterable, List, Tuple