Deep Q-Networks (DQN)¶
For background on Deep RL, its core definitions and problem formulations refer to Deep RL Background
Objective¶
The DQN uses the concept of Q-learning. When the state space is too huge, it require a large number of epochs to explore and update the Q-value of every state even at least once. Hence, we make use of function approximators. DQN uses a neural network as a function approximator and objective is to get as close to the Bellman Expectation of the Q-value function as possible. This is done by minimising the loss function which is defined as
Unlike in regular Q-learning, DQNs need more stability while updating so we often use a second neural network which we call our target model.
Algorithm Details¶
Epsilon-Greedy Action Selection¶
We choose the greedy action with a probability of \(1 - \epsilon\) and the rest of the time, we sample the action randomly. During evaluation, we use only greedy actions to judge how well the agent performs.
Experience Replay¶
Whenever an experience is played through (during the training loop), the experience is stored in what we call a Replay Buffer.
91 92 93 94 95 96 97 98 99 100 101 102 103 104 | def log(self, timestep: int) -> None:
"""Helper function to log
Sends useful parameters to the logger.
Args:
timestep (int): Current timestep of training
"""
self.logger.write(
{
"timestep": timestep,
"Episode": self.episodes,
**self.agent.get_logging_params(),
"Episode Reward": safe_mean(self.training_rewards),
|
The transitions are later sampled in batches from the replay buffer for updating the network.
Update Q-value Network¶
Once our Replay Buffer has enough experiences, we start updating the Q-value networks in the following code according to the above objective.
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
for timestep in range(0, self.max_timesteps, self.env.n_envs):
self.agent.update_params_before_select_action(timestep)
action = self.get_action(state, timestep)
next_state, reward, done, info = self.env.step(action)
if self.render:
self.env.render()
# true_dones contains the "true" value of the dones (game over statuses). It is set
# to False when the environment is not actually done but instead reaches the max
# episode length.
true_dones = [info[i]["done"] for i in range(self.env.n_envs)]
self.buffer.push((state, action, reward, next_state, true_dones))
state = next_state.detach().clone()
if self.check_game_over_status(done):
self.noise_reset()
if self.episodes % self.log_interval == 0:
self.log(timestep)
if self.episodes == self.epochs:
break
if timestep >= self.start_update and timestep % self.update_interval == 0:
self.agent.update_params(self.update_interval)
if (
timestep >= self.start_update
and self.save_interval != 0
and timestep % self.save_interval == 0
):
self.save(timestep)
self.env.close()
self.logger.close()
|
The function get_q_values calculates the Q-values of the states sampled from the replay buffer. The get_target_q_values function will get the Q-values of the same states as calculated by the target network. The update_params function is used to calculate the MSE Loss between the Q-values and the Target Q-values and updated using Stochastic Gradient Descent.
Training through the API¶
from genrl.agents import DQN
from genrl.environments import VectorEnv
from genrl.trainers import OffPolicyTrainer
env = VectorEnv("CartPole-v0")
agent = DQN("mlp", env)
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()
Variants of DQN¶
Some of the other variants of DQN that we have implemented in the repo are: - Double DQN - Dueling DQN - Prioritized Replay DQN - Noisy DQN - Categorical DQN
For some extensions of the DQN (like DoubleDQN) we have provided the methods in a file under genrl/agents/dqn/utils.py
class DuelingDQN(DQN):
def __init__(self, *args, **kwargs):
super(DuelingDQN, self).__init__(*args, **kwargs)
self.dqn_type = "dueling" # You can choose "noisy" for NoisyDQN and "categorical" for CategoricalDQN
self._create_model()
def get_target_q_values(self, *args):
return ddqn_q_target(self, *args)
The above two snippets define the same class. You can find similar APIs for the other variants in the genrl/deep/agents/dqn folder.