Deep Q-Networks (DQN)

For background on Deep RL, its core definitions and problem formulations refer to Deep RL Background

Objective

The DQN uses the concept of Q-learning. When the state space is too huge, it require a large number of epochs to explore and update the Q-value of every state even at least once. Hence, we make use of function approximators. DQN uses a neural network as a function approximator and objective is to get as close to the Bellman Expectation of the Q-value function as possible. This is done by minimising the loss function which is defined as

\[E_{(s, a, s', r) \sim D}[r + \gamma max_{a'} Q(s', a';\theta_{i}^{-}) - Q(s, a; \theta_i)]^2\]

Unlike in regular Q-learning, DQNs need more stability while updating so we often use a second neural network which we call our target model.

Algorithm Details

Epsilon-Greedy Action Selection

We choose the greedy action with a probability of \(1 - \epsilon\) and the rest of the time, we sample the action randomly. During evaluation, we use only greedy actions to judge how well the agent performs.

Experience Replay

Whenever an experience is played through (during the training loop), the experience is stored in what we call a Replay Buffer.

 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
    def log(self, timestep: int) -> None:
        """Helper function to log

        Sends useful parameters to the logger.

        Args:
            timestep (int): Current timestep of training
        """
        self.logger.write(
            {
                "timestep": timestep,
                "Episode": self.episodes,
                **self.agent.get_logging_params(),
                "Episode Reward": safe_mean(self.training_rewards),

The transitions are later sampled in batches from the replay buffer for updating the network.

Update Q-value Network

Once our Replay Buffer has enough experiences, we start updating the Q-value networks in the following code according to the above objective.

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

        for timestep in range(0, self.max_timesteps, self.env.n_envs):
            self.agent.update_params_before_select_action(timestep)

            action = self.get_action(state, timestep)
            next_state, reward, done, info = self.env.step(action)

            if self.render:
                self.env.render()

            # true_dones contains the "true" value of the dones (game over statuses). It is set
            # to False when the environment is not actually done but instead reaches the max
            # episode length.
            true_dones = [info[i]["done"] for i in range(self.env.n_envs)]
            self.buffer.push((state, action, reward, next_state, true_dones))

            state = next_state.detach().clone()

            if self.check_game_over_status(done):
                self.noise_reset()

                if self.episodes % self.log_interval == 0:
                    self.log(timestep)

                if self.episodes == self.epochs:
                    break

            if timestep >= self.start_update and timestep % self.update_interval == 0:
                self.agent.update_params(self.update_interval)

            if (
                timestep >= self.start_update
                and self.save_interval != 0
                and timestep % self.save_interval == 0
            ):
                self.save(timestep)

        self.env.close()
        self.logger.close()

The function get_q_values calculates the Q-values of the states sampled from the replay buffer. The get_target_q_values function will get the Q-values of the same states as calculated by the target network. The update_params function is used to calculate the MSE Loss between the Q-values and the Target Q-values and updated using Stochastic Gradient Descent.

Training through the API

from genrl.agents import DQN
from genrl.environments import VectorEnv
from genrl.trainers import OffPolicyTrainer

env = VectorEnv("CartPole-v0")
agent = DQN("mlp", env)
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()

Variants of DQN

Some of the other variants of DQN that we have implemented in the repo are: - Double DQN - Dueling DQN - Prioritized Replay DQN - Noisy DQN - Categorical DQN

For some extensions of the DQN (like DoubleDQN) we have provided the methods in a file under genrl/agents/dqn/utils.py

class DuelingDQN(DQN):
    def __init__(self, *args, **kwargs):
        super(DuelingDQN, self).__init__(*args, **kwargs)
        self.dqn_type = "dueling"  # You can choose "noisy" for NoisyDQN and "categorical" for CategoricalDQN
        self._create_model()

    def get_target_q_values(self, *args):
        return ddqn_q_target(self, *args)

The above two snippets define the same class. You can find similar APIs for the other variants in the genrl/deep/agents/dqn folder.