Source code for genrl.utils.models

from typing import Tuple

import numpy as np


[docs]class TabularModel: """ Sample-based tabular model class for deterministic, discrete environments :param s_dim: environment state dimension :param a_dim: environment action dimension :type s_dim: int :type a_dim: int """ def __init__(self, s_dim: int, a_dim: int): self.s_dim = s_dim self.a_dim = a_dim self.s_model = np.zeros((s_dim, a_dim), dtype=np.uint8) self.r_model = np.zeros((s_dim, a_dim))
[docs] def add( self, state: np.ndarray, action: np.ndarray, reward: float, next_state: np.ndarray, ) -> None: """ add transition to model :param state: state :param action: action :param reward: reward :param next_state: next state :type state: float array :type action: int :type reward: int :type next_state: float array """ self.s_model[state, action] = next_state self.r_model[state, action] = reward
[docs] def sample(self) -> Tuple: """ sample state action pair from model :returns: state and action :rtype: int, float, ... ; int, float, ... """ # select random visited state state = np.random.choice(np.where(np.sum(self.s_model, axis=1) > 0)[0]) # random action in that state action = np.random.choice(np.where(self.s_model[state] > 0)[0]) return state, action
[docs] def step(self, state: np.ndarray, action: np.ndarray) -> Tuple: """ return consequence of action at state :returns: reward and next state :rtype: int; int, float, ... """ reward = self.r_model[state, action] next_state = self.s_model[state, action] return reward, next_state
[docs] def is_empty(self) -> bool: """ Check if the model has been updated or not :returns: True if model not updated yet :rtype: bool """ return not (np.any(self.s_model) or np.any(self.r_model))
model_registry = {"tabular": TabularModel}
[docs]def get_model_from_name(name_: str): """ get model object from name :param name_: name of the model ['tabular'] :type name_: str :returns: the model """ if name_ in model_registry: return model_registry[name_] return NotImplementedError