VPG¶
genrl.agents.deep.vpg.vpg module¶
-
class
genrl.agents.deep.vpg.vpg.
VPG
(*args, **kwargs)[source]¶ Bases:
genrl.agents.deep.base.onpolicy.OnPolicyAgent
Vanilla Policy Gradient algorithm
- network (str): The network type of the Q-value function.
- Supported types: [“cnn”, “mlp”]
env (Environment): The environment that the agent is supposed to act on create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards layers (
tuple
ofint
): Layers in the Neural Networkof the Q-value functionlr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the Q-value function rollout_size (int): Capacity of the Rollout Buffer buffer_type (str): Choose the type of Buffer: [“rollout”] seed (int): Seed for randomness render (bool): Should the env be rendered during training? device (str): Hardware being used for training. Options:
[“cuda” -> GPU, “cpu” -> CPU]-
get_hyperparams
() → Dict[str, Any][source]¶ Get relevant hyperparameters to save
Returns: Hyperparameters to be saved weights ( torch.Tensor
): Neural network weightsReturn type: hyperparams ( dict
)
-
get_log_probs
(states: torch.Tensor, actions: torch.Tensor)[source]¶ Get log probabilities of action values
Actions taken by actor and their respective states are analysed to get log probabilities
Parameters: - states (
torch.Tensor
) – States encountered in rollout - actions (
torch.Tensor
) – Actions taken in response to respective states
Returns: Log of action probabilities given a state
Return type: log_probs (
torch.Tensor
)- states (
-
get_logging_params
() → Dict[str, Any][source]¶ Gets relevant parameters for logging
Returns: Logging parameters for monitoring training Return type: logs ( dict
)
-
get_traj_loss
(values, dones)[source]¶ Get loss from trajectory traversed by agent during rollouts
Computes the returns and advantages needed for calculating loss
Parameters: - values (
torch.Tensor
) – Values of states encountered during the rollout - dones (
list
of bool) – Game over statuses of each environment
- values (
-
select_action
(state: torch.Tensor, deterministic: bool = False) → torch.Tensor[source]¶ Select action given state
Action Selection for Vanilla Policy Gradient
Parameters: - state (
np.ndarray
) – Current state of the environment - deterministic (bool) – Should the policy be deterministic or stochastic
Returns: Action taken by the agent value (
torch.Tensor
): Value of given state. In VPG, there is no criticto find the value so we set this to a default 0 for convenience
log_prob (
torch.Tensor
): Log probability of selected actionReturn type: action (
np.ndarray
)- state (