diff --git a/notebooks/01_observations.py b/notebooks/01_observations.py index 99608478e1..38d123bb96 100644 --- a/notebooks/01_observations.py +++ b/notebooks/01_observations.py @@ -38,6 +38,10 @@ print(f"NaN: {np.isnan(obs).sum()}, Inf: {np.isinf(obs).sum()}") print(f"% zeros: {(obs == 0).mean() * 100:.1f}%") print(f"% outside [-1,1]: {((obs < -1) | (obs > 1)).mean() * 100:.2f}%") +print( + f"road obs: lanes={env.obs_slots_lane_kept}/{env.obs_slots_lane_n} stride={env.obs_lane_stride}, " + f"boundaries={env.obs_slots_boundary_kept}/{env.obs_slots_boundary_n} stride={env.obs_boundary_stride}" +) fig, axes = plt.subplots(1, 2, figsize=(14, 4)) axes[0].hist(obs.flatten(), bins=100, edgecolor="black", alpha=0.7) @@ -61,9 +65,11 @@ reward_conditioning=env.reward_conditioning, num_target_waypoints=env.num_target_waypoints, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, ) print(f"ego: {ego.shape} = {ego}") print(f"target: {target.shape}") @@ -274,9 +280,13 @@ reward_conditioning=env.reward_conditioning, num_target_waypoints=env.num_target_waypoints, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, + obs_lane_stride=env.obs_lane_stride, + obs_boundary_stride=env.obs_boundary_stride, ) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(img) diff --git a/notebooks/02_rewards.py b/notebooks/02_rewards.py index 536f2e0850..9457f97576 100644 --- a/notebooks/02_rewards.py +++ b/notebooks/02_rewards.py @@ -30,7 +30,8 @@ f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}" ) print( - f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}" + f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, " + f"road_features={env.road_features}, stride={env.obs_lane_stride}/{env.obs_boundary_stride}" ) print( f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}" diff --git a/notebooks/03_metrics.py b/notebooks/03_metrics.py index 3c04a7dd81..c07f029aa5 100644 --- a/notebooks/03_metrics.py +++ b/notebooks/03_metrics.py @@ -32,7 +32,8 @@ f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}" ) print( - f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}" + f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, " + f"road_features={env.road_features}, stride={env.obs_lane_stride}/{env.obs_boundary_stride}" ) print( f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}" diff --git a/notebooks/05_inference.py b/notebooks/05_inference.py index 8ab2896493..c0c2e9e814 100644 --- a/notebooks/05_inference.py +++ b/notebooks/05_inference.py @@ -194,8 +194,12 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, + obs_lane_stride=env.obs_lane_stride, + obs_boundary_stride=env.obs_boundary_stride, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, obs_norm_goal_offset_m=env.obs_norm_goal_offset_m, obs_norm_xy_offset_m=env.obs_norm_xy_offset_m, @@ -227,8 +231,10 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) ego_features_over_time.append(ego) @@ -275,8 +281,10 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) @@ -464,8 +472,10 @@ def unpack_all_timesteps(bufs, agent_idx): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) egos.append(ego) @@ -540,8 +550,10 @@ def unpack_all_timesteps(bufs, agent_idx): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) dists = np.sqrt(part[:, 0] ** 2 + part[:, 1] ** 2) @@ -566,8 +578,10 @@ def unpack_all_timesteps(bufs, agent_idx): reward_conditioning=rew_cond, num_target_waypoints=n_tgt_wp, obs_slots_partners_n=env.obs_slots_partners_n, - obs_slots_lane_n=env.obs_slots_lane_kept, - obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_lane_n=env.obs_slots_lane_n, + obs_slots_boundary_n=env.obs_slots_boundary_n, + obs_dropout_lane=env.obs_dropout_lane, + obs_dropout_boundary=env.obs_dropout_boundary, obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, ) diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index 5afe94b2ee..850fc315c2 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -73,6 +73,8 @@ "offroad_behavior": 1, "obs_slots_lane_n": 80, "obs_slots_boundary_n": 80, + "obs_lane_stride": 1, + "obs_boundary_stride": 1, "obs_slots_partners_n": 16, "obs_slots_traffic_controls_n": 4, "obs_dropout_lane": 0.0, diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index cd6a7c4940..5283573d38 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -141,6 +141,9 @@ obs_slots_traffic_controls_n = 4 ; Fraction of segment observation slots to drop (reduces obs size) obs_dropout_lane = 0.0 obs_dropout_boundary = 0.0 +; Stride for lane and boundary segments; 1 means use every segment, 2 means use every other segment, etc. +obs_lane_stride = 1 +obs_boundary_stride = 1 ; --- Observation normalization --- obs_norm_goal_offset_m = 120.0 obs_norm_xy_offset_m = 120.0 diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index afc952acca..336956a10c 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -1979,6 +1979,8 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->obs_slots_partners_n = (int) unpack(kwargs, "obs_slots_partners_n"); env->obs_slots_traffic_controls_n = (int) unpack(kwargs, "obs_slots_traffic_controls_n"); env->traffic_control_scope = (int) unpack(kwargs, "traffic_control_scope"); + env->obs_lane_stride = (int) unpack(kwargs, "obs_lane_stride"); + env->obs_boundary_stride = (int) unpack(kwargs, "obs_boundary_stride"); env->dt = (float) unpack(kwargs, "dt"); env->spawn_initial_speed = (float) unpack(kwargs, "spawn_initial_speed"); env->goal_speed = (float) unpack(kwargs, "goal_speed"); diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index c25b61d789..5579740c1b 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -123,6 +123,8 @@ void demo() { .compute_eval_metrics = conf.compute_eval_metrics, .obs_slots_lane_n = conf.obs_slots_lane_n, .obs_slots_boundary_n = conf.obs_slots_boundary_n, + .obs_lane_stride = conf.obs_lane_stride, + .obs_boundary_stride = conf.obs_boundary_stride, .obs_slots_partners_n = conf.obs_slots_partners_n, .obs_slots_traffic_controls_n = conf.obs_slots_traffic_controls_n, .traffic_control_scope = conf.traffic_control_scope, @@ -246,6 +248,8 @@ void performance_test() { .num_max_agents = conf.max_agents_per_env, .obs_slots_lane_n = conf.obs_slots_lane_n, .obs_slots_boundary_n = conf.obs_slots_boundary_n, + .obs_lane_stride = conf.obs_lane_stride, + .obs_boundary_stride = conf.obs_boundary_stride, .obs_slots_partners_n = conf.obs_slots_partners_n, .obs_slots_traffic_controls_n = conf.obs_slots_traffic_controls_n, .traffic_control_scope = conf.traffic_control_scope, diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 019e74cbdd..c548240903 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -118,6 +118,8 @@ // Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution) // => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2 #define MAX_ENTITIES_PER_CELL 30 +// Heading deviation since last kept point that forces a keep when obs stride > 1 (~30 degrees). +#define OBS_STRIDE_HEADING_THRESHOLD 0.5236f // TARGET_TYPE modes (controls what target info is in observations). #define TARGET_STATIC 0 @@ -290,8 +292,9 @@ struct Log { }; struct GridMapEntity { - int entity_idx; // Index into the road_elements array - int geometry_idx; // Index into element's geometry array + int entity_idx; // Index into the road_elements array + int geometry_idx; // Index into element's geometry array + int valid_for_obs; // Whether this entity should be included in observations }; struct GridMap { @@ -328,6 +331,8 @@ struct SharedMapData { GridMap *grid_map; int *neighbor_offsets; struct LaneGraph lane_graph; + int obs_lane_stride; + int obs_boundary_stride; int ref_count; pid_t owner_pid; }; @@ -440,6 +445,8 @@ struct Drive { int obs_slots_partners_n; int obs_slots_traffic_controls_n; int traffic_control_scope; + int obs_lane_stride; + int obs_boundary_stride; int obs_slots_lane_kept; int obs_slots_boundary_kept; int road_dropout_enabled; @@ -688,6 +695,7 @@ static void add_entity_to_grid( int grid_index, int entity_idx, int geometry_idx, + int valid_for_obs, int *cell_entities_insert_index) { if (grid_index == -1) { return; @@ -706,6 +714,7 @@ static void add_entity_to_grid( env->grid_map->cells[grid_index][count].entity_idx = entity_idx; env->grid_map->cells[grid_index][count].geometry_idx = geometry_idx; + env->grid_map->cells[grid_index][count].valid_for_obs = valid_for_obs; cell_entities_insert_index[grid_index] = count + 1; } @@ -802,14 +811,32 @@ static void init_grid_map(Drive *env) { // Populate grid cells and count unique drivable grid cells for (int i = 0; i < env->num_road_elements; i++) { - for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) { - float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2; - float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2; + RoadMapElement *element = &env->road_elements[i]; + int obs_stride = 1; + if (is_road_lane(element->type)) { + obs_stride = env->obs_lane_stride; + } else if (is_road_edge(element->type)) { + obs_stride = env->obs_boundary_stride; + } + int last_kept_idx = 0; + for (int j = 0; j < element->segment_length - 1; j++) { + // Keep a point every obs_stride points, plus wherever heading deviates enough + // since the last kept point (densifies curves/intersections) + int valid_for_obs = 1; + if (obs_stride > 1 && j > 0) { + float heading_dev = fabsf(normalize_heading(element->headings[j] - element->headings[last_kept_idx])); + valid_for_obs = j - last_kept_idx >= obs_stride || heading_dev > OBS_STRIDE_HEADING_THRESHOLD; + } + if (valid_for_obs) { + last_kept_idx = j; + } + float x_center = (element->x[j] + element->x[j + 1]) / 2; + float y_center = (element->y[j] + element->y[j + 1]) / 2; int grid_index = get_grid_index(env, x_center, y_center); if (grid_index == -1) { continue; // Skip out-of-bounds entities } - add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); + add_entity_to_grid(env, grid_index, i, j, valid_for_obs, cell_entities_insert_index); // Count unique drivable grid cells if (is_drivable_road_lane(env->road_elements[i].type) && !drivable_grid_seen[grid_index]) { drivable_grid_seen[grid_index] = true; @@ -968,6 +995,7 @@ static int get_neighbors_entities( for (int j = 0; j < count && entity_list_count < max_size; j++) { entity_list[entity_list_count].entity_idx = env->grid_map->cells[neighbor_idx][j].entity_idx; entity_list[entity_list_count].geometry_idx = env->grid_map->cells[neighbor_idx][j].geometry_idx; + entity_list[entity_list_count].valid_for_obs = env->grid_map->cells[neighbor_idx][j].valid_for_obs; entity_list_count += 1; } } @@ -3665,9 +3693,11 @@ void remove_bad_trajectories(Drive *env) { env->timestep = 0; } -static struct SharedMapData *map_cache_lookup(const char *map_name) { +static struct SharedMapData *map_cache_lookup(Drive *env) { for (int i = 0; i < g_map_cache_count; i++) { - if (g_map_cache[i] != NULL && strcmp(g_map_cache[i]->map_name, map_name) == 0) { + if (g_map_cache[i] != NULL && strcmp(g_map_cache[i]->map_name, env->map_name) == 0 + && g_map_cache[i]->obs_lane_stride == env->obs_lane_stride + && g_map_cache[i]->obs_boundary_stride == env->obs_boundary_stride) { return g_map_cache[i]; } } @@ -3727,7 +3757,7 @@ void init(Drive *env) { env->timestep = 0; env->shared_map = NULL; - struct SharedMapData *shared = env->use_map_cache ? map_cache_lookup(env->map_name) : NULL; + struct SharedMapData *shared = env->use_map_cache ? map_cache_lookup(env) : NULL; if (shared != NULL) { // Cache hit: load only the per-env data (agents, traffic-control elements), // then discard the freshly-loaded geometry and borrow the shared copy. @@ -3764,6 +3794,8 @@ void init(Drive *env) { entry->grid_map = env->grid_map; entry->neighbor_offsets = env->neighbor_offsets; entry->lane_graph = env->lane_graph; + entry->obs_lane_stride = env->obs_lane_stride; + entry->obs_boundary_stride = env->obs_boundary_stride; entry->ref_count = 1; entry->owner_pid = getpid(); map_cache_insert(entry); @@ -4746,6 +4778,9 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int * if (lanes_found >= env->obs_slots_lane_n && boundaries_found >= env->obs_slots_boundary_n) { break; } + if (!neighbor_entities[k].valid_for_obs) { + continue; + } int entity_idx = neighbor_entities[k].entity_idx; int geometry_idx = neighbor_entities[k].geometry_idx; RoadMapElement *road_element = &env->road_elements[entity_idx]; diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index cc05e7004e..ecb2fd9692 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -88,6 +88,8 @@ def __init__( shared_network=True, obs_slots_lane_n=32, obs_slots_boundary_n=32, + obs_lane_stride=1, + obs_boundary_stride=1, obs_slots_partners_n=16, obs_slots_traffic_controls_n=4, traffic_control_scope=0, @@ -181,8 +183,16 @@ def __init__( self.ego_features = binding.EGO_FEATURES # Extract observation shapes from constants + obs_lane_stride = int(obs_lane_stride) + obs_boundary_stride = int(obs_boundary_stride) + if obs_lane_stride < 1: + raise ValueError(f"obs_lane_stride must be >= 1. Got: {obs_lane_stride}") + if obs_boundary_stride < 1: + raise ValueError(f"obs_boundary_stride must be >= 1. Got: {obs_boundary_stride}") self.obs_slots_lane_n = obs_slots_lane_n self.obs_slots_boundary_n = obs_slots_boundary_n + self.obs_lane_stride = obs_lane_stride + self.obs_boundary_stride = obs_boundary_stride self.obs_slots_partners_n = obs_slots_partners_n self.traffic_control_scope = traffic_control_scope self.obs_slots_traffic_controls_n = obs_slots_traffic_controls_n @@ -437,6 +447,8 @@ def _env_init_kwargs(self, map_file, max_agents): "goal_on_lane": self.goal_on_lane, "obs_slots_lane_n": self.obs_slots_lane_n, "obs_slots_boundary_n": self.obs_slots_boundary_n, + "obs_lane_stride": self.obs_lane_stride, + "obs_boundary_stride": self.obs_boundary_stride, "obs_slots_partners_n": self.obs_slots_partners_n, "obs_slots_traffic_controls_n": self.obs_slots_traffic_controls_n, "traffic_control_scope": self.traffic_control_scope, diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index 3c60f0c2f9..3c26fe33a1 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -294,6 +294,8 @@ int eval_gif( .collision_behavior = conf.collision_behavior, .offroad_behavior = conf.offroad_behavior, .compute_eval_metrics = conf.compute_eval_metrics, + .obs_lane_stride = conf.obs_lane_stride, + .obs_boundary_stride = conf.obs_boundary_stride, .goal_behavior = goal_behavior, .init_mode = init_mode, .control_mode = control_mode, diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index 89269a83cc..cb4d14a2e0 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -49,6 +49,8 @@ typedef struct { int max_agents_per_env; int obs_slots_lane_n; int obs_slots_boundary_n; + int obs_lane_stride; + int obs_boundary_stride; float obs_dropout_lane; float obs_dropout_boundary; int obs_slots_partners_n; @@ -195,6 +197,10 @@ static int handler(void *config, const char *section, const char *name, const ch env_config->obs_slots_boundary_n = atoi(value); } else if (MATCH("env", "obs_slots_lane_n")) { env_config->obs_slots_lane_n = atoi(value); + } else if (MATCH("env", "obs_lane_stride")) { + env_config->obs_lane_stride = atoi(value); + } else if (MATCH("env", "obs_boundary_stride")) { + env_config->obs_boundary_stride = atoi(value); } else if (MATCH("env", "obs_dropout_lane")) { env_config->obs_dropout_lane = atof(value); } else if (MATCH("env", "obs_dropout_boundary")) { diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 8f505a6533..eeb8364053 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -1388,6 +1388,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop "trajectory_scaling_factors", "obs_slots_boundary_n", "obs_slots_lane_n", + "obs_boundary_stride", + "obs_lane_stride", "obs_dropout_boundary", "obs_dropout_lane", "obs_slots_partners_n", @@ -1581,6 +1583,8 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None, early_stop "obs_slots_partners_n", "obs_slots_lane_n", "obs_slots_boundary_n", + "obs_lane_stride", + "obs_boundary_stride", "obs_slots_traffic_controls_n", "obs_dropout_lane", "obs_dropout_boundary", diff --git a/pufferlib/viz.py b/pufferlib/viz.py index eadea0b883..28ffa8f9f4 100644 --- a/pufferlib/viz.py +++ b/pufferlib/viz.py @@ -615,6 +615,8 @@ def plot_observation( obs_slots_traffic_controls_n=4, obs_dropout_lane=0.0, obs_dropout_boundary=0.0, + obs_lane_stride=1, + obs_boundary_stride=1, agent_idx=0, obs_norm_goal_offset_m=100.0, obs_norm_xy_offset_m=100.0, @@ -785,7 +787,7 @@ def plot_observation( ax.text( 0.12, 0.95, - f"Lanes: {count_lane}\nBoundaries: {count_boundary}", + f"Lanes: {count_lane}\nBoundaries: {count_boundary}\nStride: {obs_lane_stride}/{obs_boundary_stride}", transform=ax.transAxes, fontsize=10, verticalalignment="top", @@ -957,6 +959,12 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"): "partner_features": int(binding.PARTNER_FEATURES), "lane_count": int(lane_count), "boundary_count": int(boundary_count), + "obs_slots_lane_n": int(env_cfg["obs_slots_lane_n"]), + "obs_slots_boundary_n": int(env_cfg["obs_slots_boundary_n"]), + "obs_dropout_lane": float(env_cfg.get("obs_dropout_lane", 0.0)), + "obs_dropout_boundary": float(env_cfg.get("obs_dropout_boundary", 0.0)), + "obs_lane_stride": int(env_cfg.get("obs_lane_stride", 1)), + "obs_boundary_stride": int(env_cfg.get("obs_boundary_stride", 1)), "traffic_obs_count": int(env_cfg["obs_slots_traffic_controls_n"]), "target_features": 3 if env_cfg.get("target_type", "static") == "static" else 5, "scales": scales, @@ -1029,6 +1037,7 @@ def generate_interactive_replay(scenario, replay, filename="replay.html"):