Skip to content

Commit b5d7882

Browse files
committed
Fixed parsing
1 parent 7724849 commit b5d7882

File tree

1 file changed

+110
-57
lines changed

1 file changed

+110
-57
lines changed

synth/syntax/dsl.py

Lines changed: 110 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import itertools
23
from typing import Any, Callable, Dict, Mapping, Optional, List as TList, Set, Tuple
34
from synth.syntax.type_helper import FunctionType
45

@@ -123,23 +124,49 @@ def fix_types(
123124
-----------
124125
A parsed program that matches the given string
125126
"""
126-
return self.__fix_types__(program)
127+
return self.__fix_types__(program)[0]
127128

128129
def __fix_types__(
129-
self, program: Program, forced_type: Optional[Type] = None
130-
) -> Program:
130+
self,
131+
program: Program,
132+
forced_type: Optional[Type] = None,
133+
force_fix: bool = False,
134+
) -> Tuple[Program, bool]:
135+
is_ambiguous = False
131136
if isinstance(program, Function):
132-
fixed_fun = self.__fix_types__(program.function)
133-
out: Program = Function(
134-
fixed_fun,
135-
[
136-
self.__fix_types__(arg, arg_type)
137+
fixed_fun, ambiguous = self.__fix_types__(
138+
program.function, force_fix=force_fix
139+
)
140+
args = [
141+
self.__fix_types__(arg, arg_type)[0]
142+
for arg, arg_type in zip(program.arguments, fixed_fun.type.arguments())
143+
]
144+
145+
if ambiguous and forced_type is not None:
146+
print(
147+
"before:",
148+
fixed_fun,
149+
"type:",
150+
fixed_fun.type,
151+
"args:",
152+
args,
153+
"target:",
154+
FunctionType(*([arg.type for arg in args] + [forced_type])),
155+
)
156+
fixed_fun = self.__fix_types__(
157+
program.function,
158+
FunctionType(*[arg.type for arg in args], forced_type),
159+
force_fix=force_fix,
160+
)[0]
161+
print("after:", fixed_fun, "type:", fixed_fun.type)
162+
args = [
163+
self.__fix_types__(arg, arg_type, force_fix=force_fix)[0]
137164
for arg, arg_type in zip(
138165
program.arguments, fixed_fun.type.arguments()
139166
)
140-
],
141-
)
142-
elif not program.type.is_under_specified():
167+
]
168+
out: Program = Function(fixed_fun, args)
169+
elif not force_fix and not program.type.is_under_specified():
143170
out = program
144171
elif isinstance(program, Variable):
145172
out = Variable(program.variable, forced_type or program.type)
@@ -155,11 +182,12 @@ def __fix_types__(
155182
if len(matching) == 1:
156183
forced_type = matching[0].type
157184
elif len(matching) > 1:
185+
is_ambiguous = True
158186
forced_type = Sum(*list(map(lambda x: x.type, matching)))
159187
out = Primitive(program.primitive, forced_type or program.type)
160188
else:
161189
assert False, "no implemented"
162-
return out
190+
return out, is_ambiguous
163191

164192
def auto_parse_program(
165193
self,
@@ -189,31 +217,19 @@ def auto_parse_program(
189217
tr = FunctionType(*[UnknownType()] * (nvars + 1))
190218
return self.fix_types(self.parse_program(program, tr, constants, False))
191219

192-
def parse_program(
220+
def __parse_program__(
193221
self,
194222
program: str,
195223
type_request: Type,
196224
constants: Dict[str, Tuple[Type, Any]] = {},
197-
check: bool = True,
198-
) -> Program:
225+
) -> TList[Program]:
199226
"""
200-
Parse a program from its string representation given the type request.
201-
202-
Parameters:
203-
-----------
204-
- program: the string representation of the program, i.e. str(prog)
205-
- type_request: the type of the requested program in order to identify variable types
206-
- constants: str representation of constants that map to their (type, value)
207-
- check: ensure the program was correctly parsed
208-
209-
Returns:
210-
-----------
211-
A parsed program that matches the given string
227+
Produce all possible interpretations of a parsed program.
212228
"""
213229
if " " in program:
214230
parts = list(
215231
map(
216-
lambda p: self.parse_program(p, type_request, constants, check),
232+
lambda p: self.__parse_program__(p, type_request, constants),
217233
program.split(" "),
218234
)
219235
)
@@ -234,45 +250,82 @@ def parse_program(
234250
end += 1
235251
levels.pop()
236252

237-
def parse_stack(l: TList[Program], function_calls: TList[int]) -> Program:
238-
if len(l) == 1:
239-
return l[0]
240-
current = l.pop(0)
241-
f_call = function_calls.pop(0)
242-
if current.type.is_instance(Arrow) and f_call > 0:
243-
args = [
244-
parse_stack(l, function_calls)
245-
for _ in current.type.arguments()[:f_call]
246-
]
247-
return Function(current, args)
248-
return current
249-
250-
sol = parse_stack(parts, function_calls)
251-
if check:
252-
str_repr = str(sol)
253-
for ori, (__, rep) in constants.items():
254-
str_repr = str_repr.replace(f" {rep} ", f" {ori} ")
255-
str_repr = str_repr.replace(f" {rep})", f" {ori})")
256-
assert (
257-
str_repr == program
258-
), f"Failed parsing:\n{program}\n\tgot:\n{str_repr}\n\ttype request:{type_request} obtained:{sol.type}"
259-
return sol
253+
n = len(parts)
254+
255+
def parse_stack(i: int) -> TList[Tuple[Program, int]]:
256+
if i + 1 == n:
257+
return [(p, n) for p in parts[-1]]
258+
current = parts[i]
259+
f_call = function_calls[i]
260+
out: TList[Tuple[Program, int]] = []
261+
for some in current:
262+
if some.type.is_instance(Arrow) and f_call > 0:
263+
poss_args: TList[Tuple[TList[Program], int]] = [([], i + 1)]
264+
for _ in some.type.arguments()[:f_call]:
265+
next = []
266+
for poss, j in poss_args:
267+
parsed = parse_stack(j)
268+
for x, k in parsed:
269+
next.append((poss + [x], k))
270+
poss_args = next
271+
272+
for poss, j in poss_args:
273+
out.append((Function(some, list(poss)), j))
274+
else:
275+
out.append((some, i + 1))
276+
return out
277+
278+
sols = parse_stack(0)
279+
280+
return [p for p, _ in sols]
260281
else:
261282
program = program.strip("()")
262-
for P in self.list_primitives:
263-
if P.primitive == program:
264-
return P
265-
if program.startswith("var"):
283+
matching: TList[Program] = [
284+
P for P in self.list_primitives if P.primitive == program
285+
]
286+
if len(matching) > 0:
287+
return matching
288+
elif program.startswith("var"):
266289
varno = int(program[3:])
267290
vart = type_request
268291
if type_request.is_instance(Arrow):
269292
vart = type_request.arguments()[varno]
270-
return Variable(varno, vart)
293+
return [Variable(varno, vart)]
271294
elif program in constants:
272295
t, val = constants[program]
273-
return Constant(t, val, True)
296+
return [Constant(t, val, True)]
274297
assert False, f"can't parse: '{program}'"
275298

299+
def parse_program(
300+
self,
301+
program: str,
302+
type_request: Type,
303+
constants: Dict[str, Tuple[Type, Any]] = {},
304+
check: bool = True,
305+
) -> Program:
306+
"""
307+
Parse a program from its string representation given the type request.
308+
309+
Parameters:
310+
-----------
311+
- program: the string representation of the program, i.e. str(prog)
312+
- type_request: the type of the requested program in order to identify variable types
313+
- constants: str representation of constants that map to their (type, value)
314+
- check: ensure the program was correctly parsed with type checking
315+
316+
Returns:
317+
-----------
318+
A parsed program that matches the given string
319+
"""
320+
possibles = self.__parse_program__(program, type_request, constants)
321+
if check:
322+
coherents = [p for p in possibles if p.type_checks()]
323+
assert (
324+
len(coherents) > 0
325+
), f"failed to parse a program that type checks for: {program}"
326+
return coherents[0]
327+
return possibles[0]
328+
276329
def get_primitive(self, name: str) -> Optional[Primitive]:
277330
"""
278331
Returns the Primitive object with the specified name if it exists and None otherwise

0 commit comments

Comments
 (0)