Source code for banditpylib.protocols.collaborative_learning_protocol

from typing import List, cast, Dict, Tuple
from copy import deepcopy as dcopy

import numpy as np
from absl import logging

from banditpylib.bandits import Bandit
from banditpylib.data_pb2 import Trial, Actions
from banditpylib.learners import Learner, CollaborativeLearner
from .utils import Protocol


[docs]class CollaborativeLearningProtocol(Protocol): """Collaborative learning protocol :cite:`tao2019collaborative` This class defines the communication protocol for the collaborative learning multi-agent game as discussed in the reference paper. The game runs in rounds. During each round, the protocol runs the following steps in sequence: - For each agent, * fetch the state of the corresponding bandit environment and ask the agent for actions; * send the actions to the bandit environment for execution; * update the agent with the feedback of the bandit environment; * repeat the above steps until the agent enters the `WAIT` or `STOP` state. - If there is at least one agent in `WAIT` state, then fetch information broadcasted from every waiting agent and send them to master to decide arm assignment of next round. Otherwise, stop the game. :param Bandit bandit: bandit environment :param List[CollaborativeLearner] learners: learners that will be compared with .. note:: Each agent interacts with an independent bandit environment. .. note:: Each action counts as a timestep. The time (or sample) complexity equals to the maximum number of pulls across different agents. .. note:: According to the protocol, number of rounds always equals to number of communication rounds plus one. """ def __init__(self, bandit: Bandit, learners: List[CollaborativeLearner]): super().__init__(bandit=bandit, learners=cast(List[Learner], learners)) @property def name(self) -> str: return 'collaborative_learning_protocol' def _one_trial(self, random_seed: int) -> bytes: if self._debug: logging.set_verbosity(logging.DEBUG) np.random.seed(random_seed) # Initialization current_learner = cast(CollaborativeLearner, self._current_learner) current_learner.reset() agents = current_learner.agents bandits = [] master = current_learner.master for _ in range(len(agents)): bandits.append(dcopy(self._bandit)) bandits[-1].reset() trial = Trial() trial.bandit = self._bandit.name trial.learner = current_learner.name communication_rounds, total_pulls = 0, 0 active_agent_ids = list(range(len(agents))) agent_arm_assignment = master.initial_arm_assignment() for agent_id in agent_arm_assignment: agents[agent_id].set_input_arms(agent_arm_assignment[agent_id]) while True: max_pulls = 0 agent_in_wait_ids = [] # Preparation and learning for agent_id in active_agent_ids: agent = agents[agent_id] pulls = 0 while True: actions = agent.actions(bandits[agent_id].context) for arm_pull in actions.arm_pulls: pulls += arm_pull.times if actions.state == Actions.DEFAULT_NORMAL: feedback = bandits[agent_id].feed(actions) agent.update(feedback) elif actions.state == Actions.WAIT: agent_in_wait_ids.append(agent_id) break else: # actions.state == Actions.STOP break max_pulls = max(max_pulls, pulls) total_pulls += max_pulls # Stop if all agents are in STOP states which is equivalent to that no # agents are in WAIT states if not agent_in_wait_ids: break # Communication and aggregation # Key is agent id and target is a dict storing information broadcasted by # the agent accumulated_messages: Dict[int, Dict[int, Tuple[float, int]]] = {} for agent_id in agent_in_wait_ids: agent = agents[agent_id] message_from_agent = agent.broadcast() accumulated_messages[agent_id] = message_from_agent # Send info to master for elimination to get arm assignment for next round agent_arm_assignment = master.elimination(accumulated_messages) for agent_id in agent_arm_assignment: agents[agent_id].set_input_arms(agent_arm_assignment[agent_id]) communication_rounds += 1 # Add simulation results when algorithm stops running result = trial.results.add() result.rounds = communication_rounds result.total_actions = total_pulls result.regret = self._bandit.regret(current_learner.goal) return trial.SerializeToString()