Source code for banditpylib.learners.mab_learner.explore_then_commit

from typing import Optional

import numpy as np

from banditpylib import argmax_or_min_tuple
from banditpylib.arms import PseudoArm
from banditpylib.data_pb2 import Context, Actions, Feedback
from .utils import MABLearner


[docs]class ExploreThenCommit(MABLearner): r"""Explore-Then-Commit policy During the first :math:`T' \leq T` time steps (exploration period), play each arm in a round-robin way. Then for the remaining time steps, play the arm with the maximum empirical mean reward within exploration period consistently. :param int arm_num: number of arms :param int T_prime: time steps to explore :param Optional[str] name: alias name """ def __init__(self, arm_num: int, T_prime: int, name: Optional[str] = None): super().__init__(arm_num=arm_num, name=name) if T_prime < arm_num: raise ValueError('T\' is expected at least %d. got %d.' % (arm_num, T_prime)) self.__T_prime = T_prime self.__best_arm: int = -1 def _name(self) -> str: return 'explore_then_commit'
[docs] def reset(self): self.__pseudo_arms = [PseudoArm() for arm_id in range(self.arm_num)] # Current time step self.__time = 1
[docs] def actions(self, context: Context) -> Actions: del context actions = Actions() arm_pull = actions.arm_pulls.add() if self.__time <= self.__T_prime: arm_pull.arm.id = (self.__time - 1) % self.arm_num else: arm_pull.arm.id = self.__best_arm arm_pull.times = 1 return actions
[docs] def update(self, feedback: Feedback): arm_feedback = feedback.arm_feedbacks[0] self.__pseudo_arms[arm_feedback.arm.id].update( np.array(arm_feedback.rewards)) self.__time += 1 if self.__best_arm < 0 and self.__time > self.__T_prime: self.__best_arm = argmax_or_min_tuple([ (self.__pseudo_arms[arm_id].em_mean, arm_id) for arm_id in range(self.arm_num) ])