Source code for genrl.core.policies

from typing import Tuple

import numpy as np

from genrl.core.base import BasePolicy
from genrl.utils.utils import cnn, mlp


[docs]class MlpPolicy(BasePolicy): """ MLP Policy :param state_dim: State dimensions of the environment :param action_dim: Action dimensions of the environment :param hidden: Sizes of hidden layers :param discrete: True if action space is discrete, else False :type state_dim: int :type action_dim: int :type hidden: tuple or list :type discrete: bool """ def __init__( self, state_dim: int, action_dim: int, hidden: Tuple = (32, 32), discrete: bool = True, *args, **kwargs ): super(MlpPolicy, self).__init__( state_dim, action_dim, hidden, discrete, **kwargs ) self.activation = kwargs["activation"] if "activation" in kwargs else "relu" self.model = mlp( [state_dim] + list(hidden) + [action_dim], activation=self.activation, sac=self.sac, )
[docs]class CNNPolicy(BasePolicy): """ CNN Policy :param framestack: Number of previous frames to stack together :param action_dim: Action dimensions of the environment :param fc_layers: Sizes of hidden layers :param discrete: True if action space is discrete, else False :param channels: Channel sizes for cnn layers :type framestack: int :type action_dim: int :type fc_layers: tuple or list :type discrete: bool :type channels: list or tuple """ def __init__( self, framestack: int, action_dim: int, hidden: Tuple = (32, 32), discrete: bool = True, *args, **kwargs ): super(CNNPolicy, self).__init__( framestack, action_dim, hidden, discrete, **kwargs ) channels = (framestack, 16, 32) self.conv, output_size = cnn(channels) self.fc = mlp([output_size] + list(hidden) + [action_dim], sac=self.sac)
[docs] def forward(self, state: np.ndarray) -> np.ndarray: state = self.conv(state) state = state.view(state.size(0), -1) action = self.fc(state) return action
policy_registry = {"mlp": MlpPolicy, "cnn": CNNPolicy}
[docs]def get_policy_from_name(name_: str): """ Returns policy given the name of the policy :param name_: Name of the policy needed :type name_: str :returns: Policy Function to be used """ if name_ in policy_registry: return policy_registry[name_] raise NotImplementedError