Source code for banditpylib.learners.mab_fbbai_learner.sr

from typing import Optional, Dict

import math
import numpy as np

from banditpylib import argmax_or_min_tuple
from banditpylib.arms import PseudoArm
from banditpylib.data_pb2 import Context, Actions, Feedback
from .utils import MABFixedBudgetBAILearner


[docs]class SR(MABFixedBudgetBAILearner): """Successive rejects policy :cite:`audibert2010best` Eliminate one arm in each round. :param int arm_num: number of arms :param int budget: total number of pulls :param Optional[str] name: alias name """ def __init__(self, arm_num: int, budget: int, name: Optional[str] = None): super().__init__(arm_num=arm_num, budget=budget, name=name) # calculate bar_log_K self.__bar_log_K = 0.5 + sum([1 / i for i in range(2, self.arm_num + 1)]) if (budget - arm_num) < arm_num * self.__bar_log_K: raise Exception('Budget is expected at least %d. Got %d.' % (arm_num * self.__bar_log_K + arm_num, budget)) def _name(self) -> str: return 'sr'
[docs] def reset(self): # Calculate pulls assigned to each arm per round self.__pulls_per_round = [-1] nk = [0] for k in range(1, self.arm_num): nk.append( math.ceil(1 / self.__bar_log_K * (self.budget - self.arm_num) / (self.arm_num + 1 - k))) self.__pulls_per_round.append(nk[k] - nk[k - 1]) self.__active_arms: Dict[int, PseudoArm] = dict() for arm_id in range(self.arm_num): self.__active_arms[arm_id] = PseudoArm() self.__budget_left = self.budget self.__best_arm = None # Current round self.__round = 1
[docs] def actions(self, context: Context) -> Actions: del context actions = Actions() if self.__round < self.arm_num: if self.__round < self.arm_num - 1: for arm_id in self.__active_arms: arm_pull = actions.arm_pulls.add() arm_pull.arm.id = arm_id arm_pull.times = self.__pulls_per_round[self.__round] else: # Use up the remaining budget when there are only two arms left pulls = [self.__budget_left // 2] pulls.append(self.__budget_left - pulls[0]) for i in range(2): arm_pull = actions.arm_pulls.add() arm_pull.arm.id = list(self.__active_arms.keys())[i] arm_pull.times = pulls[i] return actions
[docs] def update(self, feedback: Feedback): for arm_feedback in feedback.arm_feedbacks: self.__active_arms[arm_feedback.arm.id].update( np.array(arm_feedback.rewards)) self.__budget_left -= len(arm_feedback.rewards) # Eliminate the arm with the smallest mean reward arm_id_to_remove = argmax_or_min_tuple( [(arm.em_mean, arm_id) for arm_id, arm in self.__active_arms.items()], find_min=True) del self.__active_arms[arm_id_to_remove] if self.__round == self.arm_num - 1: self.__best_arm = list(self.__active_arms.keys())[0] self.__round += 1
@property def best_arm(self) -> int: if self.__best_arm is None: raise Exception('I don\'t have an answer yet.') return self.__best_arm