Source code for banditpylib.protocols.utils

import multiprocessing
from multiprocessing import Pool
import time
from typing import List

from abc import ABC, abstractmethod
from absl import logging

from google.protobuf.internal.encoder import _VarintBytes  # type: ignore
import numpy as np

from banditpylib.bandits import Bandit
from banditpylib.learners import Learner


def time_seed() -> int:
  """Generate random seed using current time

  Returns:
    random seed
  """
  tem_time = time.time()
  return int((tem_time - int(tem_time)) * 10000000)


[docs]class Protocol(ABC): """Abstract class for a communication protocol which defines the principles of the interactions between the learner and the bandit environment. :param Bandit bandit: bandit environment :param List[Learner] learners: learners used to run simulations """ def __init__(self, bandit: Bandit, learners: List[Learner]): for learner in learners: if not isinstance(learner.running_environment, List): if not isinstance(bandit, learner.running_environment): raise Exception('Learner %s can not recognize environment %s.' % (learner.name, bandit.name)) else: # learner.running_environment is a list environment_is_recognizable = False for env in learner.running_environment: if isinstance(bandit, env): environment_is_recognizable = True if not environment_is_recognizable: raise Exception('Learner %s can not recognize environment %s.' % (learner.name, bandit.name)) self.__bandit = bandit self.__learners = learners @property @abstractmethod def name(self) -> str: """Protocol name""" @property def _bandit(self) -> Bandit: """Bandit environment""" return self.__bandit @property def _current_learner(self) -> Learner: """The learner in simulation currently""" return self.__current_learner @property def _horizon(self) -> int: """Horizon of the game""" return self.__horizon @property def _intermediate_horizons(self) -> List[int]: """Horizons used to report intermediate regrets""" return self.__intermediate_horizons @property def _debug(self) -> bool: """Debug mode""" return self.__debug @abstractmethod def _one_trial(self, random_seed: int) -> bytes: """One trial of the game This method defines how to run one trial of the game. Args: random_seed: random seed debug: whether to run the trial in debug mode Returns: one trial data """ def __write_to_file(self, data: bytes): """Write the result of one trial to file Args: data: one trial data """ with open(self.__output_filename, 'ab') as f: f.write(_VarintBytes(len(data))) f.write(data) f.flush()
[docs] def play( self, trials: int, output_filename: str, processes: int = -1, debug: bool = False, # pylint: disable=dangerous-default-value intermediate_horizons: List[int] = [], horizon: int = np.inf): # type: ignore """Start playing the game Args: trials: number of repetitions output_filename: name of the file used to dump the simulation results processes: maximum number of processes to run. -1 means no limit debug: debug mode. When it is set to `True`, `trials` will be automatically set to 1 and debug information of the trial will be printed out. intermediate_horizons: report intermediate regrets after these horizons horizon: horizon of the game. Different protocols may have different interpretations. .. warning:: By default, `output_filename` will be opened with mode `a`. """ if debug: trials = 1 self.__debug = debug self.__horizon = horizon self.__intermediate_horizons = intermediate_horizons for learner in self.__learners: # Set current learner self.__current_learner = learner logging.info('start %s\'s play with %s', self.__current_learner.name, self.__bandit.name) start_time = time.time() self.__output_filename = output_filename pool = Pool(processes=( multiprocessing.cpu_count() if processes < 0 else processes)) trial_results = [] for _ in range(trials): result = pool.apply_async(self._one_trial, args=[time_seed()], callback=self.__write_to_file) trial_results.append(result) # Can not apply for processes any more pool.close() pool.join() # Check if there are exceptions during the trials for result in trial_results: result.get() logging.info('%s\'s play with %s runs %.2f seconds.', self.__current_learner.name, self.__bandit.name, time.time() - start_time)