Prioritized Deep Q-Networks¶
Objective¶
The main motivation behind using prioritized experience replay over uniformly sampled experience replay stems from the fact that an agent may be able to learn more from some transitions than others. In uniformly sampled experience replay, some transitions which might not be very useful for the agent or that might be redundant will be replayed with the same frequency as those having more learning potential. Prioritized experience replay solves this problem by replaying more useful transitions more frequently.
The loss function for prioritized DQN is defined as
Algorithm Details¶
Epsilon-Greedy Action Selection¶
The action exploration is stochastic wherein the greedy action is chosen with a probability of \(1 - \epsilon\) and rest of the time, we sample the action randomly. During evaluation, we use only greedy actions to judge how well the agent performs.
Prioritized Experience Replay¶
The replay buffer is no longer uniformly sampled, but is sampled according to the priority of a transition. Transitions with greater scope of learning are assigned a higher priorities. Priority of a particular transition is decided using the TD-error since the measure of the magnitude of the TD error can be interpreted as how unexpected the transition is.
The transition with the maximum TD-error is given the maximum priority. Every new transition is given the highest priority to ensure that each transition is considered at-least once.
Stochastic Prioritization¶
Sampling transition greedily has some disadvantages such as transitions having a low TD-error on the first replay might not be sampled ever again, higher chances of overfitting since only a small set of transitions with high priorities are replayed over and over again and sensitivity to noise spikes. To tackle these problems, instead of sampling transitions greedily everytime, we use a stochastic approach wherein each transition is assigned a certain probability with which it is sampled. The sampling probability is defined as
where \(p_i > 0\) is the priority of transition \(i\). \(\alpha\) determines the amount of prioritization done. The priority of the transition can be defined in the following two ways:
- \(p_i = |\delta_i| + \epsilon\)
- \(p_i = \frac{1}{rank(i)}\)
where \(\epsilon\) is a small positive constant to ensure that the sampling probability is not zero for any transition and \(rank(i)\) is the rank of the transition when the replay buffer is sorted with respect to priorities.
We also use importance sampling (IS) weights to correct certain bais introduced by prioritized experience replay.
Update the Q-value Networks¶
The importance sampling weights can be folded into the Q-learning update by using \(w\delta_i\) instead of \(\delta_i\). Once our Replay Buffer has enough experiences, we start updating the Q-value networks in the following code according to the above objective.
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
for timestep in range(0, self.max_timesteps, self.env.n_envs):
self.agent.update_params_before_select_action(timestep)
action = self.get_action(state, timestep)
next_state, reward, done, info = self.env.step(action)
if self.render:
self.env.render()
# true_dones contains the "true" value of the dones (game over statuses). It is set
# to False when the environment is not actually done but instead reaches the max
# episode length.
true_dones = [info[i]["done"] for i in range(self.env.n_envs)]
self.buffer.push((state, action, reward, next_state, true_dones))
state = next_state.detach().clone()
if self.check_game_over_status(done):
self.noise_reset()
if self.episodes % self.log_interval == 0:
self.log(timestep)
if self.episodes == self.epochs:
break
if timestep >= self.start_update and timestep % self.update_interval == 0:
self.agent.update_params(self.update_interval)
if (
timestep >= self.start_update
and self.save_interval != 0
and timestep % self.save_interval == 0
):
self.save(timestep)
self.env.close()
self.logger.close()
|
Training through the API¶
from genrl.agents import PrioritizedReplayDQN
from genrl.environments import VectorEnv
from genrl.trainers import OffPolicyTrainer
env = VectorEnv("CartPole-v0")
agent = PrioritizedReplayDQN("mlp", env)
trainer = OffPolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()