22import itertools
33from typing import Generator
44
5+ from tqdm import tqdm
6+
57from grape import types
68from grape .automaton .spec_manager import is_specialized
79from grape .automaton .tree_automaton import DFTA
810from grape .dsl import DSL
9- from grape .program import Function , Primitive , Program , Variable
11+ from grape .program import Primitive , Program , Variable
1012
1113
1214class LoopingAlgorithm (StrEnum ):
@@ -157,10 +159,18 @@ def __all_sub_args__(
157159 yield new_args
158160
159161
162+ def __product__ (elements : list [int ]) -> int :
163+ out = 1
164+ for x in elements :
165+ out *= x
166+ return out
167+
168+
160169def add_loops (
161170 dfta : DFTA [str , Program | str ],
162171 dsl : DSL ,
163172 algorithm : LoopingAlgorithm = LoopingAlgorithm .OBSERVATIONAL_EQUIVALENCE ,
173+ use_tqdm : bool = False ,
164174) -> DFTA [str , Program ]:
165175 """
166176 Assumes specialized DFTA, one state = one letter and that variants are mapped.
@@ -172,7 +182,9 @@ def add_loops(
172182 else :
173183 match algorithm :
174184 case LoopingAlgorithm .OBSERVATIONAL_EQUIVALENCE :
175- is_allowed = lambda * args , ** kwargs : True
185+
186+ def is_allowed (* args , ** kwargs ):
187+ return True
176188 case LoopingAlgorithm .GRAPE :
177189
178190 def is_allowed (
@@ -238,10 +250,27 @@ def is_allowed(
238250 new_dfta .refresh_reversed_rules ()
239251 merge_memory = {}
240252 largest_merge = {}
253+
254+ update = lambda : 1
255+ if use_tqdm :
256+ pbar = tqdm (
257+ total = sum (
258+ __product__ (
259+ [
260+ len (states_by_types [arg_t ])
261+ for arg_t in types .arguments (Ptype )
262+ ]
263+ )
264+ for (Ptype , _ ) in dsl .primitives .values ()
265+ ),
266+ desc = "adding loops" ,
267+ )
268+ update = lambda : pbar .update ()
241269 for P , (Ptype , _ ) in dsl .primitives .items ():
242270 args_types , rtype = types .parse (Ptype )
243271 possibles = [states_by_types [arg_t ] for arg_t in args_types ]
244272 for combi in itertools .product (* possibles ):
273+ update ()
245274 key = (P , combi )
246275 dst_size = sum (map (lambda x : state_to_size [x ], combi )) + 1
247276 if dst_size > max_size and is_allowed (
@@ -266,7 +295,8 @@ def is_allowed(
266295 )
267296 assert new_state in state_to_size
268297 new_dfta .rules [key ] = new_state
269-
298+ if use_tqdm :
299+ pbar .close ()
270300 for no in virtual_vars :
271301 dst = str (Variable (no ))
272302 del new_dfta .rules [(dst , tuple ())]
0 commit comments