|
| 1 | +""" |
| 2 | +AO* (And-Or) Graph Search Algorithm |
| 3 | +=================================== |
| 4 | +
|
| 5 | +This module implements the AO* (And-Or Star) search algorithm for solving |
| 6 | +AND-OR graphs — a generalization of search trees where nodes can represent |
| 7 | +either OR decisions or AND decompositions. |
| 8 | +
|
| 9 | +Each node in the graph maps to a list of *options*, where each option is a |
| 10 | +list of child nodes. This naturally encodes the AND/OR structure: |
| 11 | +
|
| 12 | +- The outer list represents OR options (choices). |
| 13 | +- Each inner list represents AND sets (all must be solved). |
| 14 | +
|
| 15 | +This implementation follows the GeeksforGeeks explanation: |
| 16 | +https://www.geeksforgeeks.org/artificial-intelligence/ao-algorithm-artificial-intelligence/ |
| 17 | +
|
| 18 | +Time Complexity: |
| 19 | + Exponential in worst case (depends on branching factor). |
| 20 | +
|
| 21 | +""" |
| 22 | + |
| 23 | +from __future__ import annotations |
| 24 | + |
| 25 | +from typing import Any |
| 26 | + |
| 27 | + |
| 28 | +def _update_node( |
| 29 | + graph: dict[Any, list[list[Any]]], |
| 30 | + node: Any, |
| 31 | + h: dict[Any, float], |
| 32 | + solved: dict[Any, list[Any]], |
| 33 | + weight: float = 1.0, |
| 34 | +) -> None: |
| 35 | + """ |
| 36 | + Update the heuristic value and solved status of a node. |
| 37 | +
|
| 38 | + Uses the AO* rule: |
| 39 | + value(node) = min_{option in options(node)} |
| 40 | + (sum(value(child) + weight) for child in option) |
| 41 | + """ |
| 42 | + if not graph.get(node): |
| 43 | + solved[node] = [] |
| 44 | + return |
| 45 | + |
| 46 | + best_cost = float("inf") |
| 47 | + best_option: list[Any] | None = None |
| 48 | + best_all_solved = False |
| 49 | + |
| 50 | + for option in graph[node]: |
| 51 | + # option cost is sum of (child value + edge weight) |
| 52 | + total = sum(h[child] + weight for child in option) |
| 53 | + all_solved = all(child in solved for child in option) |
| 54 | + if total < best_cost: |
| 55 | + best_cost = total |
| 56 | + best_option = option |
| 57 | + best_all_solved = all_solved |
| 58 | + |
| 59 | + h[node] = best_cost |
| 60 | + # mark node solved only if best option's children are solved |
| 61 | + if best_option and best_all_solved: |
| 62 | + solved[node] = best_option |
| 63 | + |
| 64 | + |
| 65 | +def _ao_star( |
| 66 | + graph: dict[Any, list[list[Any]]], |
| 67 | + node: Any, |
| 68 | + h: dict[Any, float], |
| 69 | + solved: dict[Any, list[Any]], |
| 70 | + weight: float = 1.0, |
| 71 | +) -> float: |
| 72 | + """ |
| 73 | + Recursive AO* solver. |
| 74 | +
|
| 75 | + Returns the updated heuristic value of the node. |
| 76 | + """ |
| 77 | + # terminal node |
| 78 | + if not graph.get(node): |
| 79 | + solved[node] = [] |
| 80 | + return h[node] |
| 81 | + |
| 82 | + # initial estimate/update using current child heuristics |
| 83 | + _update_node(graph, node, h, solved, weight) |
| 84 | + |
| 85 | + # continue expanding the currently best AND-set until node becomes solved |
| 86 | + while node not in solved: |
| 87 | + # choose best option under current estimates (sum(child + weight)) |
| 88 | + best_option = min( |
| 89 | + graph[node], |
| 90 | + key=lambda option: sum(h[child] + weight for child in option), |
| 91 | + ) |
| 92 | + |
| 93 | + # expand (refine) each child in that best option |
| 94 | + for child in best_option: |
| 95 | + # recursively refine child; this will explore child's best options |
| 96 | + _ao_star(graph, child, h, solved, weight) |
| 97 | + |
| 98 | + # after expanding, update this node again (propagate refinements upward) |
| 99 | + prev = h[node] |
| 100 | + _update_node(graph, node, h, solved, weight) |
| 101 | + |
| 102 | + # if no meaningful change, we can break to avoid infinite loops on cycles |
| 103 | + if abs(h[node] - prev) < 1e-9: |
| 104 | + break |
| 105 | + |
| 106 | + return h[node] |
| 107 | + |
| 108 | + |
| 109 | +def ao_star( |
| 110 | + graph: dict[Any, list[list[Any]]], |
| 111 | + start: Any, |
| 112 | + h: dict[Any, float], |
| 113 | + weight: float = 1.0, |
| 114 | +) -> tuple[dict[Any, list[Any]], float]: |
| 115 | + """ |
| 116 | + Perform AO* (And-Or Star) search on a given AND-OR graph. |
| 117 | +
|
| 118 | + Args: |
| 119 | + graph: Mapping of node → list of AND-options. |
| 120 | + Each option is a list of child nodes. |
| 121 | + start: Start node key. |
| 122 | + h: dictionary of heuristic values. Updated in-place. |
| 123 | + weight: edge cost (g(n)) to add per child (default 1.0). |
| 124 | +
|
| 125 | + Returns: |
| 126 | + A tuple (solution, value): |
| 127 | + - solution: dict mapping solved nodes to their chosen AND-children. |
| 128 | + - value: final value (cost) of the start node. |
| 129 | +
|
| 130 | + Example: |
| 131 | + >>> g = {'A': [['B', 'C'], ['D']], 'B': [], 'C': [], 'D': []} |
| 132 | + >>> h = {'A': 2.0, 'B': 1.0, 'C': 1.0, 'D': 5.0} |
| 133 | + >>> sol, val = ao_star(g, 'A', h) |
| 134 | + >>> val |
| 135 | + 4.0 |
| 136 | + >>> sol['A'] |
| 137 | + ['B', 'C'] |
| 138 | +
|
| 139 | + Chain example: |
| 140 | + >>> g = {'A': [['B']], 'B': [['C']], 'C': []} |
| 141 | + >>> h = {'A': 10.0, 'B': 5.0, 'C': 1.0} |
| 142 | + >>> sol, val = ao_star(g, 'A', h) |
| 143 | + >>> val |
| 144 | + 3.0 |
| 145 | + >>> sol['A'] |
| 146 | + ['B'] |
| 147 | + >>> sol['B'] |
| 148 | + ['C'] |
| 149 | +
|
| 150 | + Example (the case you described; note edge cost = 1 by default): |
| 151 | + >>> graph = { |
| 152 | + ... 'A': [['B'], ['C', 'D']], |
| 153 | + ... 'B': [['E'], ['F']], |
| 154 | + ... 'C': [['G'], ['H', 'I']], |
| 155 | + ... 'D': [['J']], |
| 156 | + ... 'E': [], 'F': [], 'G': [], 'H': [], 'I': [], 'J': [] |
| 157 | + ... } |
| 158 | + >>> h = { |
| 159 | + ... 'A': 100.0, 'B': 5.0, 'C': 2.0, 'D': 4.0, |
| 160 | + ... 'E': 7.0, 'F': 9.0, 'G': 3.0, 'H': 0.0, 'I': 0.0, 'J': 0.0 |
| 161 | + ... } |
| 162 | + >>> sol, v = ao_star(graph, 'A', h) |
| 163 | + >>> v |
| 164 | + 5.0 |
| 165 | + >>> sol['C'] == ['H', 'I'] |
| 166 | + True |
| 167 | + >>> sol['D'] == ['J'] |
| 168 | + True |
| 169 | + """ |
| 170 | + if start not in graph: |
| 171 | + raise ValueError("Start node must exist in graph.") |
| 172 | + |
| 173 | + solved: dict[Any, list[Any]] = {} |
| 174 | + final_value = _ao_star(graph, start, h, solved, weight) |
| 175 | + return solved, final_value |
| 176 | + |
| 177 | + |
| 178 | +if __name__ == "__main__": |
| 179 | + demo_graph = { |
| 180 | + "S": [["A", "B"], ["C", "D"]], |
| 181 | + "A": [["E"], ["F", "G"]], |
| 182 | + "B": [], |
| 183 | + "C": [["H"]], |
| 184 | + "D": [], |
| 185 | + "E": [], |
| 186 | + "F": [], |
| 187 | + "G": [], |
| 188 | + "H": [], |
| 189 | + } |
| 190 | + |
| 191 | + heuristics = { |
| 192 | + "S": 100.0, |
| 193 | + "A": 50.0, |
| 194 | + "B": 3.0, |
| 195 | + "C": 20.0, |
| 196 | + "D": 4.0, |
| 197 | + "E": 1.0, |
| 198 | + "F": 2.0, |
| 199 | + "G": 2.0, |
| 200 | + "H": 1.0, |
| 201 | + } |
| 202 | + |
| 203 | + sol, val = ao_star(demo_graph, "S", heuristics) |
| 204 | + print("Solution graph:", sol) |
| 205 | + print("Final cost of start node:", val) |
0 commit comments