Source code for banditpylib.learners.mab_fbbai_learner.uniform

from typing import Optional

import numpy as np

from banditpylib import argmax_or_min
from banditpylib.arms import PseudoArm
from banditpylib.data_pb2 import Context, Actions, Feedback
from .utils import MABFixedBudgetBAILearner


[docs]class Uniform(MABFixedBudgetBAILearner): """Uniform sampling policy Play each arm the same number of times and then output the arm with the highest empirical mean. :param int arm_num: number of arms :param int budget: total number of pulls :param Optional[str] name: alias name """ def __init__(self, arm_num: int, budget: int, name: Optional[str] = None): super().__init__(arm_num=arm_num, budget=budget, name=name) def _name(self) -> str: return 'uniform'
[docs] def reset(self): self.__pseudo_arms = [PseudoArm() for arm_id in range(self.arm_num)] self.__best_arm = None self.__stop = False
[docs] def actions(self, context: Context) -> Actions: del context actions = Actions() if not self.__stop: # Make sure each arm is sampled at least once pulls = np.random.multinomial(self.budget - self.arm_num, np.ones(self.arm_num) / self.arm_num, size=1)[0] for arm_id in range(self.arm_num): arm_pull = actions.arm_pulls.add() arm_pull.arm.id = arm_id arm_pull.times = pulls[arm_id] + 1 self.__stop = True return actions
[docs] def update(self, feedback: Feedback): for arm_feedback in feedback.arm_feedbacks: self.__pseudo_arms[arm_feedback.arm.id].update( np.array(arm_feedback.rewards)) if self.__stop: self.__best_arm = argmax_or_min( [arm.em_mean for arm in self.__pseudo_arms])
@property def best_arm(self) -> int: if self.__best_arm is None: raise Exception('I don\'t have an answer yet.') return self.__best_arm