11from abc import ABC , abstractmethod
2- from typing import Any , Dict , List , Set , Callable
2+ from typing import Any , Dict , List , Set , Callable , Tuple
33
44from synth .syntax .program import Constant , Function , Primitive , Program , Variable
55
@@ -35,10 +35,10 @@ def __init__(self, semantics: Dict[Primitive, Any], use_cache: bool = True) -> N
3535 # Statistics
3636 self ._total_requests = 0
3737 self ._cache_hits = 0
38- self ._dsl_constants : Dict [Any , Primitive ] = {}
38+ self ._dsl_constants : Dict [Tuple [ Type , Any ] , Primitive ] = {}
3939 for p , val in semantics .items ():
4040 if len (p .type .arguments ()) == 0 :
41- self ._dsl_constants [__tuplify__ (val )] = p
41+ self ._dsl_constants [( p . type , __tuplify__ (val ) )] = p
4242
4343 def compress (self , program : Program , allow_constants : bool = True ) -> Program :
4444 """
@@ -61,8 +61,9 @@ def compress(self, program: Program, allow_constants: bool = True) -> Program:
6161 if isinstance (value , Callable ): # type: ignore
6262 return Function (program .function , args )
6363 tval = __tuplify__ (value )
64- if tval in self ._dsl_constants :
65- return self ._dsl_constants [tval ]
64+ rtype = program .type
65+ if (rtype , tval ) in self ._dsl_constants :
66+ return self ._dsl_constants [(rtype , tval )]
6667 if allow_constants :
6768 return Constant (program .type .returns (), value , True )
6869 else :
0 commit comments