Dueling Deep Q-Network¶
Objective¶
The main objective of DQN is to learn a function approximator for the Q-function using a neural network. This is done by training the approximator to get as close to the Bellman Expectation of the Q-value function as possible by minimising the loss which is defined as:
Dueling Deep Q-network modifies the architecture of a simple DQN into one better suited for model-free RL
Algorithm Details¶
Network architechture¶
The Dueling DQN architechture splits the single stream of fully connected layers in a normal DQN into two separate streams : one representing the value function and the other representing the advantage function. Advantage function.
The advantage for a state action pair represents how beneficial it is to take an action over others when in a particular state. The dueling architechture can learn which states are or are not valuable without having to learn the effect of action for each state. This is useful in instances when taking any action would affect the environment in any significant way.
Another layer combines the value stream and the advantage stream to get the Q-values
Combining the value and the advantage streams¶
- Value Function : \(V(s; \theta, \beta)\)
- Advantage Function : \(A(s, a; \theta, \alpha)\)
where \(\theta\) denotes the parameters of the underlying convolutional layers whereas \(\alpha\) and \(\beta\) are the parameters of the two separate streams of fully connected layers
The two stream cannot be simply added (\(Q(s, a; \theta, \alpha, \beta) = V(s; \theta, \beta) + A(s, a; \theta, \alpha)\)) to get the Q-values because:
- \(Q(s, a; \theta, \alpha, \beta)\) is only a parameterized estimate of the true Q-function
- It would be wrong to assume that \(V(s; \theta, \beta)\) and \(Q(s, a; \theta, \alpha)\) are reasonable estimates of the value and the advantage functions
To address these concerns, we train in order to force the expected value of the advantage function to be zero (the expectation of advantage is always zero)
Thus, the combining module combines the value and advantage streams to get the Q-values in the following fashion:
Epsilon-Greedy Action Selection¶
Similar to a normal DQN, the action exploration is stochastic wherein the greedy action is chosen with a probability of \(1 - \epsilon\) and rest of the time, we sample the action randomly. During evaluation, we use only greedy actions to judge how well the agent performs.
Experience Replay¶
Every transition occuring during the training is stored in a separate Replay Buffer
91 92 93 94 95 96 97 98 99 100 101 102 103 104 | def log(self, timestep: int) -> None:
"""Helper function to log
Sends useful parameters to the logger.
Args:
timestep (int): Current timestep of training
"""
self.logger.write(
{
"timestep": timestep,
"Episode": self.episodes,
**self.agent.get_logging_params(),
"Episode Reward": safe_mean(self.training_rewards),
|
The transitions are later sampled in batches from the replay buffer for updating the network
Update the Q Network¶
Once enough number of transitions ae stored in the replay buffer, we start updating the Q-values according to the given objective. The loss function is defined in a fashion similar to a DQN. This allows any new improvisations in training techniques of DQN such as Double DQN or NoisyNet DQN to be readily adapted in the dueling architechture.
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
for timestep in range(0, self.max_timesteps, self.env.n_envs):
self.agent.update_params_before_select_action(timestep)
action = self.get_action(state, timestep)
next_state, reward, done, info = self.env.step(action)
if self.render:
self.env.render()
# true_dones contains the "true" value of the dones (game over statuses). It is set
# to False when the environment is not actually done but instead reaches the max
# episode length.
true_dones = [info[i]["done"] for i in range(self.env.n_envs)]
self.buffer.push((state, action, reward, next_state, true_dones))
state = next_state.detach().clone()
if self.check_game_over_status(done):
self.noise_reset()
if self.episodes % self.log_interval == 0:
self.log(timestep)
if self.episodes == self.epochs:
break
if timestep >= self.start_update and timestep % self.update_interval == 0:
self.agent.update_params(self.update_interval)
if (
timestep >= self.start_update
and self.save_interval != 0
and timestep % self.save_interval == 0
):
self.save(timestep)
self.env.close()
self.logger.close()
|
Training through the API¶
from genrl.agents import DuelingDQN
from genrl.environments import VectorEnv
from genrl.trainers import OffPolicyTrainer
env = VectorEnv("CartPole-v0")
agent = DuelingDQN("mlp", env)
trainer = OffpolicyTrainer(agent, env, max_timesteps=20000)
trainer.train()
trainer.evaluate()