Source code for genrl.agents.bandits.contextual.base

from typing import Optional

import torch

from genrl.core import Bandit, BanditAgent


[docs]class DCBAgent(BanditAgent): """Base class for deep contextual bandit solving agents Args: bandit (gennav.deep.bandit.data_bandits.DataBasedBandit): The bandit to solve device (str): Device to use for tensor operations. "cpu" for cpu or "cuda" for cuda. Defaults to "cpu". Attributes: bandit (gennav.deep.bandit.data_bandits.DataBasedBandit): The bandit to solve device (torch.device): Device to use for tensor operations. """ def __init__(self, bandit: Bandit, device: str = "cpu", **kwargs): super(DCBAgent, self).__init__() if "cuda" in device and torch.cuda.is_available(): self.device = torch.device(device) else: self.device = torch.device("cpu") self._bandit = bandit self.context_dim = self._bandit.context_dim self.n_actions = self._bandit.n_actions self._action_hist = [] self.init_pulls = kwargs.get("init_pulls", 3)
[docs] def select_action(self, context: torch.Tensor) -> int: """Select an action based on given context Args: context (torch.Tensor): The context vector to select action for Note: This method needs to be implemented in the specific agent. Returns: int: The action to take """ raise NotImplementedError
[docs] def update_parameters( self, action: Optional[int] = None, batch_size: Optional[int] = None, train_epochs: Optional[int] = None, ) -> None: """Update parameters of the agent. Args: action (Optional[int], optional): Action to update the parameters for. Defaults to None. batch_size (Optional[int], optional): Size of batch to update parameters with. Defaults to None. train_epochs (Optional[int], optional): Epochs to train neural network for. Defaults to None. Note: This method needs to be implemented in the specific agent. """ raise NotImplementedError