Adding a new Deep Contextual Bandit AgentΒΆ
The bandit
submodule like all of genrl
has been designed to be
easily extensible for custom additions. This tutorial will show how to
create a deep contextual bandit agent which will work with the rest of
genrl.bandit
For the purpose of this tutorial we will consider a simple neural network based agent. Although this is a simplictic agent, implementation of any level of agent will need to have the following steps.
To start off with lets import necessary modules and make a class which
inherits from genrl.agents.bandits.contextual.base.DCBAgent
from typing import Optional
import torch
from genrl.agents.bandits.contextual.base import DCBAgent
from genrl.agents.bandits.contextual.common import NeuralBanditModel, TransitionDB
from genrl.utils.data_bandits.base import DataBasedBandit
class NeuralAgent(DCBAgent):
"""Deep contextual bandit agent based on a neural network."""
def __init__(self, bandit: DataBasedBandit, **kwargs):
def select_action(self, context: torch.Tensor) -> int:
def update_db(self, context: torch.Tensor, action: int, reward: int):
def update_params(
self,
action: Optional[int] = None,
batch_size: int = 512,
train_epochs: int = 20,
):
We will need to implement __init__
, select_action
, update_db
and update_param
to make the class functional.
Lets start off with __init__
. Here we will need to initialise some
required parameters (init_pulls
, eval_with_dropout
, t
and
update_count
) along with our transition database and the neural
network. For the neural network, you can use the NeuralBanditModel
class. It packages together many of the functionalities a neural network
might require. Refer to the docs for more details.
def __init__(self, bandit: DataBasedBandit, **kwargs):
super(NeuralAgent, self).__init__(bandit, **kwargs)
self.model = (
NeuralBanditModel(
context_dim=self.context_dim,
n_actions=self.n_actions,
**kwargs
)
.to(torch.float)
.to(self.device)
)
self.eval_with_dropout = kwargs.get("eval_with_dropout", False)
self.db = TransitionDB(self.device)
self.t = 0
self.update_count = 0
For the select action function, the agent will pass the context vector
through the neural network to produce logits for each action. It will
then select the action with highest logit value. Note that it must also
increment the timestep, and if take every action atleast init_pulls
number of times initially.
def select_action(self, context: torch.Tensor) -> int:
"""Selects action for a given context"""
self.model.use_dropout = self.eval_with_dropout
self.t += 1
if self.t < self.n_actions * self.init_pulls:
return torch.tensor(
self.t % self.n_actions, device=self.device, dtype=torch.int
)
results = self.model(context)
action = torch.argmax(results["pred_rewards"]).to(torch.int)
return action
For updating the databse we can use the add
method of
TransitionDB
class.
def update_db(self, context: torch.Tensor, action: int, reward: int):
"""Updates transition database."""
self.db.add(context, action, reward)
In update_params
we need to train the model on the observations seen
so far. Since the NeuralBanditModel
class already hass a train
function, we just need to call that. However if you are writing your own
model, this is where the updates to the parameters would happen.
def update_params(
self,
action: Optional[int] = None,
batch_size: int = 512,
train_epochs: int = 20,
):
"""Update parameters of the agent."""
self.update_count += 1
self.model.train_model(self.db, train_epochs, batch_size)
Note that some of these functions have unused arguments. The signatures have been decided so as such to ensure generality over all classes of algorithms.
Once you are done with the above, you can use the NeuralAgent
class
like you would any other agent from genrl.bandit
. You can use it
with any of the bandits as well as training it with
genrl.bandit.DCBTrainer.