Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions notebooks/01_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
print(f"NaN: {np.isnan(obs).sum()}, Inf: {np.isinf(obs).sum()}")
print(f"% zeros: {(obs == 0).mean() * 100:.1f}%")
print(f"% outside [-1,1]: {((obs < -1) | (obs > 1)).mean() * 100:.2f}%")
print(
f"road obs: lanes={env.obs_slots_lane_kept}/{env.obs_slots_lane_n} stride={env.obs_lane_stride}, "
f"boundaries={env.obs_slots_boundary_kept}/{env.obs_slots_boundary_n} stride={env.obs_boundary_stride}"
)

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].hist(obs.flatten(), bins=100, edgecolor="black", alpha=0.7)
Expand All @@ -61,9 +65,11 @@
reward_conditioning=env.reward_conditioning,
num_target_waypoints=env.num_target_waypoints,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
)
print(f"ego: {ego.shape} = {ego}")
print(f"target: {target.shape}")
Expand Down Expand Up @@ -274,9 +280,13 @@
reward_conditioning=env.reward_conditioning,
num_target_waypoints=env.num_target_waypoints,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_lane_stride=env.obs_lane_stride,
obs_boundary_stride=env.obs_boundary_stride,
)
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(img)
Expand Down
3 changes: 2 additions & 1 deletion notebooks/02_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}"
)
print(
f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}"
f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, "
f"road_features={env.road_features}, stride={env.obs_lane_stride}/{env.obs_boundary_stride}"
)
print(
f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}"
Expand Down
3 changes: 2 additions & 1 deletion notebooks/03_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}"
)
print(
f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}"
f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, "
f"road_features={env.road_features}, stride={env.obs_lane_stride}/{env.obs_boundary_stride}"
)
print(
f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}"
Expand Down
38 changes: 26 additions & 12 deletions notebooks/05_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON):
reward_conditioning=rew_cond,
num_target_waypoints=n_tgt_wp,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_lane_stride=env.obs_lane_stride,
obs_boundary_stride=env.obs_boundary_stride,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
obs_norm_goal_offset_m=env.obs_norm_goal_offset_m,
obs_norm_xy_offset_m=env.obs_norm_xy_offset_m,
Expand Down Expand Up @@ -227,8 +231,10 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON):
reward_conditioning=rew_cond,
num_target_waypoints=n_tgt_wp,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
)
ego_features_over_time.append(ego)
Expand Down Expand Up @@ -275,8 +281,10 @@ def run_rollout(env, policy, deterministic=False, horizon=HORIZON):
reward_conditioning=rew_cond,
num_target_waypoints=n_tgt_wp,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
)

Expand Down Expand Up @@ -464,8 +472,10 @@ def unpack_all_timesteps(bufs, agent_idx):
reward_conditioning=rew_cond,
num_target_waypoints=n_tgt_wp,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
)
egos.append(ego)
Expand Down Expand Up @@ -540,8 +550,10 @@ def unpack_all_timesteps(bufs, agent_idx):
reward_conditioning=rew_cond,
num_target_waypoints=n_tgt_wp,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
)
dists = np.sqrt(part[:, 0] ** 2 + part[:, 1] ** 2)
Expand All @@ -566,8 +578,10 @@ def unpack_all_timesteps(bufs, agent_idx):
reward_conditioning=rew_cond,
num_target_waypoints=n_tgt_wp,
obs_slots_partners_n=env.obs_slots_partners_n,
obs_slots_lane_n=env.obs_slots_lane_kept,
obs_slots_boundary_n=env.obs_slots_boundary_kept,
obs_slots_lane_n=env.obs_slots_lane_n,
obs_slots_boundary_n=env.obs_slots_boundary_n,
obs_dropout_lane=env.obs_dropout_lane,
obs_dropout_boundary=env.obs_dropout_boundary,
obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,
)

Expand Down
2 changes: 2 additions & 0 deletions notebooks/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"offroad_behavior": 1,
"obs_slots_lane_n": 80,
"obs_slots_boundary_n": 80,
"obs_lane_stride": 1,
"obs_boundary_stride": 1,
"obs_slots_partners_n": 16,
"obs_slots_traffic_controls_n": 4,
"obs_dropout_lane": 0.0,
Expand Down
3 changes: 3 additions & 0 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ obs_slots_traffic_controls_n = 4
; Fraction of segment observation slots to drop (reduces obs size)
obs_dropout_lane = 0.0
obs_dropout_boundary = 0.0
; Stride for lane and boundary segments; 1 means use every segment, 2 means use every other segment, etc.
obs_lane_stride = 1
obs_boundary_stride = 1
; --- Observation normalization ---
obs_norm_goal_offset_m = 120.0
obs_norm_xy_offset_m = 120.0
Expand Down
2 changes: 2 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,8 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->obs_slots_partners_n = (int) unpack(kwargs, "obs_slots_partners_n");
env->obs_slots_traffic_controls_n = (int) unpack(kwargs, "obs_slots_traffic_controls_n");
env->traffic_control_scope = (int) unpack(kwargs, "traffic_control_scope");
env->obs_lane_stride = (int) unpack(kwargs, "obs_lane_stride");
env->obs_boundary_stride = (int) unpack(kwargs, "obs_boundary_stride");
env->dt = (float) unpack(kwargs, "dt");
env->spawn_initial_speed = (float) unpack(kwargs, "spawn_initial_speed");
env->goal_speed = (float) unpack(kwargs, "goal_speed");
Expand Down
4 changes: 4 additions & 0 deletions pufferlib/ocean/drive/drive.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ void demo() {
.compute_eval_metrics = conf.compute_eval_metrics,
.obs_slots_lane_n = conf.obs_slots_lane_n,
.obs_slots_boundary_n = conf.obs_slots_boundary_n,
.obs_lane_stride = conf.obs_lane_stride,
.obs_boundary_stride = conf.obs_boundary_stride,
.obs_slots_partners_n = conf.obs_slots_partners_n,
.obs_slots_traffic_controls_n = conf.obs_slots_traffic_controls_n,
.traffic_control_scope = conf.traffic_control_scope,
Expand Down Expand Up @@ -246,6 +248,8 @@ void performance_test() {
.num_max_agents = conf.max_agents_per_env,
.obs_slots_lane_n = conf.obs_slots_lane_n,
.obs_slots_boundary_n = conf.obs_slots_boundary_n,
.obs_lane_stride = conf.obs_lane_stride,
.obs_boundary_stride = conf.obs_boundary_stride,
.obs_slots_partners_n = conf.obs_slots_partners_n,
.obs_slots_traffic_controls_n = conf.obs_slots_traffic_controls_n,
.traffic_control_scope = conf.traffic_control_scope,
Expand Down
53 changes: 44 additions & 9 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
// Depends on resolution of data Formula: 3 * (2 + GRID_CELL_SIZE*sqrt(2)/resolution)
// => For each entity type in gridmap, diagonal poly-lines -> sqrt(2), include diagonal ends -> 2
#define MAX_ENTITIES_PER_CELL 30
// Heading deviation since last kept point that forces a keep when obs stride > 1 (~30 degrees).
#define OBS_STRIDE_HEADING_THRESHOLD 0.5236f

// TARGET_TYPE modes (controls what target info is in observations).
#define TARGET_STATIC 0
Expand Down Expand Up @@ -290,8 +292,9 @@ struct Log {
};

struct GridMapEntity {
int entity_idx; // Index into the road_elements array
int geometry_idx; // Index into element's geometry array
int entity_idx; // Index into the road_elements array
int geometry_idx; // Index into element's geometry array
int valid_for_obs; // Whether this entity should be included in observations
};

struct GridMap {
Expand Down Expand Up @@ -328,6 +331,8 @@ struct SharedMapData {
GridMap *grid_map;
int *neighbor_offsets;
struct LaneGraph lane_graph;
int obs_lane_stride;
int obs_boundary_stride;
int ref_count;
pid_t owner_pid;
};
Expand Down Expand Up @@ -440,6 +445,8 @@ struct Drive {
int obs_slots_partners_n;
int obs_slots_traffic_controls_n;
int traffic_control_scope;
int obs_lane_stride;
int obs_boundary_stride;
int obs_slots_lane_kept;
int obs_slots_boundary_kept;
int road_dropout_enabled;
Expand Down Expand Up @@ -688,6 +695,7 @@ static void add_entity_to_grid(
int grid_index,
int entity_idx,
int geometry_idx,
int valid_for_obs,
int *cell_entities_insert_index) {
if (grid_index == -1) {
return;
Expand All @@ -706,6 +714,7 @@ static void add_entity_to_grid(

env->grid_map->cells[grid_index][count].entity_idx = entity_idx;
env->grid_map->cells[grid_index][count].geometry_idx = geometry_idx;
env->grid_map->cells[grid_index][count].valid_for_obs = valid_for_obs;
cell_entities_insert_index[grid_index] = count + 1;
}

Expand Down Expand Up @@ -802,14 +811,32 @@ static void init_grid_map(Drive *env) {

// Populate grid cells and count unique drivable grid cells
for (int i = 0; i < env->num_road_elements; i++) {
for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) {
float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2;
float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2;
RoadMapElement *element = &env->road_elements[i];
int obs_stride = 1;
if (is_road_lane(element->type)) {
obs_stride = env->obs_lane_stride;
} else if (is_road_edge(element->type)) {
obs_stride = env->obs_boundary_stride;
}
int last_kept_idx = 0;
for (int j = 0; j < element->segment_length - 1; j++) {
// Keep a point every obs_stride points, plus wherever heading deviates enough
// since the last kept point (densifies curves/intersections)
int valid_for_obs = 1;
if (obs_stride > 1 && j > 0) {
float heading_dev = fabsf(normalize_heading(element->headings[j] - element->headings[last_kept_idx]));
valid_for_obs = j - last_kept_idx >= obs_stride || heading_dev > OBS_STRIDE_HEADING_THRESHOLD;
}
if (valid_for_obs) {
last_kept_idx = j;
}
float x_center = (element->x[j] + element->x[j + 1]) / 2;
float y_center = (element->y[j] + element->y[j + 1]) / 2;
int grid_index = get_grid_index(env, x_center, y_center);
if (grid_index == -1) {
continue; // Skip out-of-bounds entities
}
add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index);
add_entity_to_grid(env, grid_index, i, j, valid_for_obs, cell_entities_insert_index);
// Count unique drivable grid cells
if (is_drivable_road_lane(env->road_elements[i].type) && !drivable_grid_seen[grid_index]) {
drivable_grid_seen[grid_index] = true;
Expand Down Expand Up @@ -968,6 +995,7 @@ static int get_neighbors_entities(
for (int j = 0; j < count && entity_list_count < max_size; j++) {
entity_list[entity_list_count].entity_idx = env->grid_map->cells[neighbor_idx][j].entity_idx;
entity_list[entity_list_count].geometry_idx = env->grid_map->cells[neighbor_idx][j].geometry_idx;
entity_list[entity_list_count].valid_for_obs = env->grid_map->cells[neighbor_idx][j].valid_for_obs;
entity_list_count += 1;
}
Comment on lines 995 to 1000
}
Expand Down Expand Up @@ -3665,9 +3693,11 @@ void remove_bad_trajectories(Drive *env) {
env->timestep = 0;
}

static struct SharedMapData *map_cache_lookup(const char *map_name) {
static struct SharedMapData *map_cache_lookup(Drive *env) {
for (int i = 0; i < g_map_cache_count; i++) {
if (g_map_cache[i] != NULL && strcmp(g_map_cache[i]->map_name, map_name) == 0) {
if (g_map_cache[i] != NULL && strcmp(g_map_cache[i]->map_name, env->map_name) == 0
&& g_map_cache[i]->obs_lane_stride == env->obs_lane_stride
&& g_map_cache[i]->obs_boundary_stride == env->obs_boundary_stride) {
return g_map_cache[i];
}
}
Expand Down Expand Up @@ -3727,7 +3757,7 @@ void init(Drive *env) {
env->timestep = 0;
env->shared_map = NULL;

struct SharedMapData *shared = env->use_map_cache ? map_cache_lookup(env->map_name) : NULL;
struct SharedMapData *shared = env->use_map_cache ? map_cache_lookup(env) : NULL;
if (shared != NULL) {
// Cache hit: load only the per-env data (agents, traffic-control elements),
// then discard the freshly-loaded geometry and borrow the shared copy.
Expand Down Expand Up @@ -3764,6 +3794,8 @@ void init(Drive *env) {
entry->grid_map = env->grid_map;
entry->neighbor_offsets = env->neighbor_offsets;
entry->lane_graph = env->lane_graph;
entry->obs_lane_stride = env->obs_lane_stride;
entry->obs_boundary_stride = env->obs_boundary_stride;
entry->ref_count = 1;
entry->owner_pid = getpid();
map_cache_insert(entry);
Expand Down Expand Up @@ -4746,6 +4778,9 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int *
if (lanes_found >= env->obs_slots_lane_n && boundaries_found >= env->obs_slots_boundary_n) {
break;
}
if (!neighbor_entities[k].valid_for_obs) {
continue;
}
int entity_idx = neighbor_entities[k].entity_idx;
int geometry_idx = neighbor_entities[k].geometry_idx;
RoadMapElement *road_element = &env->road_elements[entity_idx];
Expand Down
12 changes: 12 additions & 0 deletions pufferlib/ocean/drive/drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
shared_network=True,
obs_slots_lane_n=32,
obs_slots_boundary_n=32,
obs_lane_stride=1,
obs_boundary_stride=1,
obs_slots_partners_n=16,
obs_slots_traffic_controls_n=4,
traffic_control_scope=0,
Expand Down Expand Up @@ -181,8 +183,16 @@ def __init__(
self.ego_features = binding.EGO_FEATURES

# Extract observation shapes from constants
obs_lane_stride = int(obs_lane_stride)
obs_boundary_stride = int(obs_boundary_stride)
if obs_lane_stride < 1:
raise ValueError(f"obs_lane_stride must be >= 1. Got: {obs_lane_stride}")
if obs_boundary_stride < 1:
raise ValueError(f"obs_boundary_stride must be >= 1. Got: {obs_boundary_stride}")
self.obs_slots_lane_n = obs_slots_lane_n
self.obs_slots_boundary_n = obs_slots_boundary_n
self.obs_lane_stride = obs_lane_stride
self.obs_boundary_stride = obs_boundary_stride
self.obs_slots_partners_n = obs_slots_partners_n
self.traffic_control_scope = traffic_control_scope
self.obs_slots_traffic_controls_n = obs_slots_traffic_controls_n
Expand Down Expand Up @@ -437,6 +447,8 @@ def _env_init_kwargs(self, map_file, max_agents):
"goal_on_lane": self.goal_on_lane,
"obs_slots_lane_n": self.obs_slots_lane_n,
"obs_slots_boundary_n": self.obs_slots_boundary_n,
"obs_lane_stride": self.obs_lane_stride,
"obs_boundary_stride": self.obs_boundary_stride,
"obs_slots_partners_n": self.obs_slots_partners_n,
"obs_slots_traffic_controls_n": self.obs_slots_traffic_controls_n,
"traffic_control_scope": self.traffic_control_scope,
Expand Down
2 changes: 2 additions & 0 deletions pufferlib/ocean/drive/visualize.c
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ int eval_gif(
.collision_behavior = conf.collision_behavior,
.offroad_behavior = conf.offroad_behavior,
.compute_eval_metrics = conf.compute_eval_metrics,
.obs_lane_stride = conf.obs_lane_stride,
.obs_boundary_stride = conf.obs_boundary_stride,
.goal_behavior = goal_behavior,
.init_mode = init_mode,
.control_mode = control_mode,
Expand Down
6 changes: 6 additions & 0 deletions pufferlib/ocean/env_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ typedef struct {
int max_agents_per_env;
int obs_slots_lane_n;
int obs_slots_boundary_n;
int obs_lane_stride;
int obs_boundary_stride;
float obs_dropout_lane;
float obs_dropout_boundary;
int obs_slots_partners_n;
Expand Down Expand Up @@ -195,6 +197,10 @@ static int handler(void *config, const char *section, const char *name, const ch
env_config->obs_slots_boundary_n = atoi(value);
} else if (MATCH("env", "obs_slots_lane_n")) {
env_config->obs_slots_lane_n = atoi(value);
} else if (MATCH("env", "obs_lane_stride")) {
env_config->obs_lane_stride = atoi(value);
} else if (MATCH("env", "obs_boundary_stride")) {
env_config->obs_boundary_stride = atoi(value);
Comment on lines +200 to +203
} else if (MATCH("env", "obs_dropout_lane")) {
env_config->obs_dropout_lane = atof(value);
} else if (MATCH("env", "obs_dropout_boundary")) {
Expand Down
Loading
Loading