import random
from typing import Tuple, Union
import torch
[docs]class TransitionDB(object):
"""
Database for storing (context, action, reward) transitions.
Args:
device (str): Device to use for tensor operations.
"cpu" for cpu or "cuda" for cuda. Defaults to "cpu".
Attributes:
db (dict): Dictionary containing list of transitions.
db_size (int): Number of transitions stored in database.
device (torch.device): Device to use for tensor operations.
"""
def __init__(self, device: Union[str, torch.device] = "cpu"):
if type(device) is str:
self.device = (
torch.device(device)
if "cuda" in device and torch.cuda.is_available()
else torch.device("cpu")
)
else:
self.device = device
self.db = {"contexts": [], "actions": [], "rewards": []}
self.db_size = 0
[docs] def add(self, context: torch.Tensor, action: int, reward: int):
"""Add (context, action, reward) transition to database
Args:
context (torch.Tensor): Context recieved
action (int): Action taken
reward (int): Reward recieved
"""
self.db["contexts"].append(context)
self.db["actions"].append(action)
self.db["rewards"].append(reward)
self.db_size += 1
[docs] def get_data(
self, batch_size: Union[int, None] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get a batch of transition from database
Args:
batch_size (Union[int, None], optional): Size of batch required.
Defaults to None which implies all transitions in the database
are to be included in batch.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of stacked
contexts, actions, rewards tensors.
"""
if batch_size is None:
batch_size = self.db_size
else:
batch_size = min(batch_size, self.db_size)
idx = [random.randrange(self.db_size) for _ in range(batch_size)]
x = (
torch.stack([self.db["contexts"][i] for i in idx])
.to(self.device)
.to(torch.float)
)
y = (
torch.tensor([self.db["rewards"][i] for i in idx])
.to(self.device)
.to(torch.float)
.unsqueeze(1)
)
a = (
torch.stack([self.db["actions"][i] for i in idx])
.to(self.device)
.to(torch.long)
)
return x, a, y
[docs] def get_data_for_action(
self, action: int, batch_size: Union[int, None] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get a batch of transition from database for a given action.
Args:
action (int): The action to sample transitions for.
batch_size (Union[int, None], optional): Size of batch required.
Defaults to None which implies all transitions in the database
are to be included in batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of stacked
contexts and rewards tensors.
"""
action_idx = [i for i in range(self.db_size) if self.db["actions"][i] == action]
if batch_size is None:
t_batch_size = len(action_idx)
else:
t_batch_size = min(batch_size, len(action_idx))
idx = random.sample(action_idx, t_batch_size)
x = (
torch.stack([self.db["contexts"][i] for i in idx])
.to(self.device)
.to(torch.float)
)
y = (
torch.tensor([self.db["rewards"][i] for i in idx])
.to(self.device)
.to(torch.float)
.unsqueeze(1)
)
return x, y