from typing import List
from banditpylib.arms import StochasticArm
from banditpylib.data_pb2 import Context, Actions, Feedback, ArmPull, \
ArmFeedback
from banditpylib.learners import Goal, MaximizeCorrectAnswers, \
MakeAllAnswersCorrect
from .utils import Bandit
[docs]class ThresholdingBandit(Bandit):
r"""Thresholding bandit environment
Arms are indexed from 0 by default. Each time the learner pulls arm :math:`i`,
she will obtain an `i.i.d.` reward generated from an `unknown` distribution
:math:`\mathcal{D}_i`. Different from the ordinary MAB, there is a threshold
parameter :math:`\theta`. The learner should try to infer whether an arm's
expected reward is above the threshold or not. Besides, the environment also
accepts a parameter :math:`\epsilon >= 0` which is the radius of indifference
zone meaning that the answers about the arms with expected rewards within
:math:`[\theta - \epsilon, \theta + \epsilon]` do not matter.
:param List[StochasticArm] arms: arms in thresholding bandit
:param float theta: threshold
:param float eps: radius of indifferent zone
"""
def __init__(self, arms: List[StochasticArm], theta: float, eps: float):
if len(arms) < 2:
raise ValueError('Number of arms is expected at least 2. Got %d.' %
len(arms))
self.__arms = arms
self.__arm_num = len(arms)
# Correct answers of all the arms whether its expected rewards is above the
# threshold or not
self.__correct_answers = [
1 if self.__arms[arm_id].mean >= theta else 0
for arm_id in range(self.__arm_num)
]
if eps < 0:
raise ValueError(
'Radius of indifference zone is expected at least 0. Got %.2f.' %
eps)
# The answer of the learner does not matter if the expected rewards of an
# arm is within the range [theta-eps, theta+eps]. Hence weight assigned to
# such an arm is 0.
self.__weights = [
0 if theta - eps <= self.__arms[arm_id].mean <= theta + eps else 1
for arm_id in range(self.__arm_num)
]
@property
def name(self) -> str:
return 'thresholding_bandit'
@property
def arm_num(self) -> int:
"""Total number of arms"""
return self.__arm_num
def _take_action(self, arm_pull: ArmPull) -> ArmFeedback:
"""Pull one arm
Args:
arm_pull: arm and its pulls
Returns:
arm_feedback: arm and its feedback
"""
arm_id = arm_pull.arm.id
pulls = arm_pull.times
if arm_id not in range(self.__arm_num):
raise Exception('Arm id %d is out of range [0, %d)!' % \
(arm_id, self.__arm_num))
arm_feedback = ArmFeedback()
if pulls < 1:
return arm_feedback
# Empirical rewards when `arm_id` is pulled for `pulls` times
em_rewards = self.__arms[arm_id].pull(pulls=pulls)
arm_feedback.arm.id = arm_id
arm_feedback.rewards.extend(list(em_rewards)) # type: ignore
return arm_feedback
[docs] def feed(self, actions: Actions) -> Feedback:
feedback = Feedback()
for arm_pull in actions.arm_pulls:
arm_feedback = self._take_action(arm_pull=arm_pull)
if arm_feedback.rewards:
feedback.arm_feedbacks.append(arm_feedback)
return feedback
@property
def context(self) -> Context:
return Context()
[docs] def regret(self, goal: Goal) -> float:
if isinstance(goal, MaximizeCorrectAnswers):
# Aggregate regret which is equal to the number of wrong answers
agg_regret = 0
for arm_id in range(self.__arm_num):
agg_regret += (goal.answers[arm_id] !=
self.__correct_answers[arm_id]) * self.__weights[arm_id]
return agg_regret
elif isinstance(goal, MakeAllAnswersCorrect):
# Simple regret which is 1 when there is at least one wrong answer and 0
# otherwise
for arm_id in range(self.__arm_num):
if (goal.answers[arm_id] !=
self.__correct_answers[arm_id]) and self.__weights[arm_id] == 1:
return 1
return 0
raise Exception('Goal %s is not supported.' % goal.name)