Q-Learning & Deep Q-Networks
Q-Learning is one of the most important RL algorithms. It learns the Q-function (action-value function) directly, without needing a model of the environment.
Q-Table Learning
For environments with small, discrete state and action spaces, we can represent Q(s, a) as a table:
| State | Left | Right |
|---|---|---|
| s0 | 0.5 | 1.2 |
| s1 | 2.1 | 0.3 |
| s2 | 0.0 | 3.5 |
The Q-Learning Update Rule
After taking action *a* in state *s*, observing reward *r* and next state *s'*:
**Q(s, a) <- Q(s, a) + alpha * [r + gamma * max_a' Q(s', a') - Q(s, a)]
Where:
Off-Policy vs On-Policy
1import numpy as np
2
3class QLearningAgent:
4 """Tabular Q-Learning agent."""
5
6 def __init__(self, n_states, n_actions, lr=0.1, gamma=0.99, epsilon=1.0):
7 self.q_table = np.zeros((n_states, n_actions))
8 self.lr = lr
9 self.gamma = gamma
10 self.epsilon = epsilon
11
12 def select_action(self, state):
13 if np.random.random() < self.epsilon:
14 return np.random.randint(self.q_table.shape[1])
15 return np.argmax(self.q_table[state])
16
17 def update(self, state, action, reward, next_state, done):
18 """Apply the Q-learning update rule."""
19 if done:
20 td_target = reward
21 else:
22 td_target = reward + self.gamma * np.max(self.q_table[next_state])
23
24 td_error = td_target - self.q_table[state, action]
25 self.q_table[state, action] += self.lr * td_error
26
27# Example: Simple environment with 16 states and 4 actions
28agent = QLearningAgent(n_states=16, n_actions=4)
29print(f"Q-table shape: {agent.q_table.shape}")
30print(f"Initial Q-values for state 0: {agent.q_table[0]}")From Q-Tables to Deep Q-Networks (DQN)
Q-tables don't scale. If your state is an image (e.g., Atari game frames at 210x160 pixels), the table would need an astronomically large number of rows. DQN replaces the table with a neural network that approximates Q(s, a).
DQN Architecture
Key DQN Innovations
1. Experience Replay Instead of learning from consecutive experiences (which are correlated), store transitions in a replay buffer and sample random mini-batches. This breaks correlations and improves data efficiency.
2. Target Network Use a separate, slowly updated copy of the Q-network for computing TD targets. This stabilizes training by preventing the targets from shifting rapidly.
1import numpy as np
2from collections import deque
3import random
4
5class ReplayBuffer:
6 """Experience replay buffer for DQN."""
7
8 def __init__(self, capacity=10000):
9 self.buffer = deque(maxlen=capacity)
10
11 def push(self, state, action, reward, next_state, done):
12 self.buffer.append((state, action, reward, next_state, done))
13
14 def sample(self, batch_size=32):
15 batch = random.sample(self.buffer, batch_size)
16 states, actions, rewards, next_states, dones = zip(*batch)
17 return (
18 np.array(states),
19 np.array(actions),
20 np.array(rewards, dtype=np.float32),
21 np.array(next_states),
22 np.array(dones, dtype=np.float32),
23 )
24
25 def __len__(self):
26 return len(self.buffer)
27
28# Demo
29buffer = ReplayBuffer(capacity=5000)
30for i in range(100):
31 state = np.random.randn(4)
32 action = np.random.randint(2)
33 reward = np.random.randn()
34 next_state = np.random.randn(4)
35 done = i % 20 == 0
36 buffer.push(state, action, reward, next_state, done)
37
38states, actions, rewards, next_states, dones = buffer.sample(8)
39print(f"Sampled batch - states shape: {states.shape}, actions: {actions}")Double DQN
Standard DQN tends to overestimate Q-values because it uses the max operator for both selecting and evaluating actions. Double DQN fixes this by decoupling selection and evaluation:
Standard DQN target: r + gamma * max_a' Q_target(s', a')
Double DQN target: r + gamma * Q_target(s', argmax_a' Q_online(s', a'))
The online network selects the best action, but the target network evaluates it. This significantly reduces overestimation and improves performance.
DQN Timeline
1import numpy as np
2
3# Pseudocode for DQN training loop (PyTorch-style)
4"""
5import torch
6import torch.nn as nn
7
8class DQN(nn.Module):
9 def __init__(self, state_dim, action_dim):
10 super().__init__()
11 self.network = nn.Sequential(
12 nn.Linear(state_dim, 128),
13 nn.ReLU(),
14 nn.Linear(128, 128),
15 nn.ReLU(),
16 nn.Linear(128, action_dim),
17 )
18
19 def forward(self, x):
20 return self.network(x)
21
22# Training loop pseudocode
23online_net = DQN(state_dim=4, action_dim=2)
24target_net = DQN(state_dim=4, action_dim=2)
25target_net.load_state_dict(online_net.state_dict())
26
27optimizer = torch.optim.Adam(online_net.parameters(), lr=1e-3)
28buffer = ReplayBuffer(capacity=10000)
29
30for episode in range(1000):
31 state = env.reset()
32 done = False
33
34 while not done:
35 # Epsilon-greedy action selection using online_net
36 action = select_action(state, online_net, epsilon)
37 next_state, reward, done, _ = env.step(action)
38 buffer.push(state, action, reward, next_state, done)
39
40 if len(buffer) >= batch_size:
41 # Sample batch and compute loss
42 states, actions, rewards, next_states, dones = buffer.sample(32)
43
44 # Double DQN target
45 with torch.no_grad():
46 best_actions = online_net(next_states).argmax(dim=1)
47 target_q = target_net(next_states).gather(1, best_actions.unsqueeze(1))
48 targets = rewards + gamma * target_q.squeeze() * (1 - dones)
49
50 current_q = online_net(states).gather(1, actions.unsqueeze(1)).squeeze()
51 loss = nn.MSELoss()(current_q, targets)
52
53 optimizer.zero_grad()
54 loss.backward()
55 optimizer.step()
56
57 # Update target network periodically
58 if step % target_update_freq == 0:
59 target_net.load_state_dict(online_net.state_dict())
60
61 state = next_state
62"""
63print("DQN architecture and training loop defined (pseudocode)")
64print("Key components: online net, target net, replay buffer, Double DQN targets")