Source code for banditpylib.bandits.thresholding_bandit

from typing import List

from banditpylib.arms import StochasticArm
from banditpylib.data_pb2 import Context, Actions, Feedback, ArmPull, \
    ArmFeedback
from banditpylib.learners import Goal, MaximizeCorrectAnswers, \
    MakeAllAnswersCorrect
from .utils import Bandit


[docs]class ThresholdingBandit(Bandit): r"""Thresholding bandit environment Arms are indexed from 0 by default. Each time the learner pulls arm :math:`i`, she will obtain an `i.i.d.` reward generated from an `unknown` distribution :math:`\mathcal{D}_i`. Different from the ordinary MAB, there is a threshold parameter :math:`\theta`. The learner should try to infer whether an arm's expected reward is above the threshold or not. Besides, the environment also accepts a parameter :math:`\epsilon >= 0` which is the radius of indifference zone meaning that the answers about the arms with expected rewards within :math:`[\theta - \epsilon, \theta + \epsilon]` do not matter. :param List[StochasticArm] arms: arms in thresholding bandit :param float theta: threshold :param float eps: radius of indifferent zone """ def __init__(self, arms: List[StochasticArm], theta: float, eps: float): if len(arms) < 2: raise ValueError('Number of arms is expected at least 2. Got %d.' % len(arms)) self.__arms = arms self.__arm_num = len(arms) # Correct answers of all the arms whether its expected rewards is above the # threshold or not self.__correct_answers = [ 1 if self.__arms[arm_id].mean >= theta else 0 for arm_id in range(self.__arm_num) ] if eps < 0: raise ValueError( 'Radius of indifference zone is expected at least 0. Got %.2f.' % eps) # The answer of the learner does not matter if the expected rewards of an # arm is within the range [theta-eps, theta+eps]. Hence weight assigned to # such an arm is 0. self.__weights = [ 0 if theta - eps <= self.__arms[arm_id].mean <= theta + eps else 1 for arm_id in range(self.__arm_num) ] @property def name(self) -> str: return 'thresholding_bandit'
[docs] def reset(self): pass
@property def arm_num(self) -> int: """Total number of arms""" return self.__arm_num def _take_action(self, arm_pull: ArmPull) -> ArmFeedback: """Pull one arm Args: arm_pull: arm and its pulls Returns: arm_feedback: arm and its feedback """ arm_id = arm_pull.arm.id pulls = arm_pull.times if arm_id not in range(self.__arm_num): raise Exception('Arm id %d is out of range [0, %d)!' % \ (arm_id, self.__arm_num)) arm_feedback = ArmFeedback() if pulls < 1: return arm_feedback # Empirical rewards when `arm_id` is pulled for `pulls` times em_rewards = self.__arms[arm_id].pull(pulls=pulls) arm_feedback.arm.id = arm_id arm_feedback.rewards.extend(list(em_rewards)) # type: ignore return arm_feedback
[docs] def feed(self, actions: Actions) -> Feedback: feedback = Feedback() for arm_pull in actions.arm_pulls: arm_feedback = self._take_action(arm_pull=arm_pull) if arm_feedback.rewards: feedback.arm_feedbacks.append(arm_feedback) return feedback
@property def context(self) -> Context: return Context()
[docs] def regret(self, goal: Goal) -> float: if isinstance(goal, MaximizeCorrectAnswers): # Aggregate regret which is equal to the number of wrong answers agg_regret = 0 for arm_id in range(self.__arm_num): agg_regret += (goal.answers[arm_id] != self.__correct_answers[arm_id]) * self.__weights[arm_id] return agg_regret elif isinstance(goal, MakeAllAnswersCorrect): # Simple regret which is 1 when there is at least one wrong answer and 0 # otherwise for arm_id in range(self.__arm_num): if (goal.answers[arm_id] != self.__correct_answers[arm_id]) and self.__weights[arm_id] == 1: return 1 return 0 raise Exception('Goal %s is not supported.' % goal.name)