Skip to content

Commit 4b3257e

Browse files
committed
big rework of evaluation
1 parent a9c9d8b commit 4b3257e

File tree

6 files changed

+247
-255
lines changed

6 files changed

+247
-255
lines changed

examples/rl/bandits/clever_evaluator.py

Lines changed: 0 additions & 136 deletions
This file was deleted.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from examples.rl.program_evaluator import ProgramEvaluator
2+
import numpy as np
3+
4+
from typing import List, Optional, Tuple, TypeVar, Generic
5+
6+
7+
T = TypeVar("T", covariant=True)
8+
9+
10+
class TopkManager(Generic[T]):
11+
#
12+
def __init__(
13+
self,
14+
evaluator: ProgramEvaluator,
15+
c: float = 0.7,
16+
k: int = 2,
17+
) -> None:
18+
self.evaluator = evaluator
19+
self.candidates: List[T] = []
20+
self.k = k
21+
self.c = c
22+
23+
def num_candidates(self) -> int:
24+
return len(self.candidates)
25+
26+
def challenge_with(
27+
self,
28+
new_candidate: T,
29+
max_budget: int = 100,
30+
prior_experience: List[float] = [],
31+
) -> Tuple[Optional[T], int]:
32+
"""
33+
return: the T ejected and the no of calls to get_return
34+
"""
35+
# Add new program
36+
self.evaluator.add_returns(new_candidate, prior_experience)
37+
self.candidates.append(new_candidate)
38+
ejected_candidate, budget_used = self.__run_until_ejection__(max_budget)
39+
if ejected_candidate:
40+
self.__eject__(ejected_candidate)
41+
return ejected_candidate, budget_used
42+
43+
def get_best_stats(self) -> Tuple[T, float, float, float, float]:
44+
best_arm = np.argmax([self.evaluator.mean_return(p) for p in self.candidates])
45+
candidate = self.candidates[best_arm]
46+
n = self.evaluator.samples(candidate)
47+
if n == 0:
48+
return candidate, float("nan"), float("inf"), -float("inf"), float("inf")
49+
rew = self.evaluator.returns(candidate)
50+
mean_return = np.mean(rew)
51+
return (
52+
candidate,
53+
mean_return,
54+
n,
55+
min(rew),
56+
max(rew),
57+
)
58+
59+
def run_at_least(self, min_budget: int, min_score: float = -float("inf")) -> int:
60+
best_arm = np.argmax([self.evaluator.mean_return(p) for p in self.candidates])
61+
candidate = self.candidates[best_arm]
62+
initial: int = self.evaluator.samples(candidate)
63+
budget_used: int = 0
64+
while (
65+
initial + budget_used < min_budget
66+
and self.evaluator.mean_return(candidate) >= min_score
67+
):
68+
budget_used += 1
69+
has_no_error = self.evaluator.eval(candidate)
70+
if not has_no_error:
71+
break
72+
return budget_used
73+
74+
def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
75+
"""
76+
return: the T ejected and the cost
77+
"""
78+
budget_used: int = 0
79+
while self.__get_candidate_to_eject__() is None and budget_used < max_budget:
80+
index: int = np.argmin([self.evaluator.samples(p) for p in self.candidates])
81+
candidate: T = self.candidates[index]
82+
has_no_error = self.evaluator.eval(candidate)
83+
if not has_no_error:
84+
return candidate, budget_used
85+
budget_used += 1
86+
return self.__get_candidate_to_eject__(
87+
len(self.candidates) >= self.k
88+
), budget_used
89+
90+
def __get_candidate_to_eject__(self, force: bool = False) -> Optional[T]:
91+
if len(self.candidates) == 1:
92+
return None
93+
mean_returns = [self.evaluator.mean_return(p) for p in self.candidates]
94+
worst_arm = np.argmin(mean_returns)
95+
worst = self.candidates[worst_arm]
96+
if force:
97+
return worst
98+
best_arm = np.argmax(mean_returns)
99+
best = self.candidates[best_arm]
100+
101+
if mean_returns[best_arm] - self.uncertainty(best) >= mean_returns[
102+
worst_arm
103+
] + self.uncertainty(worst):
104+
return worst
105+
return None
106+
107+
def uncertainty(self, candidate: T) -> float:
108+
n = self.evaluator.samples(candidate)
109+
if n == 0:
110+
return float("inf")
111+
return self.c * np.sqrt(
112+
np.log(sum(self.evaluator.samples(p) for p in self.candidates)) / n
113+
)
114+
115+
def __eject__(self, candidate: T):
116+
self.evaluator.delete_data(candidate)
117+
self.candidates.remove(candidate)

examples/rl/optim/constant_optimizer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,25 @@ def __init__(self, seed: Optional[int] = None) -> None:
1414
self.c = 0.7
1515
self.min_budget_per_arm = 15
1616
self.best_return = 0
17-
self._rng = np.random.default_rng(seed)
17+
# self._rng = np.random.default_rng(seed)
1818

1919
def optimize(
20-
self,
21-
eval: Callable[[], float],
22-
constants: List[Constant],
20+
self, eval: Callable[[], float], constants: List[Constant], **kwargs
2321
) -> Tuple[List[Tile], List[List[float]]]:
2422
self.budget_used = 0
2523
self._constants = constants
2624
tiles = [tile_split(-np.inf, np.inf, splits=4) for _ in constants]
2725
self._eval = eval
2826
self._can_hope_to_beat_best = True
29-
return self._optimize_tiles_(constants, tiles)
27+
return self._optimize_tiles_(constants, tiles, **kwargs)
3028

3129
def _pick_values(self) -> List[int]:
3230
arms = []
3331
for index, bandit in enumerate(self._bandits):
3432
arm = bandit.choose_arm_ucb()
3533
arms.append(arm)
3634
self._constants[index].assign(
37-
self._tiles_list[index][arm].map(self._rng.uniform(0, 1))
35+
self._tiles_list[index][arm].map(np.random.uniform(0, 1))
3836
)
3937
return arms
4038

@@ -62,7 +60,7 @@ def _optimize_tiles_(
6260
constants: List[Constant],
6361
tiles_list: List[List[Tile]],
6462
prev_experiences=None,
65-
max_total_budget=1500,
63+
max_total_budget=1000,
6664
) -> Tuple[List[Tile], List[List[float]]]:
6765
self._constants = constants
6866
self._tiles_list = tiles_list

examples/rl/program_evaluator.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import Callable, List, Tuple
2+
import gymnasium as gym
3+
from synth.semantic.evaluator import Evaluator
4+
from synth.syntax.program import Program
5+
import numpy as np
6+
7+
8+
def __state2env__(state: np.ndarray) -> Tuple:
9+
return tuple(state.tolist())
10+
11+
12+
def __adapt_action2env__(env: gym.Env, action) -> List:
13+
if isinstance(env.action_space, gym.spaces.Box):
14+
if len(env.action_space.shape) == 1 and env.action_space.shape[0] == 1:
15+
return [min(max(action, env.action_space.low[0]), env.action_space.high[0])]
16+
return action
17+
18+
19+
class ProgramEvaluator:
20+
def __init__(self, env_factory: Callable[[], gym.Env], evaluator: Evaluator):
21+
self.cache = {}
22+
self.env_factory = env_factory
23+
self.dsl_eval = evaluator
24+
self.recording = True
25+
self.tmp_keys = []
26+
27+
def record(self, record: bool):
28+
if not self.recording and record:
29+
for key in self.tmp_keys:
30+
del self.cache[key]
31+
self.tmp_keys.clear()
32+
self.recording = record
33+
34+
def delete_data(self, program: Program):
35+
del self.cache[program.hash]
36+
37+
def returns(self, program: Program) -> List[float]:
38+
return self.cache.get(program.hash, (0, []))[1]
39+
40+
def mean_return(self, program: Program) -> float:
41+
r = self.returns(program)
42+
if len(r) == 0:
43+
return 0
44+
return sum(r) / len(r)
45+
46+
def samples(self, program: Program) -> int:
47+
return len(self.cache.get(program.hash, (0, []))[1])
48+
49+
def add_returns(self, program: Program, returns: List[float]):
50+
if program.hash not in self.cache:
51+
self.cache[program.hash] = (self.env_factory(), [])
52+
if not self.recording:
53+
self.tmp_keys.append(program.hash)
54+
li = self.returns(program)
55+
for el in returns:
56+
li.append(el)
57+
58+
def eval(self, program: Program, n_episodes: int = 1) -> bool:
59+
if program.hash not in self.cache:
60+
self.cache[program.hash] = (self.env_factory(), [])
61+
if not self.recording:
62+
self.tmp_keys.append(program.hash)
63+
env, returns = self.cache[program.hash]
64+
try:
65+
state = None
66+
for _ in range(n_episodes):
67+
episode = []
68+
state = env.reset()[0]
69+
done = False
70+
while not done:
71+
input = __state2env__(state)
72+
action = self.dsl_eval.eval(program, input)
73+
adapted_action = __adapt_action2env__(env, action)
74+
if adapted_action not in env.action_space:
75+
return False
76+
next_state, reward, done, truncated, _ = env.step(adapted_action)
77+
done |= truncated
78+
episode.append(reward)
79+
state = next_state
80+
returns.append(sum(episode))
81+
except OverflowError:
82+
return False
83+
return True

0 commit comments

Comments
 (0)