Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

## Version 0.4.6 - Unreleased

### Added
* Experimental AST-based optimizer available via `optimize_ast()` and
`delayed_image.experimental.astopt.optimize()`. The existing `optimize()`
behavior is unchanged.

### Fix
* Handle case when input sensorchan strings are string subclasses.
* Fix issue where lazy warps did not respect explicitly given dsize arguments
Expand Down
13 changes: 13 additions & 0 deletions delayed_image/delayed_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,19 @@ def optimize(self):
"""
raise NotImplementedError

def optimize_ast(self, **kwargs):
"""
Experimental AST-based optimizer.

Args:
**kwargs: forwarded to the experimental optimizer.

Returns:
DelayedOperation
"""
from delayed_image.experimental.astopt import optimize
return optimize(self, **kwargs)

def _set_nested_params(self, **kwargs):
"""
Hack to override nested params on all warps for things like
Expand Down
3 changes: 3 additions & 0 deletions delayed_image/delayed_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class DelayedOperation(ub.NiceRepr):
def optimize(self) -> DelayedOperation:
...

def optimize_ast(self, **kwargs) -> DelayedOperation:
...


class DelayedNaryOperation(DelayedOperation):
parts: Incomplete
Expand Down
5 changes: 5 additions & 0 deletions delayed_image/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Experimental submodules for delayed_image."""

from delayed_image.experimental import astopt # noqa: F401

__all__ = ["astopt"]
5 changes: 5 additions & 0 deletions delayed_image/experimental/astopt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""AST-based experimental optimizer."""

from delayed_image.experimental.astopt.optimizer import optimize, optimize_trace # noqa: F401

__all__ = ["optimize", "optimize_trace"]
94 changes: 94 additions & 0 deletions delayed_image/experimental/astopt/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Experimental AST-based optimizer driver."""

from __future__ import __annotations__

from collections import Counter
from dataclasses import dataclass, field
from typing import Dict, Tuple

from delayed_image.experimental.astopt import rules
from delayed_image.experimental.astopt.signature import node_signature
from delayed_image.experimental.astopt.transformer import get_children, rebuild


@dataclass
class OptimizeTrace:
applied: list[dict] = field(default_factory=list)
counts: Counter = field(default_factory=Counter)

def record(self, rule_name, node, new_node):
self.applied.append({
"rule": rule_name,
"node": node.__class__.__name__,
"new_node": new_node.__class__.__name__,
})
self.counts[rule_name] += 1


class ASTOptimizer:
"""AST optimizer with rule-based local rewrites."""

def __init__(self, trace: OptimizeTrace | None = None, legacy_fallback: bool = True):
self.trace = trace
self.legacy_fallback = legacy_fallback
self._memo: Dict[int, Tuple[object, str]] = {}

def optimize(self, node):
return self._optimize_node(node)

def _optimize_node(self, node):
node_id = id(node)
if node_id in self._memo:
return self._memo[node_id][0]

replacements = {}
child_signatures = []
for path, child in get_children(node):
new_child = self._optimize_node(child)
if new_child is not child:
replacements[path] = new_child
child_signatures.append(node_signature(new_child, []))

if replacements:
node = rebuild(node, replacements)

node = self._apply_rules(node)

if self.legacy_fallback:
node = node.optimize()

signature = node_signature(node, child_signatures)
self._memo[node_id] = (node, signature)
return node

def _apply_rules(self, node):
max_iters = 20
for _ in range(max_iters):
changed = False
for rule in rules.rules_for(node):
new_node, did_change, name = rule(node)
if did_change:
if self.trace is not None and name:
self.trace.record(name, node, new_node)
node = new_node
changed = True
break
if not changed:
break
return node


def optimize(node, **kwargs):
"""Optimize a delayed node using the experimental AST optimizer."""
trace = kwargs.pop("trace", None)
legacy_fallback = kwargs.pop("legacy_fallback", True)
optimizer = ASTOptimizer(trace=trace, legacy_fallback=legacy_fallback)
return optimizer.optimize(node)


def optimize_trace(node, **kwargs):
"""Optimize and return the trace of applied rules."""
trace = OptimizeTrace()
kwargs["trace"] = trace
result = optimize(node, **kwargs)
return result, trace
178 changes: 178 additions & 0 deletions delayed_image/experimental/astopt/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Rule registry for the experimental AST optimizer."""

from __future__ import __annotations__

from collections import defaultdict

import ubelt as ub

from delayed_image import delayed_nodes


RULES = defaultdict(list)


def register_rule(node_type):
def decorator(func):
RULES[node_type].append(func)
return func
return decorator


def rules_for(node):
rules = []
for cls in node.__class__.__mro__:
rules.extend(RULES.get(cls, []))
return rules


DelayedWarp = delayed_nodes.DelayedWarp
DelayedCrop = delayed_nodes.DelayedCrop
DelayedOverview = delayed_nodes.DelayedOverview
DelayedDequantize = delayed_nodes.DelayedDequantize
DelayedChannelConcat = delayed_nodes.DelayedChannelConcat
isinstance2 = delayed_nodes.isinstance2


@register_rule(DelayedWarp)
def fuse_warps(node):
if isinstance2(node.subdata, DelayedWarp):
return node._opt_fuse_warps(), True, "fuse_warps"
return node, False, None


@register_rule(DelayedWarp)
def remove_identity_warp(node):
noop_eps = node.meta.get("noop_eps", 0)
is_negligible = (
node.dsize == node.subdata.dsize
and node.transform.isclose_identity(rtol=noop_eps, atol=noop_eps)
)
if is_negligible:
return node.subdata, True, "remove_identity_warp"
return node, False, None


@register_rule(DelayedWarp)
def push_warp_under_concat(node):
if isinstance2(node.subdata, DelayedChannelConcat):
return node._opt_push_under_concat(), True, "warp_under_concat"
return node, False, None


@register_rule(DelayedWarp)
def warp_on_optimized_subdata(node):
if hasattr(node.subdata, "_optimized_warp"):
warp_kwargs = ub.dict_isect(node.meta, node._data_keys + node._algo_keys)
return node.subdata._optimized_warp(**warp_kwargs), True, "subdata_optimized_warp"
return node, False, None


@register_rule(DelayedWarp)
def split_warp_overview(node):
split = node._opt_split_warp_overview()
if split is not node:
return split, True, "split_warp_overview"
return node, False, None


@register_rule(DelayedWarp)
def absorb_overview(node):
absorbed = node._opt_absorb_overview()
if absorbed is not node:
return absorbed, True, "absorb_overview"
return node, False, None


@register_rule(DelayedCrop)
def fuse_crops(node):
if isinstance2(node.subdata, DelayedCrop):
return node._opt_fuse_crops(), True, "fuse_crops"
return node, False, None


@register_rule(DelayedCrop)
def optimized_crop_subdata(node):
if hasattr(node.subdata, "_optimized_crop"):
crop_kwargs = ub.dict_isect(node.meta, {"space_slice", "chan_idxs"})
return node.subdata._optimized_crop(**crop_kwargs), True, "subdata_optimized_crop"
return node, False, None


@register_rule(DelayedCrop)
def crop_after_warp(node):
if isinstance2(node.subdata, DelayedWarp):
return node._opt_warp_after_crop(), True, "crop_after_warp"
return node, False, None


@register_rule(DelayedCrop)
def dequant_after_crop(node):
if isinstance2(node.subdata, DelayedDequantize):
return node._opt_dequant_after_crop(), True, "dequant_after_crop"
return node, False, None


@register_rule(DelayedCrop)
def crop_under_concat(node):
if isinstance2(node.subdata, DelayedChannelConcat):
if node.meta.get("chan_idxs", None) is not None:
return node, False, None
return node._opt_push_under_concat(), True, "crop_under_concat"
return node, False, None


@register_rule(DelayedOverview)
def fuse_overview(node):
if isinstance2(node.subdata, DelayedOverview):
return node._opt_fuse_overview(), True, "fuse_overview"
return node, False, None


@register_rule(DelayedOverview)
def drop_identity_overview(node):
if node.meta.get("overview", None) == 0:
return node.subdata, True, "drop_overview_0"
return node, False, None


@register_rule(DelayedOverview)
def crop_after_overview(node):
if isinstance2(node.subdata, DelayedCrop):
return node._opt_crop_after_overview(), True, "crop_after_overview"
return node, False, None


@register_rule(DelayedOverview)
def warp_after_overview(node):
if isinstance2(node.subdata, DelayedWarp):
return node._opt_warp_after_overview(), True, "warp_after_overview"
return node, False, None


@register_rule(DelayedOverview)
def dequant_after_overview(node):
if isinstance2(node.subdata, DelayedDequantize):
return node._opt_dequant_after_overview(), True, "dequant_after_overview"
return node, False, None


@register_rule(DelayedOverview)
def overview_under_concat(node):
if isinstance2(node.subdata, DelayedChannelConcat):
return node._opt_push_under_concat(), True, "overview_under_concat"
return node, False, None


@register_rule(DelayedDequantize)
def dequant_before_warp(node):
if isinstance2(node.subdata, DelayedWarp):
return node._opt_dequant_before_other(), True, "dequant_before_warp"
return node, False, None


@register_rule(DelayedDequantize)
def dequant_under_concat(node):
if isinstance2(node.subdata, DelayedChannelConcat):
return node._opt_push_under_concat(), True, "dequant_under_concat"
return node, False, None
40 changes: 40 additions & 0 deletions delayed_image/experimental/astopt/signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Helpers for creating canonical signatures for delayed nodes."""

from __future__ import __annotations__

from typing import Any

import ubelt as ub


def _normalize_value(value: Any) -> Any:
if hasattr(value, "concise"):
try:
value = value.concise()
except Exception:
pass
if hasattr(value, "spec"):
try:
value = value.spec
except Exception:
pass
if isinstance(value, dict):
return {key: _normalize_value(val) for key, val in value.items()}
if isinstance(value, (list, tuple)):
return [_normalize_value(val) for val in value]
return value


def node_signature(node, child_signatures) -> str:
"""Return a hashable signature for a delayed node."""
meta = getattr(node, "meta", {})
normalized_meta = {key: _normalize_value(val) for key, val in meta.items()}
payload = {
"type": node.__class__.__name__,
"meta": normalized_meta,
"children": child_signatures,
}
try:
return ub.hash_data(payload)
except Exception:
return ub.hash_data(ub.urepr(payload, sort=1))
Loading