Source code for genrl.agents.deep.td3.td3

from copy import deepcopy
from typing import Any, Dict

import numpy as np
import torch

from genrl.agents import OffPolicyAgentAC
from genrl.core import ActionNoise
from genrl.utils import get_env_properties, get_model, safe_mean


[docs]class TD3(OffPolicyAgentAC): """Twin Delayed DDPG Algorithm Paper: https://arxiv.org/abs/1509.02971 Attributes: network (str): The network type of the Q-value function. Supported types: ["cnn", "mlp"] env (Environment): The environment that the agent is supposed to act on create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the critic replay_size (int): Capacity of the Replay Buffer buffer_type (str): Choose the type of Buffer: ["push", "prioritized"] polyak (float): Target model update parameter (1 for hard update) policy_frequency (int): Frequency of policy updates in comparison to critic updates noise (:obj:`ActionNoise`): Action Noise function added to aid in exploration noise_std (float): Standard deviation of the action noise distribution seed (int): Seed for randomness render (bool): Should the env be rendered during training? device (str): Hardware being used for training. Options: ["cuda" -> GPU, "cpu" -> CPU] """ def __init__( self, *args, policy_frequency: int = 2, noise: ActionNoise = None, noise_std: float = 0.2, **kwargs, ): super(TD3, self).__init__(*args, **kwargs) self.policy_frequency = policy_frequency self.noise = noise self.noise_std = noise_std self.doublecritic = True self.empty_logs() if self.create_model: self._create_model() def _create_model(self) -> None: """Initializes class objects Initializes actor-critic architecture, replay buffer and optimizers """ state_dim, action_dim, discrete, _ = get_env_properties(self.env, self.network) if discrete: raise Exception( "Discrete Environments not supported for {}.".format(__class__.__name__) ) if isinstance(self.network, str): # Below, the "12" corresponds to the Single Actor, Double Critic network architecture self.ac = get_model("ac", self.network + "12")( state_dim, action_dim, policy_layers=self.policy_layers, value_layers=self.value_layers, val_type="Qsa", discrete=False, ) else: self.ac = self.network if self.noise is not None: self.noise = self.noise( np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim) ) self.ac_target = deepcopy(self.ac) self.critic_params = list(self.ac.critic1.parameters()) + list( self.ac.critic2.parameters() ) self.optimizer_value = torch.optim.Adam(self.critic_params, lr=self.lr_value) self.optimizer_policy = torch.optim.Adam( self.ac.actor.parameters(), lr=self.lr_policy )
[docs] def update_params(self, update_interval: int) -> None: """Update parameters of the model Args: update_interval (int): Interval between successive updates of the target model """ for timestep in range(update_interval): batch = self.sample_from_buffer() value_loss = self.get_q_loss(batch) self.optimizer_value.zero_grad() value_loss.backward() self.optimizer_value.step() # Delayed Update if timestep % self.policy_frequency == 0: policy_loss = self.get_p_loss(batch.states) self.optimizer_policy.zero_grad() policy_loss.backward() self.optimizer_policy.step() self.logs["policy_loss"].append(policy_loss.item()) self.logs["value_loss"].append(value_loss.item()) self.update_target_model()
[docs] def get_hyperparams(self) -> Dict[str, Any]: """Get relevant hyperparameters to save Returns: hyperparams (:obj:`dict`): Hyperparameters to be saved """ hyperparams = { "network": self.network, "gamma": self.gamma, "batch_size": self.batch_size, "replay_size": self.replay_size, "lr_policy": self.lr_policy, "lr_value": self.lr_value, "polyak": self.polyak, "policy_frequency": self.policy_frequency, "noise_std": self.noise_std, "weights": self.ac.state_dict(), } return hyperparams
[docs] def get_logging_params(self) -> Dict[str, Any]: """Gets relevant parameters for logging Returns: logs (:obj:`dict`): Logging parameters for monitoring training """ logs = { "policy_loss": safe_mean(self.logs["policy_loss"]), "value_loss": safe_mean(self.logs["value_loss"]), } self.empty_logs() return logs
[docs] def empty_logs(self): """Empties logs""" self.logs = {} self.logs["policy_loss"] = [] self.logs["value_loss"] = []