Source code for genrl.agents.bandits.contextual.common.bayesian

from typing import Dict, Optional, Tuple

import torch  # noqa
import torch.nn as nn  # noqa
import torch.nn.functional as F

from genrl.agents.bandits.contextual.common.base_model import Model
from genrl.agents.bandits.contextual.common.transition import TransitionDB


[docs]class BayesianLinear(nn.Module): """Linear Layer for Bayesian Neural Networks. Args: in_features (int): size of each input sample out_features (int): size of each output sample bias (bool, optional): Whether to use an additive bias. Defaults to True. """ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(BayesianLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.bias = bias self.w_mu = nn.Parameter(torch.Tensor(out_features, in_features)) self.w_sigma = nn.Parameter(torch.Tensor(out_features, in_features)) self.b_mu = nn.Parameter(torch.Tensor(out_features)) if self.bias else None self.b_sigma = nn.Parameter(torch.Tensor(out_features)) if self.bias else None self.reset_parameters()
[docs] def reset_parameters(self) -> None: """Resets weight and bias parameters of the layer.""" self.w_mu.data.normal_(0, 0.1) self.w_sigma.data.normal_(0, 0.1) self.b_mu.data.normal_(0, 0.1) if self.bias else None self.b_sigma.data.normal_(0, 0.1) if self.bias else None
[docs] def forward( self, x: torch.Tensor, kl: bool = True, frozen: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Apply linear transormation to input. The weight and bias is sampled for each forward pass from a normal distribution. The KL divergence of the sampled weigth and bias can also be computed if specified. Args: x (torch.Tensor): Input to be transformed kl (bool, optional): Whether to compute the KL divergence. Defaults to True. frozen (bool, optional): Whether to freeze current parameters. Defaults to False. Returns: Tuple[torch.Tensor, Optional[torch.Tensor]]: The transformed input and optionally the computed KL divergence value. """ if frozen: kl_val = None w = self.w_mu b = self.b_mu else: b = None w_dist = torch.distributions.Normal(self.w_mu, self.w_sigma) w = w_dist.rsample() kl_val = ( torch.sum( w_dist.log_prob(w) - torch.distributions.Normal(0, 0.1).log_prob(w) ) if kl else None ) if self.bias: b_dist = torch.distributions.Normal(self.b_mu, self.b_sigma) b = b_dist.rsample() if kl: kl_val += torch.sum( b_dist.log_prob(b) - torch.distributions.Normal(0, 0.1).log_prob(b) ) return F.linear(x, w, b), kl_val
[docs]class BayesianNNBanditModel(Model): """Bayesian Neural Network used in Deep Contextual Bandit Models. Args: context_dim (int): Length of context vector. hidden_dims (List[int], optional): Dimensions of hidden layers of network. n_actions (int): Number of actions that can be selected. Taken as length of output vector for network to predict. init_lr (float, optional): Initial learning rate. max_grad_norm (float, optional): Maximum norm of gradients for gradient clipping. lr_decay (float, optional): Decay rate for learning rate. lr_reset (bool, optional): Whether to reset learning rate ever train interval. Defaults to False. dropout_p (Optional[float], optional): Probability for dropout. Defaults to None which implies dropout is not to be used. noise_std (float): Standard deviation of noise used in the network. Defaults to 0.1 Attributes: use_dropout (int): Indicated whether or not dropout should be used in forward pass. """ def __init__(self, **kwargs): super(BayesianNNBanditModel, self).__init__(BayesianLinear, **kwargs) self.noise_std = kwargs.get("noise_std", 0.1)
[docs] def forward( self, context: torch.Tensor, kl: bool = True ) -> Dict[str, torch.Tensor]: """Computes forward pass through the network. Args: context (torch.Tensor): The context vector to perform forward pass on. Returns: Dict[str, torch.Tensor]: Dictionary of outputs """ kl_val = 0.0 x = context for layer in self.layers[:-1]: x, kl_v = layer(x) x = F.relu(x) if self.dropout_p is not None and self.use_dropout is True: x = F.dropout(x, p=self.dropout_p) kl_val += kl_v pred_rewards, kl_v = self.layers[-1](x) kl_val += kl_v kl_val = kl_val if kl else None return dict(x=x, pred_rewards=pred_rewards, kl_val=kl_val)
def _compute_loss( self, db: TransitionDB, x: torch.Tensor, action_mask: torch.Tensor, reward_vec: torch.Tensor, batch_size: int, ) -> torch.Tensor: """Computes loss for the model Args: db (TransitionDB): The database of transitions to train on. x (torch.Tensor): Context. action_mask (torch.Tensor): Mask of actions taken. reward_vec (torch.Tensor): Reward vector recieved. batch_size (int): The size of each batch to perform gradient descent on. Returns: torch.Tensor: The computed loss. """ results = self.forward(x) pred_rewards = results["pred_rewards"] kl_val = results["kl_val"] log_likelihood = torch.distributions.Normal( pred_rewards, self.noise_std ).log_prob(reward_vec) loss = torch.sum(action_mask * log_likelihood) / batch_size - ( kl_val / db.db_size ) return loss