Source code for genrl.agents.deep.base.base

from abc import ABC
from typing import Any, Dict, Tuple

import numpy as np
import torch

from genrl.utils import set_seeds


[docs]class BaseAgent(ABC): """Base Agent Class Attributes: network (str): The network type of the Q-value function. Supported types: ["cnn", "mlp"] env (Environment): The environment that the agent is supposed to act on create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the Q-value function lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the Q-value function seed (int): Seed for randomness render (bool): Should the env be rendered during training? device (str): Hardware being used for training. Options: ["cuda" -> GPU, "cpu" -> CPU] """ def __init__( self, network: Any, env: Any, create_model: bool = True, batch_size: int = 64, gamma: float = 0.99, shared_layers=None, policy_layers: Tuple = (64, 64), value_layers: Tuple = (64, 64), lr_policy: float = 0.0001, lr_value: float = 0.001, **kwargs ): self.network = network self.env = env self.create_model = create_model self.batch_size = batch_size self.gamma = gamma self.shared_layers = shared_layers self.policy_layers = policy_layers self.rewards = [] self.value_layers = value_layers self.lr_policy = lr_policy self.lr_value = lr_value self.seed = kwargs["seed"] if "seed" in kwargs else None self.render = kwargs["render"] if "render" in kwargs else False # Assign device device = kwargs["device"] if "device" in kwargs else "cpu" if "cuda" in device and torch.cuda.is_available(): self.device = torch.device(device) else: self.device = torch.device("cpu") # Assign seed if self.seed is not None: set_seeds(self.seed, self.env) def _create_model(self) -> None: """Function to initialize all models of the agent""" raise NotImplementedError
[docs] def select_action( self, state: np.ndarray, deterministic: bool = False ) -> np.ndarray: """Select action given state Action selection method Args: state (:obj:`np.ndarray`): Current state of the environment deterministic (bool): Should the policy be deterministic or stochastic Returns: action (:obj:`np.ndarray`): Action taken by the agent """ raise NotImplementedError
[docs] def get_hyperparams(self) -> Dict[str, Any]: """Get relevant hyperparameters to save Returns: hyperparams (:obj:`dict`): Hyperparameters to be saved """ raise NotImplementedError
def _load_weights(self, weights) -> None: """Load weights for the agent from pretrained model Args: weights (:obj:`torch.tensor`): neural net weights """ raise NotImplementedError
[docs] def get_logging_params(self) -> Dict[str, Any]: """Gets relevant parameters for logging Returns: logs (:obj:`dict`): Logging parameters for monitoring training """ raise NotImplementedError
[docs] def empty_logs(self): """Empties logs""" raise NotImplementedError