Skip to content

DQN

dqn

Figure: DQN algorithm pseudocode 1

toyrl.dqn.simple_config module-attribute

simple_config = DqnConfig(env_name='CartPole-v1', render_mode=None, solved_threshold=475.0, max_training_steps=500000, learning_rate=0.00025, use_target_network=True, target_soft_update_beta=0.0, target_update_frequency=5, log_wandb=True)

toyrl.dqn.trainer module-attribute

toyrl.dqn.DqnConfig dataclass

DqnConfig(env_name: str = 'CartPole-v1', render_mode: str | None = None, solved_threshold: float = 475.0, gamma: float = 0.999, replay_buffer_capacity: int = 10000, max_training_steps: int = 500000, learning_starts: int = 10000, policy_update_frequency: int = 10, batches_per_training_step: int = 16, batch_size: int = 128, updates_per_batch: int = 1, learning_rate: float = 0.01, use_target_network: bool = False, target_update_frequency: int = 10, target_soft_update_beta: float = 0.0, log_wandb: bool = False)

Configuration for DQN algorithm.

env_name class-attribute instance-attribute

env_name: str = 'CartPole-v1'

render_mode class-attribute instance-attribute

render_mode: str | None = None

solved_threshold class-attribute instance-attribute

solved_threshold: float = 475.0

gamma class-attribute instance-attribute

gamma: float = 0.999

The discount factor for future rewards.

replay_buffer_capacity class-attribute instance-attribute

replay_buffer_capacity: int = 10000

The maximum capacity of the experience replay buffer.

max_training_steps class-attribute instance-attribute

max_training_steps: int = 500000

The maximum number of environment steps to train for.

learning_starts class-attribute instance-attribute

learning_starts: int = 10000

The number of steps to collect before starting learning.

policy_update_frequency class-attribute instance-attribute

policy_update_frequency: int = 10

How often to update the policy network (in environment steps).

batches_per_training_step class-attribute instance-attribute

batches_per_training_step: int = 16

The number of experience batches to sample in each training step.

batch_size class-attribute instance-attribute

batch_size: int = 128

The size of each training batch.

updates_per_batch class-attribute instance-attribute

updates_per_batch: int = 1

The number of optimization steps to perform on each batch.

learning_rate class-attribute instance-attribute

learning_rate: float = 0.01

The learning rate for the optimizer.

use_target_network class-attribute instance-attribute

use_target_network: bool = False

Whether to use a separate target network (Double DQN when True).

target_update_frequency class-attribute instance-attribute

target_update_frequency: int = 10

How often to update the target network (in environment steps).

target_soft_update_beta class-attribute instance-attribute

target_soft_update_beta: float = 0.0

The soft update parameter for target network (0.0 means hard update).

log_wandb class-attribute instance-attribute

log_wandb: bool = False

Whether to log the training process to Weights and Biases.

toyrl.dqn.PolicyNet

PolicyNet(env_dim: int, action_num: int)

Bases: Module

Source code in toyrl/dqn.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    env_dim: int,
    action_num: int,
) -> None:
    super().__init__()
    self.env_dim = env_dim
    self.action_num = action_num

    layers = [
        nn.Linear(self.env_dim, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, self.action_num),
    ]
    self.model = nn.Sequential(*layers)

env_dim instance-attribute

env_dim = env_dim

action_num instance-attribute

action_num = action_num

model instance-attribute

model = Sequential(*layers)

forward

forward(x: Tensor) -> Tensor
Source code in toyrl/dqn.py
74
75
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.model(x)

toyrl.dqn.Experience dataclass

Experience(terminated: bool, truncated: bool, observation: Any, action: Any, reward: float, next_observation: Any)

terminated instance-attribute

terminated: bool

truncated instance-attribute

truncated: bool

observation instance-attribute

observation: Any

action instance-attribute

action: Any

reward instance-attribute

reward: float

next_observation instance-attribute

next_observation: Any

toyrl.dqn.ReplayBuffer dataclass

ReplayBuffer(replay_buffer_size: int = 10000, buffer: list[Experience] = list(), _head_pointer: int = 0)

replay_buffer_size class-attribute instance-attribute

replay_buffer_size: int = 10000

buffer class-attribute instance-attribute

buffer: list[Experience] = field(default_factory=list)

__len__

__len__() -> int
Source code in toyrl/dqn.py
94
95
def __len__(self) -> int:
    return len(self.buffer)

add_experience

add_experience(experience: Experience) -> None
Source code in toyrl/dqn.py
 97
 98
 99
100
101
102
103
104
105
106
107
def add_experience(self, experience: Experience) -> None:
    if len(self.buffer) < self.replay_buffer_size:
        # Buffer not full yet, append new experience
        self.buffer.append(experience)
    else:
        # Buffer full, overwrite oldest experience
        index = self._head_pointer % self.replay_buffer_size
        self.buffer[index] = experience

    # Increment pointer
    self._head_pointer += 1

reset

reset() -> None
Source code in toyrl/dqn.py
109
110
111
def reset(self) -> None:
    self.buffer = []
    self._head_pointer = 0

sample

sample(batch_size: int) -> list[Experience]
Source code in toyrl/dqn.py
113
114
def sample(self, batch_size: int) -> list[Experience]:
    return random.sample(self.buffer, min(batch_size, len(self.buffer)))

toyrl.dqn.Agent

Agent(policy_net: PolicyNet, target_net: PolicyNet | None, optimizer: Optimizer, replay_buffer_size: int)
Source code in toyrl/dqn.py
118
119
120
121
122
123
124
125
126
127
128
129
def __init__(
    self,
    policy_net: PolicyNet,
    target_net: PolicyNet | None,
    optimizer: torch.optim.Optimizer,
    replay_buffer_size: int,
) -> None:
    self._policy_net = policy_net
    self._target_net = target_net
    self._optimizer = optimizer
    self._replay_buffer = ReplayBuffer(replay_buffer_size)
    self._action_num = policy_net.action_num

add_experience

add_experience(experience: Experience) -> None
Source code in toyrl/dqn.py
131
132
def add_experience(self, experience: Experience) -> None:
    self._replay_buffer.add_experience(experience)

act

act(observation: floating, tau: float) -> tuple[int, float]
Source code in toyrl/dqn.py
134
135
136
137
138
139
140
def act(self, observation: np.floating, tau: float) -> tuple[int, float]:
    x = torch.from_numpy(observation.astype(np.float32))
    with torch.no_grad():
        logits = self._policy_net(x)
    next_action = torch.distributions.Categorical(logits=logits / tau).sample().item()
    q_value = logits[next_action].item()
    return next_action, q_value

sample

sample(batch_size: int) -> list[Experience]
Source code in toyrl/dqn.py
142
143
def sample(self, batch_size: int) -> list[Experience]:
    return self._replay_buffer.sample(batch_size)

policy_update

policy_update(gamma: float, experiences: list[Experience]) -> float
Source code in toyrl/dqn.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def policy_update(self, gamma: float, experiences: list[Experience]) -> float:
    observations = torch.tensor([experience.observation for experience in experiences])
    actions = torch.tensor([experience.action for experience in experiences], dtype=torch.float32)
    next_observations = torch.tensor([experience.next_observation for experience in experiences])
    rewards = torch.tensor([experience.reward for experience in experiences])
    terminated = torch.tensor(
        [experience.terminated for experience in experiences],
        dtype=torch.float32,
    )

    # q preds
    action_q_preds = self._policy_net(observations).gather(1, actions.long().unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        next_action_logits = self._policy_net(next_observations)
        next_actions = torch.argmax(next_action_logits, dim=1)
        if self._target_net is None:  # Vanilla DQN
            next_action_q_preds = torch.gather(next_action_logits, 1, next_actions.unsqueeze(1)).squeeze(1)
        else:  # Double DQN
            next_action_q_preds = (
                self._target_net(next_observations).gather(1, next_actions.unsqueeze(1)).squeeze(1)
            )
    action_q_targets = rewards + gamma * (1 - terminated) * next_action_q_preds
    loss = torch.nn.functional.mse_loss(action_q_preds, action_q_targets)
    # update
    self._optimizer.zero_grad()
    loss.backward()
    self._optimizer.step()
    return loss.item()

polyak_update

polyak_update(beta: float) -> None
Source code in toyrl/dqn.py
175
176
177
178
179
180
def polyak_update(self, beta: float) -> None:
    if self._target_net is not None:
        for target_param, param in zip(self._target_net.parameters(), self._policy_net.parameters()):
            target_param.data.copy_(beta * target_param.data + (1 - beta) * param.data)
    else:
        raise ValueError("Target net is not set.")

toyrl.dqn.DqnTrainer

DqnTrainer(config: DqnConfig)
Source code in toyrl/dqn.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def __init__(self, config: DqnConfig) -> None:
    self.config = config
    self.env = self._make_env()
    if isinstance(self.env.action_space, gym.spaces.Discrete) is False:
        raise ValueError("Only discrete action space is supported.")
    env_dim = self.env.observation_space.shape[0]
    action_num = self.env.action_space.n

    policy_net = PolicyNet(env_dim=env_dim, action_num=action_num)
    optimizer = optim.Adam(policy_net.parameters(), lr=config.learning_rate)
    if config.use_target_network:
        target_net = PolicyNet(env_dim=env_dim, action_num=action_num)
        target_net.load_state_dict(policy_net.state_dict())
    else:
        target_net = None
    self.agent = Agent(
        policy_net=policy_net,
        target_net=target_net,
        optimizer=optimizer,
        replay_buffer_size=config.replay_buffer_capacity,
    )

    self.gamma = config.gamma
    self.solved_threshold = config.solved_threshold
    if config.log_wandb:
        wandb.init(
            # set the wandb project where this run will be logged
            project=self._get_dqn_name(),
            name=f"[{config.env_name}],lr={config.learning_rate}",
            # track hyperparameters and run metadata
            config=asdict(config),
        )

config instance-attribute

config = config

env instance-attribute

env = _make_env()

agent instance-attribute

agent = Agent(policy_net=policy_net, target_net=target_net, optimizer=optimizer, replay_buffer_size=replay_buffer_capacity)

gamma instance-attribute

gamma = gamma

solved_threshold instance-attribute

solved_threshold = solved_threshold

train

train() -> None
Source code in toyrl/dqn.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def train(self) -> None:
    tau = 5.0
    global_step = 0
    observation, _ = self.env.reset()
    while global_step < self.config.max_training_steps:
        global_step += 1
        # decay tau
        tau = max(0.1, tau * 0.995)

        action, q_value = self.agent.act(observation, tau)
        if self.config.log_wandb:
            wandb.log({"global_step": global_step, "q_value": q_value})

        next_observation, reward, terminated, truncated, info = self.env.step(action)
        experience = Experience(
            observation=observation,
            action=action,
            reward=float(reward),
            next_observation=next_observation,
            terminated=terminated,
            truncated=truncated,
        )
        self.agent.add_experience(experience)
        observation = next_observation

        if terminated or truncated:
            if info and "episode" in info:
                reward = info["episode"]["r"]
                print(f"global_step={global_step}, episodic_return={reward}")
                if self.config.log_wandb:
                    wandb.log(
                        {
                            "global_step": global_step,
                            "episode_reward": reward,
                        }
                    )

        if self.env.render_mode is not None:
            self.env.render()

        if global_step >= self.config.learning_starts and global_step % self.config.policy_update_frequency == 0:
            loss = self._train_step()
            if self.config.log_wandb:
                wandb.log(
                    {
                        "global_step": global_step,
                        "loss": loss,
                    }
                )
        # update target net
        if self.config.use_target_network and global_step % self.config.target_update_frequency == 0:
            self.agent.polyak_update(beta=self.config.target_soft_update_beta)

  1. L. Graesser and W. L. Keng, Foundations of deep reinforcement learning: Theory and practice in python. Addison-Wesley Professional, 2019.