Skip to content

Commit b92a3ec

Browse files
authored
Add path-based overrides (#39)
* Add path-based overrides. Sometimes neighbor algorithms depend on where you're coming from. Trains are not able to do 90° turns. Straight roads or rails make for more speed. Thus this patch allows overriding neighbor and cost estimates that take the path-so-far into account. TODO: add testcases ... * Use both nodes in path_neighbor, add a cache It's more regular and the cache can be used to directly store relevant information in the destination node which otherwise would have to be recalculated. * README: Document path_* methods and the cache.
1 parent 6aa97c7 commit b92a3ec

2 files changed

Lines changed: 56 additions & 21 deletions

File tree

README.rst

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ The `astar` library only requires the following property from these objects:
3434
For the default implementation of `is_goal_reached`, the objects must be
3535
comparable for same-ness (i.e. implement `__eq__`).
3636

37-
A simple way to achieve this, is to use simple objects based on strings,
37+
A simple way to achieve this is to use simple objects based on strings,
3838
floats, integers, tuples.
3939
[`dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
4040
objects declared with `@dataclass(frozen=True)` directly implement `__hash__`
@@ -54,7 +54,14 @@ For a given node, returns (or yields) the list of its neighbors.
5454
This is the method that one would provide in order to give to the
5555
algorithm the description of the graph to use during for computation.
5656

57-
This method must be implemented in a subclass.
57+
Alternately, your override method may be named "path\_neighbors". Instead of
58+
your node, this method receives a "SearchNode" object whose "came_from"
59+
attribute points to the previous node; your node is in its "data" attribute.
60+
You might want to use this if your path is directional, like the track of a
61+
train that can't do 90° turns.
62+
63+
One of these methods must be implemented in a subclass.
64+
5865

5966
distance\_between
6067
~~~~~~~~~~~~~~~~~
@@ -68,7 +75,14 @@ Gives the real distance/cost between two adjacent nodes n1 and n2 (i.e
6875
n2 belongs to the list of n1's neighbors). n2 is guaranteed to belong to
6976
the list returned by a call to neighbors(n1).
7077

71-
This method must be implemented in a subclass.
78+
Alternately, you may override "path\_distance\_between". The arguments
79+
will be a "SearchNode", as in "path\_neighbors". You might want to use this
80+
if your distance measure should include the path's attainable speed, the
81+
kind and number of turns on it, or similar. You can use the nodes' "cache"
82+
attributes to store some data, to speed up calculation.
83+
84+
One of these methods must be implemented in a subclass.
85+
7286

7387
heuristic\_cost\_estimate
7488
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -82,7 +96,7 @@ Computes the estimated (rough) distance/cost between a node and the
8296
goal. The first argument is the start node, or any node that have been
8397
returned by a call to the neighbors() method.
8498

85-
This method is used to give to the algorithm an hint about the node he
99+
This method is used to give to the algorithm an hint about the node it
86100
may try next during search.
87101

88102
This method must be implemented in a subclass.
@@ -92,7 +106,6 @@ is\_goal\_reached
92106

93107
.. code:: py
94108
95-
96109
def is_goal_reached(self, current, goal)
97110
98111
This method shall return a truthy value when the goal is 'reached'. By

astar/__init__.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class SearchNode(Generic[T]):
1616
"""Representation of a search node"""
1717

18-
__slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset")
18+
__slots__ = ("data", "gscore", "fscore", "closed", "came_from", "in_openset", "cache")
1919

2020
def __init__(
2121
self, data: T, gscore: float = infinity, fscore: float = infinity
@@ -26,6 +26,7 @@ def __init__(
2626
self.closed = False
2727
self.in_openset = False
2828
self.came_from: Union[None, SearchNode[T]] = None
29+
self.cache: Any = None
2930

3031
def __lt__(self, b: "SearchNode[T]") -> bool:
3132
"""Natural order is based on the fscore value & is used by heapq operations"""
@@ -84,29 +85,48 @@ def heuristic_cost_estimate(self, current: T, goal: T) -> float:
8485
"""
8586
Computes the estimated (rough) distance between a node and the goal.
8687
The second parameter is always the goal.
88+
8789
This method must be implemented in a subclass.
8890
"""
8991
raise NotImplementedError
9092

91-
@abstractmethod
9293
def distance_between(self, n1: T, n2: T) -> float:
9394
"""
9495
Gives the real distance between two adjacent nodes n1 and n2 (i.e n2
9596
belongs to the list of n1's neighbors).
9697
n2 is guaranteed to belong to the list returned by the call to neighbors(n1).
97-
This method must be implemented in a subclass.
98+
99+
This method (or "path_distance_between") must be implemented in a subclass.
98100
"""
101+
raise NotImplementedError
102+
103+
def path_distance_between(self, n1: SearchNode[T], n2: SearchNode[T]) -> float:
104+
"""
105+
Gives the real distance between the node n1 and its neighbor n2.
106+
n2 is guaranteed to belong to the list returned by the call to
107+
path_neighbors(n1).
108+
109+
Calls "distance_between"`by default.
110+
"""
111+
return self.distance_between(n1.data, n2.data)
99112

100-
@abstractmethod
101113
def neighbors(self, node: T) -> Iterable[T]:
102114
"""
103115
For a given node, returns (or yields) the list of its neighbors.
104-
This method must be implemented in a subclass.
116+
117+
This method (or "path_neighbors") must be implemented in a subclass.
105118
"""
106119
raise NotImplementedError
107120

121+
def path_neighbors(self, node: SearchNode[T]) -> Iterable[T]:
122+
"""
123+
For a given node, returns (or yields) the list of its reachable neighbors.
124+
Calls "neighbors" by default.
125+
"""
126+
return self.neighbors(node.data)
127+
108128
def _neighbors(self, current: SearchNode[T], search_nodes: SearchNodeDict[T]) -> Iterable[SearchNode]:
109-
return (search_nodes[n] for n in self.neighbors(current.data))
129+
return (search_nodes[n] for n in self.path_neighbors(current))
110130

111131
def is_goal_reached(self, current: T, goal: T) -> bool:
112132
"""
@@ -153,25 +173,27 @@ def astar(
153173
if neighbor.closed:
154174
continue
155175

156-
tentative_gscore = current.gscore + self.distance_between(
157-
current.data, neighbor.data
158-
)
176+
gscore = current.gscore + self.path_distance_between(current, neighbor)
159177

160-
if tentative_gscore >= neighbor.gscore:
178+
if gscore >= neighbor.gscore:
161179
continue
162180

163-
neighbor_from_openset = neighbor.in_openset
181+
fscore = gscore + self.heuristic_cost_estimate(
182+
neighbor.data, goal
183+
)
184+
185+
if neighbor.in_openset:
186+
if neighbor.fscore < fscore:
187+
# the new path to this node isn't better
188+
continue
164189

165-
if neighbor_from_openset:
166190
# we have to remove the item from the heap, as its score has changed
167191
openSet.remove(neighbor)
168192

169193
# update the node
170194
neighbor.came_from = current
171-
neighbor.gscore = tentative_gscore
172-
neighbor.fscore = tentative_gscore + self.heuristic_cost_estimate(
173-
neighbor.data, goal
174-
)
195+
neighbor.gscore = gscore
196+
neighbor.fscore = fscore
175197

176198
openSet.push(neighbor)
177199

0 commit comments

Comments
 (0)