Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4a60b30
Refactor DrivePolicy architecture and configuration
vcharraut Jun 2, 2026
a9952e5
Update tests
vcharraut Jun 2, 2026
212fd0e
Potential fix for pull request finding
vcharraut Jun 2, 2026
b3c91f0
Rename mask_padded_observations to mask_padded_features for consistency
vcharraut Jun 2, 2026
7186131
Rename OBS_COUNT_FEATURES to OBS_SLOT_NUM_TYPES for consistency acros…
vcharraut Jun 2, 2026
e1d817f
Refactor conditioning terminology to target in training, inference, a…
vcharraut Jun 2, 2026
77ca722
Add neural network architecture notebook and initialize notebooks pac…
vcharraut Jun 3, 2026
69e0f25
Merge branch 'emerge/temp_training' into vcha/encoders
vcharraut Jun 3, 2026
0961e3a
Rename target_dim to goal_dim for clarity in Drive class
vcharraut Jun 3, 2026
03cc538
Update inference and architecture notebooks to include mask_padded_fe…
vcharraut Jun 3, 2026
4e58499
Update CI configuration and test data for improved performance and co…
vcharraut Jun 3, 2026
f6fcbf4
Merge remote-tracking branch 'emerge/emerge/temp_training' into vcha/…
vcharraut Jun 3, 2026
3f78356
Merge remote-tracking branch 'emerge/emerge/temp_training' into vcha/…
vcharraut Jun 5, 2026
a9ee68a
Add carla_lhs binaries for Town01 to Town10HD
vcharraut Jun 5, 2026
c2326a0
Update .dockerignore to include additional directories and file types
vcharraut Jun 5, 2026
2fdfcad
Refactor partner observation features to replace relative velocity wi…
vcharraut Jun 5, 2026
3016104
Update golden JSON files with revised environment metrics and loss va…
vcharraut Jun 5, 2026
a341ba1
Merge branch 'emerge/temp_training' into vcha/encoders
vcharraut Jun 5, 2026
0161adb
Update golden JSON files and refactor input sizes in training configu…
vcharraut Jun 5, 2026
b80a634
Add unit test for encoding and pooling of masked padded objects in Dr…
vcharraut Jun 5, 2026
16667c9
Renamed encoder to context
vcharraut Jun 5, 2026
c94955e
Rename 'conditioning' to 'context' in inference processing
vcharraut Jun 5, 2026
2069382
Rename OBS_SLOT_NUM_TYPES to OBS_VALID_COUNT_FEATURES across multiple…
vcharraut Jun 5, 2026
9d3fcc5
Merge branch 'emerge/temp_training' into vcha/encoders
vcharraut Jun 11, 2026
6156be2
Add configuration files for reward-dense and reward-sparse environments
vcharraut Jun 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,42 @@ build
*.egg-info
experiments
wandb
.neptune
.pytest_cache
.ruff_cache

benchmark*/
!pufferlib/ocean/benchmark/
!pufferlib/ocean/benchmark/**
runs*/
weights/
checkpoints/
/data/
!/tests/smoke_tests/data/
!/tests/smoke_tests/data/**
/artifacts/
external/

pufferlib/resources/drive/binaries/*/
!pufferlib/resources/drive/binaries/carla/
!pufferlib/resources/drive/binaries/carla/**
!pufferlib/resources/drive/binaries/dense/
!pufferlib/resources/drive/binaries/dense/**
!pufferlib/resources/drive/binaries/lateral/
!pufferlib/resources/drive/binaries/lateral/**
!pufferlib/resources/drive/binaries/longitudinal/
!pufferlib/resources/drive/binaries/longitudinal/**
!pufferlib/resources/drive/binaries/nuplan/
!pufferlib/resources/drive/binaries/nuplan/**
!pufferlib/resources/drive/binaries/obstacles/
!pufferlib/resources/drive/binaries/obstacles/**
!pufferlib/resources/drive/binaries/vru/
!pufferlib/resources/drive/binaries/vru/**

pufferlib/resources/drive/output*.gif
pufferlib/resources/drive/pufferdrive_*.gif
*.mp4
*.mov
*.webm
*.avi
*.mkv
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
PIP_NO_CACHE_DIR: 1
run: |
sudo apt-get update && sudo apt-get install -y build-essential cmake
python -m pip install -U pip pytest jupytext nbclient ipykernel ipywidgets
python -m pip install -U pip pytest jupytext nbclient ipykernel
pip install -e . --no-cache-dir
python setup.py build_ext --inplace --force

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ pufferlib/ocean/impulse_wars/benchmark/
data/
pufferlib/resources/drive/binaries/*/
!pufferlib/resources/drive/binaries/carla/
!pufferlib/resources/drive/binaries/carla_lhs/
!pufferlib/resources/drive/binaries/carla/**
# Re-ignore .DS_Store inside carla binaries
pufferlib/resources/drive/binaries/carla/.DS_Store
Expand Down
3 changes: 2 additions & 1 deletion notebooks/01_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
idx += env.obs_slots_traffic_controls_n * env.traffic_control_features
assert np.allclose(traffic_manual, traffic), "traffic mismatch"

idx += env.obs_valid_count_features
assert idx == obs.shape[1], f"obs size mismatch: used {idx}, total {obs.shape[1]}"
print(f"All slices match. Total features used: {idx}")

Expand Down Expand Up @@ -175,7 +176,7 @@
"width",
"heading_cos",
"heading_sin",
"speed",
"sim_speed_signed",
"seconds_stopped",
]
active_mask = ~np.all(partners == 0, axis=1)
Expand Down
16 changes: 8 additions & 8 deletions notebooks/04_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@
f"ego_obs: shape={ego_obs.shape}, NaN={torch.isnan(ego_obs).sum().item()}, range=[{ego_obs.min():.3f}, {ego_obs.max():.3f}]"
)

cond_dim = backbone.conditioning_dim
if cond_dim > 0:
cond_obs = x[:, slide_idx : slide_idx + cond_dim]
slide_idx += cond_dim
print(f"cond_obs: shape={cond_obs.shape}, NaN={torch.isnan(cond_obs).sum().item()}")
context_dim = backbone.context_dim
if context_dim > 0:
context_obs = x[:, slide_idx : slide_idx + context_dim]
slide_idx += context_dim
print(f"context_obs: shape={context_obs.shape}, NaN={torch.isnan(context_obs).sum().item()}")

partner_dim = env.obs_slots_partners_n * env.partner_features
lane_dim = env.obs_slots_lane_kept * env.road_features
Expand Down Expand Up @@ -141,11 +141,11 @@
f"{name:>10s}_enc: NaN={torch.isnan(enc).sum().item()}, dead={((enc.abs().sum(dim=0) == 0).sum().item())}, range=[{enc.min():.3f}, {enc.max():.3f}]"
)

if cond_dim > 0:
if context_dim > 0:
with torch.no_grad():
cond_enc = backbone.conditioning_encoder(cond_obs)
context_enc = backbone.context_encoder(context_obs)
print(
f"{'cond':>10s}_enc: NaN={torch.isnan(cond_enc).sum().item()}, dead={((cond_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{cond_enc.min():.3f}, {cond_enc.max():.3f}]"
f"{'context':>10s}_enc: NaN={torch.isnan(context_enc).sum().item()}, dead={((context_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{context_enc.min():.3f}, {context_enc.max():.3f}]"
)

# %% [markdown]
Expand Down
50 changes: 28 additions & 22 deletions notebooks/05_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON):
# - **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit
# - **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints
# - **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint
# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed, seconds_stopped
# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, sim_speed_signed, seconds_stopped
# - **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin
# - **Boundaries** (MAX_BOUNDS x 7): same as lanes
# - **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state
Expand Down Expand Up @@ -351,11 +351,11 @@ def layer_stats(name, arr):
"width",
"heading_cos",
"heading_sin",
"speed",
"sim_speed_signed",
"seconds_stopped",
]
for p in range(min(int(n_visible), 5)):
vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(len(partner_labels)))
vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(env.partner_features))
print(f" [{p}] {vals}")
if n_visible > 5:
print(f" ... ({n_visible - 5} more)")
Expand Down Expand Up @@ -631,12 +631,14 @@ def unpack_all_timesteps(bufs, agent_idx):
for i in range(partners.shape[0]):
if np.allclose(partners[i], 0):
continue
rx, ry, rz, w, l, hc, hs, vx, vy = partners[i]
rx, ry, rz, length, width, hc, hs, speed, _ = partners[i]
heading = np.arctan2(hs, hc)
rect = Rectangle((-l / 2, -w / 2), l, w, facecolor="orange", edgecolor="black", alpha=0.6, zorder=9)
rect = Rectangle(
(-length / 2, -width / 2), length, width, facecolor="orange", edgecolor="black", alpha=0.6, zorder=9
)
rect.set_transform(plt.matplotlib.transforms.Affine2D().rotate(heading).translate(rx, ry) + ax.transData)
ax.add_patch(rect)
ax.annotate(f"{vx:.2f}, {vy:.2f}", (rx, ry), fontsize=7, ha="center", color="darkred", zorder=12)
ax.annotate(f"{speed:.2f}", (rx, ry), fontsize=7, ha="center", color="darkred", zorder=12)
part_mask = np.any(partners != 0, axis=1)
if part_mask.any():
ax.scatter(
Expand Down Expand Up @@ -773,7 +775,7 @@ def unpack_all_timesteps(bufs, agent_idx):
"width",
"heading_cos",
"heading_sin",
"speed",
"sim_speed_signed",
"seconds_stopped",
]
obs_slots_partners_n = env.obs_slots_partners_n
Expand All @@ -799,7 +801,7 @@ def unpack_all_timesteps(bufs, agent_idx):
f"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)"
)

fig, axes = plt.subplots(3, 3, figsize=(21, 10))
fig, axes = plt.subplots(3, 4, figsize=(21, 11))
axes = axes.flatten()

for i, label in enumerate(partner_labels):
Expand All @@ -811,12 +813,16 @@ def unpack_all_timesteps(bufs, agent_idx):
axes[i].tick_params(labelsize=7)

# rel_x vs rel_y scatter in last panel
axes[8].scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange")
axes[8].set_xlabel("rel_x")
axes[8].set_ylabel("rel_y")
axes[8].set_title("Partner positions (ego frame)")
axes[8].set_aspect("equal")
axes[8].grid(True, alpha=0.3)
pos_ax = axes[len(partner_labels)]
pos_ax.scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange")
pos_ax.set_xlabel("rel_x")
pos_ax.set_ylabel("rel_y")
pos_ax.set_title("Partner positions (ego frame)")
pos_ax.set_aspect("equal")
pos_ax.grid(True, alpha=0.3)

for ax in axes[len(partner_labels) + 1 :]:
ax.axis("off")

fig.suptitle("Partner features: all visible, full rollout", fontsize=13)
plt.tight_layout()
Expand Down Expand Up @@ -1303,7 +1309,7 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam):
# %% [markdown]
# ## Encoder analysis — what the policy encodes
#
# Each obs layer has its own encoder projecting raw features → `input_size` embedding:
# Each obs layer has its own encoder projecting raw features → embedding width:
# - **ego** and **conditioning** (reward coefs + target): single vector, no pooling.
# - **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed.
#
Expand Down Expand Up @@ -1343,8 +1349,8 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam):
True,
)
)
if bb.conditioning_dim > 0:
enc_inventory.append(("conditioning", bb.conditioning_encoder, bb.conditioning_dim, 1, False))
if bb.context_dim > 0:
enc_inventory.append(("context", bb.context_encoder, bb.context_dim, 1, False))

enc_names = [n for n, *_ in enc_inventory]
set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set]
Expand All @@ -1354,10 +1360,10 @@ def compute_gae(rewards, values, terminals, truncations, gamma, lam):
for name, mod, rin, nslots, is_set in enc_inventory:
nparam = sum(p.numel() for p in mod.parameters())
print(
f"{name:>13s} | {rin:>6d} | {bb.input_size:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}"
f"{name:>13s} | {rin:>6d} | {mod[-1].out_features:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}"
)
print(
f"\nBackbone input = {len(enc_inventory)} x {bb.input_size} = {len(enc_inventory) * bb.input_size} -> backbone -> {bb.out_dim}"
f"\nBackbone input = {sum(mod[-1].out_features for _, mod, _, _, _ in enc_inventory)} -> backbone -> {bb.out_dim}"
)

# Capture pre-pool encoder outputs via forward hooks
Expand All @@ -1383,7 +1389,7 @@ def fn(m, i, o):
lane_dim = bb.obs_slots_lane_kept * bb.road_features_count
boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count
traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count
_s = ego_dim + bb.conditioning_dim
_s = ego_dim + bb.context_dim
sl = {}
sl["partner"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count)
_s += partner_dim
Expand Down Expand Up @@ -1413,10 +1419,10 @@ def fn(m, i, o):
masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf)
vm = (~pad[name]).any(dim=1)
valid_sample[name] = vm
winners[name] = masked.max(dim=1).indices # (B, input_size): winning slot per dim
winners[name] = masked.max(dim=1).indices # (B, embedding dim): winning slot per dim
pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values))

for name in ("ego", "conditioning"):
for name in ("ego", "context"):
if name in enc_names:
pooled[name] = captured[name]

Expand Down
Loading
Loading