genrl.agents.deep.vpg.vpg module

class genrl.agents.deep.vpg.vpg.VPG(*args, **kwargs)[source]

Bases: genrl.agents.deep.base.onpolicy.OnPolicyAgent

Vanilla Policy Gradient algorithm

Paper https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf

network (str): The network type of the Q-value function.
Supported types: [“cnn”, “mlp”]

env (Environment): The environment that the agent is supposed to act on create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards layers (tuple of int): Layers in the Neural Network

of the Q-value function

lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the Q-value function rollout_size (int): Capacity of the Rollout Buffer buffer_type (str): Choose the type of Buffer: [“rollout”] seed (int): Seed for randomness render (bool): Should the env be rendered during training? device (str): Hardware being used for training. Options:

[“cuda” -> GPU, “cpu” -> CPU]

Empties logs

get_hyperparams() → Dict[str, Any][source]

Get relevant hyperparameters to save

Returns:Hyperparameters to be saved weights (torch.Tensor): Neural network weights
Return type:hyperparams (dict)
get_log_probs(states: torch.Tensor, actions: torch.Tensor)[source]

Get log probabilities of action values

Actions taken by actor and their respective states are analysed to get log probabilities

  • states (torch.Tensor) – States encountered in rollout
  • actions (torch.Tensor) – Actions taken in response to respective states

Log of action probabilities given a state

Return type:

log_probs (torch.Tensor)

get_logging_params() → Dict[str, Any][source]

Gets relevant parameters for logging

Returns:Logging parameters for monitoring training
Return type:logs (dict)
get_traj_loss(values, dones)[source]

Get loss from trajectory traversed by agent during rollouts

Computes the returns and advantages needed for calculating loss

  • values (torch.Tensor) – Values of states encountered during the rollout
  • dones (list of bool) – Game over statuses of each environment
select_action(state: torch.Tensor, deterministic: bool = False) → torch.Tensor[source]

Select action given state

Action Selection for Vanilla Policy Gradient

  • state (np.ndarray) – Current state of the environment
  • deterministic (bool) – Should the policy be deterministic or stochastic

Action taken by the agent value (torch.Tensor): Value of given state. In VPG, there is no critic

to find the value so we set this to a default 0 for convenience

log_prob (torch.Tensor): Log probability of selected action

Return type:

action (np.ndarray)

update_params() → None[source]

Updates the the A2C network

Function to update the A2C actor-critic architecture