Skip to content

Commit cc3aae7

Browse files
committed
Merge branch 'reviewer-edits' of github.com:MunchLab/ect into reviewer-edits
2 parents 66e30da + 585f399 commit cc3aae7

2 files changed

Lines changed: 200 additions & 19 deletions

File tree

src/ect/ect.py

Lines changed: 183 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
from .results import ECTResult
99

1010

11+
def _thresholds_are_uniform(thresholds: np.ndarray) -> bool:
12+
thresholds = np.asarray(thresholds, dtype=float)
13+
if thresholds.ndim != 1:
14+
raise ValueError("thresholds must be a 1-dimensional array")
15+
n = thresholds.size
16+
if n <= 1:
17+
return True
18+
diffs = np.diff(thresholds)
19+
first = diffs[0]
20+
if first == 0.0:
21+
return bool(np.all(diffs == 0.0))
22+
tol = 1e-12 * max(1.0, abs(first))
23+
return bool(np.all(np.abs(diffs - first) <= tol))
24+
25+
1126
class ECT:
1227
"""
1328
A class to calculate the Euler Characteristic Transform (ECT) from an input :class:`ect.embed_complex.EmbeddedComplex`,
@@ -55,6 +70,24 @@ def __init__(
5570
self.bound_radius = bound_radius
5671
self.thresholds = thresholds
5772
self.dtype = dtype
73+
self._thresholds_validated = False
74+
if self.thresholds is not None:
75+
self.thresholds = np.asarray(self.thresholds, dtype=float)
76+
if self.thresholds.ndim != 1:
77+
raise ValueError("thresholds must be a 1-dimensional array")
78+
self._thresholds_validated = True
79+
if num_thresh is not None:
80+
self.is_uniform = True
81+
else:
82+
self.is_uniform = False
83+
if self.thresholds is None:
84+
raise ValueError(
85+
"thresholds must be provided if num_thresh is not provided"
86+
)
87+
if not _thresholds_are_uniform(self.thresholds):
88+
raise ValueError(
89+
"thresholds must be uniform if num_thresh is not provided"
90+
)
5891

5992
def _ensure_directions(self, graph_dim, theta=None):
6093
"""Ensures directions is a valid Directions object of correct dimension"""
@@ -97,11 +130,14 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
97130
or graph.get_bounding_radius()
98131
)
99132
self.thresholds = np.linspace(-radius, radius, self.num_thresh, dtype=float)
133+
self.is_uniform = True
134+
self._thresholds_validated = True
100135
else:
101-
# validate and convert existing thresholds
102-
self.thresholds = np.asarray(self.thresholds, dtype=float)
103-
if self.thresholds.ndim != 1:
104-
raise ValueError("thresholds must be a 1-dimensional array")
136+
if not self._thresholds_validated:
137+
self.thresholds = np.asarray(self.thresholds, dtype=float)
138+
if self.thresholds.ndim != 1:
139+
raise ValueError("thresholds must be a 1-dimensional array")
140+
self._thresholds_validated = True
105141

106142
def calculate(
107143
self,
@@ -124,21 +160,32 @@ def _compute_ect(
124160
cell_vertex_pointers, cell_vertex_indices_flat, cell_euler_signs, N = (
125161
graph._build_incidence_csr()
126162
)
127-
thresholds = np.asarray(thresholds, dtype=np.float64)
163+
thresholds = np.asarray(thresholds, dtype=np.float32)
128164

129165
V = directions.vectors
130166
X = graph.coord_matrix
131167
H = X @ V.T # (N, m)
132168
H_T = np.ascontiguousarray(H.T) # (m, N) for contiguous per-direction rows
133169

134-
out64 = _ect_all_dirs(
135-
H_T,
136-
cell_vertex_pointers,
137-
cell_vertex_indices_flat,
138-
cell_euler_signs,
139-
thresholds,
140-
N,
141-
)
170+
is_uniform = bool(self.is_uniform) and thresholds[0] != thresholds[-1]
171+
if is_uniform:
172+
out64 = _ect_all_dirs_uniform(
173+
H_T,
174+
cell_vertex_pointers,
175+
cell_vertex_indices_flat,
176+
cell_euler_signs,
177+
thresholds,
178+
N,
179+
)
180+
else:
181+
out64 = _ect_all_dirs_search(
182+
H_T,
183+
cell_vertex_pointers,
184+
cell_vertex_indices_flat,
185+
cell_euler_signs,
186+
thresholds,
187+
N,
188+
)
142189
if dtype == np.int32:
143190
return out64.astype(np.int32)
144191
return out64
@@ -246,3 +293,126 @@ def _ect_all_dirs(
246293
ect_values[dir_idx, thresh_idx] = euler_prefix[rank_cursor]
247294

248295
return ect_values
296+
297+
298+
@njit(cache=True, parallel=True)
299+
def _ect_all_dirs_uniform(
300+
heights_by_direction,
301+
cell_vertex_pointers,
302+
cell_vertex_indices_flat,
303+
cell_euler_signs,
304+
threshold_values,
305+
num_vertices,
306+
):
307+
num_directions = heights_by_direction.shape[0]
308+
num_thresholds = threshold_values.shape[0]
309+
t_min = threshold_values[0] if num_thresholds > 0 else 0.0
310+
t_max = threshold_values[-1] if num_thresholds > 0 else 0.0
311+
span = t_max - t_min
312+
inv_span = 1.0 / span
313+
n_minus_1 = num_thresholds - 1
314+
315+
ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)
316+
317+
for dir_idx in prange(num_directions):
318+
heights = heights_by_direction[dir_idx]
319+
320+
diff = np.zeros(num_thresholds, dtype=np.int64)
321+
vertex_thresh_index = np.empty(num_vertices, dtype=np.int64)
322+
323+
for v in range(num_vertices):
324+
h = heights[v]
325+
u = (h - t_min) * inv_span
326+
idx = int(np.ceil(u * n_minus_1))
327+
if idx < 0:
328+
idx = 0
329+
elif idx >= num_thresholds:
330+
idx = num_thresholds
331+
332+
vertex_thresh_index[v] = idx
333+
if idx < num_thresholds:
334+
diff[idx] += 1
335+
336+
num_cells = cell_vertex_pointers.shape[0] - 1
337+
338+
for cell_idx in range(num_cells):
339+
start = cell_vertex_pointers[cell_idx]
340+
end = cell_vertex_pointers[cell_idx + 1]
341+
342+
entrance_idx = -1
343+
for k in range(start, end):
344+
v = cell_vertex_indices_flat[k]
345+
t_idx = vertex_thresh_index[v]
346+
if t_idx > entrance_idx:
347+
entrance_idx = t_idx
348+
349+
if 0 <= entrance_idx < num_thresholds:
350+
diff[entrance_idx] += cell_euler_signs[cell_idx]
351+
352+
running = 0
353+
for j in range(num_thresholds):
354+
running += diff[j]
355+
ect_values[dir_idx, j] = running
356+
357+
return ect_values
358+
359+
360+
@njit(cache=True, parallel=True)
361+
def _ect_all_dirs_search(
362+
heights_by_direction,
363+
cell_vertex_pointers,
364+
cell_vertex_indices_flat,
365+
cell_euler_signs,
366+
threshold_values,
367+
num_vertices,
368+
):
369+
num_directions = heights_by_direction.shape[0]
370+
num_thresholds = threshold_values.shape[0]
371+
372+
ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)
373+
374+
for dir_idx in prange(num_directions):
375+
heights = heights_by_direction[dir_idx]
376+
377+
diff = np.zeros(num_thresholds, dtype=np.int64)
378+
vertex_thresh_index = np.empty(num_vertices, dtype=np.int64)
379+
380+
for v in range(num_vertices):
381+
h = heights[v]
382+
383+
left = 0
384+
right = num_thresholds
385+
while left < right:
386+
mid = (left + right) // 2
387+
if threshold_values[mid] >= h:
388+
right = mid
389+
else:
390+
left = mid + 1
391+
idx = left
392+
393+
vertex_thresh_index[v] = idx
394+
if idx < num_thresholds:
395+
diff[idx] += 1
396+
397+
num_cells = cell_vertex_pointers.shape[0] - 1
398+
399+
for cell_idx in range(num_cells):
400+
start = cell_vertex_pointers[cell_idx]
401+
end = cell_vertex_pointers[cell_idx + 1]
402+
403+
entrance_idx = -1
404+
for k in range(start, end):
405+
v = cell_vertex_indices_flat[k]
406+
t_idx = vertex_thresh_index[v]
407+
if t_idx > entrance_idx:
408+
entrance_idx = t_idx
409+
410+
if 0 <= entrance_idx < num_thresholds:
411+
diff[entrance_idx] += cell_euler_signs[cell_idx]
412+
413+
running = 0
414+
for j in range(num_thresholds):
415+
running += diff[j]
416+
ect_values[dir_idx, j] = running
417+
418+
return ect_values

src/ect/embed_complex.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99
from sklearn.decomposition import PCA
1010

1111
from .utils.naming import next_vert_name
12-
from .utils.face_check import (
13-
point_in_polygon,
14-
validate_face_embedding,
15-
validate_edge_embedding,
16-
)
1712
from .validation import EmbeddingValidator, ValidationRule
1813

1914

@@ -55,6 +50,7 @@ def __init__(self, validate_embedding=False, embedding_tol=1e-10):
5550
self._node_to_index = {}
5651
self._coord_matrix = None
5752
self.cells = defaultdict(list)
53+
self._incidence_csr_cache = None
5854

5955
self.validate_embedding = validate_embedding
6056
self.embedding_tol = embedding_tol
@@ -69,6 +65,12 @@ def edge_checker(v1_idx: int, v2_idx: int) -> bool:
6965

7066
self._validator = EmbeddingValidator(embedding_tol, edge_checker)
7167

68+
def _invalidate_incidence_csr_cache(self) -> None:
69+
self._incidence_csr_cache = None
70+
71+
def precompute_incidence_csr(self) -> tuple:
72+
return self._build_incidence_csr()
73+
7274
@property
7375
def coord_matrix(self):
7476
"""
@@ -192,6 +194,7 @@ def add_node(self, node_id, coord):
192194
self._node_list.append(node_id)
193195
self._node_to_index[node_id] = len(self._node_list) - 1
194196
super().add_node(node_id)
197+
self._invalidate_incidence_csr_cache()
195198

196199
def add_nodes_from_dict(self, nodes_with_coords: Dict[Union[str, int], np.ndarray]):
197200
"""Add multiple vertices to the complex.
@@ -232,6 +235,7 @@ def add_edge(self, node_id1, node_id2):
232235
raise ValueError(node_result.message)
233236

234237
super().add_edge(node_id1, node_id2)
238+
self._invalidate_incidence_csr_cache()
235239

236240
def add_cell(
237241
self,
@@ -302,6 +306,7 @@ def add_cell(
302306
self.add_edge(cell_vertices[0], cell_vertices[1])
303307

304308
self.cells[dim].append(cell_indices)
309+
self._invalidate_incidence_csr_cache()
305310

306311
def enable_embedding_validation(self, tol: float = 1e-10):
307312
"""
@@ -434,6 +439,7 @@ def add_cycle(self, coord_matrix):
434439
new_names = next_vert_name(self._node_list[-1] if self._node_list else 0, n)
435440
self.add_nodes_from(zip(new_names, coord_matrix))
436441
self.add_edges_from([(new_names[i], new_names[(i + 1) % n]) for i in range(n)])
442+
self.precompute_incidence_csr()
437443

438444
def get_center(self, method: str = "bounding_box") -> np.ndarray:
439445
"""
@@ -967,6 +973,9 @@ def _build_incidence_csr(self) -> tuple:
967973
Example: takes the complex [(1,3),(2,4),(1,2,3)] and returns [(0,2,4,7),(1,3,2,4,1,2,3),(-1,-1,1),4]
968974
969975
"""
976+
if self._incidence_csr_cache is not None:
977+
return self._incidence_csr_cache
978+
970979
n_vertices = len(self.node_list)
971980

972981
cells_by_dimension = {}
@@ -1010,12 +1019,14 @@ def _build_incidence_csr(self) -> tuple:
10101019
cell_vertex_pointers[cell_index] = len(cell_vertex_indices_flat)
10111020

10121021
cell_vertex_indices_flat = np.asarray(cell_vertex_indices_flat, dtype=np.int32)
1013-
return (
1022+
out = (
10141023
cell_vertex_pointers,
10151024
cell_vertex_indices_flat,
10161025
cell_euler_signs,
10171026
n_vertices,
10181027
)
1028+
self._incidence_csr_cache = out
1029+
return out
10191030

10201031

10211032
EmbeddedGraph = EmbeddedComplex

0 commit comments

Comments
 (0)