SAC

genrl.agents.deep.sac.sac module

class genrl.agents.deep.sac.sac.SAC(*args, alpha: float = 0.01, polyak: float = 0.995, entropy_tuning: bool = True, **kwargs)[source]

Bases: genrl.agents.deep.base.offpolicy.OffPolicyAgentAC

Soft Actor Critic algorithm (SAC)

Paper: https://arxiv.org/abs/1812.05905

network

The network type of the Q-value function. Supported types: [“cnn”, “mlp”]

Type:str
env

The environment that the agent is supposed to act on

Type:Environment
create_model

Whether the model of the algo should be created when initialised

Type:bool
batch_size

Mini batch size for loading experiences

Type:int
gamma

The discount factor for rewards

Type:float
policy_layers

Neural network layer dimensions for the policy

Type:tuple of int
value_layers

Neural network layer dimensions for the critics

Type:tuple of int
lr_policy

Learning rate for the policy/actor

Type:float
lr_value

Learning rate for the critic

Type:float
replay_size

Capacity of the Replay Buffer

Type:int
buffer_type

Choose the type of Buffer: [“push”, “prioritized”]

Type:str
alpha

Entropy factor

Type:str
polyak

Target model update parameter (1 for hard update)

Type:float
entropy_tuning

True if entropy tuning should be done, False otherwise

Type:bool
seed

Seed for randomness

Type:int
render

Should the env be rendered during training?

Type:bool
device

Hardware being used for training. Options: [“cuda” -> GPU, “cpu” -> CPU]

Type:str
empty_logs()[source]

Empties logs

get_alpha_loss(log_probs)[source]

Calculate Entropy Loss

Parameters:log_probs (float) – Log probs
get_hyperparams() → Dict[str, Any][source]

Get relevant hyperparameters to save

Returns:Hyperparameters to be saved
Return type:hyperparams (dict)
get_logging_params() → Dict[str, Any][source]

Gets relevant parameters for logging

Returns:Logging parameters for monitoring training
Return type:logs (dict)
get_p_loss(states: torch.Tensor) → torch.Tensor[source]

Function to get the Policy loss

Parameters:states (torch.Tensor) – States for which Q-values need to be found
Returns:Calculated policy loss
Return type:loss (torch.Tensor)
get_target_q_values(next_states: torch.Tensor, rewards: List[float], dones: List[bool]) → torch.Tensor[source]

Get target Q values for the SAC

Parameters:
  • next_states (torch.Tensor) – Next states for which target Q-values need to be found
  • rewards (list) – Rewards at each timestep for each environment
  • dones (list) – Game over status for each environment
Returns:

Target Q values for the SAC

Return type:

target_q_values (torch.Tensor)

select_action(state: numpy.ndarray, deterministic: bool = False) → numpy.ndarray[source]

Select action given state

Action Selection

Parameters:
  • state (np.ndarray) – Current state of the environment
  • deterministic (bool) – Should the policy be deterministic or stochastic
Returns:

Action taken by the agent

Return type:

action (np.ndarray)

update_params(update_interval: int) → None[source]

Update parameters of the model

Parameters:update_interval (int) – Interval between successive updates of the target model
update_target_model() → None[source]

Function to update the target Q model

Updates the target model with the training model’s weights when called