@@ -71,20 +71,25 @@ def get_best_stats(self) -> Tuple[T, float, float, float, float]:
7171 max (best_returns ),
7272 )
7373
74- def run_at_least (self , min_budget : int ) -> int :
74+ def run_at_least (self , min_budget : int , min_score : float = - float ( "inf" ) ) -> int :
7575 arm : int = self .bandit .worst_arm ()
7676 candidate = (
7777 self ._arm2candidate [arm ]
7878 if arm < self ._last_ejected
7979 else self ._arm2candidate [arm - 1 ]
8080 )
8181 budget_used : int = 0
82- while self .bandit .samples (arm ) < min_budget :
83- can_continue , arm_return = self .get_return (candidate )
82+ sum_ret = sum ([- x for x in self .bandit .returns [arm ]])
83+ n = len (self .bandit .returns [arm ])
84+ while self .bandit .samples (arm ) < min_budget and sum_ret / n >= min_score :
85+ has_no_error , arm_return = self .get_return (candidate )
8486 budget_used += 1
85- if not can_continue :
87+ if not has_no_error :
88+ self .bandit .add_return (arm , 1e10 )
8689 break
8790 self .bandit .add_return (arm , - arm_return )
91+ n += 1
92+ sum_ret += - arm_return
8893 return budget_used
8994
9095 def __run_until_ejection__ (self , max_budget : int ) -> Tuple [Optional [T ], int ]:
@@ -95,8 +100,9 @@ def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
95100 while self .__get_candidate_to_eject__ () is None and budget_used < max_budget :
96101 arm : int = self .bandit .choose_arm_ucb ()
97102 candidate : T = self ._arm2candidate [arm ]
98- can_continue , arm_return = self .get_return (candidate )
99- if not can_continue :
103+ has_no_error , arm_return = self .get_return (candidate )
104+ if not has_no_error :
105+ self .bandit .add_return (arm , 1e10 )
100106 return candidate , budget_used
101107 self .bandit .add_return (arm , - arm_return )
102108 budget_used += 1
0 commit comments