import collections
from typing import List
import numpy as np
import torch
from torch.nn import functional as F
from genrl.agents.deep.base import BaseAgent
from genrl.core import (
PrioritizedBuffer,
PrioritizedReplayBufferSamples,
PushReplayBuffer,
ReplayBufferSamples,
)
[docs]class OffPolicyAgent(BaseAgent):
"""Off Policy Agent Base Class
Attributes:
network (str): The network type of the Q-value function.
Supported types: ["cnn", "mlp"]
env (Environment): The environment that the agent is supposed to act on
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
buffer_type (str): Choose the type of Buffer: ["push", "prioritized"]
seed (int): Seed for randomness
render (bool): Should the env be rendered during training?
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""
def __init__(
self, *args, replay_size: int = 5000, buffer_type: str = "push", **kwargs
):
super(OffPolicyAgent, self).__init__(*args, **kwargs)
self.replay_size = replay_size
if buffer_type == "push":
self.replay_buffer = PushReplayBuffer(self.replay_size)
elif buffer_type == "prioritized":
self.replay_buffer = PrioritizedBuffer(self.replay_size)
else:
raise NotImplementedError
[docs] def update_params_before_select_action(self, timestep: int) -> None:
"""Update any parameters before selecting action like epsilon for decaying epsilon greedy
Args:
timestep (int): Timestep in the training process
"""
pass
[docs] def update_params(self, update_interval: int) -> None:
"""Update parameters of the model"""
raise NotImplementedError
[docs] def update_target_model(self) -> None:
"""Function to update the target Q model
Updates the target model with the training model's weights when called
"""
raise NotImplementedError
def _reshape_batch(self, batch: List):
"""Function to reshape experiences
Can be modified for individual algorithm usage
Args:
batch (:obj:`list`): List of experiences that are being replayed
Returns:
batch (:obj:`list`): Reshaped experiences for replay
"""
return [*batch]
[docs] def sample_from_buffer(self, beta: float = None):
"""Samples experiences from the buffer and converts them into usable formats
Args:
beta (float): Importance-Sampling beta for prioritized replay
Returns:
batch (:obj:`list`): Replay experiences sampled from the buffer
"""
# Samples from the buffer
if beta is not None:
batch = self.replay_buffer.sample(self.batch_size, beta=beta)
else:
batch = self.replay_buffer.sample(self.batch_size)
states, actions, rewards, next_states, dones = self._reshape_batch(batch)
# Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples.
if isinstance(self.replay_buffer, PushReplayBuffer):
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
elif isinstance(self.replay_buffer, PrioritizedBuffer):
indices, weights = batch[5], batch[6]
batch = PrioritizedReplayBufferSamples(
*[states, actions, rewards, next_states, dones, indices, weights]
)
else:
raise NotImplementedError
return batch
[docs] def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
"""Normal Function to calculate the loss of the Q-function or critic
Args:
batch (:obj:`collections.namedtuple` of :obj:`torch.Tensor`): Batch of experiences
Returns:
loss (:obj:`torch.Tensor`): Calculated loss of the Q-function
"""
q_values = self.get_q_values(batch.states, batch.actions)
target_q_values = self.get_target_q_values(
batch.next_states, batch.rewards, batch.dones
)
loss = F.mse_loss(q_values, target_q_values)
return loss
[docs]class OffPolicyAgentAC(OffPolicyAgent):
"""Off Policy Agent Base Class
Attributes:
network (str): The network type of the Q-value function.
Supported types: ["cnn", "mlp"]
env (Environment): The environment that the agent is supposed to act on
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
buffer_type (str): Choose the type of Buffer: ["push", "prioritized"]
seed (int): Seed for randomness
render (bool): Should the env be rendered during training?
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""
def __init__(self, *args, polyak=0.995, **kwargs):
super(OffPolicyAgentAC, self).__init__(*args, **kwargs)
self.polyak = polyak
self.doublecritic = False
[docs] def select_action(
self, state: np.ndarray, deterministic: bool = True
) -> np.ndarray:
"""Select action given state
Deterministic Action Selection with Noise
Args:
state (:obj:`np.ndarray`): Current state of the environment
deterministic (bool): Should the policy be deterministic or stochastic
Returns:
action (:obj:`np.ndarray`): Action taken by the agent
"""
state = torch.as_tensor(state).float()
action, _ = self.ac.get_action(state, deterministic)
action = action.detach().cpu().numpy()
# add noise to output from policy network
if self.noise is not None:
action += self.noise()
return np.clip(
action, self.env.action_space.low[0], self.env.action_space.high[0]
)
[docs] def update_target_model(self) -> None:
"""Function to update the target Q model
Updates the target model with the training model's weights when called
"""
for param, param_target in zip(
self.ac.parameters(), self.ac_target.parameters()
):
param_target.data.mul_(self.polyak)
param_target.data.add_((1 - self.polyak) * param.data)
[docs] def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""Get Q values corresponding to specific states and actions
Args:
states (:obj:`torch.Tensor`): States for which Q-values need to be found
actions (:obj:`torch.Tensor`): Actions taken at respective states
Returns:
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
if self.doublecritic:
q_values = self.ac.get_value(
torch.cat([states, actions], dim=-1), mode="both"
)
else:
q_values = self.ac.get_value(torch.cat([states, actions], dim=-1))
return q_values
[docs] def get_target_q_values(
self, next_states: torch.Tensor, rewards: List[float], dones: List[bool]
) -> torch.Tensor:
"""Get target Q values for the TD3
Args:
next_states (:obj:`torch.Tensor`): Next states for which target Q-values
need to be found
rewards (:obj:`list`): Rewards at each timestep for each environment
dones (:obj:`list`): Game over status for each environment
Returns:
target_q_values (:obj:`torch.Tensor`): Target Q values for the TD3
"""
next_target_actions = self.ac_target.get_action(next_states, True)[0]
if self.doublecritic:
next_q_target_values = self.ac_target.get_value(
torch.cat([next_states, next_target_actions], dim=-1), mode="min"
)
else:
next_q_target_values = self.ac_target.get_value(
torch.cat([next_states, next_target_actions], dim=-1)
)
target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values
return target_q_values
[docs] def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
"""Actor Critic Function to calculate the loss of the Q-function or critic
Args:
batch (:obj:`collections.namedtuple` of :obj:`torch.Tensor`): Batch of experiences
Returns:
loss (:obj:`torch.Tensor`): Calculated loss of the Q-function
"""
q_values = self.get_q_values(batch.states, batch.actions)
target_q_values = self.get_target_q_values(
batch.next_states, batch.rewards, batch.dones
)
if self.doublecritic:
loss = F.mse_loss(q_values[0], target_q_values) + F.mse_loss(
q_values[1], target_q_values
)
else:
loss = F.mse_loss(q_values, target_q_values)
return loss
[docs] def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
"""Function to get the Policy loss
Args:
states (:obj:`torch.Tensor`): States for which Q-values need to be found
Returns:
loss (:obj:`torch.Tensor`): Calculated policy loss
"""
next_best_actions = self.ac.get_action(states, True)[0]
q_values = self.ac.get_value(torch.cat([states, next_best_actions], dim=-1))
policy_loss = -torch.mean(q_values)
return policy_loss
[docs] def load_weights(self, weights) -> None:
"""Load weights for the agent from pretrained model
Args:
weights (:obj:`dict`): Dictionary of different neural net weights
"""
self.ac.load_state_dict(weights["weights"])