From 7c3d6439e2f8f4741c504a253b3919fe8ed4a31e Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Mon, 15 Jun 2026 11:10:02 +0200 Subject: [PATCH 1/2] Compact code grid code Refactor code to enhance readability and optimize performance by reducing redundancy and improving structure. --- pufferlib/ocean/drive/drive.h | 196 ++++++++++++---------------------- 1 file changed, 68 insertions(+), 128 deletions(-) diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index d5bf0fb3e..48c70c068 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -288,16 +288,13 @@ struct GridMap { float bottom_right_y; int grid_cols; int grid_rows; - int cell_size_x; - int cell_size_y; - int *cell_entities_count; // number of entities in each cell of the GridMap - GridMapEntity **cells; // list of gridEntities in each cell of the GridMap - // Extras/Optimizations int vision_range; - int *neighbor_cache_count; // number of entities in each cells neighbor cache - GridMapEntity **neighbor_cache_entities; // preallocated array to hold neighbor entities + int *cell_entities_count; + int *neighbor_cache_count; int *grid_index_drivable; int num_drivable_grid_cell; + GridMapEntity **cells; + GridMapEntity **neighbor_cache_entities; }; // Static, read-only map geometry shared across envs loading the same map file @@ -575,7 +572,15 @@ static void invalidate_agent(Agent *agent) { agent->sim_valid = 0; } -static void update_agent_speed(Agent *agent) { +static inline void apply_infraction_behavior(Agent *agent, int behavior) { + if (behavior == STOP_AGENT && !agent->stopped) { + agent->stopped = 1; + } else if (behavior == REMOVE_AGENT && !agent->removed) { + agent->removed = 1; + } +} + +static inline void update_agent_speed(Agent *agent) { float speed = sqrtf(agent->sim_vx * agent->sim_vx + agent->sim_vy * agent->sim_vy); float v_dot_heading = agent->sim_vx * agent->cos_heading + agent->sim_vy * agent->sin_heading; agent->sim_speed = speed; @@ -630,18 +635,16 @@ static inline void project_point_to_ego_frame( static int get_grid_index(Drive *env, float x1, float y1) { if (env->grid_map->top_left_x >= env->grid_map->bottom_right_x || env->grid_map->bottom_right_y >= env->grid_map->top_left_y) { - return -1; // Invalid grid coordinates + return -1; } - - float relativeX = x1 - env->grid_map->top_left_x; // Distance from left - float relativeY = y1 - env->grid_map->bottom_right_y; // Distance from bottom - int gridX = (int) (relativeX / GRID_CELL_SIZE); // Column index - int gridY = (int) (relativeY / GRID_CELL_SIZE); // Row index - if (gridX < 0 || gridX >= env->grid_map->grid_cols || gridY < 0 || gridY >= env->grid_map->grid_rows) { - return -1; // Return -1 for out of bounds + float rel_x = x1 - env->grid_map->top_left_x; + float rel_y = y1 - env->grid_map->bottom_right_y; + int grid_x = (int) (rel_x / GRID_CELL_SIZE); + int grid_y = (int) (rel_y / GRID_CELL_SIZE); + if (grid_x < 0 || grid_x >= env->grid_map->grid_cols || grid_y < 0 || grid_y >= env->grid_map->grid_rows) { + return -1; } - int index = (gridY * env->grid_map->grid_cols) + gridX; - return index; + return (grid_y * env->grid_map->grid_cols) + grid_x; } static void add_entity_to_grid( @@ -671,27 +674,18 @@ static void add_entity_to_grid( } static void init_grid_map(Drive *env) { - // Allocate memory for the grid map structure env->grid_map = (GridMap *) malloc(sizeof(GridMap)); env->grid_map->num_drivable_grid_cell = 0; - // Find top left and bottom right points of the map - float top_left_x = 0.0f; - float top_left_y = 0.0f; - float bottom_right_x = 0.0f; - float bottom_right_y = 0.0f; + float top_left_x = 0.0f, top_left_y = 0.0f, bottom_right_x = 0.0f, bottom_right_y = 0.0f; bool first_valid_point = false; for (int i = 0; i < env->num_road_elements; i++) { - // Check all points in the geometry for road elements (ROAD_LANE, ROAD_LINE, ROAD_EDGE) - if (!is_road(env->road_elements[i].type)) { + if (!is_road_grid_candidate(env->road_elements[i].type)) { continue; } RoadMapElement *element = &env->road_elements[i]; for (int j = 0; j < element->segment_length; j++) { - if (element->x[j] == INVALID_POSITION) { - continue; - } - if (element->y[j] == INVALID_POSITION) { + if (element->x[j] == INVALID_POSITION || element->y[j] == INVALID_POSITION) { continue; } if (!first_valid_point) { @@ -714,15 +708,11 @@ static void init_grid_map(Drive *env) { } } } - env->grid_map->top_left_x = top_left_x; env->grid_map->top_left_y = top_left_y; env->grid_map->bottom_right_x = bottom_right_x; env->grid_map->bottom_right_y = bottom_right_y; - env->grid_map->cell_size_x = GRID_CELL_SIZE; - env->grid_map->cell_size_y = GRID_CELL_SIZE; - // Calculate grid dimensions float grid_width = bottom_right_x - top_left_x; float grid_height = top_left_y - bottom_right_y; env->grid_map->grid_cols = ceil(grid_width / GRID_CELL_SIZE); @@ -730,56 +720,44 @@ static void init_grid_map(Drive *env) { int grid_cell_count = env->grid_map->grid_cols * env->grid_map->grid_rows; env->grid_map->cells = (GridMapEntity **) calloc(grid_cell_count, sizeof(GridMapEntity *)); env->grid_map->cell_entities_count = (int *) calloc(grid_cell_count, sizeof(int)); - - // Calculate number of entities in each grid cell + // First pass to count entities in each grid cell for (int i = 0; i < env->num_road_elements; i++) { - for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) { - float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2; - float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2; + RoadMapElement *element = &env->road_elements[i]; + for (int j = 0; j < element->segment_length - 1; j++) { + float x_center = (element->x[j] + element->x[j + 1]) / 2; + float y_center = (element->y[j] + element->y[j + 1]) / 2; int grid_index = get_grid_index(env, x_center, y_center); if (grid_index == -1) { - continue; // Skip out-of-bounds entities + continue; } env->grid_map->cell_entities_count[grid_index]++; } } - + // Allocate grid cells based on counts int *cell_entities_insert_index = (int *) calloc(grid_cell_count, sizeof(int)); - - // Initialize grid cells for (int grid_index = 0; grid_index < grid_cell_count; grid_index++) { - env->grid_map->cells[grid_index] - = (GridMapEntity *) calloc(env->grid_map->cell_entities_count[grid_index], sizeof(GridMapEntity)); + int count = env->grid_map->cell_entities_count[grid_index]; + env->grid_map->cells[grid_index] = (GridMapEntity *) calloc(count, sizeof(GridMapEntity)); } - for (int i = 0; i < grid_cell_count; i++) { - if (cell_entities_insert_index[i] != 0) { - printf("Error: cell_entities_insert_index[%d] not zero during initialization.\n", i); - cell_entities_insert_index[i] = 0; - } - } - - // Track which grid cells contain drivable lanes (for spawning) + // Track which grid cells have drivable lanes bool *drivable_grid_seen = (bool *) calloc(grid_cell_count, sizeof(bool)); - - // Populate grid cells and count unique drivable grid cells for (int i = 0; i < env->num_road_elements; i++) { - for (int j = 0; j < env->road_elements[i].segment_length - 1; j++) { - float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j + 1]) / 2; - float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j + 1]) / 2; + RoadMapElement *element = &env->road_elements[i]; + for (int j = 0; j < element->segment_length - 1; j++) { + float x_center = (element->x[j] + element->x[j + 1]) / 2; + float y_center = (element->y[j] + element->y[j + 1]) / 2; int grid_index = get_grid_index(env, x_center, y_center); if (grid_index == -1) { - continue; // Skip out-of-bounds entities + continue; } add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); - // Count unique drivable grid cells - if (is_drivable_road_lane(env->road_elements[i].type) && !drivable_grid_seen[grid_index]) { + if (is_drivable_road_lane(element->type) && !drivable_grid_seen[grid_index]) { drivable_grid_seen[grid_index] = true; env->grid_map->num_drivable_grid_cell++; } } } - - // Allocate and fill drivable grid index array + // Create a compact array of drivable grid cell indices for quick access env->grid_map->grid_index_drivable = (int *) malloc(env->grid_map->num_drivable_grid_cell * sizeof(int)); int drivable_idx = 0; for (int i = 0; i < grid_cell_count; i++) { @@ -787,50 +765,38 @@ static void init_grid_map(Drive *env) { env->grid_map->grid_index_drivable[drivable_idx++] = i; } } - free(drivable_grid_seen); free(cell_entities_insert_index); } static void init_neighbor_offsets(Drive *env) { - // Allocate memory for the offsets - env->neighbor_offsets = (int *) calloc(env->grid_map->vision_range * env->grid_map->vision_range * 2, sizeof(int)); - // neighbor offsets in a spiral pattern + int vr = env->grid_map->vision_range; + env->neighbor_offsets = (int *) calloc(vr * vr * 2, sizeof(int)); + // Spiral pattern generation int dx[] = {1, 0, -1, 0}; int dy[] = {0, 1, 0, -1}; - int x = 0; // Current x offset - int y = 0; // Current y offset - int dir = 0; // Current direction (0: right, 1: up, 2: left, 3: down) - int steps_to_take = 1; // Number of steps in current direction - int steps_taken = 0; // Steps taken in current direction - int segments_completed = 0; // Count of direction segments completed - int total = 0; // Total offsets added - int max_offsets = env->grid_map->vision_range * env->grid_map->vision_range; - // Start at center (0,0) - int curr_idx = 0; - env->neighbor_offsets[curr_idx++] = 0; // x offset - env->neighbor_offsets[curr_idx++] = 0; // y offset + int x = 0, y = 0, dir = 0, steps_taken = 0, segments_completed = 0, total = 0, curr_idx = 0; + int steps_to_take = 1; + int max_offsets = vr * vr; + env->neighbor_offsets[curr_idx++] = 0; + env->neighbor_offsets[curr_idx++] = 0; total++; // Generate spiral pattern while (total < max_offsets) { - // Move in current direction x += dx[dir]; y += dy[dir]; - // Only add if within vision range bounds - if (abs(x) <= env->grid_map->vision_range / 2 && abs(y) <= env->grid_map->vision_range / 2) { + if (abs(x) <= vr / 2 && abs(y) <= vr / 2) { env->neighbor_offsets[curr_idx++] = x; env->neighbor_offsets[curr_idx++] = y; total++; } steps_taken++; - // Check if we need to change direction if (steps_taken != steps_to_take) { continue; } - steps_taken = 0; // Reset steps taken + steps_taken = 0; dir = (dir + 1) % 4; // Change direction (clockwise: right->up->left->down) segments_completed++; - // Increase step length every two direction changes if (segments_completed % 2 == 0) { steps_to_take++; } @@ -868,7 +834,7 @@ static void cache_neighbor_offsets(Drive *env) { env->grid_map->neighbor_cache_count[cell_count] = count; for (int i = 0; i < cell_count; i++) { - int cell_x = i % env->grid_map->grid_cols; // Convert to 2D coordinates + int cell_x = i % env->grid_map->grid_cols; int cell_y = i / env->grid_map->grid_cols; int base_index = 0; for (int j = 0; j < env->grid_map->vision_range * env->grid_map->vision_range; j++) { @@ -879,18 +845,14 @@ static void cache_neighbor_offsets(Drive *env) { continue; } int grid_count = env->grid_map->cell_entities_count[grid_index]; - // Skip if no entities or source is NULL if (grid_count == 0 || env->grid_map->cells[grid_index] == NULL) { continue; } - - int src_idx = grid_index; - int dst_idx = base_index; // Copy grid_count pairs (entity_idx, geometry_idx) at once memcpy( - &env->grid_map->neighbor_cache_entities[i][dst_idx], - env->grid_map->cells[src_idx], + &env->grid_map->neighbor_cache_entities[i][base_index], + env->grid_map->cells[grid_index], grid_count * sizeof(GridMapEntity)); base_index += grid_count; } @@ -905,20 +867,19 @@ static int get_neighbors_entities( int max_size, const int (*local_offsets)[2], int offset_size) { - // Get the grid index for the given position (x, y) int index = get_grid_index(env, x, y); if (index == -1) { - return 0; // Return 0 size if position invalid + return 0; } // Calculate 2D grid coordinates - int cellsX = env->grid_map->grid_cols; - int gridX = index % cellsX; - int gridY = index / cellsX; + int cols = env->grid_map->grid_cols; + int cell_x = index % cols; + int cell_y = index / cols; int entity_list_count = 0; // Fill the provided array for (int i = 0; i < offset_size; i++) { - int nx = gridX + local_offsets[i][0]; - int ny = gridY + local_offsets[i][1]; + int nx = cell_x + local_offsets[i][0]; + int ny = cell_y + local_offsets[i][1]; // Ensure the neighbor is within grid bounds if (nx < 0 || nx >= env->grid_map->grid_cols || ny < 0 || ny >= env->grid_map->grid_rows) { continue; @@ -3992,15 +3953,10 @@ static void compute_metrics(Drive *env, int agent_idx, int log_idx) { if (agent->sim_x == INVALID_POSITION) { return; // invalid agent position } - - // Current agent is offgrid, treat as offroad if (get_grid_index(env, agent->sim_x, agent->sim_y) == -1) { + // Current agent is offgrid, treat as offroad agent->metrics_array[OFFROAD_IDX] = 1.0f; - if (env->offroad_behavior == STOP_AGENT && !agent->stopped) { - agent->stopped = 1; - } else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) { - agent->removed = 1; - } + apply_infraction_behavior(agent, env->offroad_behavior); return; } @@ -4252,43 +4208,27 @@ static void compute_metrics(Drive *env, int agent_idx, int log_idx) { // Priority 1: Handle offroad if (is_offroad) { agent->metrics_array[OFFROAD_IDX] = 1.0f; - if (env->offroad_behavior == STOP_AGENT && !agent->stopped) { // Stop - agent->stopped = 1; - } else if (env->offroad_behavior == REMOVE_AGENT && !agent->removed) { - agent->removed = 1; - } - return; // early return: no other terminal flags set when offroad + apply_infraction_behavior(agent, env->offroad_behavior); + return; } // Priority 2: Handle vehicle collision int car_collided_with_index = collision_check(env, agent_idx); - if (car_collided_with_index != -1) { agent->metrics_array[COLLISION_IDX] = 1.0f; - // Track at-fault collisions for evaluation metrics. if (env->compute_eval_metrics && is_at_fault_collision(env, agent_idx, car_collided_with_index)) { log_agent->at_fault_collision_rate = 1.0f; agent->metrics_array[AT_FAULT_COLLISION_IDX] = 1.0f; } - if (env->collision_behavior == STOP_AGENT && !agent->stopped) { // Stop - agent->stopped = 1; - } else if (env->collision_behavior == REMOVE_AGENT && !agent->removed) { - agent->removed = 1; - } - - return; // early return: red_light not checked after collision + apply_infraction_behavior(agent, env->collision_behavior); + return; } // Priority 3: Handle red light violation if (env->obs_slots_traffic_controls_n && check_red_light_violation(env, agent_idx)) { agent->metrics_array[RED_LIGHT_IDX] = 1.0f; - if (env->traffic_light_behavior == STOP_AGENT && !agent->stopped) { - agent->stopped = 1; - } else if (env->traffic_light_behavior == REMOVE_AGENT && !agent->removed) { - agent->removed = 1; - } - - return; // early return: no goal reaching when red light violation + apply_infraction_behavior(agent, env->traffic_light_behavior); + return; } float distance_to_goal From dc25793a0708a5db485316e378e17b8bb9091871 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Mon, 15 Jun 2026 12:13:33 +0200 Subject: [PATCH 2/2] Add function to check if a type is a road grid candidate --- pufferlib/ocean/drive/datatypes.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h index 9df963627..85c0a5647 100644 --- a/pufferlib/ocean/drive/datatypes.h +++ b/pufferlib/ocean/drive/datatypes.h @@ -127,6 +127,10 @@ static inline int is_road(int type) { return is_road_lane(type) || is_road_line(type) || is_road_edge(type); } +static inline int is_road_grid_candidate(int type) { + return is_road_lane(type) || is_road_edge(type); +} + static inline int is_controllable_agent(int type) { return (type == VEHICLE || type == PEDESTRIAN || type == CYCLIST); }