88from .results import ECTResult
99
1010
11+ def _thresholds_are_uniform (thresholds : np .ndarray ) -> bool :
12+ thresholds = np .asarray (thresholds , dtype = float )
13+ if thresholds .ndim != 1 :
14+ raise ValueError ("thresholds must be a 1-dimensional array" )
15+ n = thresholds .size
16+ if n <= 1 :
17+ return True
18+ diffs = np .diff (thresholds )
19+ first = diffs [0 ]
20+ if first == 0.0 :
21+ return bool (np .all (diffs == 0.0 ))
22+ tol = 1e-12 * max (1.0 , abs (first ))
23+ return bool (np .all (np .abs (diffs - first ) <= tol ))
24+
25+
1126class ECT :
1227 """
1328 A class to calculate the Euler Characteristic Transform (ECT) from an input :class:`ect.embed_complex.EmbeddedComplex`,
@@ -55,6 +70,24 @@ def __init__(
5570 self .bound_radius = bound_radius
5671 self .thresholds = thresholds
5772 self .dtype = dtype
73+ self ._thresholds_validated = False
74+ if self .thresholds is not None :
75+ self .thresholds = np .asarray (self .thresholds , dtype = float )
76+ if self .thresholds .ndim != 1 :
77+ raise ValueError ("thresholds must be a 1-dimensional array" )
78+ self ._thresholds_validated = True
79+ if num_thresh is not None :
80+ self .is_uniform = True
81+ else :
82+ self .is_uniform = False
83+ if self .thresholds is None :
84+ raise ValueError (
85+ "thresholds must be provided if num_thresh is not provided"
86+ )
87+ if not _thresholds_are_uniform (self .thresholds ):
88+ raise ValueError (
89+ "thresholds must be uniform if num_thresh is not provided"
90+ )
5891
5992 def _ensure_directions (self , graph_dim , theta = None ):
6093 """Ensures directions is a valid Directions object of correct dimension"""
@@ -97,11 +130,14 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
97130 or graph .get_bounding_radius ()
98131 )
99132 self .thresholds = np .linspace (- radius , radius , self .num_thresh , dtype = float )
133+ self .is_uniform = True
134+ self ._thresholds_validated = True
100135 else :
101- # validate and convert existing thresholds
102- self .thresholds = np .asarray (self .thresholds , dtype = float )
103- if self .thresholds .ndim != 1 :
104- raise ValueError ("thresholds must be a 1-dimensional array" )
136+ if not self ._thresholds_validated :
137+ self .thresholds = np .asarray (self .thresholds , dtype = float )
138+ if self .thresholds .ndim != 1 :
139+ raise ValueError ("thresholds must be a 1-dimensional array" )
140+ self ._thresholds_validated = True
105141
106142 def calculate (
107143 self ,
@@ -124,21 +160,32 @@ def _compute_ect(
124160 cell_vertex_pointers , cell_vertex_indices_flat , cell_euler_signs , N = (
125161 graph ._build_incidence_csr ()
126162 )
127- thresholds = np .asarray (thresholds , dtype = np .float64 )
163+ thresholds = np .asarray (thresholds , dtype = np .float32 )
128164
129165 V = directions .vectors
130166 X = graph .coord_matrix
131167 H = X @ V .T # (N, m)
132168 H_T = np .ascontiguousarray (H .T ) # (m, N) for contiguous per-direction rows
133169
134- out64 = _ect_all_dirs (
135- H_T ,
136- cell_vertex_pointers ,
137- cell_vertex_indices_flat ,
138- cell_euler_signs ,
139- thresholds ,
140- N ,
141- )
170+ is_uniform = bool (self .is_uniform ) and thresholds [0 ] != thresholds [- 1 ]
171+ if is_uniform :
172+ out64 = _ect_all_dirs_uniform (
173+ H_T ,
174+ cell_vertex_pointers ,
175+ cell_vertex_indices_flat ,
176+ cell_euler_signs ,
177+ thresholds ,
178+ N ,
179+ )
180+ else :
181+ out64 = _ect_all_dirs_search (
182+ H_T ,
183+ cell_vertex_pointers ,
184+ cell_vertex_indices_flat ,
185+ cell_euler_signs ,
186+ thresholds ,
187+ N ,
188+ )
142189 if dtype == np .int32 :
143190 return out64 .astype (np .int32 )
144191 return out64
@@ -246,3 +293,126 @@ def _ect_all_dirs(
246293 ect_values [dir_idx , thresh_idx ] = euler_prefix [rank_cursor ]
247294
248295 return ect_values
296+
297+
298+ @njit (cache = True , parallel = True )
299+ def _ect_all_dirs_uniform (
300+ heights_by_direction ,
301+ cell_vertex_pointers ,
302+ cell_vertex_indices_flat ,
303+ cell_euler_signs ,
304+ threshold_values ,
305+ num_vertices ,
306+ ):
307+ num_directions = heights_by_direction .shape [0 ]
308+ num_thresholds = threshold_values .shape [0 ]
309+ t_min = threshold_values [0 ] if num_thresholds > 0 else 0.0
310+ t_max = threshold_values [- 1 ] if num_thresholds > 0 else 0.0
311+ span = t_max - t_min
312+ inv_span = 1.0 / span
313+ n_minus_1 = num_thresholds - 1
314+
315+ ect_values = np .empty ((num_directions , num_thresholds ), dtype = np .int64 )
316+
317+ for dir_idx in prange (num_directions ):
318+ heights = heights_by_direction [dir_idx ]
319+
320+ diff = np .zeros (num_thresholds , dtype = np .int64 )
321+ vertex_thresh_index = np .empty (num_vertices , dtype = np .int64 )
322+
323+ for v in range (num_vertices ):
324+ h = heights [v ]
325+ u = (h - t_min ) * inv_span
326+ idx = int (np .ceil (u * n_minus_1 ))
327+ if idx < 0 :
328+ idx = 0
329+ elif idx >= num_thresholds :
330+ idx = num_thresholds
331+
332+ vertex_thresh_index [v ] = idx
333+ if idx < num_thresholds :
334+ diff [idx ] += 1
335+
336+ num_cells = cell_vertex_pointers .shape [0 ] - 1
337+
338+ for cell_idx in range (num_cells ):
339+ start = cell_vertex_pointers [cell_idx ]
340+ end = cell_vertex_pointers [cell_idx + 1 ]
341+
342+ entrance_idx = - 1
343+ for k in range (start , end ):
344+ v = cell_vertex_indices_flat [k ]
345+ t_idx = vertex_thresh_index [v ]
346+ if t_idx > entrance_idx :
347+ entrance_idx = t_idx
348+
349+ if 0 <= entrance_idx < num_thresholds :
350+ diff [entrance_idx ] += cell_euler_signs [cell_idx ]
351+
352+ running = 0
353+ for j in range (num_thresholds ):
354+ running += diff [j ]
355+ ect_values [dir_idx , j ] = running
356+
357+ return ect_values
358+
359+
360+ @njit (cache = True , parallel = True )
361+ def _ect_all_dirs_search (
362+ heights_by_direction ,
363+ cell_vertex_pointers ,
364+ cell_vertex_indices_flat ,
365+ cell_euler_signs ,
366+ threshold_values ,
367+ num_vertices ,
368+ ):
369+ num_directions = heights_by_direction .shape [0 ]
370+ num_thresholds = threshold_values .shape [0 ]
371+
372+ ect_values = np .empty ((num_directions , num_thresholds ), dtype = np .int64 )
373+
374+ for dir_idx in prange (num_directions ):
375+ heights = heights_by_direction [dir_idx ]
376+
377+ diff = np .zeros (num_thresholds , dtype = np .int64 )
378+ vertex_thresh_index = np .empty (num_vertices , dtype = np .int64 )
379+
380+ for v in range (num_vertices ):
381+ h = heights [v ]
382+
383+ left = 0
384+ right = num_thresholds
385+ while left < right :
386+ mid = (left + right ) // 2
387+ if threshold_values [mid ] >= h :
388+ right = mid
389+ else :
390+ left = mid + 1
391+ idx = left
392+
393+ vertex_thresh_index [v ] = idx
394+ if idx < num_thresholds :
395+ diff [idx ] += 1
396+
397+ num_cells = cell_vertex_pointers .shape [0 ] - 1
398+
399+ for cell_idx in range (num_cells ):
400+ start = cell_vertex_pointers [cell_idx ]
401+ end = cell_vertex_pointers [cell_idx + 1 ]
402+
403+ entrance_idx = - 1
404+ for k in range (start , end ):
405+ v = cell_vertex_indices_flat [k ]
406+ t_idx = vertex_thresh_index [v ]
407+ if t_idx > entrance_idx :
408+ entrance_idx = t_idx
409+
410+ if 0 <= entrance_idx < num_thresholds :
411+ diff [entrance_idx ] += cell_euler_signs [cell_idx ]
412+
413+ running = 0
414+ for j in range (num_thresholds ):
415+ running += diff [j ]
416+ ect_values [dir_idx , j ] = running
417+
418+ return ect_values
0 commit comments