Skip to content

Commit 6718d2e

Browse files
committed
offer tqdm when adding loops
1 parent 617091b commit 6718d2e

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

grape/automaton/loop_manager.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import itertools
33
from typing import Generator
44

5+
from tqdm import tqdm
6+
57
from grape import types
68
from grape.automaton.spec_manager import is_specialized
79
from grape.automaton.tree_automaton import DFTA
810
from grape.dsl import DSL
9-
from grape.program import Function, Primitive, Program, Variable
11+
from grape.program import Primitive, Program, Variable
1012

1113

1214
class LoopingAlgorithm(StrEnum):
@@ -157,10 +159,18 @@ def __all_sub_args__(
157159
yield new_args
158160

159161

162+
def __product__(elements: list[int]) -> int:
163+
out = 1
164+
for x in elements:
165+
out *= x
166+
return out
167+
168+
160169
def add_loops(
161170
dfta: DFTA[str, Program | str],
162171
dsl: DSL,
163172
algorithm: LoopingAlgorithm = LoopingAlgorithm.OBSERVATIONAL_EQUIVALENCE,
173+
use_tqdm: bool = False,
164174
) -> DFTA[str, Program]:
165175
"""
166176
Assumes specialized DFTA, one state = one letter and that variants are mapped.
@@ -172,7 +182,9 @@ def add_loops(
172182
else:
173183
match algorithm:
174184
case LoopingAlgorithm.OBSERVATIONAL_EQUIVALENCE:
175-
is_allowed = lambda *args, **kwargs: True
185+
186+
def is_allowed(*args, **kwargs):
187+
return True
176188
case LoopingAlgorithm.GRAPE:
177189

178190
def is_allowed(
@@ -238,10 +250,27 @@ def is_allowed(
238250
new_dfta.refresh_reversed_rules()
239251
merge_memory = {}
240252
largest_merge = {}
253+
254+
update = lambda: 1
255+
if use_tqdm:
256+
pbar = tqdm(
257+
total=sum(
258+
__product__(
259+
[
260+
len(states_by_types[arg_t])
261+
for arg_t in types.arguments(Ptype)
262+
]
263+
)
264+
for (Ptype, _) in dsl.primitives.values()
265+
),
266+
desc="adding loops",
267+
)
268+
update = lambda: pbar.update()
241269
for P, (Ptype, _) in dsl.primitives.items():
242270
args_types, rtype = types.parse(Ptype)
243271
possibles = [states_by_types[arg_t] for arg_t in args_types]
244272
for combi in itertools.product(*possibles):
273+
update()
245274
key = (P, combi)
246275
dst_size = sum(map(lambda x: state_to_size[x], combi)) + 1
247276
if dst_size > max_size and is_allowed(
@@ -266,7 +295,8 @@ def is_allowed(
266295
)
267296
assert new_state in state_to_size
268297
new_dfta.rules[key] = new_state
269-
298+
if use_tqdm:
299+
pbar.close()
270300
for no in virtual_vars:
271301
dst = str(Variable(no))
272302
del new_dfta.rules[(dst, tuple())]

grape/cli/prune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def main():
124124
type_req = type_request_from_specialized(reduced_grammar, dsl)
125125
loop_algorithm = args.strategy
126126
if loop_algorithm != "none":
127-
grammar = add_loops(reduced_grammar, dsl, loop_algorithm)
127+
grammar = add_loops(reduced_grammar, dsl, loop_algorithm, use_tqdm=True)
128128
else:
129129
grammar = reduced_grammar
130130

0 commit comments

Comments
 (0)