Prioritized Deep Q-Networks

Objective

The main motivation behind using prioritized experience replay over uniformly sampled experience replay stems from the fact that an agent may be able to learn more from some transitions than others. In uniformly sampled experience replay, some transitions which might not be very useful for the agent or that might be redundant will be replayed with the same frequency as those having more learning potential. Prioritized experience replay solves this problem by replaying more useful transitions more frequently.

The loss function for prioritized DQN is defined as

\[E_{(s, a, s', r, p) \sim D}[r + \gamma max_{a'} Q(s', a';\theta_{i}^{-}) - Q(s, a; \theta_i)]^2\]

Algorithm Details

Epsilon-Greedy Action Selection

The action exploration is stochastic wherein the greedy action is chosen with a probability of \(1 - \epsilon\) and rest of the time, we sample the action randomly. During evaluation, we use only greedy actions to judge how well the agent performs.

Prioritized Experience Replay

The replay buffer is no longer uniformly sampled, but is sampled according to the priority of a transition. Transitions with greater scope of learning are assigned a higher priorities. Priority of a particular transition is decided using the TD-error since the measure of the magnitude of the TD error can be interpreted as how unexpected the transition is.

\[\delta = R + \gamma max_{a'} Q(s', a';\theta_{i}^{-}) - Q(s, a; \theta_i)\]

The transition with the maximum TD-error is given the maximum priority. Every new transition is given the highest priority to ensure that each transition is considered at-least once.

Stochastic Prioritization

Sampling transition greedily has some disadvantages such as transitions having a low TD-error on the first replay might not be sampled ever again, higher chances of overfitting since only a small set of transitions with high priorities are replayed over and over again and sensitivity to noise spikes. To tackle these problems, instead of sampling transitions greedily everytime, we use a stochastic approach wherein each transition is assigned a certain probability with which it is sampled. The sampling probability is defined as

\[P(i) = \frac{p_i^{\alpha}}{\Sigma_k p_k^{\alpha}}\]

where \(p_i > 0\) is the priority of transition \(i\). \(\alpha\) determines the amount of prioritization done. The priority of the transition can be defined in the following two ways:

  • \(p_i = |\delta_i| + \epsilon\)
  • \(p_i = \frac{1}{rank(i)}\)

where \(\epsilon\) is a small positive constant to ensure that the sampling probability is not zero for any transition and \(rank(i)\) is the rank of the transition when the replay buffer is sorted with respect to priorities.

We also use importance sampling (IS) weights to correct certain bais introduced by prioritized experience replay.

\[w_i = (\frac{1}{N} \frac{1}{P(i)})^{\beta}\]

Update the Q-value Networks

The importance sampling weights can be folded into the Q-learning update by using \(w\delta_i\) instead of \(\delta_i\). Once our Replay Buffer has enough experiences, we start updating the Q-value networks in the following code according to the above objective.

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

        for timestep in range(0, self.max_timesteps, self.env.n_envs):
            self.agent.update_params_before_select_action(timestep)

            action = self.get_action(state, timestep)
            next_state, reward, done, info = self.env.step(action)

            if self.render:
                self.env.render()

            # true_dones contains the "true" value of the dones (game over statuses). It is set
            # to False when the environment is not actually done but instead reaches the max
            # episode length.
            true_dones = [info[i]["done"] for i in range(self.env.n_envs)]
            self.buffer.push((state, action, reward, next_state, true_dones))

            state = next_state.detach().clone()

            if self.check_game_over_status(done):
                self.noise_reset()

                if self.episodes % self.log_interval == 0:
                    self.log(timestep)

                if self.episodes == self.epochs:
                    break

            if timestep >= self.start_update and timestep % self.update_interval == 0:
                self.agent.update_params(self.update_interval)

            if (
                timestep >= self.start_update
                and self.save_interval != 0
                and timestep % self.save_interval == 0
            ):
                self.save(timestep)

        self.env.close()
        self.logger.close()

Training through the API

from genrl.agents import PrioritizedReplayDQN
from genrl.environments import VectorEnv
from genrl.trainers import OffPolicyTrainer

env = VectorEnv("CartPole-v0")
agent = PrioritizedReplayDQN("mlp", env)
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()