Skip to content

Commit b597f63

Browse files
committed
Fixed type issues in compression
1 parent 06d2f20 commit b597f63

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

synth/semantic/evaluator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Set, Callable
2+
from typing import Any, Dict, List, Set, Callable, Tuple
33

44
from synth.syntax.program import Constant, Function, Primitive, Program, Variable
55

@@ -35,10 +35,10 @@ def __init__(self, semantics: Dict[Primitive, Any], use_cache: bool = True) -> N
3535
# Statistics
3636
self._total_requests = 0
3737
self._cache_hits = 0
38-
self._dsl_constants: Dict[Any, Primitive] = {}
38+
self._dsl_constants: Dict[Tuple[Type, Any], Primitive] = {}
3939
for p, val in semantics.items():
4040
if len(p.type.arguments()) == 0:
41-
self._dsl_constants[__tuplify__(val)] = p
41+
self._dsl_constants[(p.type, __tuplify__(val))] = p
4242

4343
def compress(self, program: Program, allow_constants: bool = True) -> Program:
4444
"""
@@ -61,8 +61,9 @@ def compress(self, program: Program, allow_constants: bool = True) -> Program:
6161
if isinstance(value, Callable): # type: ignore
6262
return Function(program.function, args)
6363
tval = __tuplify__(value)
64-
if tval in self._dsl_constants:
65-
return self._dsl_constants[tval]
64+
rtype = program.type
65+
if (rtype, tval) in self._dsl_constants:
66+
return self._dsl_constants[(rtype, tval)]
6667
if allow_constants:
6768
return Constant(program.type.returns(), value, True)
6869
else:

tests/semantic/test_evaluator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,12 @@
2828
dsl = DSL(syntax)
2929
cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth)
3030

31-
other_syntax = {
32-
"+1": FunctionType(INT, INT),
33-
"0": INT,
34-
"2": INT,
35-
}
31+
other_syntax = {"+1": FunctionType(INT, INT), "0": INT, "2": INT, "True": STRING}
3632

3733
other_semantics = {
3834
"+1": lambda x: x + 1,
3935
"0": 0,
36+
"True": True,
4037
"2": 2,
4138
}
4239
other_dsl = DSL(other_syntax)

0 commit comments

Comments
 (0)