-
Notifications
You must be signed in to change notification settings - Fork 113
Project Submitted By - Rishit Aggarwal (23BAI10329) Lakshya Mangla (23BAI10814) Tanmay Singh (23BAI10328) #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # DDPG Robot Path Planning (Continuous Space) | ||
| # Beginner-friendly single file implementation | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.optim as optim | ||
| import random | ||
|
|
||
| # ---------------- ENVIRONMENT ---------------- | ||
| class RobotEnv: | ||
| def __init__(self): | ||
| self.state_dim = 2 # (x, y) | ||
| self.action_dim = 2 # move dx, dy | ||
| self.max_steps = 100 | ||
| self.reset() | ||
|
|
||
| def reset(self): | ||
| self.pos = np.array([0.0, 0.0]) | ||
| self.goal = np.array([5.0, 5.0]) | ||
| self.obstacle = np.array([2.5, 2.5]) | ||
| self.steps = 0 | ||
| return self.pos | ||
|
|
||
| def step(self, action): | ||
| self.pos = self.pos + action | ||
| self.steps += 1 | ||
|
|
||
| # distance to goal | ||
| dist_goal = np.linalg.norm(self.pos - self.goal) | ||
| dist_obs = np.linalg.norm(self.pos - self.obstacle) | ||
|
|
||
| reward = -dist_goal | ||
|
|
||
| # penalty if near obstacle | ||
| if dist_obs < 1.0: | ||
| reward -= 5 | ||
|
|
||
| done = dist_goal < 0.5 or self.steps >= self.max_steps | ||
|
|
||
| return self.pos, reward, done | ||
|
|
||
| # ---------------- REPLAY BUFFER ---------------- | ||
| class ReplayBuffer: | ||
| def __init__(self, capacity=10000): | ||
| self.buffer = [] | ||
| self.capacity = capacity | ||
|
|
||
| def add(self, s, a, r, s2, d): | ||
| if len(self.buffer) >= self.capacity: | ||
| self.buffer.pop(0) | ||
| self.buffer.append((s, a, r, s2, d)) | ||
|
|
||
| def sample(self, batch_size): | ||
| batch = random.sample(self.buffer, batch_size) | ||
| s, a, r, s2, d = zip(*batch) | ||
| return np.array(s), np.array(a), np.array(r), np.array(s2), np.array(d) | ||
|
|
||
| # ---------------- NETWORKS ---------------- | ||
| class Actor(nn.Module): | ||
| def __init__(self, state_dim, action_dim): | ||
| super().__init__() | ||
| self.net = nn.Sequential( | ||
| nn.Linear(state_dim, 64), | ||
| nn.ReLU(), | ||
| nn.Linear(64, action_dim), | ||
| nn.Tanh() | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.net(x) | ||
|
|
||
| class Critic(nn.Module): | ||
| def __init__(self, state_dim, action_dim): | ||
| super().__init__() | ||
| self.net = nn.Sequential( | ||
| nn.Linear(state_dim + action_dim, 64), | ||
| nn.ReLU(), | ||
| nn.Linear(64, 1) | ||
| ) | ||
|
|
||
| def forward(self, s, a): | ||
| return self.net(torch.cat([s, a], dim=1)) | ||
|
|
||
| # ---------------- DDPG AGENT ---------------- | ||
| class DDPG: | ||
| def __init__(self, state_dim, action_dim): | ||
| self.actor = Actor(state_dim, action_dim) | ||
| self.critic = Critic(state_dim, action_dim) | ||
|
|
||
| self.target_actor = Actor(state_dim, action_dim) | ||
| self.target_critic = Critic(state_dim, action_dim) | ||
|
|
||
|
Comment on lines
+91
to
+93
|
||
| self.actor_opt = optim.Adam(self.actor.parameters(), lr=1e-3) | ||
| self.critic_opt = optim.Adam(self.critic.parameters(), lr=1e-3) | ||
|
|
||
| self.buffer = ReplayBuffer() | ||
| self.gamma = 0.99 | ||
| self.tau = 0.005 | ||
|
|
||
| def select_action(self, state): | ||
| state = torch.FloatTensor(state).unsqueeze(0) | ||
| action = self.actor(state).detach().numpy()[0] | ||
| return action | ||
|
Comment on lines
+101
to
+104
|
||
|
|
||
| def train(self, batch_size=32): | ||
| if len(self.buffer.buffer) < batch_size: | ||
| return | ||
|
|
||
| s, a, r, s2, d = self.buffer.sample(batch_size) | ||
|
|
||
| s = torch.FloatTensor(s) | ||
| a = torch.FloatTensor(a) | ||
| r = torch.FloatTensor(r).unsqueeze(1) | ||
| s2 = torch.FloatTensor(s2) | ||
| d = torch.FloatTensor(d).unsqueeze(1) | ||
|
|
||
| # Critic loss | ||
| target_a = self.target_actor(s2) | ||
| target_q = self.target_critic(s2, target_a) | ||
| y = r + self.gamma * target_q * (1 - d) | ||
|
|
||
| critic_loss = nn.MSELoss()(self.critic(s, a), y.detach()) | ||
|
Comment on lines
+118
to
+123
|
||
|
|
||
| self.critic_opt.zero_grad() | ||
| critic_loss.backward() | ||
| self.critic_opt.step() | ||
|
|
||
| # Actor loss | ||
| actor_loss = -self.critic(s, self.actor(s)).mean() | ||
|
|
||
| self.actor_opt.zero_grad() | ||
| actor_loss.backward() | ||
| self.actor_opt.step() | ||
|
Comment on lines
+129
to
+134
|
||
|
|
||
| # Soft update | ||
| for target, source in zip(self.target_actor.parameters(), self.actor.parameters()): | ||
| target.data.copy_(self.tau * source.data + (1 - self.tau) * target.data) | ||
|
|
||
| for target, source in zip(self.target_critic.parameters(), self.critic.parameters()): | ||
| target.data.copy_(self.tau * source.data + (1 - self.tau) * target.data) | ||
|
|
||
| # ---------------- TRAINING ---------------- | ||
| if __name__ == "__main__": | ||
| env = RobotEnv() | ||
| agent = DDPG(env.state_dim, env.action_dim) | ||
|
|
||
| episodes = 200 | ||
|
|
||
| for ep in range(episodes): | ||
| state = env.reset() | ||
| total_reward = 0 | ||
|
|
||
| for _ in range(100): | ||
| action = agent.select_action(state) | ||
|
Comment on lines
+148
to
+155
|
||
| next_state, reward, done = env.step(action) | ||
|
|
||
| agent.buffer.add(state, action, reward, next_state, done) | ||
| agent.train() | ||
|
Comment on lines
+154
to
+159
|
||
|
|
||
| state = next_state | ||
| total_reward += reward | ||
|
|
||
| if done: | ||
| break | ||
|
|
||
| print(f"Episode {ep+1}, Reward: {total_reward:.2f}") | ||
|
|
||
| print("Training Complete!") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReplayBuffer evicts items with list.pop(0), which is O(n) per eviction and will slow down once capacity is reached. Consider using collections.deque with a maxlen (or a ring buffer) to make appends/evictions O(1).