From 19612c53f34e8d56f962ef08583c93f6feae2eaf Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Mon, 8 Jun 2026 10:35:57 +0200 Subject: [PATCH 1/4] Add obs_lane_stride and obs_boundary_stride parameters for observation configuration --- pufferlib/config/ocean/drive.ini | 3 +++ pufferlib/ocean/drive/binding.c | 2 ++ pufferlib/ocean/drive/drive.h | 29 ++++++++++++++++++++++++----- pufferlib/ocean/drive/drive.py | 12 ++++++++++++ pufferlib/ocean/env_config.h | 6 ++++++ pufferlib/pufferl.py | 2 ++ 6 files changed, 49 insertions(+), 5 deletions(-) diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 6b01fa5e97..61e1c5b067 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -131,6 +131,9 @@ obs_slots_traffic_controls_n = 4 ; Fraction of segment observation slots to drop (reduces obs size) obs_dropout_lane = 0.0 obs_dropout_boundary = 0.0 +; Stride for lane and boundary segments; 1 means use every segment, 2 means use every other segment, etc. +obs_lane_stride = 1 +obs_boundary_stride = 1 ; --- Observation normalization --- obs_norm_goal_offset_m = 120.0 obs_norm_xy_offset_m = 120.0 diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index afc952acca..336956a10c 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -1979,6 +1979,8 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->obs_slots_partners_n = (int) unpack(kwargs, "obs_slots_partners_n"); env->obs_slots_traffic_controls_n = (int) unpack(kwargs, "obs_slots_traffic_controls_n"); env->traffic_control_scope = (int) unpack(kwargs, "traffic_control_scope"); + env->obs_lane_stride = (int) unpack(kwargs, "obs_lane_stride"); + env->obs_boundary_stride = (int) unpack(kwargs, "obs_boundary_stride"); env->dt = (float) unpack(kwargs, "dt"); env->spawn_initial_speed = (float) unpack(kwargs, "spawn_initial_speed"); env->goal_speed = (float) unpack(kwargs, "goal_speed"); diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 66c3b570a0..0cbcdc533f 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -290,8 +290,9 @@ struct Log { }; struct GridMapEntity { - int entity_idx; // Index into the road_elements array - int geometry_idx; // Index into element's geometry array + int entity_idx; // Index into the road_elements array + int geometry_idx; // Index into element's geometry array + int valid_for_obs; // Whether this entity should be included in observations }; struct GridMap { @@ -328,6 +329,8 @@ struct SharedMapData { GridMap *grid_map; int *neighbor_offsets; struct LaneGraph lane_graph; + int obs_lane_stride; + int obs_boundary_stride; int ref_count; pid_t owner_pid; }; @@ -440,6 +443,8 @@ struct Drive { int obs_slots_partners_n; int obs_slots_traffic_controls_n; int traffic_control_scope; + int obs_lane_stride; + int obs_boundary_stride; int obs_slots_lane_kept; int obs_slots_boundary_kept; int road_dropout_enabled; @@ -706,6 +711,12 @@ static void add_entity_to_grid( env->grid_map->cells[grid_index][count].entity_idx = entity_idx; env->grid_map->cells[grid_index][count].geometry_idx = geometry_idx; + env->grid_map->cells[grid_index][count].valid_for_obs = 1; + if (is_road_lane(env->road_elements[entity_idx].type)) { + env->grid_map->cells[grid_index][count].valid_for_obs = geometry_idx % env->obs_lane_stride == 0; + } else if (is_road_edge(env->road_elements[entity_idx].type)) { + env->grid_map->cells[grid_index][count].valid_for_obs = geometry_idx % env->obs_boundary_stride == 0; + } cell_entities_insert_index[grid_index] = count + 1; } @@ -968,6 +979,7 @@ static int get_neighbors_entities( for (int j = 0; j < count && entity_list_count < max_size; j++) { entity_list[entity_list_count].entity_idx = env->grid_map->cells[neighbor_idx][j].entity_idx; entity_list[entity_list_count].geometry_idx = env->grid_map->cells[neighbor_idx][j].geometry_idx; + entity_list[entity_list_count].valid_for_obs = env->grid_map->cells[neighbor_idx][j].valid_for_obs; entity_list_count += 1; } } @@ -3665,9 +3677,11 @@ void remove_bad_trajectories(Drive *env) { env->timestep = 0; } -static struct SharedMapData *map_cache_lookup(const char *map_name) { +static struct SharedMapData *map_cache_lookup(Drive *env) { for (int i = 0; i < g_map_cache_count; i++) { - if (g_map_cache[i] != NULL && strcmp(g_map_cache[i]->map_name, map_name) == 0) { + if (g_map_cache[i] != NULL && strcmp(g_map_cache[i]->map_name, env->map_name) == 0 + && g_map_cache[i]->obs_lane_stride == env->obs_lane_stride + && g_map_cache[i]->obs_boundary_stride == env->obs_boundary_stride) { return g_map_cache[i]; } } @@ -3727,7 +3741,7 @@ void init(Drive *env) { env->timestep = 0; env->shared_map = NULL; - struct SharedMapData *shared = env->use_map_cache ? map_cache_lookup(env->map_name) : NULL; + struct SharedMapData *shared = env->use_map_cache ? map_cache_lookup(env) : NULL; if (shared != NULL) { // Cache hit: load only the per-env data (agents, traffic-control elements), // then discard the freshly-loaded geometry and borrow the shared copy. @@ -3764,6 +3778,8 @@ void init(Drive *env) { entry->grid_map = env->grid_map; entry->neighbor_offsets = env->neighbor_offsets; entry->lane_graph = env->lane_graph; + entry->obs_lane_stride = env->obs_lane_stride; + entry->obs_boundary_stride = env->obs_boundary_stride; entry->ref_count = 1; entry->owner_pid = getpid(); map_cache_insert(entry); @@ -4769,6 +4785,9 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int * if (lanes_found >= env->obs_slots_lane_n && boundaries_found >= env->obs_slots_boundary_n) { break; } + if (!neighbor_entities[k].valid_for_obs) { + continue; + } int entity_idx = neighbor_entities[k].entity_idx; int geometry_idx = neighbor_entities[k].geometry_idx; RoadMapElement *road_element = &env->road_elements[entity_idx]; diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 61da6c8d10..3dcc867ea5 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -88,6 +88,8 @@ def __init__( split_network=False, obs_slots_lane_n=32, obs_slots_boundary_n=32, + obs_lane_stride=1, + obs_boundary_stride=1, obs_slots_partners_n=16, obs_slots_traffic_controls_n=4, traffic_control_scope=0, @@ -181,8 +183,16 @@ def __init__( self.ego_features = binding.EGO_FEATURES # Extract observation shapes from constants + obs_lane_stride = int(obs_lane_stride) + obs_boundary_stride = int(obs_boundary_stride) + if obs_lane_stride < 1: + raise ValueError(f"obs_lane_stride must be >= 1. Got: {obs_lane_stride}") + if obs_boundary_stride < 1: + raise ValueError(f"obs_boundary_stride must be >= 1. Got: {obs_boundary_stride}") self.obs_slots_lane_n = obs_slots_lane_n self.obs_slots_boundary_n = obs_slots_boundary_n + self.obs_lane_stride = obs_lane_stride + self.obs_boundary_stride = obs_boundary_stride self.obs_slots_partners_n = obs_slots_partners_n self.traffic_control_scope = traffic_control_scope self.obs_slots_traffic_controls_n = obs_slots_traffic_controls_n @@ -435,6 +445,8 @@ def _env_init_kwargs(self, map_file, max_agents): "goal_on_lane": self.goal_on_lane, "obs_slots_lane_n": self.obs_slots_lane_n, "obs_slots_boundary_n": self.obs_slots_boundary_n, + "obs_lane_stride": self.obs_lane_stride, + "obs_boundary_stride": self.obs_boundary_stride, "obs_slots_partners_n": self.obs_slots_partners_n, "obs_slots_traffic_controls_n": self.obs_slots_traffic_controls_n, "traffic_control_scope": self.traffic_control_scope, diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index 89269a83cc..cb4d14a2e0 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -49,6 +49,8 @@ typedef struct { int max_agents_per_env; int obs_slots_lane_n; int obs_slots_boundary_n; + int obs_lane_stride; + int obs_boundary_stride; float obs_dropout_lane; float obs_dropout_boundary; int obs_slots_partners_n; @@ -195,6 +197,10 @@ static int handler(void *config, const char *section, const char *name, const ch env_config->obs_slots_boundary_n = atoi(value); } else if (MATCH("env", "obs_slots_lane_n")) { env_config->obs_slots_lane_n = atoi(value); + } else if (MATCH("env", "obs_lane_stride")) { + env_config->obs_lane_stride = atoi(value); + } else if (MATCH("env", "obs_boundary_stride")) { + env_config->obs_boundary_stride = atoi(value); } else if (MATCH("env", "obs_dropout_lane")) { env_config->obs_dropout_lane = atof(value); } else if (MATCH("env", "obs_dropout_boundary")) { diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 8f505a6533..887f5dcf6a 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -1388,6 +1388,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop "trajectory_scaling_factors", "obs_slots_boundary_n", "obs_slots_lane_n", + "obs_boundary_stride", + "obs_lane_stride", "obs_dropout_boundary", "obs_dropout_lane", "obs_slots_partners_n", From 113a69e268cba0b504423c4e257ad3b6357bb2fd Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Tue, 9 Jun 2026 11:19:18 +0200 Subject: [PATCH 2/4] Add observation stride parameters and update related configurations --- notebooks/01_observations.py | 18 +++++++-- notebooks/02_rewards.py | 3 +- notebooks/03_metrics.py | 3 +- notebooks/05_inference.py | 38 +++++++++++++------ notebooks/notebook_utils.py | 2 + pufferlib/ocean/drive/drive.c | 4 ++ pufferlib/ocean/drive/visualize.c | 2 + pufferlib/pufferl.py | 2 + pufferlib/viz.py | 12 +++++- .../test_validation_replay_html.py | 2 + tests/unit_tests/test_drive_config.py | 36 +++++++++++++++++- tests/unit_tests/test_eval_manager.py | 6 +++ tests/unit_tests/test_map_cache.py | 26 ++++++++++++- 13 files changed, 132 insertions(+), 22 deletions(-) diff --git a/notebooks/01_observations.py b/notebooks/01_observations.py index 7a2f99d60b..80906fcdf7 100644 --- a/notebooks/01_observations.py +++ b/notebooks/01_observations.py @@ -38,6 +38,10 @@ print(f"NaN: {np.isnan(obs).sum()}, Inf: {np.isinf(obs).sum()}") print(f"% zeros: {(obs == 0).mean() * 100:.1f}%") print(f"% outside [-1,1]: {((obs < -1) | (obs > 1)).mean() * 100:.2f}%") +print( + f"road obs: lanes={env.obs_slots_lane_kept}/{env.obs_slots_lane_n} stride={env.obs_lane_stride}, " + f"boundaries={env.obs_slots_boundary_kept}/{env.obs_slots_boundary_n} stride={env.obs_boundary_stride}" +) fig, axes = plt.subplots(1, 2, figsize=(14, 4)) axes[0].hist(obs.flatten(), bins=100, edgecolor="black", alpha=0.7) @@ -61,9 +65,11 @@ reward_conditioning=env.reward_conditioning, num_target_waypoints=env.num_target_waypoints, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, ) print(f"ego: {ego.shape} = {ego}") print(f"target: {target.shape}") @@ -273,9 +279,13 @@ reward_conditioning=env.reward_conditioning, num_target_waypoints=env.num_target_waypoints, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, + obs_lane_stride=env.obs_lane_stride, + obs_boundary_stride=env.obs_boundary_stride, ) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(img) diff --git a/notebooks/02_rewards.py b/notebooks/02_rewards.py index 536f2e0850..9457f97576 100644 --- a/notebooks/02_rewards.py +++ b/notebooks/02_rewards.py @@ -30,7 +30,8 @@ f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}" ) print( - f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}" + f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, " + f"road_features={env.road_features}, stride={env.obs_lane_stride}/{env.obs_boundary_stride}" ) print( f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}" diff --git a/notebooks/03_metrics.py b/notebooks/03_metrics.py index 3c04a7dd81..c07f029aa5 100644 --- a/notebooks/03_metrics.py +++ b/notebooks/03_metrics.py @@ -32,7 +32,8 @@ f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}" ) print( - f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}" + f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, " + f"road_features={env.road_features}, stride={env.obs_lane_stride}/{env.obs_boundary_stride}" ) print( f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}" diff --git a/notebooks/05_inference.py b/notebooks/05_inference.py index f074c28c90..a260fb9868 100644 --- a/notebooks/05_inference.py +++ b/notebooks/05_inference.py @@ -194,8 +194,12 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, + obs_lane_stride=env.obs_lane_stride, + obs_boundary_stride=env.obs_boundary_stride, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, obs_norm_goal_offset_m=env.obs_norm_goal_offset_m, obs_norm_xy_offset_m=env.obs_norm_xy_offset_m, @@ -227,8 +231,10 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) ego_features_over_time.append(ego) @@ -275,8 +281,10 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) @@ -464,8 +472,10 @@ def unpack_all_timesteps(bufs, agent_idx): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) egos.append(ego) @@ -540,8 +550,10 @@ def unpack_all_timesteps(bufs, agent_idx): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) dists = np.sqrt(part[:, 0] ** 2 + part[:, 1] ** 2) @@ -566,8 +578,10 @@ def unpack_all_timesteps(bufs, agent_idx): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index df3c8dc063..38071d3014 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -73,6 +73,8 @@ "offroad_behavior": 1, "obs_slots_lane_n": 80, "obs_slots_boundary_n": 80, + "obs_lane_stride": 1, + "obs_boundary_stride": 1, "obs_slots_partners_n": 16, "obs_slots_traffic_controls_n": 4, "obs_dropout_lane": 0.0, diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index c25b61d789..5579740c1b 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -123,6 +123,8 @@ void demo() { .compute_eval_metrics = conf.compute_eval_metrics, .obs_slots_lane_n = conf.obs_slots_lane_n, .obs_slots_boundary_n = conf.obs_slots_boundary_n, + .obs_lane_stride = conf.obs_lane_stride, + .obs_boundary_stride = conf.obs_boundary_stride, .obs_slots_partners_n = conf.obs_slots_partners_n, .obs_slots_traffic_controls_n = conf.obs_slots_traffic_controls_n, .traffic_control_scope = conf.traffic_control_scope, @@ -246,6 +248,8 @@ void performance_test() { .num_max_agents = conf.max_agents_per_env, .obs_slots_lane_n = conf.obs_slots_lane_n, .obs_slots_boundary_n = conf.obs_slots_boundary_n, + .obs_lane_stride = conf.obs_lane_stride, + .obs_boundary_stride = conf.obs_boundary_stride, .obs_slots_partners_n = conf.obs_slots_partners_n, .obs_slots_traffic_controls_n = conf.obs_slots_traffic_controls_n, .traffic_control_scope = conf.traffic_control_scope, diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index 3c60f0c2f9..3c26fe33a1 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -294,6 +294,8 @@ int eval_gif( .collision_behavior = conf.collision_behavior, .offroad_behavior = conf.offroad_behavior, .compute_eval_metrics = conf.compute_eval_metrics, + .obs_lane_stride = conf.obs_lane_stride, + .obs_boundary_stride = conf.obs_boundary_stride, .goal_behavior = goal_behavior, .init_mode = init_mode, .control_mode = control_mode, diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 887f5dcf6a..eeb8364053 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -1583,6 +1583,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop "obs_slots_partners_n", "obs_slots_lane_n", "obs_slots_boundary_n", + "obs_lane_stride", + "obs_boundary_stride", "obs_slots_traffic_controls_n", "obs_dropout_lane", "obs_dropout_boundary", diff --git a/pufferlib/viz.py b/pufferlib/viz.py index eadea0b883..28ffa8f9f4 100644 --- a/pufferlib/viz.py +++ b/pufferlib/viz.py @@ -615,6 +615,8 @@ def plot_observation( obs_slots_traffic_controls_n=4, obs_dropout_lane=0.0, obs_dropout_boundary=0.0, + obs_lane_stride=1, + obs_boundary_stride=1, agent_idx=0, obs_norm_goal_offset_m=100.0, obs_norm_xy_offset_m=100.0, @@ -785,7 +787,7 @@ def plot_observation( ax.text( 0.12, 0.95, - f"Lanes: {count_lane}\nBoundaries: {count_boundary}", + f"Lanes: {count_lane}\nBoundaries: {count_boundary}\nStride: {obs_lane_stride}/{obs_boundary_stride}", transform=ax.transAxes, fontsize=10, verticalalignment="top", @@ -957,6 +959,12 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): "partner_features": int(binding.PARTNER_FEATURES), "lane_count": int(lane_count), "boundary_count": int(boundary_count), + "obs_slots_lane_n": int(env_cfg["obs_slots_lane_n"]), + "obs_slots_boundary_n": int(env_cfg["obs_slots_boundary_n"]), + "obs_dropout_lane": float(env_cfg.get("obs_dropout_lane", 0.0)), + "obs_dropout_boundary": float(env_cfg.get("obs_dropout_boundary", 0.0)), + "obs_lane_stride": int(env_cfg.get("obs_lane_stride", 1)), + "obs_boundary_stride": int(env_cfg.get("obs_boundary_stride", 1)), "traffic_obs_count": int(env_cfg["obs_slots_traffic_controls_n"]), "target_features": 3 if env_cfg.get("target_type", "static") == "static" else 5, "scales": scales, @@ -1029,6 +1037,7 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"):
ID
-
Step
0
Camera
Free Roam
+
Obs Road
-
@@ -1101,6 +1110,7 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): for (const name of Object.keys(H.chunks)) C[name] = chunk(name); document.getElementById('meta-map').textContent = String(H.map_name).split('/').pop(); document.getElementById('meta-id').textContent = H.scenario_id || "-"; + document.getElementById('meta-obs-road').textContent = `L ${H.lane_count}/${H.obs_slots_lane_n} s${H.obs_lane_stride} d${Number(H.obs_dropout_lane).toFixed(2)} | B ${H.boundary_count}/${H.obs_slots_boundary_n} s${H.obs_boundary_stride} d${Number(H.obs_dropout_boundary).toFixed(2)}`; document.getElementById('sld').max = frameMax(); const first = getFrameAgents(0)[0]; if (first) { cam.x = first.x; cam.y = first.y; } document.getElementById('loading-overlay').style.display = 'none'; diff --git a/tests/smoke_tests/test_validation_replay_html.py b/tests/smoke_tests/test_validation_replay_html.py index 751ecd3c82..c5d154ddb1 100644 --- a/tests/smoke_tests/test_validation_replay_html.py +++ b/tests/smoke_tests/test_validation_replay_html.py @@ -103,6 +103,8 @@ def test_html_render_backend_produces_html(tmp_path, backend): "obs_slots_partners_n": 16, "obs_slots_lane_n": 80, "obs_slots_boundary_n": 80, + "obs_lane_stride": 2, + "obs_boundary_stride": 3, "obs_slots_traffic_controls_n": 4, }, "eval": { diff --git a/tests/unit_tests/test_drive_config.py b/tests/unit_tests/test_drive_config.py index 566444557f..2510e68e79 100644 --- a/tests/unit_tests/test_drive_config.py +++ b/tests/unit_tests/test_drive_config.py @@ -8,13 +8,15 @@ import os import sys +import tempfile import unittest from unittest.mock import patch from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from pufferlib.pufferl import load_config, pufferlib +from pufferlib.ocean.drive.drive import Drive +from pufferlib.pufferl import _ARCH_ENV_KEYS, _merge_checkpoint_arch, load_config, pufferlib VERBOSITY = 0 @@ -36,10 +38,42 @@ def test_load_config(self): # load_config should return a populated config dict without raising. self.assertIsInstance(args, dict) self.assertTrue(len(args) > 0) + self.assertEqual(args["env"]["obs_lane_stride"], 1) + self.assertEqual(args["env"]["obs_boundary_stride"], 1) except Exception as err: self.fail(f"load_config failed with an unexpected exception: {err}") + @patch("sys.argv", ["pufferl.py", "--env.obs-lane-stride=3", "--env.obs-boundary-stride=4"]) + def test_obs_stride_cli_override(self): + args = load_config("puffer_drive") + self.assertEqual(args["env"]["obs_lane_stride"], 3) + self.assertEqual(args["env"]["obs_boundary_stride"], 4) + + def test_obs_stride_validation(self): + with self.assertRaisesRegex(ValueError, "obs_lane_stride"): + Drive(obs_lane_stride=0) + with self.assertRaisesRegex(ValueError, "obs_boundary_stride"): + Drive(obs_boundary_stride=0) + + def test_checkpoint_arch_merge_keeps_obs_stride(self): + with tempfile.TemporaryDirectory() as tmp_dir: + exp_dir = Path(tmp_dir) + models_dir = exp_dir / "models" + models_dir.mkdir() + model_path = models_dir / "model.pt" + model_path.touch() + with open(exp_dir / "config.yaml", "w") as f: + f.write("env:\n obs_lane_stride: 5\n obs_boundary_stride: 6\n") + + args = {"env": {"obs_lane_stride": 1, "obs_boundary_stride": 1}, "train": {}} + _merge_checkpoint_arch(args, str(model_path)) + + self.assertIn("obs_lane_stride", _ARCH_ENV_KEYS) + self.assertIn("obs_boundary_stride", _ARCH_ENV_KEYS) + self.assertEqual(args["env"]["obs_lane_stride"], 5) + self.assertEqual(args["env"]["obs_boundary_stride"], 6) + @patch("sys.argv", ["pufferl.py", "--train.learning-rate=0.5"]) def test_cli_override(self): """Test that command-line arguments override INI file values.""" diff --git a/tests/unit_tests/test_eval_manager.py b/tests/unit_tests/test_eval_manager.py index 4d3a470994..2a5b3b541c 100644 --- a/tests/unit_tests/test_eval_manager.py +++ b/tests/unit_tests/test_eval_manager.py @@ -156,10 +156,12 @@ def test_clean_macro_loses_to_explicit_override(): "foo": { "type": "multi_scenario", "env.obs_dropout_lane": 0.5, # explicit > macro default of 0.0 + "env.obs_lane_stride": 3, } } cfg = _build_section_config("foo", sections["foo"], sections) assert cfg["env"]["obs_dropout_lane"] == 0.5 + assert cfg["env"]["obs_lane_stride"] == 3 def test_manager_from_config_skips_template_sections(): @@ -623,6 +625,8 @@ def test_eval_args_compose_train_section_and_clean_macro(): train_config = { "env": { "obs_dropout_lane": 0.5, # training perturbation + "obs_lane_stride": 4, + "obs_boundary_stride": 5, "scenario_length": 91, "num_agents": 1024, # only present in train baseline }, @@ -643,6 +647,8 @@ def test_eval_args_compose_train_section_and_clean_macro(): assert args["env"]["scenario_length"] == 201, "section override wins" assert args["env"]["obs_dropout_lane"] == 0.0, "clean macro applied" + assert args["env"]["obs_lane_stride"] == 1, "clean macro applied" + assert args["env"]["obs_boundary_stride"] == 1, "clean macro applied" assert args["env"]["num_agents"] == 1024, "train baseline preserved" diff --git a/tests/unit_tests/test_map_cache.py b/tests/unit_tests/test_map_cache.py index 1b3c15790e..26b8cd0e50 100644 --- a/tests/unit_tests/test_map_cache.py +++ b/tests/unit_tests/test_map_cache.py @@ -1,8 +1,8 @@ """Tests for the per-process map cache (use_map_cache env knob). Covers: cache-on vs cache-off observation parity, refcount discipline across -close orderings, slot-reuse keeping cache size bounded, and per-entry owner_pid -correctness in a forked child. +close orderings, stride-specific cache keys, slot-reuse keeping cache size +bounded, and per-entry owner_pid correctness in a forked child. """ import os @@ -77,6 +77,28 @@ def test_obs_reward_done_parity_cache_on_vs_off(): np.testing.assert_array_equal(off_trunc[t], on_trunc[t], err_msg=f"truncations diverged at step {t}") +def test_map_cache_key_includes_road_obs_stride(): + env_a = None + env_b = None + try: + env_a = _make_drive(use_map_cache=1, obs_lane_stride=1, obs_boundary_stride=1) + env_a.reset(seed=0) + live_after_a = drive_binding.map_cache_live_count() + + env_b = _make_drive(use_map_cache=1, obs_lane_stride=2, obs_boundary_stride=3) + env_b.reset(seed=0) + live_after_b = drive_binding.map_cache_live_count() + + assert live_after_b == live_after_a + 1, ( + "Map cache reused a grid built with different obs_lane_stride/obs_boundary_stride." + ) + finally: + if env_b is not None: + env_b.close() + if env_a is not None: + env_a.close() + + @pytest.mark.parametrize( "close_order", [ From ed75e669a160db4cdd03aca9ea79031a4270a1d5 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Thu, 11 Jun 2026 10:14:28 +0200 Subject: [PATCH 3/4] Add heading deviation threshold and update observation stride logic --- pufferlib/ocean/drive/drive.h | 36 +++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 0cbcdc533f..840417d1b8 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -118,6 +118,8 @@ // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) // => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 #define MAX_ENTITIES_PER_CELL 30 +// Heading deviation since last kept point that forces a keep when obs stride > 1 (~30 degrees). +#define OBS_STRIDE_HEADING_THRESHOLD 0.5236f // TARGET_TYPE modes (controls what target info is in observations). #define TARGET_STATIC 0 @@ -693,6 +695,7 @@ static void add_entity_to_grid( int grid_index, int entity_idx, int geometry_idx, + int valid_for_obs, int *cell_entities_insert_index) { if (grid_index == -1) { return; @@ -711,12 +714,7 @@ static void add_entity_to_grid( env->grid_map->cells[grid_index][count].entity_idx = entity_idx; env->grid_map->cells[grid_index][count].geometry_idx = geometry_idx; - env->grid_map->cells[grid_index][count].valid_for_obs = 1; - if (is_road_lane(env->road_elements[entity_idx].type)) { - env->grid_map->cells[grid_index][count].valid_for_obs = geometry_idx % env->obs_lane_stride == 0; - } else if (is_road_edge(env->road_elements[entity_idx].type)) { - env->grid_map->cells[grid_index][count].valid_for_obs = geometry_idx % env->obs_boundary_stride == 0; - } + env->grid_map->cells[grid_index][count].valid_for_obs = valid_for_obs; cell_entities_insert_index[grid_index] = count + 1; } @@ -813,14 +811,32 @@ static void init_grid_map(Drive *env) { // Populate grid cells and count unique drivable grid cells for (int i = 0; i < env->num_road_elements; i++) { - for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) { - float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2; - float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2; + RoadMapElement *element = &env->road_elements[i]; + int obs_stride = 1; + if (is_road_lane(element->type)) { + obs_stride = env->obs_lane_stride; + } else if (is_road_edge(element->type)) { + obs_stride = env->obs_boundary_stride; + } + int last_kept_idx = 0; + for (int j = 0; j < element->segment_length - 1; j++) { + // Keep a point every obs_stride points, plus wherever heading deviates enough + // since the last kept point (densifies curves/intersections) + int valid_for_obs = 1; + if (obs_stride > 1 && j > 0) { + float heading_dev = fabsf(normalize_heading(element->headings[j] - element->headings[last_kept_idx])); + valid_for_obs = j - last_kept_idx >= obs_stride || heading_dev > OBS_STRIDE_HEADING_THRESHOLD; + } + if (valid_for_obs) { + last_kept_idx = j; + } + float x_center = (element->x[j] + element->x[j + 1]) / 2; + float y_center = (element->y[j] + element->y[j + 1]) / 2; int grid_index = get_grid_index(env, x_center, y_center); if (grid_index == -1) { continue; // Skip out-of-bounds entities } - add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); + add_entity_to_grid(env, grid_index, i, j, valid_for_obs, cell_entities_insert_index); // Count unique drivable grid cells if (is_drivable_road_lane(env->road_elements[i].type) && !drivable_grid_seen[grid_index]) { drivable_grid_seen[grid_index] = true; From c067307a1b6260dc23f3f4c2c992c6316abd7b93 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Thu, 11 Jun 2026 18:33:58 +0200 Subject: [PATCH 4/4] Remove obs_lane_stride and obs_boundary_stride from train section in eval_args test --- tests/unit_tests/test_eval_manager.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit_tests/test_eval_manager.py b/tests/unit_tests/test_eval_manager.py index 2a5b3b541c..07628fb8bf 100644 --- a/tests/unit_tests/test_eval_manager.py +++ b/tests/unit_tests/test_eval_manager.py @@ -625,8 +625,6 @@ def test_eval_args_compose_train_section_and_clean_macro(): train_config = { "env": { "obs_dropout_lane": 0.5, # training perturbation - "obs_lane_stride": 4, - "obs_boundary_stride": 5, "scenario_length": 91, "num_agents": 1024, # only present in train baseline }, @@ -647,8 +645,6 @@ def test_eval_args_compose_train_section_and_clean_macro(): assert args["env"]["scenario_length"] == 201, "section override wins" assert args["env"]["obs_dropout_lane"] == 0.0, "clean macro applied" - assert args["env"]["obs_lane_stride"] == 1, "clean macro applied" - assert args["env"]["obs_boundary_stride"] == 1, "clean macro applied" assert args["env"]["num_agents"] == 1024, "train baseline preserved"