Source code for banditpylib.learners.mnl_bandit_learner.utils

import copy

from typing import Optional, List, Union

import numpy as np

from banditpylib.bandits import MNLBandit
from banditpylib.bandits import Reward
from banditpylib.learners import SinglePlayerLearner, Goal, MaximizeTotalRewards


[docs]class MNLBanditLearner(SinglePlayerLearner): """Abstract class for learners playing with mnl bandit Product 0 is reserved for non-purchase. And it is assumed that the preference parameter for non-purchase is 1. :param np.ndarray revenues: product revenues :param Reward reward: reward the learner wants to maximize :param int card_limit: cardinality constraint :param bool use_local_search: whether to use local search for searching the best assortment :param int random_neighbors: number of random neighbors to look up if local search is enabled :param Optional[str] name: alias name """ def __init__(self, revenues: np.ndarray, reward: Reward, card_limit: int, use_local_search: bool, random_neighbors: int, name: Optional[str]): super().__init__(name) self.__product_num = len(revenues) - 1 if self.__product_num < 2: raise ValueError('Number of products is expected at least 2. Got %d.' % self.__product_num) if revenues[0] != 0: raise ValueError('The revenue of product 0 is expected 0. Got %.2f.' % revenues[0]) self.__revenues = revenues if card_limit < 1: raise ValueError('Cardinality limit is expected at least 1. Got %d.' % card_limit) self.__card_limit = min(card_limit, self.__product_num) self.__reward = copy.deepcopy(reward) self.__reward.set_revenues(self.__revenues) self.__use_local_search = use_local_search if 0 <= random_neighbors < 3: raise ValueError( 'Number of neighbors for local search is expected 3. Got %d.' % random_neighbors) self.__random_neighbors = random_neighbors @property def running_environment(self) -> Union[type, List[type]]: return MNLBandit @property def product_num(self) -> int: """Product numbers (not including product 0)""" return self.__product_num @property def revenues(self) -> np.ndarray: """Revenues of products (product 0 is included)""" return self.__revenues @property def card_limit(self) -> int: """Cardinality limit""" return self.__card_limit @property def reward(self) -> Reward: """Reward the learner wants to maximize""" return self.__reward @property def use_local_search(self) -> bool: """Whether local search is enabled""" return self.__use_local_search @property def random_neighbors(self) -> int: """Number of random neighbors to look up when local search is enabled""" return self.__random_neighbors @property def goal(self) -> Goal: return MaximizeTotalRewards()