@@ -41,7 +41,7 @@ def challenge_with(
4141 self ._arm2candidate .insert (arm , new_candidate )
4242 # Add prior experience
4343 for arm_return in prior_experience :
44- self .bandit .add_return (arm , - arm_return )
44+ self .bandit .add_return (arm , arm_return )
4545 break
4646
4747 assert new_candidate in self .candidates
@@ -53,13 +53,13 @@ def challenge_with(
5353 return ejected_candidate , budget_used
5454
5555 def get_best_stats (self ) -> Tuple [T , float , float , float , float ]:
56- arm : int = self .bandit .worst_arm ()
56+ arm : int = self .bandit .best_arm ()
5757 candidate = (
5858 self ._arm2candidate [arm ]
5959 if arm < self ._last_ejected
6060 else self ._arm2candidate [arm - 1 ]
6161 )
62- best_returns = [- x for x in self .bandit .returns [arm ]]
62+ best_returns = [x for x in self .bandit .returns [arm ]]
6363 if len (best_returns ) == 0 :
6464 return candidate , float ("nan" ), float ("inf" ), - float ("inf" ), float ("inf" )
6565 mean_return = sum (best_returns ) / len (best_returns )
@@ -72,24 +72,24 @@ def get_best_stats(self) -> Tuple[T, float, float, float, float]:
7272 )
7373
7474 def run_at_least (self , min_budget : int , min_score : float = - float ("inf" )) -> int :
75- arm : int = self .bandit .worst_arm ()
75+ arm : int = self .bandit .best_arm ()
7676 candidate = (
7777 self ._arm2candidate [arm ]
7878 if arm < self ._last_ejected
7979 else self ._arm2candidate [arm - 1 ]
8080 )
8181 budget_used : int = 0
82- sum_ret = sum ([- x for x in self .bandit .returns [arm ]])
82+ sum_ret = sum ([x for x in self .bandit .returns [arm ]])
8383 n = len (self .bandit .returns [arm ])
8484 while self .bandit .samples (arm ) < min_budget and sum_ret / n >= min_score :
8585 has_no_error , arm_return = self .get_return (candidate )
8686 budget_used += 1
8787 if not has_no_error :
8888 self .bandit .add_return (arm , 1e10 )
8989 break
90- self .bandit .add_return (arm , - arm_return )
90+ self .bandit .add_return (arm , arm_return )
9191 n += 1
92- sum_ret += - arm_return
92+ sum_ret += arm_return
9393 return budget_used
9494
9595 def __run_until_ejection__ (self , max_budget : int ) -> Tuple [Optional [T ], int ]:
@@ -98,20 +98,24 @@ def __run_until_ejection__(self, max_budget: int) -> Tuple[Optional[T], int]:
9898 """
9999 budget_used : int = 0
100100 while self .__get_candidate_to_eject__ () is None and budget_used < max_budget :
101- arm : int = self .bandit .choose_arm_ucb ()
101+ arm : int = self .bandit .choose_arm_incertitude ()
102102 candidate : T = self ._arm2candidate [arm ]
103103 has_no_error , arm_return = self .get_return (candidate )
104104 if not has_no_error :
105105 self .bandit .add_return (arm , 1e10 )
106106 return candidate , budget_used
107- self .bandit .add_return (arm , - arm_return )
107+ self .bandit .add_return (arm , arm_return )
108108 budget_used += 1
109109 return self .__get_candidate_to_eject__ (True ), budget_used
110110
111111 def __get_candidate_to_eject__ (self , force : bool = False ) -> Optional [T ]:
112- worst_arm = self .bandit .best_arm ()
112+ worst_arm = self .bandit .worst_arm ()
113113 if force :
114114 return self ._arm2candidate [worst_arm ]
115+ if self .bandit .best_possible_return (
116+ worst_arm
117+ ) < self .bandit .worst_possible_return (self .bandit .best_arm ()):
118+ return self ._arm2candidate [worst_arm ]
115119 return_intervals = self .bandit .possible_returns ()
116120 low , high = return_intervals .pop (worst_arm )
117121 midpoint : float = (high + low ) / 2
0 commit comments