Source code for genrl.agents.deep.dqn.categorical

import collections
from typing import Tuple

import torch

from genrl.agents.deep.dqn.base import DQN
from genrl.agents.deep.dqn.utils import (

[docs]class CategoricalDQN(DQN): """Categorical DQN Algorithm Paper: Attributes: network (str): The network type of the Q-value function. Supported types: ["cnn", "mlp"] env (Environment): The environment that the agent is supposed to act on create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the Q-value function lr_value (float): Learning rate for the Q-value function replay_size (int): Capacity of the Replay Buffer buffer_type (str): Choose the type of Buffer: ["push", "prioritized"] max_epsilon (str): Maximum epsilon for exploration min_epsilon (str): Minimum epsilon for exploration epsilon_decay (str): Rate of decay of epsilon (in order to decrease exploration with time) noisy_layers (:obj:`tuple` of :obj:`int`): Noisy layers in the Neural Network of the Q-value function num_atoms (int): Number of atoms used in the discrete distribution v_min (int): Lower bound of value distribution v_max (int): Upper bound of value distribution seed (int): Seed for randomness render (bool): Should the env be rendered during training? device (str): Hardware being used for training. Options: ["cuda" -> GPU, "cpu" -> CPU] """ def __init__( self, *args, noisy_layers: Tuple = (32, 128), num_atoms: int = 51, v_min: int = -10, v_max: int = 10, **kwargs ): super(CategoricalDQN, self).__init__(*args, **kwargs) self.noisy_layers = noisy_layers self.num_atoms = num_atoms self.v_min = v_min self.v_max = v_max self.dqn_type = "categorical" self.noisy = True self.empty_logs() if self.create_model: self._create_model(noisy_layers=self.noisy_layers, num_atoms=self.num_atoms)
[docs] def get_greedy_action(self, state: torch.Tensor) -> torch.Tensor: """Greedy action selection Args: state (:obj:`torch.Tensor`): Current state of the environment Returns: action (:obj:`torch.Tensor`): Action taken by the agent """ return categorical_greedy_action(self, state)
[docs] def get_q_values(self, states: torch.Tensor, actions: torch.Tensor): """Get Q values corresponding to specific states and actions Args: states (:obj:`torch.Tensor`): States for which Q-values need to be found actions (:obj:`torch.Tensor`): Actions taken at respective states Returns: q_values (:obj:`torch.Tensor`): Q values for the given states and actions """ return categorical_q_values(self, states, actions)
[docs] def get_target_q_values( self, next_states: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor ): """Projected Distribution of Q-values Helper function for Categorical/Distributional DQN Args: next_states (:obj:`torch.Tensor`): Next states being encountered by the agent rewards (:obj:`torch.Tensor`): Rewards received by the agent dones (:obj:`torch.Tensor`): Game over status of each environment Returns: target_q_values (object): Projected Q-value Distribution or Target Q Values """ return categorical_q_target(self, next_states, rewards, dones)
[docs] def get_q_loss(self, batch: collections.namedtuple): """Categorical DQN loss function to calculate the loss of the Q-function Args: batch (:obj:`collections.namedtuple` of :obj:`torch.Tensor`): Batch of experiences Returns: loss (:obj:`torch.Tensor`): Calculateed loss of the Q-function """ return categorical_q_loss(self, batch)