Source code for genrl.agents.classical.sarsa.sarsa

from typing import Tuple

import gym
import numpy as np


[docs]class SARSA: """ SARSA Algorithm. Paper- http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.17.2539&rep=rep1&type=pdf Attributes: env (gym.Env): Environment with which agent interacts. epsilon (float, optional): exploration coefficient for epsilon-greedy exploration. gamma (float, optional): discount factor. lr (float, optional): learning rate for optimizer. """ def __init__( self, env: gym.Env, epsilon: float = 0.9, lmbda: float = 0.9, gamma: float = 0.95, lr: float = 0.01, ): self.env = env self.epsilon = epsilon self.lmbda = lmbda self.gamma = gamma self.lr = lr self.Q = np.zeros((self.env.observation_space.n, self.env.action_space.n)) self.e = np.zeros((self.env.observation_space.n, self.env.action_space.n))
[docs] def get_action(self, state: np.ndarray, explore: bool = True) -> np.ndarray: """Epsilon greedy selection of epsilon in the explore phase. Args: state (np.ndarray): Environment state. explore (bool, optional): True if exploration is required. False if not. Returns: np.ndarray: action. """ if explore: if np.random.uniform() > self.epsilon: return self.env.action_space.sample() return np.argmax(self.Q[state, :])
[docs] def update(self, transition: Tuple) -> None: """Update the Q table and e values Args: transition (Tuple): transition 4-tuple used to update Q-table. In the form (state, action, reward, next_state) """ state, action, reward, next_state = transition next_action = self.get_action(next_state) delta = reward + self.gamma * ( self.Q[next_state, next_action] - self.Q[state, action] ) self.e[state, action] += 1 for _si in range(self.env.observation_space.n): for _ai in range(self.env.action_space.n): self.Q[state, action] += self.lr * (delta * self.e[state, action]) self.e[state, action] = self.gamma * ( self.lmbda * self.e[state, action] )