11import copy
2+ import itertools
23from typing import Any , Callable , Dict , Mapping , Optional , List as TList , Set , Tuple
34from synth .syntax .type_helper import FunctionType
45
@@ -123,23 +124,49 @@ def fix_types(
123124 -----------
124125 A parsed program that matches the given string
125126 """
126- return self .__fix_types__ (program )
127+ return self .__fix_types__ (program )[ 0 ]
127128
128129 def __fix_types__ (
129- self , program : Program , forced_type : Optional [Type ] = None
130- ) -> Program :
130+ self ,
131+ program : Program ,
132+ forced_type : Optional [Type ] = None ,
133+ force_fix : bool = False ,
134+ ) -> Tuple [Program , bool ]:
135+ is_ambiguous = False
131136 if isinstance (program , Function ):
132- fixed_fun = self .__fix_types__ (program .function )
133- out : Program = Function (
134- fixed_fun ,
135- [
136- self .__fix_types__ (arg , arg_type )
137+ fixed_fun , ambiguous = self .__fix_types__ (
138+ program .function , force_fix = force_fix
139+ )
140+ args = [
141+ self .__fix_types__ (arg , arg_type )[0 ]
142+ for arg , arg_type in zip (program .arguments , fixed_fun .type .arguments ())
143+ ]
144+
145+ if ambiguous and forced_type is not None :
146+ print (
147+ "before:" ,
148+ fixed_fun ,
149+ "type:" ,
150+ fixed_fun .type ,
151+ "args:" ,
152+ args ,
153+ "target:" ,
154+ FunctionType (* ([arg .type for arg in args ] + [forced_type ])),
155+ )
156+ fixed_fun = self .__fix_types__ (
157+ program .function ,
158+ FunctionType (* [arg .type for arg in args ], forced_type ),
159+ force_fix = force_fix ,
160+ )[0 ]
161+ print ("after:" , fixed_fun , "type:" , fixed_fun .type )
162+ args = [
163+ self .__fix_types__ (arg , arg_type , force_fix = force_fix )[0 ]
137164 for arg , arg_type in zip (
138165 program .arguments , fixed_fun .type .arguments ()
139166 )
140- ],
141- )
142- elif not program .type .is_under_specified ():
167+ ]
168+ out : Program = Function ( fixed_fun , args )
169+ elif not force_fix and not program .type .is_under_specified ():
143170 out = program
144171 elif isinstance (program , Variable ):
145172 out = Variable (program .variable , forced_type or program .type )
@@ -155,11 +182,12 @@ def __fix_types__(
155182 if len (matching ) == 1 :
156183 forced_type = matching [0 ].type
157184 elif len (matching ) > 1 :
185+ is_ambiguous = True
158186 forced_type = Sum (* list (map (lambda x : x .type , matching )))
159187 out = Primitive (program .primitive , forced_type or program .type )
160188 else :
161189 assert False , "no implemented"
162- return out
190+ return out , is_ambiguous
163191
164192 def auto_parse_program (
165193 self ,
@@ -189,31 +217,19 @@ def auto_parse_program(
189217 tr = FunctionType (* [UnknownType ()] * (nvars + 1 ))
190218 return self .fix_types (self .parse_program (program , tr , constants , False ))
191219
192- def parse_program (
220+ def __parse_program__ (
193221 self ,
194222 program : str ,
195223 type_request : Type ,
196224 constants : Dict [str , Tuple [Type , Any ]] = {},
197- check : bool = True ,
198- ) -> Program :
225+ ) -> TList [Program ]:
199226 """
200- Parse a program from its string representation given the type request.
201-
202- Parameters:
203- -----------
204- - program: the string representation of the program, i.e. str(prog)
205- - type_request: the type of the requested program in order to identify variable types
206- - constants: str representation of constants that map to their (type, value)
207- - check: ensure the program was correctly parsed
208-
209- Returns:
210- -----------
211- A parsed program that matches the given string
227+ Produce all possible interpretations of a parsed program.
212228 """
213229 if " " in program :
214230 parts = list (
215231 map (
216- lambda p : self .parse_program (p , type_request , constants , check ),
232+ lambda p : self .__parse_program__ (p , type_request , constants ),
217233 program .split (" " ),
218234 )
219235 )
@@ -234,45 +250,82 @@ def parse_program(
234250 end += 1
235251 levels .pop ()
236252
237- def parse_stack (l : TList [Program ], function_calls : TList [int ]) -> Program :
238- if len (l ) == 1 :
239- return l [0 ]
240- current = l .pop (0 )
241- f_call = function_calls .pop (0 )
242- if current .type .is_instance (Arrow ) and f_call > 0 :
243- args = [
244- parse_stack (l , function_calls )
245- for _ in current .type .arguments ()[:f_call ]
246- ]
247- return Function (current , args )
248- return current
249-
250- sol = parse_stack (parts , function_calls )
251- if check :
252- str_repr = str (sol )
253- for ori , (__ , rep ) in constants .items ():
254- str_repr = str_repr .replace (f" { rep } " , f" { ori } " )
255- str_repr = str_repr .replace (f" { rep } )" , f" { ori } )" )
256- assert (
257- str_repr == program
258- ), f"Failed parsing:\n { program } \n \t got:\n { str_repr } \n \t type request:{ type_request } obtained:{ sol .type } "
259- return sol
253+ n = len (parts )
254+
255+ def parse_stack (i : int ) -> TList [Tuple [Program , int ]]:
256+ if i + 1 == n :
257+ return [(p , n ) for p in parts [- 1 ]]
258+ current = parts [i ]
259+ f_call = function_calls [i ]
260+ out : TList [Tuple [Program , int ]] = []
261+ for some in current :
262+ if some .type .is_instance (Arrow ) and f_call > 0 :
263+ poss_args : TList [Tuple [TList [Program ], int ]] = [([], i + 1 )]
264+ for _ in some .type .arguments ()[:f_call ]:
265+ next = []
266+ for poss , j in poss_args :
267+ parsed = parse_stack (j )
268+ for x , k in parsed :
269+ next .append ((poss + [x ], k ))
270+ poss_args = next
271+
272+ for poss , j in poss_args :
273+ out .append ((Function (some , list (poss )), j ))
274+ else :
275+ out .append ((some , i + 1 ))
276+ return out
277+
278+ sols = parse_stack (0 )
279+
280+ return [p for p , _ in sols ]
260281 else :
261282 program = program .strip ("()" )
262- for P in self .list_primitives :
263- if P .primitive == program :
264- return P
265- if program .startswith ("var" ):
283+ matching : TList [Program ] = [
284+ P for P in self .list_primitives if P .primitive == program
285+ ]
286+ if len (matching ) > 0 :
287+ return matching
288+ elif program .startswith ("var" ):
266289 varno = int (program [3 :])
267290 vart = type_request
268291 if type_request .is_instance (Arrow ):
269292 vart = type_request .arguments ()[varno ]
270- return Variable (varno , vart )
293+ return [ Variable (varno , vart )]
271294 elif program in constants :
272295 t , val = constants [program ]
273- return Constant (t , val , True )
296+ return [ Constant (t , val , True )]
274297 assert False , f"can't parse: '{ program } '"
275298
299+ def parse_program (
300+ self ,
301+ program : str ,
302+ type_request : Type ,
303+ constants : Dict [str , Tuple [Type , Any ]] = {},
304+ check : bool = True ,
305+ ) -> Program :
306+ """
307+ Parse a program from its string representation given the type request.
308+
309+ Parameters:
310+ -----------
311+ - program: the string representation of the program, i.e. str(prog)
312+ - type_request: the type of the requested program in order to identify variable types
313+ - constants: str representation of constants that map to their (type, value)
314+ - check: ensure the program was correctly parsed with type checking
315+
316+ Returns:
317+ -----------
318+ A parsed program that matches the given string
319+ """
320+ possibles = self .__parse_program__ (program , type_request , constants )
321+ if check :
322+ coherents = [p for p in possibles if p .type_checks ()]
323+ assert (
324+ len (coherents ) > 0
325+ ), f"failed to parse a program that type checks for: { program } "
326+ return coherents [0 ]
327+ return possibles [0 ]
328+
276329 def get_primitive (self , name : str ) -> Optional [Primitive ]:
277330 """
278331 Returns the Primitive object with the specified name if it exists and None otherwise
0 commit comments