Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# 🚦 Adaptive Traffic Signal Control using Reinforcement Learning

An intelligent traffic signal controller that dynamically adjusts signal timings based on real-time traffic conditions using Reinforcement Learning.

## 🎯 Overview

Traditional traffic signals use fixed timers that don't adapt to traffic flow. This project implements RL agents (Q-Learning and DQN) that **learn** optimal signal switching policies through interaction with a simulated 4-way intersection, reducing congestion and waiting time.

## 📁 Project Structure

```
traffic_signal_rl/
├── traffic_env.py # Custom Gymnasium environment
├── q_learning_agent.py # Tabular Q-Learning agent
├── dqn_agent.py # Deep Q-Network (PyTorch) agent
├── fixed_time_controller.py # Fixed-time baseline
├── train.py # Training script (CLI)
├── evaluate.py # Evaluation & plot generation
├── visualize.py # Pygame real-time visualisation
├── report.md # Full project report
├── requirements.txt # Python dependencies
└── results/ # Models, metrics, and plots (generated)
```

## 🚀 Quick Start

### 1. Install Dependencies

```bash
pip install -r requirements.txt
```

### 2. Train an Agent

```bash
# Train DQN (default, recommended)
python train.py --agent dqn --episodes 500 --traffic medium

# Train Q-Learning
python train.py --agent qlearning --episodes 500 --traffic medium

# Train on high traffic
python train.py --agent dqn --episodes 500 --traffic high
```

### 3. Evaluate & Generate Plots

```bash
python evaluate.py --agent dqn --traffic medium
```

This generates:
- Training curves (reward, waiting time, queue length)
- RL vs Fixed-Time comparison bar chart
- Multi-traffic-level comparison

Plots are saved to `results/plots/`.

### 4. Real-Time Visualisation

```bash
# Watch trained DQN agent
python visualize.py --agent dqn --traffic medium

# Compare with fixed-time controller
python visualize.py --agent fixed --traffic medium
```

**Controls:** `R` = restart episode, `ESC` = quit.

## ⚙️ Environment Design

| Component | Description |
|-----------|-------------|
| **Intersection** | 4-way (N, S, E, W) |
| **State** | `[N_queue, S_queue, E_queue, W_queue, phase, time_in_phase]` |
| **Actions** | `0 = Keep signal`, `1 = Switch signal` |
| **Reward** | `-(queue + α × waiting_time) - β × switch_penalty` |
| **Traffic profiles** | `low`, `medium`, `high` (Poisson arrivals) |

## 🤖 Agents

- **Q-Learning** — Tabular with state discretisation & ε-greedy exploration
- **DQN** — 3-layer MLP, experience replay, target network, gradient clipping

## 📊 Expected Results

After training, the RL agent typically reduces average waiting time by **20–40%** compared to fixed-time control, with smoother traffic flow and adaptive behaviour under varying congestion levels.
169 changes: 169 additions & 0 deletions dqn_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
Deep Q-Network (DQN) agent for traffic signal control.

Uses a small MLP, experience replay, and a target network.
"""

import random
from collections import deque
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim


# ── Neural network ────────────────────────────────────────────────────
class QNetwork(nn.Module):
"""Simple 3-layer MLP."""

def __init__(self, state_dim: int = 6, n_actions: int = 2, hidden: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, n_actions),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


# ── Replay buffer ────────────────────────────────────────────────────
class ReplayBuffer:
def __init__(self, capacity: int = 10_000):
self.buffer: deque = deque(maxlen=capacity)

def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))

def sample(self, batch_size: int):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states, dtype=np.float32),
np.array(actions, dtype=np.int64),
np.array(rewards, dtype=np.float32),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32),
)

def __len__(self):
return len(self.buffer)


# ── DQN Agent ─────────────────────────────────────────────────────────
class DQNAgent:
"""Deep Q-Network agent with experience replay and target network."""

def __init__(
self,
state_dim: int = 6,
n_actions: int = 2,
hidden: int = 64,
lr: float = 1e-3,
gamma: float = 0.99,
epsilon: float = 1.0,
epsilon_min: float = 0.01,
epsilon_decay: float = 0.995,
buffer_size: int = 10_000,
batch_size: int = 64,
target_update_freq: int = 10, # episodes between target sync
):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.n_actions = n_actions
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update_freq = target_update_freq

# Networks
self.policy_net = QNetwork(state_dim, n_actions, hidden).to(self.device)
self.target_net = QNetwork(state_dim, n_actions, hidden).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()

self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.loss_fn = nn.SmoothL1Loss()

self.replay = ReplayBuffer(buffer_size)
self._episode_counter = 0

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def select_action(self, state: np.ndarray, training: bool = True) -> int:
if training and random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad():
t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q = self.policy_net(t)
return int(q.argmax(dim=1).item())

def store_transition(self, state, action, reward, next_state, done):
self.replay.push(state, action, reward, next_state, done)

def train_step(self):
"""One gradient step from a mini-batch."""
if len(self.replay) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.replay.sample(self.batch_size)

states_t = torch.FloatTensor(states).to(self.device)
actions_t = torch.LongTensor(actions).to(self.device)
rewards_t = torch.FloatTensor(rewards).to(self.device)
next_states_t = torch.FloatTensor(next_states).to(self.device)
dones_t = torch.FloatTensor(dones).to(self.device)

# Current Q values
q_values = self.policy_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

# Target Q values
with torch.no_grad():
next_q = self.target_net(next_states_t).max(dim=1)[0]
target = rewards_t + self.gamma * next_q * (1 - dones_t)

loss = self.loss_fn(q_values, target)
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
return loss.item()

def end_episode(self):
"""Call at the end of each episode for bookkeeping."""
self._episode_counter += 1
self.decay_epsilon()
if self._episode_counter % self.target_update_freq == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())

def decay_epsilon(self):
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def save(self, path: str | Path):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"policy_state": self.policy_net.state_dict(),
"target_state": self.target_net.state_dict(),
"optimizer_state": self.optimizer.state_dict(),
"epsilon": self.epsilon,
"episode": self._episode_counter,
}, str(path))

def load(self, path: str | Path):
ckpt = torch.load(str(path), map_location=self.device, weights_only=True)
self.policy_net.load_state_dict(ckpt["policy_state"])
self.target_net.load_state_dict(ckpt["target_state"])
self.optimizer.load_state_dict(ckpt["optimizer_state"])
self.epsilon = ckpt["epsilon"]
self._episode_counter = ckpt["episode"]
Loading