Source code for genrl.agents.deep.sac.sac

from copy import deepcopy
from typing import Any, Dict, List

import torch
import torch.optim as opt

from genrl.agents import OffPolicyAgentAC
from genrl.utils import get_env_properties, get_model, safe_mean


[docs]class SAC(OffPolicyAgentAC): """Soft Actor Critic algorithm (SAC) Paper: https://arxiv.org/abs/1812.05905 Attributes: network (str): The network type of the Q-value function. Supported types: ["cnn", "mlp"] env (Environment): The environment that the agent is supposed to act on create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the critic replay_size (int): Capacity of the Replay Buffer buffer_type (str): Choose the type of Buffer: ["push", "prioritized"] alpha (str): Entropy factor polyak (float): Target model update parameter (1 for hard update) entropy_tuning (bool): True if entropy tuning should be done, False otherwise seed (int): Seed for randomness render (bool): Should the env be rendered during training? device (str): Hardware being used for training. Options: ["cuda" -> GPU, "cpu" -> CPU] """ def __init__( self, *args, alpha: float = 0.01, polyak: float = 0.995, entropy_tuning: bool = True, **kwargs, ): super(SAC, self).__init__(*args, **kwargs) self.alpha = alpha self.polyak = polyak self.entropy_tuning = entropy_tuning self.doublecritic = True self.empty_logs() if self.create_model: self._create_model() def _create_model(self, **kwargs) -> None: """Initializes class objects Initializes actor-critic architecture, replay buffer and optimizers """ if self.env.action_space is None: self.action_scale = torch.FloatTensor(1.0) self.action_bias = torch.FloatTensor(0.0) else: self.action_scale = torch.FloatTensor( (self.env.action_space.high - self.env.action_space.low) / 2.0 ) self.action_bias = torch.FloatTensor( (self.env.action_space.high + self.env.action_space.low) / 2.0 ) if isinstance(self.network, str): state_dim, action_dim, discrete, _ = get_env_properties( self.env, self.network ) arch = self.network + "12" if self.shared_layers is not None: arch += "s" self.ac = get_model("ac", arch)( state_dim, action_dim, policy_layers=self.policy_layers, value_layers=self.value_layers, val_type="Qsa", discrete=False, sac=True, action_scale=self.action_scale, action_bias=self.action_bias, ) else: self.model = self.network self.ac_target = deepcopy(self.ac) actor_params, critic_params = self.ac.get_params() self.optimizer_value = opt.Adam(critic_params, self.lr_value) self.optimizer_policy = opt.Adam(actor_params, self.lr_policy) if self.entropy_tuning: self.target_entropy = -torch.prod( torch.Tensor(self.env.action_space.shape) ).item() self.log_alpha = torch.zeros(1, requires_grad=True) self.optimizer_alpha = opt.Adam([self.log_alpha], lr=self.lr_policy)
[docs] def select_action( self, state: torch.Tensor, deterministic: bool = False ) -> torch.Tensor: """Select action given state Action Selection Args: state (:obj:`np.ndarray`): Current state of the environment deterministic (bool): Should the policy be deterministic or stochastic Returns: action (:obj:`np.ndarray`): Action taken by the agent """ action, _, _ = self.ac.get_action(state, deterministic) return action.detach()
[docs] def update_target_model(self) -> None: """Function to update the target Q model Updates the target model with the training model's weights when called """ for param, param_target in zip( self.ac.parameters(), self.ac_target.parameters() ): param_target.data.mul_(self.polyak) param_target.data.add_((1 - self.polyak) * param.data)
[docs] def get_target_q_values( self, next_states: torch.Tensor, rewards: List[float], dones: List[bool] ) -> torch.Tensor: """Get target Q values for the SAC Args: next_states (:obj:`torch.Tensor`): Next states for which target Q-values need to be found rewards (:obj:`list`): Rewards at each timestep for each environment dones (:obj:`list`): Game over status for each environment Returns: target_q_values (:obj:`torch.Tensor`): Target Q values for the SAC """ next_target_actions, next_log_probs, _ = self.ac.get_action(next_states) next_q_target_values = self.ac_target.get_value( torch.cat([next_states, next_target_actions], dim=-1), mode="min" ).squeeze() - self.alpha * next_log_probs.squeeze(1) target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values return target_q_values
[docs] def get_p_loss(self, states: torch.Tensor) -> torch.Tensor: """Function to get the Policy loss Args: states (:obj:`torch.Tensor`): States for which Q-values need to be found Returns: loss (:obj:`torch.Tensor`): Calculated policy loss """ next_best_actions, log_probs, _ = self.ac.get_action(states) q_values = self.ac.get_value( torch.cat([states, next_best_actions], dim=-1), mode="min" ) policy_loss = ((self.alpha * log_probs) - q_values).mean() return policy_loss, log_probs
[docs] def get_alpha_loss(self, log_probs): """Calculate Entropy Loss Args: log_probs (float): Log probs """ if self.entropy_tuning: alpha_loss = -torch.mean( self.log_alpha * (log_probs + self.target_entropy).detach() ) else: alpha_loss = torch.FloatTensor(0.0) return alpha_loss
[docs] def update_params(self, update_interval: int) -> None: """Update parameters of the model Args: update_interval (int): Interval between successive updates of the target model """ for timestep in range(update_interval): batch = self.sample_from_buffer() value_loss = self.get_q_loss(batch) self.logs["value_loss"].append(value_loss.item()) policy_loss, log_probs = self.get_p_loss(batch.states) self.logs["policy_loss"].append(policy_loss.item()) alpha_loss = self.get_alpha_loss(log_probs) self.logs["alpha_loss"].append(alpha_loss.item()) self.optimizer_value.zero_grad() value_loss.backward() self.optimizer_value.step() self.optimizer_policy.zero_grad() policy_loss.backward() self.optimizer_policy.step() self.optimizer_alpha.zero_grad() alpha_loss.backward() self.optimizer_alpha.step() self.update_target_model()
[docs] def get_hyperparams(self) -> Dict[str, Any]: """Get relevant hyperparameters to save Returns: hyperparams (:obj:`dict`): Hyperparameters to be saved weights (:obj:`torch.Tensor`): Neural network weights """ hyperparams = { "network": self.network, "gamma": self.gamma, "lr_value": self.lr_value, "lr_policy": self.lr_policy, "replay_size": self.replay_size, "entropy_tuning": self.entropy_tuning, "alpha": self.alpha, "polyak": self.polyak, } return hyperparams, self.ac.state_dict()
[docs] def get_logging_params(self) -> Dict[str, Any]: """Gets relevant parameters for logging Returns: logs (:obj:`dict`): Logging parameters for monitoring training """ logs = { "policy_loss": safe_mean(self.logs["policy_loss"]), "value_loss": safe_mean(self.logs["value_loss"]), "alpha_loss": safe_mean(self.logs["alpha_loss"]), } self.empty_logs() return logs
[docs] def empty_logs(self): """Empties logs""" self.logs = {} self.logs["value_loss"] = [] self.logs["policy_loss"] = [] self.logs["alpha_loss"] = []