import random
from collections import deque
from typing import NamedTuple, Tuple
import numpy as np
import torch
[docs]class ReplayBufferSamples(NamedTuple):
states: torch.Tensor
actions: torch.Tensor
rewards: torch.Tensor
next_states: torch.Tensor
dones: torch.Tensor
[docs]class PrioritizedReplayBufferSamples(NamedTuple):
states: torch.Tensor
actions: torch.Tensor
rewards: torch.Tensor
next_states: torch.Tensor
dones: torch.Tensor
indices: torch.Tensor
weights: torch.Tensor
[docs]class ReplayBuffer:
"""
Implements the basic Experience Replay Mechanism
:param capacity: Size of the replay buffer
:type capacity: int
"""
def __init__(self, capacity: int):
self.capacity = capacity
self.memory = deque([], maxlen=capacity)
[docs] def push(self, inp: Tuple) -> None:
"""
Adds new experience to buffer
:param inp: Tuple containing state, action, reward, next_state and done
:type inp: tuple
:returns: None
"""
self.memory.append(inp)
[docs] def sample(
self, batch_size: int
) -> (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
"""
Returns randomly sampled experiences from replay memory
:param batch_size: Number of samples per batch
:type batch_size: int
:returns: (Tuple composing of `state`, `action`, `reward`,
`next_state` and `done`)
"""
batch = random.sample(self.memory, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return [
torch.from_numpy(v).float()
for v in [state, action, reward, next_state, done]
]
def __len__(self) -> int:
"""
Gives number of experiences in buffer currently
:returns: Length of replay memory
"""
return self.pos
[docs]class PrioritizedBuffer:
"""
Implements the Prioritized Experience Replay Mechanism
:param capacity: Size of the replay buffer
:param alpha: Level of prioritization
:type capacity: int
:type alpha: int
"""
def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4):
self.alpha = alpha
self.beta = beta
self.capacity = capacity
self.buffer = deque([], maxlen=capacity)
self.priorities = deque([], maxlen=capacity)
[docs] def push(self, inp: Tuple) -> None:
"""
Adds new experience to buffer
:param inp: (Tuple containing `state`, `action`, `reward`,
`next_state` and `done`)
:type inp: tuple
:returns: None
"""
max_priority = max(self.priorities) if self.priorities else 1.0
self.buffer.append(inp)
self.priorities.append(max_priority)
[docs] def sample(
self, batch_size: int, beta: float = None
) -> (
Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]
):
"""
(Returns randomly sampled memories from replay memory along with their
respective indices and weights)
:param batch_size: Number of samples per batch
:param beta: (Bias exponent used to correct
Importance Sampling (IS) weights)
:type batch_size: int
:type beta: float
:returns: (Tuple containing `states`, `actions`, `next_states`,
`rewards`, `dones`, `indices` and `weights`)
"""
if beta is None:
beta = self.beta
total = len(self.buffer)
priorities = np.asarray(self.priorities)
probabilities = priorities ** self.alpha
probabilities /= probabilities.sum()
indices = np.random.choice(total, batch_size, p=probabilities)
weights = (total * probabilities[indices]) ** (-beta)
weights /= weights.max()
weights = np.asarray(weights, dtype=np.float32)
samples = [self.buffer[i] for i in indices]
(states, actions, rewards, next_states, dones) = map(np.stack, zip(*samples))
return [
torch.as_tensor(v, dtype=torch.float32)
for v in [
states,
actions,
rewards,
next_states,
dones,
indices,
weights,
]
]
[docs] def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> None:
"""
Updates list of priorities with new order of priorities
:param batch_indices: List of indices of batch
:param batch_priorities: (List of priorities of the batch at the
specific indices)
:type batch_indices: list or tuple
:type batch_priorities: list or tuple
"""
for idx, priority in zip(batch_indices, batch_priorities):
self.priorities[int(idx)] = priority.mean()
def __len__(self) -> int:
"""
Gives number of experiences in buffer currently
:returns: Length of replay memory
"""
return len(self.buffer)
@property
def pos(self):
return len(self.buffer)