SAC¶
genrl.agents.deep.sac.sac module¶
-
class
genrl.agents.deep.sac.sac.
SAC
(*args, alpha: float = 0.01, polyak: float = 0.995, entropy_tuning: bool = True, **kwargs)[source]¶ Bases:
genrl.agents.deep.base.offpolicy.OffPolicyAgentAC
Soft Actor Critic algorithm (SAC)
Paper: https://arxiv.org/abs/1812.05905
-
network
¶ The network type of the Q-value function. Supported types: [“cnn”, “mlp”]
Type: str
-
env
¶ The environment that the agent is supposed to act on
Type: Environment
-
create_model
¶ Whether the model of the algo should be created when initialised
Type: bool
-
batch_size
¶ Mini batch size for loading experiences
Type: int
-
gamma
¶ The discount factor for rewards
Type: float
-
policy_layers
¶ Neural network layer dimensions for the policy
Type: tuple
ofint
-
value_layers
¶ Neural network layer dimensions for the critics
Type: tuple
ofint
-
lr_policy
¶ Learning rate for the policy/actor
Type: float
-
lr_value
¶ Learning rate for the critic
Type: float
-
replay_size
¶ Capacity of the Replay Buffer
Type: int
-
buffer_type
¶ Choose the type of Buffer: [“push”, “prioritized”]
Type: str
-
alpha
¶ Entropy factor
Type: str
-
polyak
¶ Target model update parameter (1 for hard update)
Type: float
-
entropy_tuning
¶ True if entropy tuning should be done, False otherwise
Type: bool
-
seed
¶ Seed for randomness
Type: int
-
render
¶ Should the env be rendered during training?
Type: bool
-
device
¶ Hardware being used for training. Options: [“cuda” -> GPU, “cpu” -> CPU]
Type: str
-
get_hyperparams
() → Dict[str, Any][source]¶ Get relevant hyperparameters to save
Returns: Hyperparameters to be saved Return type: hyperparams ( dict
)
-
get_logging_params
() → Dict[str, Any][source]¶ Gets relevant parameters for logging
Returns: Logging parameters for monitoring training Return type: logs ( dict
)
-
get_p_loss
(states: torch.Tensor) → torch.Tensor[source]¶ Function to get the Policy loss
Parameters: states ( torch.Tensor
) – States for which Q-values need to be foundReturns: Calculated policy loss Return type: loss ( torch.Tensor
)
-
get_target_q_values
(next_states: torch.Tensor, rewards: List[float], dones: List[bool]) → torch.Tensor[source]¶ Get target Q values for the SAC
Parameters: - next_states (
torch.Tensor
) – Next states for which target Q-values need to be found - rewards (
list
) – Rewards at each timestep for each environment - dones (
list
) – Game over status for each environment
Returns: Target Q values for the SAC
Return type: target_q_values (
torch.Tensor
)- next_states (
-
select_action
(state: numpy.ndarray, deterministic: bool = False) → numpy.ndarray[source]¶ Select action given state
Action Selection
Parameters: - state (
np.ndarray
) – Current state of the environment - deterministic (bool) – Should the policy be deterministic or stochastic
Returns: Action taken by the agent
Return type: action (
np.ndarray
)- state (
-