Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ static inline int is_road(int type) {
return is_road_lane(type) || is_road_line(type) || is_road_edge(type);
}

static inline int is_road_grid_candidate(int type) {
return is_road_lane(type) || is_road_edge(type);
}

static inline int is_controllable_agent(int type) {
return (type == VEHICLE || type == PEDESTRIAN || type == CYCLIST);
}
Expand Down
196 changes: 68 additions & 128 deletions pufferlib/ocean/drive/drive.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,13 @@ struct GridMap {
float bottom_right_y;
int grid_cols;
int grid_rows;
int cell_size_x;
int cell_size_y;
int *cell_entities_count; // number of entities in each cell of the GridMap
GridMapEntity **cells; // list of gridEntities in each cell of the GridMap
// Extras/Optimizations
int vision_range;
int *neighbor_cache_count; // number of entities in each cells neighbor cache
GridMapEntity **neighbor_cache_entities; // preallocated array to hold neighbor entities
int *cell_entities_count;
int *neighbor_cache_count;
int *grid_index_drivable;
int num_drivable_grid_cell;
GridMapEntity **cells;
GridMapEntity **neighbor_cache_entities;
};

// Static, read-only map geometry shared across envs loading the same map file
Expand Down Expand Up @@ -575,7 +572,15 @@ static void invalidate_agent(Agent *agent) {
agent->sim_valid = 0;
}

static void update_agent_speed(Agent *agent) {
static inline void apply_infraction_behavior(Agent *agent, int behavior) {
if (behavior == STOP_AGENT && !agent->stopped) {
agent->stopped = 1;
} else if (behavior == REMOVE_AGENT && !agent->removed) {
agent->removed = 1;
}
}

static inline void update_agent_speed(Agent *agent) {
float speed = sqrtf(agent->sim_vx * agent->sim_vx + agent->sim_vy * agent->sim_vy);
float v_dot_heading = agent->sim_vx * agent->cos_heading + agent->sim_vy * agent->sin_heading;
agent->sim_speed = speed;
Expand Down Expand Up @@ -630,18 +635,16 @@ static inline void project_point_to_ego_frame(
static int get_grid_index(Drive *env, float x1, float y1) {
if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x
|| env->grid_map->bottom_right_y >= env->grid_map->top_left_y) {
return -1; // Invalid grid coordinates
return -1;
}

float relativeX = x1 - env->grid_map->top_left_x; // Distance from left
float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom
int gridX = (int) (relativeX / GRID_CELL_SIZE); // Column index
int gridY = (int) (relativeY / GRID_CELL_SIZE); // Row index
if (gridX < 0 || gridX >= env->grid_map->grid_cols || gridY < 0 || gridY >= env->grid_map->grid_rows) {
return -1; // Return -1 for out of bounds
float rel_x = x1 - env->grid_map->top_left_x;
float rel_y = y1 - env->grid_map->bottom_right_y;
int grid_x = (int) (rel_x / GRID_CELL_SIZE);
int grid_y = (int) (rel_y / GRID_CELL_SIZE);
if (grid_x < 0 || grid_x >= env->grid_map->grid_cols || grid_y < 0 || grid_y >= env->grid_map->grid_rows) {
return -1;
}
int index = (gridY * env->grid_map->grid_cols) + gridX;
return index;
return (grid_y * env->grid_map->grid_cols) + grid_x;
}

static void add_entity_to_grid(
Expand Down Expand Up @@ -671,27 +674,18 @@ static void add_entity_to_grid(
}

static void init_grid_map(Drive *env) {
// Allocate memory for the grid map structure
env->grid_map = (GridMap *) malloc(sizeof(GridMap));
env->grid_map->num_drivable_grid_cell = 0;

// Find top left and bottom right points of the map
float top_left_x = 0.0f;
float top_left_y = 0.0f;
float bottom_right_x = 0.0f;
float bottom_right_y = 0.0f;
float top_left_x = 0.0f, top_left_y = 0.0f, bottom_right_x = 0.0f, bottom_right_y = 0.0f;
bool first_valid_point = false;
for (int i = 0; i < env->num_road_elements; i++) {
// Check all points in the geometry for road elements (ROAD_LANE, ROAD_LINE, ROAD_EDGE)
if (!is_road(env->road_elements[i].type)) {
if (!is_road_grid_candidate(env->road_elements[i].type)) {
continue;
Comment thread
vcharraut marked this conversation as resolved.
}
RoadMapElement *element = &env->road_elements[i];
for (int j = 0; j < element->segment_length; j++) {
if (element->x[j] == INVALID_POSITION) {
continue;
}
if (element->y[j] == INVALID_POSITION) {
if (element->x[j] == INVALID_POSITION || element->y[j] == INVALID_POSITION) {
continue;
}
if (!first_valid_point) {
Expand All @@ -714,123 +708,95 @@ static void init_grid_map(Drive *env) {
}
}
}

env->grid_map->top_left_x = top_left_x;
env->grid_map->top_left_y = top_left_y;
env->grid_map->bottom_right_x = bottom_right_x;
env->grid_map->bottom_right_y = bottom_right_y;
env->grid_map->cell_size_x = GRID_CELL_SIZE;
env->grid_map->cell_size_y = GRID_CELL_SIZE;

// Calculate grid dimensions
float grid_width = bottom_right_x - top_left_x;
float grid_height = top_left_y - bottom_right_y;
env->grid_map->grid_cols = ceil(grid_width / GRID_CELL_SIZE);
env->grid_map->grid_rows = ceil(grid_height / GRID_CELL_SIZE);
int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows;
env->grid_map->cells = (GridMapEntity **) calloc(grid_cell_count, sizeof(GridMapEntity *));
env->grid_map->cell_entities_count = (int *) calloc(grid_cell_count, sizeof(int));

// Calculate number of entities in each grid cell
// First pass to count entities in each grid cell
for (int i = 0; i < env->num_road_elements; i++) {
for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) {
float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2;
float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2;
RoadMapElement *element = &env->road_elements[i];
for (int j = 0; j < element->segment_length - 1; j++) {
float x_center = (element->x[j] + element->x[j + 1]) / 2;
float y_center = (element->y[j] + element->y[j + 1]) / 2;
int grid_index = get_grid_index(env, x_center, y_center);
if (grid_index == -1) {
continue; // Skip out-of-bounds entities
continue;
}
env->grid_map->cell_entities_count[grid_index]++;
}
}

// Allocate grid cells based on counts
int *cell_entities_insert_index = (int *) calloc(grid_cell_count, sizeof(int));

// Initialize grid cells
for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) {
env->grid_map->cells[grid_index]
= (GridMapEntity *) calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity));
int count = env->grid_map->cell_entities_count[grid_index];
env->grid_map->cells[grid_index] = (GridMapEntity *) calloc(count, sizeof(GridMapEntity));
}
for (int i = 0; i < grid_cell_count; i++) {
if (cell_entities_insert_index[i] != 0) {
printf("Error: cell_entities_insert_index[%d] not zero during initialization.\n", i);
cell_entities_insert_index[i] = 0;
}
}

// Track which grid cells contain drivable lanes (for spawning)
// Track which grid cells have drivable lanes
bool *drivable_grid_seen = (bool *) calloc(grid_cell_count, sizeof(bool));

// Populate grid cells and count unique drivable grid cells
for (int i = 0; i < env->num_road_elements; i++) {
for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) {
float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2;
float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2;
RoadMapElement *element = &env->road_elements[i];
for (int j = 0; j < element->segment_length - 1; j++) {
float x_center = (element->x[j] + element->x[j + 1]) / 2;
float y_center = (element->y[j] + element->y[j + 1]) / 2;
int grid_index = get_grid_index(env, x_center, y_center);
if (grid_index == -1) {
continue; // Skip out-of-bounds entities
continue;
}
add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index);
// Count unique drivable grid cells
if (is_drivable_road_lane(env->road_elements[i].type) && !drivable_grid_seen[grid_index]) {
if (is_drivable_road_lane(element->type) && !drivable_grid_seen[grid_index]) {
drivable_grid_seen[grid_index] = true;
env->grid_map->num_drivable_grid_cell++;
}
}
}

// Allocate and fill drivable grid index array
// Create a compact array of drivable grid cell indices for quick access
env->grid_map->grid_index_drivable = (int *) malloc(env->grid_map->num_drivable_grid_cell * sizeof(int));
int drivable_idx = 0;
for (int i = 0; i < grid_cell_count; i++) {
if (drivable_grid_seen[i]) {
env->grid_map->grid_index_drivable[drivable_idx++] = i;
}
}

free(drivable_grid_seen);
free(cell_entities_insert_index);
}

static void init_neighbor_offsets(Drive *env) {
// Allocate memory for the offsets
env->neighbor_offsets = (int *) calloc(env->grid_map->vision_range * env->grid_map->vision_range * 2, sizeof(int));
// neighbor offsets in a spiral pattern
int vr = env->grid_map->vision_range;
env->neighbor_offsets = (int *) calloc(vr * vr * 2, sizeof(int));
// Spiral pattern generation
int dx[] = {1, 0, -1, 0};
int dy[] = {0, 1, 0, -1};
int x = 0; // Current x offset
int y = 0; // Current y offset
int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down)
int steps_to_take = 1; // Number of steps in current direction
int steps_taken = 0; // Steps taken in current direction
int segments_completed = 0; // Count of direction segments completed
int total = 0; // Total offsets added
int max_offsets = env->grid_map->vision_range * env->grid_map->vision_range;
// Start at center (0,0)
int curr_idx = 0;
env->neighbor_offsets[curr_idx++] = 0; // x offset
env->neighbor_offsets[curr_idx++] = 0; // y offset
int x = 0, y = 0, dir = 0, steps_taken = 0, segments_completed = 0, total = 0, curr_idx = 0;
int steps_to_take = 1;
int max_offsets = vr * vr;
env->neighbor_offsets[curr_idx++] = 0;
env->neighbor_offsets[curr_idx++] = 0;
total++;
// Generate spiral pattern
while (total < max_offsets) {
// Move in current direction
x += dx[dir];
y += dy[dir];
// Only add if within vision range bounds
if (abs(x) <= env->grid_map->vision_range / 2 && abs(y) <= env->grid_map->vision_range / 2) {
if (abs(x) <= vr / 2 && abs(y) <= vr / 2) {
env->neighbor_offsets[curr_idx++] = x;
env->neighbor_offsets[curr_idx++] = y;
total++;
}
steps_taken++;
// Check if we need to change direction
if (steps_taken != steps_to_take) {
continue;
}
steps_taken = 0; // Reset steps taken
steps_taken = 0;
dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down)
segments_completed++;
// Increase step length every two direction changes
if (segments_completed % 2 == 0) {
steps_to_take++;
}
Expand Down Expand Up @@ -868,7 +834,7 @@ static void cache_neighbor_offsets(Drive *env) {

env->grid_map->neighbor_cache_count[cell_count] = count;
for (int i = 0; i < cell_count; i++) {
int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates
int cell_x = i % env->grid_map->grid_cols;
int cell_y = i / env->grid_map->grid_cols;
int base_index = 0;
for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) {
Expand All @@ -879,18 +845,14 @@ static void cache_neighbor_offsets(Drive *env) {
continue;
}
int grid_count = env->grid_map->cell_entities_count[grid_index];

// Skip if no entities or source is NULL
if (grid_count == 0 || env->grid_map->cells[grid_index] == NULL) {
continue;
}

int src_idx = grid_index;
int dst_idx = base_index;
// Copy grid_count pairs (entity_idx, geometry_idx) at once
memcpy(
&env->grid_map->neighbor_cache_entities[i][dst_idx],
env->grid_map->cells[src_idx],
&env->grid_map->neighbor_cache_entities[i][base_index],
env->grid_map->cells[grid_index],
grid_count * sizeof(GridMapEntity));
base_index += grid_count;
}
Expand All @@ -905,20 +867,19 @@ static int get_neighbors_entities(
int max_size,
const int (*local_offsets)[2],
int offset_size) {
// Get the grid index for the given position (x, y)
int index = get_grid_index(env, x, y);
if (index == -1) {
return 0; // Return 0 size if position invalid
return 0;
}
// Calculate 2D grid coordinates
int cellsX = env->grid_map->grid_cols;
int gridX = index % cellsX;
int gridY = index / cellsX;
int cols = env->grid_map->grid_cols;
int cell_x = index % cols;
int cell_y = index / cols;
int entity_list_count = 0;
// Fill the provided array
for (int i = 0; i < offset_size; i++) {
int nx = gridX + local_offsets[i][0];
int ny = gridY + local_offsets[i][1];
int nx = cell_x + local_offsets[i][0];
int ny = cell_y + local_offsets[i][1];
// Ensure the neighbor is within grid bounds
if (nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) {
continue;
Expand Down Expand Up @@ -3992,15 +3953,10 @@ static void compute_metrics(Drive *env, int agent_idx, int log_idx) {
if (agent->sim_x == INVALID_POSITION) {
return; // invalid agent position
}

// Current agent is offgrid, treat as offroad
if (get_grid_index(env, agent->sim_x, agent->sim_y) == -1) {
// Current agent is offgrid, treat as offroad
agent->metrics_array[OFFROAD_IDX] = 1.0f;
if (env->offroad_behavior == STOP_AGENT && !agent->stopped) {
agent->stopped = 1;
} else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) {
agent->removed = 1;
}
apply_infraction_behavior(agent, env->offroad_behavior);
return;
}

Expand Down Expand Up @@ -4252,43 +4208,27 @@ static void compute_metrics(Drive *env, int agent_idx, int log_idx) {
// Priority 1: Handle offroad
if (is_offroad) {
agent->metrics_array[OFFROAD_IDX] = 1.0f;
if (env->offroad_behavior == STOP_AGENT && !agent->stopped) { // Stop
agent->stopped = 1;
} else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) {
agent->removed = 1;
}
return; // early return: no other terminal flags set when offroad
apply_infraction_behavior(agent, env->offroad_behavior);
return;
}

// Priority 2: Handle vehicle collision
int car_collided_with_index = collision_check(env, agent_idx);

if (car_collided_with_index != -1) {
agent->metrics_array[COLLISION_IDX] = 1.0f;
// Track at-fault collisions for evaluation metrics.
if (env->compute_eval_metrics && is_at_fault_collision(env, agent_idx, car_collided_with_index)) {
log_agent->at_fault_collision_rate = 1.0f;
agent->metrics_array[AT_FAULT_COLLISION_IDX] = 1.0f;
}
if (env->collision_behavior == STOP_AGENT && !agent->stopped) { // Stop
agent->stopped = 1;
} else if (env->collision_behavior == REMOVE_AGENT && !agent->removed) {
agent->removed = 1;
}

return; // early return: red_light not checked after collision
apply_infraction_behavior(agent, env->collision_behavior);
return;
}

// Priority 3: Handle red light violation
if (env->obs_slots_traffic_controls_n && check_red_light_violation(env, agent_idx)) {
agent->metrics_array[RED_LIGHT_IDX] = 1.0f;
if (env->traffic_light_behavior == STOP_AGENT && !agent->stopped) {
agent->stopped = 1;
} else if (env->traffic_light_behavior == REMOVE_AGENT && !agent->removed) {
agent->removed = 1;
}

return; // early return: no goal reaching when red light violation
apply_infraction_behavior(agent, env->traffic_light_behavior);
return;
}

float distance_to_goal
Expand Down
Loading