Double DQN – Breakout (ALE/Breakout-v5)

Trained on ALE/Breakout-v5 using Double DQN (van Hasselt et al., 2016) with 4 parallel environments.

Algorithm: Double DQN

# Vanilla DQN (overestimates):
next_q = q_target(s').max()

# Double DQN (this model):
best_a = q_online(s').argmax()      # online picks action
next_q = q_target(s')[best_a]       # target evaluates it
target = r + Ξ³ * (1 - done) * next_q

Preprocessing

Step Detail
Grayscale cv2.cvtColor RGB β†’ gray
Resize 210Γ—160 β†’ 84Γ—84 (INTER_AREA)
Frame stacking 4 consecutive frames β†’ state (4, 84, 84)
Normalization uint8 stored in buffer, Γ·255 at sample time
Reward clipping Clipped to [βˆ’1, +1] during training
Frameskip 4 (ALE built-in)
Fire-reset FIRE pressed after reset and after every life loss

Network Architecture

Input:  (4, 84, 84)
Conv2d(4β†’32,  kernel=8, stride=4) β†’ ReLU
Conv2d(32β†’64, kernel=4, stride=2) β†’ ReLU
Conv2d(64β†’64, kernel=3, stride=1) β†’ ReLU
Flatten β†’ 3136
Linear(3136 β†’ 512) β†’ ReLU
Linear(512  β†’ 4)
Output: Q(s,a) for [NOOP, FIRE, RIGHT, LEFT]

Training Setup

Parameter Value
Environment ALE/Breakout-v5
Total steps 750,000
Parallel envs 4 (AsyncVectorEnv)
Replay buffer 100,000 (uint8, 4Γ— memory saving)
Batch size 64
Learning rate 1e-4 (Adam)
Discount Ξ³ 0.99
Target sync Hard update every 1,000 steps
Gradient clip Max norm 10
Epsilon 1.0 β†’ 0.01 over first 5% of steps
Training start After 5,000 steps

Usage

import torch
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, n_actions=4):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512), nn.ReLU(),
            nn.Linear(512, n_actions),
        )
    def forward(self, x):
        return self.model(x)

checkpoint = torch.load('best_breakout.pt')
model = DQN()
model.load_state_dict(checkpoint['model'])
model.eval()
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading