Skip to content

Commit be15566

Browse files
committed
optimisation for reduced number of episodes
1 parent 2c7f056 commit be15566

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

examples/rl/bandits/clever_evaluator.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,25 @@ def get_best_stats(self) -> Tuple[T, float, float, float, float]:
7171
max(best_returns),
7272
)
7373

74-
def run_at_least(self, min_budget: int) -> int:
74+
def run_at_least(self, min_budget: int, min_score: float = -float("inf")) -> int:
7575
arm: int = self.bandit.worst_arm()
7676
candidate = (
7777
self._arm2candidate[arm]
7878
if arm < self._last_ejected
7979
else self._arm2candidate[arm - 1]
8080
)
8181
budget_used: int = 0
82-
while self.bandit.samples(arm) < min_budget:
83-
can_continue, arm_return = self.get_return(candidate)
82+
sum_ret = sum([-x for x in self.bandit.returns[arm]])
83+
n = len(self.bandit.returns[arm])
84+
while self.bandit.samples(arm) < min_budget and sum_ret / n >= min_score:
85+
has_no_error, arm_return = self.get_return(candidate)
8486
budget_used += 1
85-
if not can_continue:
87+
if not has_no_error:
88+
self.bandit.add_return(arm, 1e10)
8689
break
8790
self.bandit.add_return(arm, -arm_return)
91+
n += 1
92+
sum_ret += -arm_return
8893
return budget_used
8994

9095
def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
@@ -95,8 +100,9 @@ def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
95100
while self.__get_candidate_to_eject__() is None and budget_used < max_budget:
96101
arm: int = self.bandit.choose_arm_ucb()
97102
candidate: T = self._arm2candidate[arm]
98-
can_continue, arm_return = self.get_return(candidate)
99-
if not can_continue:
103+
has_no_error, arm_return = self.get_return(candidate)
104+
if not has_no_error:
105+
self.bandit.add_return(arm, 1e10)
100106
return candidate, budget_used
101107
self.bandit.add_return(arm, -arm_return)
102108
budget_used += 1

examples/rl/solve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
# =========================================================================
7878
# GLOBAL PARAMETERS
7979
# max number of episodes that should be done at most to compare two possiby equal (optimised) candidates
80-
MAX_BUDGET: int = 40
80+
MAX_BUDGET: int = 80
8181

8282
np.random.seed(SEED)
8383

@@ -220,7 +220,7 @@ def is_solved() -> bool:
220220
current_best_return = evaluator.get_best_stats()[1]
221221
if current_best_return >= TARGET_RETURN:
222222
with chronometer.clock("evaluation.confirm"):
223-
budget_used = evaluator.run_at_least(100)
223+
budget_used = evaluator.run_at_least(100, TARGET_RETURN)
224224
counter.count("episodes.confirm", budget_used)
225225
current_best_return = evaluator.get_best_stats()[1]
226226
if current_best_return >= TARGET_RETURN:

0 commit comments

Comments
 (0)