Skip to content

Commit 16dd93c

Browse files
committed
better evaluation of candidates
1 parent be15566 commit 16dd93c

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

examples/rl/bandits/bandit.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from typing import List, Tuple
3+
from typing import List, Optional, Tuple
44

55

66
class MultiArmedBandit:
@@ -50,6 +50,9 @@ def choose_arm_ucb(self) -> int:
5050
]
5151
return np.argmax(best_returns)
5252

53+
def choose_arm_incertitude(self) -> int:
54+
return np.argmin(self._counts)
55+
5356
def best_arm(self) -> int:
5457
return self._best_arm
5558

@@ -80,12 +83,34 @@ def possible_returns(self) -> List[Tuple[float, float]]:
8083
for q_value, inc in zip(self._mean_returns, incertitudes)
8184
]
8285

83-
def best_possible_return(self) -> float:
84-
assert np.min(self._counts) > 0
85-
incertitudes = [
86-
self.c * np.sqrt(np.log(self._time) / count) for count in self._counts
87-
]
88-
best_returns = [
89-
self._mean_returns[i] + incertitudes[i] for i in range(self.arms)
90-
]
91-
return np.max(best_returns)
86+
def best_possible_return(self, arm: Optional[int] = None) -> float:
87+
if np.min(self._counts) <= 0:
88+
return float("inf")
89+
if arm is None:
90+
incertitudes = [
91+
self.c * np.sqrt(np.log(self._time) / count) for count in self._counts
92+
]
93+
best_returns = [
94+
self._mean_returns[i] + incertitudes[i] for i in range(self.arms)
95+
]
96+
return np.max(best_returns)
97+
else:
98+
return self._mean_returns[arm] + self.c * np.sqrt(
99+
np.log(self._time) / self._counts[arm]
100+
)
101+
102+
def worst_possible_return(self, arm: Optional[int] = None) -> float:
103+
if np.min(self._counts) > 0:
104+
return -float("inf")
105+
if arm is None:
106+
incertitudes = [
107+
self.c * np.sqrt(np.log(self._time) / count) for count in self._counts
108+
]
109+
best_returns = [
110+
self._mean_returns[i] - incertitudes[i] for i in range(self.arms)
111+
]
112+
return np.min(best_returns)
113+
else:
114+
return self._mean_returns[arm] - self.c * np.sqrt(
115+
np.log(self._time) / self._counts[arm]
116+
)

examples/rl/bandits/clever_evaluator.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def challenge_with(
4141
self._arm2candidate.insert(arm, new_candidate)
4242
# Add prior experience
4343
for arm_return in prior_experience:
44-
self.bandit.add_return(arm, -arm_return)
44+
self.bandit.add_return(arm, arm_return)
4545
break
4646

4747
assert new_candidate in self.candidates
@@ -53,13 +53,13 @@ def challenge_with(
5353
return ejected_candidate, budget_used
5454

5555
def get_best_stats(self) -> Tuple[T, float, float, float, float]:
56-
arm: int = self.bandit.worst_arm()
56+
arm: int = self.bandit.best_arm()
5757
candidate = (
5858
self._arm2candidate[arm]
5959
if arm < self._last_ejected
6060
else self._arm2candidate[arm - 1]
6161
)
62-
best_returns = [-x for x in self.bandit.returns[arm]]
62+
best_returns = [x for x in self.bandit.returns[arm]]
6363
if len(best_returns) == 0:
6464
return candidate, float("nan"), float("inf"), -float("inf"), float("inf")
6565
mean_return = sum(best_returns) / len(best_returns)
@@ -72,24 +72,24 @@ def get_best_stats(self) -> Tuple[T, float, float, float, float]:
7272
)
7373

7474
def run_at_least(self, min_budget: int, min_score: float = -float("inf")) -> int:
75-
arm: int = self.bandit.worst_arm()
75+
arm: int = self.bandit.best_arm()
7676
candidate = (
7777
self._arm2candidate[arm]
7878
if arm < self._last_ejected
7979
else self._arm2candidate[arm - 1]
8080
)
8181
budget_used: int = 0
82-
sum_ret = sum([-x for x in self.bandit.returns[arm]])
82+
sum_ret = sum([x for x in self.bandit.returns[arm]])
8383
n = len(self.bandit.returns[arm])
8484
while self.bandit.samples(arm) < min_budget and sum_ret / n >= min_score:
8585
has_no_error, arm_return = self.get_return(candidate)
8686
budget_used += 1
8787
if not has_no_error:
8888
self.bandit.add_return(arm, 1e10)
8989
break
90-
self.bandit.add_return(arm, -arm_return)
90+
self.bandit.add_return(arm, arm_return)
9191
n += 1
92-
sum_ret += -arm_return
92+
sum_ret += arm_return
9393
return budget_used
9494

9595
def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
@@ -98,20 +98,24 @@ def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
9898
"""
9999
budget_used: int = 0
100100
while self.__get_candidate_to_eject__() is None and budget_used < max_budget:
101-
arm: int = self.bandit.choose_arm_ucb()
101+
arm: int = self.bandit.choose_arm_incertitude()
102102
candidate: T = self._arm2candidate[arm]
103103
has_no_error, arm_return = self.get_return(candidate)
104104
if not has_no_error:
105105
self.bandit.add_return(arm, 1e10)
106106
return candidate, budget_used
107-
self.bandit.add_return(arm, -arm_return)
107+
self.bandit.add_return(arm, arm_return)
108108
budget_used += 1
109109
return self.__get_candidate_to_eject__(True), budget_used
110110

111111
def __get_candidate_to_eject__(self, force: bool = False) -> Optional[T]:
112-
worst_arm = self.bandit.best_arm()
112+
worst_arm = self.bandit.worst_arm()
113113
if force:
114114
return self._arm2candidate[worst_arm]
115+
if self.bandit.best_possible_return(
116+
worst_arm
117+
) < self.bandit.worst_possible_return(self.bandit.best_arm()):
118+
return self._arm2candidate[worst_arm]
115119
return_intervals = self.bandit.possible_returns()
116120
low, high = return_intervals.pop(worst_arm)
117121
midpoint: float = (high + low) / 2

0 commit comments

Comments
 (0)