Source code for genrl.agents.bandits.contextual.common.base_model

from abc import ABC, abstractmethod
from typing import Dict

import torch  # noqa
import torch.nn as nn  # noqa
import torch.nn.functional as F

from genrl.agents.bandits.contextual.common.transition import TransitionDB


[docs]class Model(nn.Module, ABC): """Bayesian Neural Network used in Deep Contextual Bandit Models. Args: context_dim (int): Length of context vector. hidden_dims (List[int], optional): Dimensions of hidden layers of network. n_actions (int): Number of actions that can be selected. Taken as length of output vector for network to predict. init_lr (float, optional): Initial learning rate. max_grad_norm (float, optional): Maximum norm of gradients for gradient clipping. lr_decay (float, optional): Decay rate for learning rate. lr_reset (bool, optional): Whether to reset learning rate ever train interval. Defaults to False. dropout_p (Optional[float], optional): Probability for dropout. Defaults to None which implies dropout is not to be used. noise_std (float): Standard deviation of noise used in the network. Defaults to 0.1 Attributes: use_dropout (int): Indicated whether or not dropout should be used in forward pass. """ def __init__(self, layer, **kwargs): super(Model, self).__init__() self.context_dim = kwargs.get("context_dim") self.hidden_dims = kwargs.get("hidden_dims") self.n_actions = kwargs.get("n_actions") t_hidden_dims = [self.context_dim, *self.hidden_dims, self.n_actions] self.layers = nn.ModuleList([]) for i in range(len(t_hidden_dims) - 1): self.layers.append(layer(t_hidden_dims[i], t_hidden_dims[i + 1])) self.init_lr = kwargs.get("init_lr", 3e-4) self.optimizer = torch.optim.Adam(self.layers.parameters(), lr=self.init_lr) self.lr_decay = kwargs.get("lr_decay", None) self.lr_scheduler = ( torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda i: 1 / (1 + self.lr_decay * i) ) if self.lr_decay is not None else None ) self.lr_reset = kwargs.get("lr_reset", False) self.dropout_p = kwargs.get("dropout_p", None) self.use_dropout = True if self.dropout_p is not None else False self.max_grad_norm = kwargs.get("max_grad_norm")
[docs] @abstractmethod def forward(self, context: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: """Computes forward pass through the network. Args: context (torch.Tensor): The context vector to perform forward pass on. Returns: Dict[str, torch.Tensor]: Dictionary of outputs """
[docs] def train_model(self, db: TransitionDB, epochs: int, batch_size: int): """Trains the network on a given database for given epochs and batch_size. Args: db (TransitionDB): The database of transitions to train on. epochs (int): Number of gradient steps to take. batch_size (int): The size of each batch to perform gradient descent on. """ self.use_dropout = True if self.dropout_p is not None else False if self.lr_decay is not None and self.lr_reset is True: self._reset_lr(self.init_lr) for _ in range(epochs): x, a, y = db.get_data(batch_size) action_mask = F.one_hot(a, num_classes=self.n_actions) reward_vec = y.view(-1).repeat(self.n_actions, 1).T * action_mask loss = self._compute_loss(db, x, action_mask, reward_vec, batch_size) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.lr_decay is not None: self.lr_scheduler.step()
@abstractmethod def _compute_loss( self, db: TransitionDB, x: torch.Tensor, action_mask: torch.Tensor, reward_vec: torch.Tensor, batch_size: int, ) -> torch.Tensor: """Computes loss for the model Args: db (TransitionDB): The database of transitions to train on. x (torch.Tensor): Context. action_mask (torch.Tensor): Mask of actions taken. reward_vec (torch.Tensor): Reward vector recieved. batch_size (int): The size of each batch to perform gradient descent on. Returns: torch.Tensor: The computed loss. """ def _reset_lr(self, lr: float) -> None: """Resets learning rate of optimizer. Args: lr (float): New value of learning rate """ for o in self.optimizer.param_groups: o["lr"] = lr self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lambda i: 1 / (1 + self.lr_decay * i) )