from typing import Tuple
import torch # noqa
import torch.nn as nn # noqa
from gym import spaces
from torch.distributions import Categorical, Normal
from genrl.core.base import BaseActorCritic
from genrl.core.policies import MlpPolicy
from genrl.core.values import MlpValue
from genrl.utils.utils import cnn, mlp
[docs]class MlpActorCritic(BaseActorCritic):
"""MLP Actor Critic
Attributes:
state_dim (int): State dimensions of the environment
action_dim (int): Action space dimensions of the environment
policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP
value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP
val_type (str): Value type of the critic network
discrete (bool): True if the action space is discrete, else False
sac (bool): True if a SAC-like network is needed, else False
activation (str): Activation function to be used. Can be either "tanh" or "relu"
"""
def __init__(
self,
state_dim: spaces.Space,
action_dim: spaces.Space,
shared_layers: None,
policy_layers: Tuple = (32, 32),
value_layers: Tuple = (32, 32),
val_type: str = "V",
discrete: bool = True,
**kwargs,
):
super(MlpActorCritic, self).__init__()
self.actor = MlpPolicy(state_dim, action_dim, policy_layers, discrete, **kwargs)
self.critic = MlpValue(state_dim, action_dim, val_type, value_layers, **kwargs)
[docs] def get_params(self):
actor_params = self.actor.parameters()
critic_params = self.critic.parameters()
return actor_params, critic_params
[docs]class MlpSharedActorCritic(BaseActorCritic):
"""MLP Shared Actor Critic
Attributes:
state_dim (int): State dimensions of the environment
action_dim (int): Action space dimensions of the environment
shared_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the shared MLP
policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP
value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP
val_type (str): Value type of the critic network
discrete (bool): True if the action space is discrete, else False
sac (bool): True if a SAC-like network is needed, else False
activation (str): Activation function to be used. Can be either "tanh" or "relu"
"""
def __init__(
self,
state_dim: spaces.Space,
action_dim: spaces.Space,
shared_layers: Tuple = (32, 32),
policy_layers: Tuple = (32, 32),
value_layers: Tuple = (32, 32),
val_type: str = "V",
discrete: bool = True,
**kwargs,
):
super(MlpSharedActorCritic, self).__init__()
self.shared_network = mlp([state_dim] + list(shared_layers))
self.actor = MlpPolicy(
shared_layers[-1], action_dim, policy_layers, discrete, **kwargs
)
self.critic = MlpValue(
shared_layers[-1], action_dim, val_type, value_layers, **kwargs
)
self.state_dim = state_dim
self.action_dim = action_dim
[docs] def get_params(self):
actor_params = list(self.shared_network.parameters()) + list(
self.actor.parameters()
)
critic_params = list(self.shared_network.parameters()) + list(
self.critic.parameters()
)
return actor_params, critic_params
[docs] def get_features(self, state: torch.Tensor):
"""Extract features from the state, which is then an input to get_action and get_value
Args:
state (:obj:`torch.Tensor`): The state(s) being passed
Returns:
features (:obj:`torch.Tensor`): The feature(s) extracted from the state
"""
features = self.shared_network(state)
return features
[docs] def get_action(self, state: torch.Tensor, deterministic: bool = False):
"""Get Actions from the actor
Arg:
state (:obj:`torch.Tensor`): The state(s) being passed to the critics
deterministic (bool): True if the action space is deterministic, else False
Returns:
action (:obj:`list`): List of actions as estimated by the critic
distribution (): The distribution from which the action was sampled
(None if determinist
"""
state = torch.as_tensor(state).float()
shared_features = self.get_features(state)
action_probs = self.actor(shared_features)
action_probs = nn.Softmax(dim=-1)(action_probs)
if deterministic:
action = torch.argmax(action_probs, dim=-1).unsqueeze(-1).float()
distribution = None
else:
distribution = Categorical(probs=action_probs)
action = distribution.sample()
return action, distribution
[docs] def get_value(self, state: torch.Tensor):
"""Get Values from the Critic
Arg:
state (:obj:`torch.Tensor`): The state(s) being passed to the critics
Returns:
values (:obj:`list`): List of values as estimated by the critic
"""
state = torch.as_tensor(state).float()
if self.critic.val_type == "Qsa":
# state shape = [batch_size, number of vec envs, (state_dim + action_dim)]
# extract shared_features from just the state
# state[:, :, :-action_dim] -> [batch_size, number of vec envs, state_dim]
shared_features = self.shared_network(state[:, :, : -self.action_dim])
# concatenate the actions to the extracted shared_features
# state[:, :, -action_dim:] -> [batch_size, number of vec envs, action_dim]
shared_features = torch.cat(
[shared_features, state[:, :, -self.action_dim :]], dim=-1
)
value = self.critic(shared_features).float().squeeze(-1)
else:
shared_features = self.shared_network(state)
value = self.critic(shared_features)
return value
[docs]class MlpSingleActorTwoCritic(BaseActorCritic):
"""MLP Actor Critic
Attributes:
state_dim (int): State dimensions of the environment
action_dim (int): Action space dimensions of the environment
policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP
value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP
val_type (str): Value type of the critic network
discrete (bool): True if the action space is discrete, else False
num_critics (int): Number of critics in the architecture
sac (bool): True if a SAC-like network is needed, else False
activation (str): Activation function to be used. Can be either "tanh" or "relu"
"""
def __init__(
self,
state_dim: spaces.Space,
action_dim: spaces.Space,
policy_layers: Tuple = (32, 32),
value_layers: Tuple = (32, 32),
val_type: str = "V",
discrete: bool = True,
num_critics: int = 2,
**kwargs,
):
super(MlpSingleActorTwoCritic, self).__init__()
self.num_critics = num_critics
self.actor = MlpPolicy(state_dim, action_dim, policy_layers, discrete, **kwargs)
self.critic1 = MlpValue(state_dim, action_dim, "Qsa", value_layers, **kwargs)
self.critic2 = MlpValue(state_dim, action_dim, "Qsa", value_layers, **kwargs)
self.action_scale = kwargs["action_scale"] if "action_scale" in kwargs else 1
self.action_bias = kwargs["action_bias"] if "action_bias" in kwargs else 0
[docs] def get_params(self):
actor_params = self.actor.parameters()
critic_params = list(self.critic1.parameters()) + list(
self.critic2.parameters()
)
return actor_params, critic_params
[docs] def forward(self, x):
q1_values = self.critic1(x).squeeze(-1)
q2_values = self.critic2(x).squeeze(-1)
return (q1_values, q2_values)
[docs] def get_action(self, state: torch.Tensor, deterministic: bool = False):
"""Get Actions from the actor
Arg:
state (:obj:`torch.Tensor`): The state(s) being passed to the critics
deterministic (bool): True if the action space is deterministic, else False
Returns:
action (:obj:`list`): List of actions as estimated by the critic
distribution (): The distribution from which the action was sampled
(None if determinist
"""
state = torch.as_tensor(state).float()
if self.actor.sac:
mean, log_std = self.actor(state)
std = log_std.exp()
distribution = Normal(mean, std)
action_probs = distribution.rsample()
log_probs = distribution.log_prob(action_probs)
action_probs = torch.tanh(action_probs)
action = action_probs * self.action_scale + self.action_bias
# enforcing action bound (appendix of SAC paper)
log_probs -= torch.log(
self.action_scale * (1 - action_probs.pow(2))
+ torch.finfo(torch.float32).eps
)
log_probs = log_probs.sum(1, keepdim=True)
mean = torch.tanh(mean) * self.action_scale + self.action_bias
action = (action.float(), log_probs, mean)
else:
action = self.actor.get_action(state, deterministic=deterministic)
return action
[docs] def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor:
"""Get Values from the Critic
Arg:
state (:obj:`torch.Tensor`): The state(s) being passed to the critics
mode (str): What values should be returned. Types:
"both" --> Both values will be returned
"min" --> The minimum of both values will be returned
"first" --> The value from the first critic only will be returned
Returns:
values (:obj:`list`): List of values as estimated by each individual critic
"""
state = torch.as_tensor(state).float()
if mode == "both":
values = self.forward(state)
elif mode == "min":
values = self.forward(state)
values = torch.min(*values).squeeze(-1)
elif mode == "first":
values = self.critic1(state)
else:
raise KeyError("Mode doesn't exist")
return values
[docs]class MlpSharedSingleActorTwoCritic(MlpSingleActorTwoCritic):
"""MLP Actor Critic
Attributes:
state_dim (int): State dimensions of the environment
action_dim (int): Action space dimensions of the environment
shared_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the shared MLP
policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP
value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP
val_type (str): Value type of the critic network
discrete (bool): True if the action space is discrete, else False
num_critics (int): Number of critics in the architecture
sac (bool): True if a SAC-like network is needed, else False
activation (str): Activation function to be used. Can be either "tanh" or "relu"
"""
def __init__(
self,
state_dim: spaces.Space,
action_dim: spaces.Space,
shared_layers: Tuple = (32, 32),
policy_layers: Tuple = (32, 32),
value_layers: Tuple = (32, 32),
val_type: str = "Qsa",
discrete: bool = True,
num_critics: int = 2,
**kwargs,
):
super(MlpSharedSingleActorTwoCritic, self).__init__(
shared_layers[-1],
action_dim,
policy_layers,
value_layers,
val_type,
discrete,
num_critics,
**kwargs,
)
self.shared_network = mlp([state_dim] + list(shared_layers))
self.action_dim = action_dim
[docs] def get_params(self):
actor_params = list(self.shared_network.parameters()) + list(
self.actor.parameters()
)
critic_params = (
list(self.shared_network.parameters())
+ list(self.critic1.parameters())
+ list(self.critic2.parameters())
)
return actor_params, critic_params
[docs] def get_features(self, state: torch.Tensor):
"""Extract features from the state, which is then an input to get_action and get_value
Args:
state (:obj:`torch.Tensor`): The state(s) being passed
Returns:
features (:obj:`torch.Tensor`): The feature(s) extracted from the state
"""
features = self.shared_network(state)
return features
[docs] def get_action(self, state: torch.Tensor, deterministic: bool = False):
"""Get Actions from the actor
Arg:
state (:obj:`torch.Tensor`): The state(s) being passed to the critics
deterministic (bool): True if the action space is deterministic, else False
Returns:
action (:obj:`list`): List of actions as estimated by the critic
distribution (): The distribution from which the action was sampled
(None if deterministic)
"""
return super(MlpSharedSingleActorTwoCritic, self).get_action(
self.get_features(state), deterministic=deterministic
)
[docs] def get_value(self, state: torch.Tensor, mode="first"):
"""Get Values from both the Critic
Arg:
state (:obj:`torch.Tensor`): The state(s) being passed to the critics
mode (str): What values should be returned. Types:
"both" --> Both values will be returned
"min" --> The minimum of both values will be returned
"first" --> The value from the first critic only will be returned
Returns:
values (:obj:`list`): List of values as estimated by each individual critic
"""
state = torch.as_tensor(state).float()
# state shape = [batch_size, number of vec envs, (state_dim + action_dim)]
# extract shard features for just the state
# state[:, :, :-action_dim] -> [batch_size, number of vec envs, state_dim]
x = self.get_features(state[:, :, : -self.action_dim])
# concatenate the actions to the extracted shared features
# state[:, :, -action_dim:] -> [batch_size, number of vec envs, action_dim]
state = torch.cat([x, state[:, :, -self.action_dim :]], dim=-1)
return super(MlpSharedSingleActorTwoCritic, self).get_value(state, mode)
[docs]class CNNActorCritic(BaseActorCritic):
"""
CNN Actor Critic
:param framestack: Number of previous frames to stack together
:param action_dim: Action dimensions of the environment
:param fc_layers: Sizes of hidden layers
:param val_type: Specifies type of value function: (
"V" for V(s), "Qs" for Q(s), "Qsa" for Q(s,a))
:param discrete: True if action space is discrete, else False
:param framestack: Number of previous frames to stack together
:type action_dim: int
:type fc_layers: tuple or list
:type val_type: str
:type discrete: bool
"""
def __init__(
self,
framestack: int,
action_dim: spaces.Space,
policy_layers: Tuple = (256,),
value_layers: Tuple = (256,),
val_type: str = "V",
discrete: bool = True,
*args,
**kwargs,
):
super(CNNActorCritic, self).__init__()
self.feature, output_size = cnn((framestack, 16, 32))
self.actor = MlpPolicy(
output_size, action_dim, policy_layers, discrete, **kwargs
)
self.critic = MlpValue(output_size, action_dim, val_type, value_layers)
[docs] def get_params(self):
actor_params = list(self.feature.parameters()) + list(self.actor.parameters())
critic_params = list(self.feature.parameters()) + list(self.critic.parameters())
return actor_params, critic_params
[docs] def get_action(
self, state: torch.Tensor, deterministic: bool = False
) -> torch.Tensor:
"""
Get action from the Actor based on input
:param state: The state being passed as input to the Actor
:param deterministic: (True if the action space is deterministic,
else False)
:type state: Tensor
:type deterministic: boolean
:returns: action
"""
state = self.feature(state)
state = state.view(state.size(0), -1)
action_probs = self.actor(state)
action_probs = nn.Softmax(dim=-1)(action_probs)
if deterministic:
action = torch.argmax(action_probs, dim=-1)
distribution = None
else:
distribution = Categorical(probs=action_probs)
action = distribution.sample()
return action, distribution
[docs] def get_value(self, inp: torch.Tensor) -> torch.Tensor:
"""
Get value from the Critic based on input
:param inp: Input to the Critic
:type inp: Tensor
:returns: value
"""
inp = self.feature(inp)
inp = inp.view(inp.size(0), -1)
value = self.critic(inp).squeeze(-1)
return value
actor_critic_registry = {
"mlp": MlpActorCritic,
"cnn": CNNActorCritic,
"mlp12": MlpSingleActorTwoCritic,
"mlps": MlpSharedActorCritic,
"mlp12s": MlpSharedSingleActorTwoCritic,
}
[docs]def get_actor_critic_from_name(name_: str):
"""
Returns Actor Critic given the type of the Actor Critic
:param ac_name: Name of the policy needed
:type ac_name: str
:returns: Actor Critic class to be used
"""
if name_ in actor_critic_registry:
return actor_critic_registry[name_]
raise NotImplementedError