diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 1061f33b7..1edd05bc4 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -92,7 +92,52 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { env->init_steps = init_steps; env->goal_behavior = goal_behavior; sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); - env->entities = load_map_binary(map_file, env); + load_map_binary(map_file, env); + + // Skip map if it contains traffic lights + bool has_traffic_light = false; + for(int j=0; jnum_traffic_elements; j++) { + if(env->traffic_elements[j].type == TRAFFIC_LIGHT) { + has_traffic_light = true; + break; + } + } + if(has_traffic_light) { + maps_checked++; + + // Safeguard: if we've checked all available maps and all have traffic lights, raise an error + if(maps_checked >= num_maps) { + for(int j=0;jnum_total_agents;j++) free_agent(&env->agents[j]); + for (int j=0;jnum_road_elements;j++) free_road_element(&env->road_elements[j]); + for (int j=0;jnum_traffic_elements;j++) free_traffic_element(&env->traffic_elements[j]); + free(env->agents); + free(env->road_elements); + free(env->traffic_elements); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + Py_DECREF(agent_offsets); + Py_DECREF(map_ids); + char error_msg[256]; + sprintf(error_msg, "All %d available maps contain traffic lights which are not supported", num_maps); + PyErr_SetString(PyExc_ValueError, error_msg); + return NULL; + } + + for(int j=0;jnum_total_agents;j++) free_agent(&env->agents[j]); + for (int j=0;jnum_road_elements;j++) free_road_element(&env->road_elements[j]); + for (int j=0;jnum_traffic_elements;j++) free_traffic_element(&env->traffic_elements[j]); + free(env->agents); + free(env->road_elements); + free(env->traffic_elements); + free(env->active_agent_indices); + free(env->static_agent_indices); + free(env->expert_static_agent_indices); + free(env); + continue; + } + set_active_agents(env); // Skip map if it doesn't contain any controllable agents @@ -101,10 +146,12 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { // Safeguard: if we've checked all available maps and found no active agents, raise an error if(maps_checked >= num_maps) { - for(int j=0;jnum_entities;j++) { - free_entity(&env->entities[j]); - } - free(env->entities); + for(int j=0;jnum_total_agents;j++) free_agent(&env->agents[j]); + for (int j=0;jnum_road_elements;j++) free_road_element(&env->road_elements[j]); + for (int j=0;jnum_traffic_elements;j++) free_traffic_element(&env->traffic_elements[j]); + free(env->agents); + free(env->road_elements); + free(env->traffic_elements); free(env->active_agent_indices); free(env->static_agent_indices); free(env->expert_static_agent_indices); @@ -117,16 +164,18 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { return NULL; } - for(int j=0;jnum_entities;j++) { - free_entity(&env->entities[j]); - } - free(env->entities); + for(int j=0;jnum_total_agents;j++) free_agent(&env->agents[j]); + for (int j=0;jnum_road_elements;j++) free_road_element(&env->road_elements[j]); + for (int j=0;jnum_traffic_elements;j++) free_traffic_element(&env->traffic_elements[j]); + free(env->agents); + free(env->road_elements); + free(env->traffic_elements); free(env->active_agent_indices); free(env->static_agent_indices); free(env->expert_static_agent_indices); free(env); continue; - } + } // Store map_id PyObject* map_id_obj = PyLong_FromLong(map_id); @@ -136,10 +185,12 @@ static PyObject* my_shared(PyObject* self, PyObject* args, PyObject* kwargs) { PyList_SetItem(agent_offsets, env_count, offset); total_agent_count += env->active_agent_count; env_count++; - for(int j=0;jnum_entities;j++) { - free_entity(&env->entities[j]); - } - free(env->entities); + for(int j=0;jnum_total_agents;j++) free_agent(&env->agents[j]); + for (int j=0;jnum_road_elements;j++) free_road_element(&env->road_elements[j]); + for (int j=0;jnum_traffic_elements;j++) free_traffic_element(&env->traffic_elements[j]); + free(env->agents); + free(env->road_elements); + free(env->traffic_elements); free(env->active_agent_indices); free(env->static_agent_indices); free(env->expert_static_agent_indices); @@ -197,7 +248,7 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) { int init_steps = unpack(kwargs, "init_steps"); char map_file[100]; sprintf(map_file, "resources/drive/binaries/map_%03d.bin", map_id); - env->num_agents = max_agents; + env->num_max_agents = max_agents; env->map_name = strdup(map_file); env->init_steps = init_steps; env->timestep = init_steps; diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h new file mode 100644 index 000000000..ff9b37e5d --- /dev/null +++ b/pufferlib/ocean/drive/datatypes.h @@ -0,0 +1,212 @@ +#define UNKNOWN 0 + +// -- AGENT TYPE +#define VEHICLE 1 +#define PEDESTRIAN 2 +#define CYCLIST 3 +#define OTHER 4 + +// -- ROAD TYPE +#define LANE_FREEWAY 1 +#define LANE_SURFACE_STREET 2 +#define LANE_BIKE_LANE 3 + +#define ROAD_LINE_UNKNOWN 10 +#define ROAD_LINE_BROKEN_SINGLE_WHITE 11 +#define ROAD_LINE_SOLID_SINGLE_WHITE 12 +#define ROAD_LINE_SOLID_DOUBLE_WHITE 13 +#define ROAD_LINE_BROKEN_SINGLE_YELLOW 14 +#define ROAD_LINE_BROKEN_DOUBLE_YELLOW 15 +#define ROAD_LINE_SOLID_SINGLE_YELLOW 16 +#define ROAD_LINE_SOLID_DOUBLE_YELLOW 17 +#define ROAD_LINE_PASSING_DOUBLE_YELLOW 18 + +#define ROAD_EDGE_UNKNOWN 20 +#define ROAD_EDGE_BOUNDARY 21 +#define ROAD_EDGE_MEDIAN 22 +#define ROAD_EDGE_SIDEWALK 23 + +#define CROSSWALK 31 +#define SPEED_BUMP 32 +#define DRIVEWAY 33 + +// -- TRAFFIC CONTROL TYPE +#define TRAFFIC_LIGHT 1 +#define STOP_SIGN 2 +#define YIELD_SIGN 3 +#define SPEED_LIMIT_SIGN 4 + + +int is_road_lane(int type){ + return (type >= 0 && type <= 9); +} + +int is_drivable_road_lane(int type){ + return (type == LANE_FREEWAY || type == LANE_SURFACE_STREET); +} + +int is_road_line(int type){ + return (type >= 10 && type <= 19); +} + +int is_road_edge(int type){ + return (type >= 20 && type <= 29); +} + +int is_road(int type){ + return is_road_lane(type) || is_road_line(type) || is_road_edge(type); +} + +int is_controllable_agent(int type){ + return (type == VEHICLE || type == PEDESTRIAN || type == CYCLIST); +} + +int normalize_road_type(int type){ + if(is_road_lane(type)){ + return 0; + } else if(is_road_line(type)){ + return 1; + } else if(is_road_edge(type)){ + return 2; + } else { + return -1; + } +} + +int unnormalize_road_type(int norm_type){ + if(norm_type == 0){ + return LANE_SURFACE_STREET; + } else if(norm_type == 1){ + return ROAD_LINE_BROKEN_SINGLE_WHITE; + } else if(norm_type == 2){ + return ROAD_EDGE_BOUNDARY; + } else { + return -1; // Invalid + } +} + + + +struct Agent { + int id; + int type; + + // Log trajectory + int trajectory_length; + float* log_trajectory_x; + float* log_trajectory_y; + float* log_trajectory_z; + float* log_heading; + float* log_velocity_x; + float* log_velocity_y; + float* log_length; + float* log_width; + float* log_height; + int* log_valid; + + // Simulation state + float sim_x; + float sim_y; + float sim_z; + float sim_heading; + float sim_vx; + float sim_vy; + float sim_length; + float sim_width; + float sim_height; + int sim_valid; + + // Route information + int route_length; + int* route; + + // Metrics and status tracking + int collision_state; + float metrics_array[5]; // [collision, offroad, reached_goal, lane_aligned, avg_displacement_error] + int current_lane_idx; + int collided_before_goal; + int sampled_new_goal; + int reached_goal_this_episode; + int num_goals_reached; + int active_agent; + int mark_as_expert; + float cumulative_displacement; + int displacement_sample_count; + + // Goal positions + float goal_position_x; + float goal_position_y; + float goal_position_z; + float init_goal_x; // Initialized from goal_position + float init_goal_y; // Initialized from goal_position + + // Respawn tracking + int respawn_timestep; + int respawn_count; + + int stopped; // 0/1 -> freeze if set + int removed; //0/1 -> remove from sim if set + + // Jerk dynamics + float a_long; + float a_lat; + float jerk_long; + float jerk_lat; + float steering_angle; + float wheelbase; +}; + +struct RoadMapElement { + int id; + int type; + + int segment_length; + float* x; + float* y; + float* z; + + // Lane specific info + int num_entries; + int* entry_lanes; + int num_exits; + int* exit_lanes; + float speed_limit; +}; + +struct TrafficControlElement { + int id; + int type; + + int state_length; + int* states; + float x; + float y; + float z; + int controlled_lane; +}; + +void free_agent(struct Agent* agent){ + free(agent->log_trajectory_x); + free(agent->log_trajectory_y); + free(agent->log_trajectory_z); + free(agent->log_heading); + free(agent->log_velocity_x); + free(agent->log_velocity_y); + free(agent->log_length); + free(agent->log_width); + free(agent->log_height); + free(agent->log_valid); + free(agent->route); +} + +void free_road_element(struct RoadMapElement* element){ + free(element->x); + free(element->y); + free(element->z); + free(element->entry_lanes); + free(element->exit_lanes); +} + +void free_traffic_element(struct TrafficControlElement* element){ + free(element->states); +} diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 27084cdcf..d88b2803e 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -11,25 +11,15 @@ #include "rlgl.h" #include #include "error.h" +#include "datatypes.h" -// Entity Types -#define NONE 0 -#define VEHICLE 1 -#define PEDESTRIAN 2 -#define CYCLIST 3 -#define ROAD_LANE 4 -#define ROAD_LINE 5 -#define ROAD_EDGE 6 -#define STOP_SIGN 7 -#define CROSSWALK 8 -#define SPEED_BUMP 9 -#define DRIVEWAY 10 +// GridMapEntity types +#define ENTITY_TYPE_DYNAMIC_AGENT 1 +#define ENTITY_TYPE_ROAD_ELEMENT 2 +#define ENTITY_TYPE_TRAFFIC_CONTROL 3 #define INVALID_POSITION -10000.0f -// Trajectory Length -#define TRAJECTORY_LENGTH 91 - // Initialization modes #define INIT_ALL_VALID 0 #define INIT_ONLY_CONTROLLABLE_AGENTS 1 @@ -121,6 +111,9 @@ typedef struct Client Client; typedef struct Log Log; typedef struct Graph Graph; typedef struct AdjListNode AdjListNode; +typedef struct Agent Agent; +typedef struct RoadMapElement RoadMapElement; +typedef struct TrafficControlElement TrafficControlElement; struct Log { float episode_return; @@ -141,76 +134,6 @@ struct Log { float avg_collisions_per_agent; }; -typedef struct Entity Entity; -struct Entity { - int scenario_id; - int type; - int id; - int array_size; - float* traj_x; - float* traj_y; - float* traj_z; - float* traj_vx; - float* traj_vy; - float* traj_vz; - float* traj_heading; - int* traj_valid; - float width; - float length; - float height; - float goal_position_x; - float goal_position_y; - float goal_position_z; - float init_goal_x; - float init_goal_y; - int mark_as_expert; - int collision_state; - float metrics_array[5]; // metrics_array: [collision, offroad, reached_goal, lane_aligned, avg_displacement_error] - float x; - float y; - float z; - float vx; - float vy; - float vz; - float heading; - float heading_x; - float heading_y; - int current_lane_idx; - int valid; - int respawn_timestep; - int respawn_count; - int collided_before_goal; - int sampled_new_goal; - int reached_goal_this_episode; - int num_goals_reached; - int active_agent; - float cumulative_displacement; - int displacement_sample_count; - float goal_radius; - int stopped; - int removed; - - // Jerk dynamics - float a_long; - float a_lat; - float jerk_long; - float jerk_lat; - float steering_angle; - float wheelbase; -}; - -void free_entity(Entity* entity){ - // free trajectory arrays - free(entity->traj_x); - free(entity->traj_y); - free(entity->traj_z); - free(entity->traj_vx); - free(entity->traj_vy); - free(entity->traj_vz); - free(entity->traj_heading); - free(entity->traj_valid); -} - // Utility functions float relative_distance(float a, float b){ float distance = sqrtf(powf(a - b, 2)); @@ -230,28 +153,28 @@ float clip(float value, float min, float max) { return value; } -float compute_displacement_error(Entity* agent, int timestep) { +float compute_displacement_error(Agent* agent, int timestep) { // Check if timestep is within valid range - if (timestep < 0 || timestep >= agent->array_size) { + if (timestep < 0 || timestep >= agent->trajectory_length) { return 0.0f; } // Check if reference trajectory is valid at this timestep - if (!agent->traj_valid[timestep]) { + if (!agent->log_valid[timestep]) { return 0.0f; } - // Get reference position at current timestep, skip invalid ones - float ref_x = agent->traj_x[timestep]; - float ref_y = agent->traj_y[timestep]; + // Get reference position from logged trajectory at current timestep + float ref_x = agent->log_trajectory_x[timestep]; + float ref_y = agent->log_trajectory_y[timestep]; if (ref_x == INVALID_POSITION || ref_y == INVALID_POSITION) { return 0.0f; } - // Compute deltas: Euclidean distance between actual and reference position - float dx = agent->x - ref_x; - float dy = agent->y - ref_y; + // Compute deltas: Euclidean distance between simulated and reference position + float dx = agent->sim_x - ref_x; + float dy = agent->sim_y - ref_y; float displacement = sqrtf(dx*dx + dy*dy); return displacement; @@ -259,8 +182,9 @@ float compute_displacement_error(Entity* agent, int timestep) { typedef struct GridMapEntity GridMapEntity; struct GridMapEntity { - int entity_idx; - int geometry_idx; + int entity_type; // Entity type: 1=Agent, 2=RoadMapElement, 3=TrafficControlElement + int entity_idx; // Index into the corresponding typed array + int geometry_idx; // Index into entity's trajectory/geometry array }; typedef struct GridMap GridMap; @@ -290,17 +214,19 @@ struct Drive { unsigned char* terminals; Log log; Log* logs; - int num_agents; + int num_max_agents; int active_agent_count; int* active_agent_indices; int action_type; int human_agent_idx; - Entity* entities; + Agent* agents; + RoadMapElement* road_elements; + TrafficControlElement* traffic_elements; Graph* topology_graph; - int num_entities; + int num_total_agents; + int num_road_elements; + int num_traffic_elements; int num_actors; - int num_objects; - int num_roads; int static_agent_count; int* static_agent_indices; int expert_static_agent_count; @@ -325,21 +251,28 @@ struct Drive { int logs_capacity; int goal_behavior; char* ini_file; - char* scenario_id; int collision_behavior; int offroad_behavior; - int sdc_track_index; + int control_non_vehicles; + // Metadata fields + char scenario_id[128]; + int map_index; + char dataset_name[64]; + int log_length; + int sdc_index; + int num_objects_of_interest; + int* objects_of_interest; int num_tracks_to_predict; - int* tracks_to_predict_indices; + int* tracks_to_predict; int init_mode; int control_mode; }; void add_log(Drive* env) { for(int i = 0; i < env->active_agent_count; i++){ - Entity* e = &env->entities[env->active_agent_indices[i]]; + Agent* agent = &env->agents[env->active_agent_indices[i]]; - if(e->reached_goal_this_episode){ + if(agent->reached_goal_this_episode){ env->log.completion_rate += 1.0f; } int offroad = env->logs[i].offroad_rate; @@ -352,10 +285,10 @@ void add_log(Drive* env) { env->log.avg_collisions_per_agent += avg_collisions_per_agent; int num_goals_reached = env->logs[i].num_goals_reached; env->log.num_goals_reached += num_goals_reached; - if(e->reached_goal_this_episode && !e->collided_before_goal){ + if(agent->reached_goal_this_episode && !agent->collided_before_goal){ env->log.score += 1.0f; } - if(!offroad && !collided && !e->reached_goal_this_episode){ + if(!offroad && !collided && !agent->reached_goal_this_episode){ env->log.dnf_rate += 1.0f; } int lane_aligned = env->logs[i].lane_alignment_rate; @@ -436,85 +369,178 @@ void freeTopologyGraph(struct Graph* graph) { } -Entity* load_map_binary(const char* filename, Drive* env) { +int load_map_binary(const char* filename, Drive* drive) { FILE* file = fopen(filename, "rb"); - if (!file) return NULL; + if (!file) return -1; + + int num_total_agents, num_roads, num_traffic; + fread(&num_total_agents, sizeof(int), 1, file); + fread(&num_roads, sizeof(int), 1, file); + fread(&num_traffic, sizeof(int), 1, file); + drive->num_total_agents = num_total_agents; + drive->num_road_elements = num_roads; + drive->num_traffic_elements = num_traffic; - // Read sdc_track_index - fread(&env->sdc_track_index, sizeof(int), 1, file); + if (num_total_agents > 0) { + drive->agents = (Agent*)calloc(num_total_agents, sizeof(Agent)); + } - // Read tracks_to_predict - fread(&env->num_tracks_to_predict, sizeof(int), 1, file); - if (env->num_tracks_to_predict > 0) { - env->tracks_to_predict_indices = (int*)malloc(env->num_tracks_to_predict * sizeof(int)); + if (num_roads > 0) { + drive->road_elements = (RoadMapElement*)calloc(num_roads, sizeof(RoadMapElement)); + } - for (int i = 0; i < env->num_tracks_to_predict; i++) { - fread(&env->tracks_to_predict_indices[i], sizeof(int), 1, file); + if (num_traffic > 0) { + drive->traffic_elements = (TrafficControlElement*)calloc(num_traffic, sizeof(TrafficControlElement)); + } + + for (int i = 0; i < num_total_agents; i++) { + Agent* agent = &drive->agents[i]; + + fread(&agent->id, sizeof(int), 1, file); + fread(&agent->type, sizeof(int), 1, file); + fread(&agent->trajectory_length, sizeof(int), 1, file); + + int tlen = agent->trajectory_length; + + agent->log_trajectory_x = (float*)malloc(tlen * sizeof(float)); + agent->log_trajectory_y = (float*)malloc(tlen * sizeof(float)); + agent->log_trajectory_z = (float*)malloc(tlen * sizeof(float)); + agent->log_heading = (float*)malloc(tlen * sizeof(float)); + agent->log_velocity_x = (float*)malloc(tlen * sizeof(float)); + agent->log_velocity_y = (float*)malloc(tlen * sizeof(float)); + agent->log_length = (float*)malloc(tlen * sizeof(float)); + agent->log_width = (float*)malloc(tlen * sizeof(float)); + agent->log_height = (float*)malloc(tlen * sizeof(float)); + agent->log_valid = (int*)malloc(tlen * sizeof(int)); + + fread(agent->log_trajectory_x, sizeof(float), tlen, file); + fread(agent->log_trajectory_y, sizeof(float), tlen, file); + fread(agent->log_trajectory_z, sizeof(float), tlen, file); + fread(agent->log_heading, sizeof(float), tlen, file); + fread(agent->log_velocity_x, sizeof(float), tlen, file); + fread(agent->log_velocity_y, sizeof(float), tlen, file); + fread(agent->log_length, sizeof(float), tlen, file); + fread(agent->log_width, sizeof(float), tlen, file); + fread(agent->log_height, sizeof(float), tlen, file); + fread(agent->log_valid, sizeof(int), tlen, file); + + fread(&agent->route_length, sizeof(int), 1, file); + + if (agent->route_length > 0) { + agent->route = (int*)malloc(agent->route_length * sizeof(int)); + fread(agent->route, sizeof(int), agent->route_length, file); + } else { + agent->route = NULL; } - } else { - env->tracks_to_predict_indices = NULL; - } - - fread(&env->num_objects, sizeof(int), 1, file); - fread(&env->num_roads, sizeof(int), 1, file); - env->num_entities = env->num_objects + env->num_roads; - Entity* entities = (Entity*)malloc(env->num_entities * sizeof(Entity)); - for (int i = 0; i < env->num_entities; i++) { - // Read base entity data - fread(&entities[i].scenario_id, sizeof(int), 1, file); - fread(&entities[i].type, sizeof(int), 1, file); - fread(&entities[i].id, sizeof(int), 1, file); - fread(&entities[i].array_size, sizeof(int), 1, file); - // Allocate arrays based on type - int size = entities[i].array_size; - entities[i].traj_x = (float*)malloc(size * sizeof(float)); - entities[i].traj_y = (float*)malloc(size * sizeof(float)); - entities[i].traj_z = (float*)malloc(size * sizeof(float)); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type - // Allocate arrays for object-specific data - entities[i].traj_vx = (float*)malloc(size * sizeof(float)); - entities[i].traj_vy = (float*)malloc(size * sizeof(float)); - entities[i].traj_vz = (float*)malloc(size * sizeof(float)); - entities[i].traj_heading = (float*)malloc(size * sizeof(float)); - entities[i].traj_valid = (int*)malloc(size * sizeof(int)); + + fread(&agent->goal_position_x, sizeof(float), 1, file); + fread(&agent->goal_position_y, sizeof(float), 1, file); + fread(&agent->goal_position_z, sizeof(float), 1, file); + fread(&agent->mark_as_expert, sizeof(int), 1, file); + } + + for (int i = 0; i < num_roads; i++) { + RoadMapElement* road = &drive->road_elements[i]; + + fread(&road->id, sizeof(int), 1, file); + fread(&road->type, sizeof(int), 1, file); + fread(&road->segment_length, sizeof(int), 1, file); + + int slen = road->segment_length; + + road->x = (float*)malloc(slen * sizeof(float)); + road->y = (float*)malloc(slen * sizeof(float)); + road->z = (float*)malloc(slen * sizeof(float)); + + fread(road->x, sizeof(float), slen, file); + fread(road->y, sizeof(float), slen, file); + fread(road->z, sizeof(float), slen, file); + + if (is_road_lane(road->type)) { + fread(&road->num_entries, sizeof(int), 1, file); + if (road->num_entries > 0) { + road->entry_lanes = (int*)malloc(road->num_entries * sizeof(int)); + fread(road->entry_lanes, sizeof(int), road->num_entries, file); + } else { + road->entry_lanes = NULL; + } + + fread(&road->num_exits, sizeof(int), 1, file); + if (road->num_exits > 0) { + road->exit_lanes = (int*)malloc(road->num_exits * sizeof(int)); + fread(road->exit_lanes, sizeof(int), road->num_exits, file); + } else { + road->exit_lanes = NULL; + } + + fread(&road->speed_limit, sizeof(float), 1, file); } else { - // Roads don't use these arrays - entities[i].traj_vx = NULL; - entities[i].traj_vy = NULL; - entities[i].traj_vz = NULL; - entities[i].traj_heading = NULL; - entities[i].traj_valid = NULL; - } - // Read array data - fread(entities[i].traj_x, sizeof(float), size, file); - fread(entities[i].traj_y, sizeof(float), size, file); - fread(entities[i].traj_z, sizeof(float), size, file); - if (entities[i].type == VEHICLE || entities[i].type == PEDESTRIAN || entities[i].type == CYCLIST) { // Object type - fread(entities[i].traj_vx, sizeof(float), size, file); - fread(entities[i].traj_vy, sizeof(float), size, file); - fread(entities[i].traj_vz, sizeof(float), size, file); - fread(entities[i].traj_heading, sizeof(float), size, file); - fread(entities[i].traj_valid, sizeof(int), size, file); - } - // Read remaining scalar fields - fread(&entities[i].width, sizeof(float), 1, file); - fread(&entities[i].length, sizeof(float), 1, file); - fread(&entities[i].height, sizeof(float), 1, file); - fread(&entities[i].goal_position_x, sizeof(float), 1, file); - fread(&entities[i].goal_position_y, sizeof(float), 1, file); - fread(&entities[i].goal_position_z, sizeof(float), 1, file); - fread(&entities[i].mark_as_expert, sizeof(int), 1, file); + road->entry_lanes = NULL; + road->exit_lanes = NULL; + road->speed_limit = 0.0f; + } + } + + for (int i = 0; i < num_traffic; i++) { + TrafficControlElement* traffic = &drive->traffic_elements[i]; + + fread(&traffic->id, sizeof(int), 1, file); + fread(&traffic->type, sizeof(int), 1, file); + fread(&traffic->x, sizeof(float), 1, file); + fread(&traffic->y, sizeof(float), 1, file); + fread(&traffic->z, sizeof(float), 1, file); + fread(&traffic->state_length, sizeof(int), 1, file); + + int state_len = traffic->state_length; + + traffic->states = (int*)malloc(state_len * sizeof(int)); + fread(traffic->states, sizeof(int), state_len, file); + + // Read controlled_lanes array (but only store first one in single controlled_lane field) + int num_controlled_lanes; + fread(&num_controlled_lanes, sizeof(int), 1, file); + if (num_controlled_lanes > 0) { + fread(&traffic->controlled_lane, sizeof(int), 1, file); + // Skip remaining controlled lanes if any + for (int j = 1; j < num_controlled_lanes; j++) { + int dummy; + fread(&dummy, sizeof(int), 1, file); + } + } else { + traffic->controlled_lane = -1; + } + } + + fread(drive->scenario_id, sizeof(char), 128, file); + fread(&drive->map_index, sizeof(int), 1, file); + fread(drive->dataset_name, sizeof(char), 64, file); + fread(&drive->log_length, sizeof(int), 1, file); + fread(&drive->sdc_index, sizeof(int), 1, file); + fread(&drive->num_objects_of_interest, sizeof(int), 1, file); + + if (drive->num_objects_of_interest > 0) { + drive->objects_of_interest = (int*)malloc(drive->num_objects_of_interest * sizeof(int)); + fread(drive->objects_of_interest, sizeof(int), drive->num_objects_of_interest, file); + } else { + drive->objects_of_interest = NULL; + } + + fread(&drive->num_tracks_to_predict, sizeof(int), 1, file); + + if (drive->num_tracks_to_predict > 0) { + drive->tracks_to_predict = (int*)malloc(drive->num_tracks_to_predict * sizeof(int)); + fread(drive->tracks_to_predict, sizeof(int), drive->num_tracks_to_predict, file); + } else { + drive->tracks_to_predict = NULL; } fclose(file); - return entities; + return 0; } void set_start_position(Drive* env){ - //InitWindow(800, 600, "GPU Drive"); - //BeginDrawing(); - for(int i = 0; i < env->num_entities; i++){ + for(int i = 0; i < env->num_total_agents; i++){ int is_active = 0; for(int j = 0; j < env->active_agent_count; j++){ if(env->active_agent_indices[j] == i){ @@ -522,56 +548,60 @@ void set_start_position(Drive* env){ break; } } - Entity* e = &env->entities[i]; + Agent* agent = &env->agents[i]; // Clamp init_steps to ensure we don't go out of bounds int step = env->init_steps; - if (step >= e->array_size) step = e->array_size - 1; + if (step >= agent->trajectory_length) step = agent->trajectory_length - 1; if (step < 0) step = 0; - e->x = e->traj_x[step]; - e->y = e->traj_y[step]; - e->z = e->traj_z[step]; + // Initialize simulation trajectory from logged trajectory at init_steps + agent->sim_x = agent->log_trajectory_x[step]; + agent->sim_y = agent->log_trajectory_y[step]; + agent->sim_z = agent->log_trajectory_z[step]; + agent->sim_heading = agent->log_heading[step]; + agent->sim_valid = agent->log_valid[step]; + agent->sim_length = agent->log_length[step]; + agent->sim_width = agent->log_width[step]; + agent->sim_height = agent->log_height[step]; + + if(agent->type == UNKNOWN) continue; - if(e->type > CYCLIST || e->type == 0){ - continue; - } if(is_active == 0){ - e->vx = 0; - e->vy = 0; - e->vz = 0; - e->collided_before_goal = 0; + agent->sim_vx = 0.0f; + agent->sim_vy = 0.0f; + agent->collided_before_goal = 0; } else { - e->vx = e->traj_vx[env->init_steps]; - e->vy = e->traj_vy[env->init_steps]; - e->vz = e->traj_vz[env->init_steps]; - } - e->heading = e->traj_heading[env->init_steps]; - e->heading_x = cosf(e->heading); - e->heading_y = sinf(e->heading); - e->valid = e->traj_valid[env->init_steps]; - e->collision_state = 0; - e->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision - e->metrics_array[OFFROAD_IDX] = 0.0f; // offroad - e->metrics_array[REACHED_GOAL_IDX] = 0.0f; // reached goal - e->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned - e->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; // avg displacement error - e->cumulative_displacement = 0.0f; - e->displacement_sample_count = 0; - e->respawn_timestep = -1; - e->stopped = 0; - e->removed = 0; - e->respawn_count = 0; + agent->sim_vx = agent->log_velocity_x[step]; + agent->sim_vy = agent->log_velocity_y[step]; + } + + // Shrink width and length slightly to avoid initial collisions + agent->sim_length *= 0.7f; + agent->sim_width *= 0.7f; + + // Initialize metrics and state + agent->collision_state = 0; + agent->metrics_array[COLLISION_IDX] = 0.0f; + agent->metrics_array[OFFROAD_IDX] = 0.0f; + agent->metrics_array[REACHED_GOAL_IDX] = 0.0f; + agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; + agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; + agent->cumulative_displacement = 0.0f; + agent->displacement_sample_count = 0; + agent->respawn_timestep = -1; + agent->stopped = 0; + agent->removed = 0; + agent->respawn_count = 0; // Dynamics - e->a_long = 0.0f; - e->a_lat = 0.0f; - e->jerk_long = 0.0f; - e->jerk_lat = 0.0f; - e->steering_angle = 0.0f; - e->wheelbase = 0.6f * e->length; - } - //EndDrawing(); + agent->a_long = 0.0f; + agent->a_lat = 0.0f; + agent->jerk_long = 0.0f; + agent->jerk_lat = 0.0f; + agent->steering_angle = 0.0f; + agent->wheelbase = 0.6f * agent->sim_length; + } } int getGridIndex(Drive* env, float x1, float y1) { @@ -590,7 +620,7 @@ int getGridIndex(Drive* env, float x1, float y1) { return index; } -void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry_idx, int* cell_entities_insert_index){ +void add_entity_to_grid(Drive* env, int grid_index, int entity_type, int entity_idx, int geometry_idx, int* cell_entities_insert_index){ if(grid_index == -1){ return; } @@ -601,6 +631,7 @@ void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry return; } + env->grid_map->cells[grid_index][count].entity_type = entity_type; env->grid_map->cells[grid_index][count].entity_idx = entity_idx; env->grid_map->cells[grid_index][count].geometry_idx = geometry_idx; cell_entities_insert_index[grid_index] = count + 1; @@ -608,10 +639,10 @@ void add_entity_to_grid(Drive* env, int grid_index, int entity_idx, int geometry void init_topology_graph(Drive* env){ - // Count ROAD_LANE entities + // Count ROAD_LANE entities in road_elements int road_lane_count = 0; - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type == ROAD_LANE){ + for(int i = 0; i < env->num_road_elements; i++){ + if(is_drivable_road_lane(env->road_elements[i].type)){ road_lane_count++; } } @@ -621,35 +652,35 @@ void init_topology_graph(Drive* env){ return; } - // Create graph with all entities as vertices (we'll only use ROAD_LANE indices) - env->topology_graph = createGraph(env->num_entities); + // Create graph with road_elements as vertices + env->topology_graph = createGraph(env->num_road_elements); // Connect ROAD_LANE entities based on geometric connectivity - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type != ROAD_LANE) continue; + for(int i = 0; i < env->num_road_elements; i++){ + if(!is_drivable_road_lane(env->road_elements[i].type)) continue; - Entity* lane_i = &env->entities[i]; - if(lane_i->array_size < 2) continue; // Need at least 2 points + RoadMapElement* lane_i = &env->road_elements[i]; + if(lane_i->segment_length < 2) continue; // Need at least 2 points // Get end point of current lane - float end_x = lane_i->traj_x[lane_i->array_size - 1]; - float end_y = lane_i->traj_y[lane_i->array_size - 1]; - float end_vector_x = lane_i->traj_x[lane_i->array_size - 1] - lane_i->traj_x[lane_i->array_size - 2]; - float end_vector_y = lane_i->traj_y[lane_i->array_size - 1] - lane_i->traj_y[lane_i->array_size - 2]; + float end_x = lane_i->x[lane_i->segment_length - 1]; + float end_y = lane_i->y[lane_i->segment_length - 1]; + float end_vector_x = lane_i->x[lane_i->segment_length - 1] - lane_i->x[lane_i->segment_length - 2]; + float end_vector_y = lane_i->y[lane_i->segment_length - 1] - lane_i->y[lane_i->segment_length - 2]; float end_heading = atan2f(end_vector_y, end_vector_x); // Find lanes that start near this lane's end - for(int j = 0; j < env->num_entities; j++){ - if(i == j || env->entities[j].type != ROAD_LANE) continue; + for(int j = 0; j < env->num_road_elements; j++){ + if(i == j || !is_drivable_road_lane(env->road_elements[j].type)) continue; - Entity* lane_j = &env->entities[j]; - if(lane_j->array_size < 2) continue; + RoadMapElement* lane_j = &env->road_elements[j]; + if(lane_j->segment_length < 2) continue; // Get start point of potential next lane - float start_x = lane_j->traj_x[0]; - float start_y = lane_j->traj_y[0]; - float start_vector_x = lane_j->traj_x[1] - lane_j->traj_x[0]; - float start_vector_y = lane_j->traj_y[1] - lane_j->traj_y[0]; + float start_x = lane_j->x[0]; + float start_y = lane_j->y[0]; + float start_vector_x = lane_j->x[1] - lane_j->x[0]; + float start_vector_y = lane_j->y[1] - lane_j->y[0]; float start_heading = atan2f(start_vector_y, start_vector_x); // Check if end of lane_i is close to start of lane_j @@ -679,24 +710,23 @@ void init_grid_map(Drive* env){ float bottom_right_x; float bottom_right_y; int first_valid_point = 0; - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ - // Check all points in the trajectory for road elements - Entity* e = &env->entities[i]; - for(int j = 0; j < e->array_size; j++){ - if(e->traj_x[j] == INVALID_POSITION) continue; - if(e->traj_y[j] == INVALID_POSITION) continue; - if(!first_valid_point) { - top_left_x = bottom_right_x = e->traj_x[j]; - top_left_y = bottom_right_y = e->traj_y[j]; - first_valid_point = true; - continue; - } - if(e->traj_x[j] < top_left_x) top_left_x = e->traj_x[j]; - if(e->traj_x[j] > bottom_right_x) bottom_right_x = e->traj_x[j]; - if(e->traj_y[j] > top_left_y) top_left_y = e->traj_y[j]; - if(e->traj_y[j] < bottom_right_y) bottom_right_y = e->traj_y[j]; + for(int i = 0; i < env->num_road_elements; i++){ + // Check all points in the geometry for road elements (ROAD_LANE, ROAD_LINE, ROAD_EDGE) + if (!is_road(env->road_elements[i].type)) continue; + RoadMapElement* element = &env->road_elements[i]; + for(int j = 0; j < element->segment_length; j++){ + if(element->x[j] == INVALID_POSITION) continue; + if(element->y[j] == INVALID_POSITION) continue; + if(!first_valid_point) { + top_left_x = bottom_right_x = element->x[j]; + top_left_y = bottom_right_y = element->y[j]; + first_valid_point = true; + continue; } + if(element->x[j] < top_left_x) top_left_x = element->x[j]; + if(element->x[j] > bottom_right_x) bottom_right_x = element->x[j]; + if(element->y[j] > top_left_y) top_left_y = element->y[j]; + if(element->y[j] < bottom_right_y) bottom_right_y = element->y[j]; } } @@ -717,16 +747,16 @@ void init_grid_map(Drive* env){ env->grid_map->cell_entities_count = (int*)calloc(grid_cell_count, sizeof(int)); // Calculate number of entities in each grid cell - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ - for(int j = 0; j < env->entities[i].array_size - 1; j++){ - float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; - float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; - int grid_index = getGridIndex(env, x_center, y_center); - env->grid_map->cell_entities_count[grid_index]++; - } + for(int i = 0; i < env->num_road_elements; i++){ + for(int j = 0; j < env->road_elements[i].segment_length - 1; j++){ + float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j+1]) / 2; + float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j+1]) / 2; + int grid_index = getGridIndex(env, x_center, y_center); + if (grid_index == -1) continue; // Skip out-of-bounds entities + env->grid_map->cell_entities_count[grid_index]++; } } + int cell_entities_insert_index[grid_cell_count]; // Helper array for insertion index memset(cell_entities_insert_index, 0, grid_cell_count * sizeof(int)); @@ -742,14 +772,13 @@ void init_grid_map(Drive* env){ } // Populate grid cells - for(int i = 0; i < env->num_entities; i++){ - if(env->entities[i].type > 3 && env->entities[i].type < 7){ // NOTE: Only Road Edges, Lines, and Lanes in grid map - for(int j = 0; j < env->entities[i].array_size - 1; j++){ - float x_center = (env->entities[i].traj_x[j] + env->entities[i].traj_x[j+1]) / 2; - float y_center = (env->entities[i].traj_y[j] + env->entities[i].traj_y[j+1]) / 2; - int grid_index = getGridIndex(env, x_center, y_center); - add_entity_to_grid(env, grid_index, i, j, cell_entities_insert_index); - } + for(int i = 0; i < env->num_road_elements; i++){ + for(int j = 0; j < env->road_elements[i].segment_length - 1; j++){ + float x_center = (env->road_elements[i].x[j] + env->road_elements[i].x[j+1]) / 2; + float y_center = (env->road_elements[i].y[j] + env->road_elements[i].y[j+1]) / 2; + int grid_index = getGridIndex(env, x_center, y_center); + if (grid_index == -1) continue; // Skip out-of-bounds entities + add_entity_to_grid(env, grid_index, ENTITY_TYPE_ROAD_ELEMENT, i, j, cell_entities_insert_index); } } } @@ -874,68 +903,106 @@ void set_means(Drive* env) { float mean_y = 0.0f; int64_t point_count = 0; - // Compute single mean for all entities (vehicles and roads) - for (int i = 0; i < env->num_entities; i++) { - if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) { - for (int j = 0; j < env->entities[i].array_size; j++) { - // Assume a validity flag exists (e.g., valid[j]); adjust if not available - if (env->entities[i].traj_valid[j]) { // Add validity check if applicable - point_count++; - mean_x += (env->entities[i].traj_x[j] - mean_x) / point_count; - mean_y += (env->entities[i].traj_y[j] - mean_y) / point_count; - } - } - } else if (env->entities[i].type >= 4) { - for (int j = 0; j < env->entities[i].array_size; j++) { + // Compute mean from dynamic agents + for (int i = 0; i < env->num_total_agents; i++) { + Agent* agent = &env->agents[i]; + for (int j = 0; j < agent->trajectory_length; j++) { + if (agent->log_valid[j]) { point_count++; - mean_x += (env->entities[i].traj_x[j] - mean_x) / point_count; - mean_y += (env->entities[i].traj_y[j] - mean_y) / point_count; + mean_x += (agent->log_trajectory_x[j] - mean_x) / point_count; + mean_y += (agent->log_trajectory_y[j] - mean_y) / point_count; } } } + + // Compute mean from road elements + for (int i = 0; i < env->num_road_elements; i++) { + RoadMapElement* element = &env->road_elements[i]; + for (int j = 0; j < element->segment_length; j++) { + if(element->x[j] == INVALID_POSITION) continue; + point_count++; + mean_x += (element->x[j] - mean_x) / point_count; + mean_y += (element->y[j] - mean_y) / point_count; + } + } + env->world_mean_x = mean_x; env->world_mean_y = mean_y; - for (int i = 0; i < env->num_entities; i++) { - if (env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST || env->entities[i].type >= 4) { - for (int j = 0; j < env->entities[i].array_size; j++) { - if(env->entities[i].traj_x[j] == INVALID_POSITION) continue; - env->entities[i].traj_x[j] -= mean_x; - env->entities[i].traj_y[j] -= mean_y; - } - env->entities[i].goal_position_x -= mean_x; - env->entities[i].goal_position_y -= mean_y; + + // Subtract mean from dynamic agents + for (int i = 0; i < env->num_total_agents; i++) { + Agent* agent = &env->agents[i]; + for (int j = 0; j < agent->trajectory_length; j++) { + if(agent->log_trajectory_x[j] == INVALID_POSITION) continue; + agent->log_trajectory_x[j] -= mean_x; + agent->log_trajectory_y[j] -= mean_y; + } + // Normalize current sim position (scalars) + if(agent->sim_x != INVALID_POSITION) { + agent->sim_x -= mean_x; + agent->sim_y -= mean_y; } + agent->goal_position_x -= mean_x; + agent->goal_position_y -= mean_y; + agent->init_goal_x -= mean_x; + agent->init_goal_y -= mean_y; } + // Subtract mean from road elements + for (int i = 0; i < env->num_road_elements; i++) { + RoadMapElement* element = &env->road_elements[i]; + for (int j = 0; j < element->segment_length; j++) { + if(element->x[j] == INVALID_POSITION) continue; + element->x[j] -= mean_x; + element->y[j] -= mean_y; + } + } } void move_expert(Drive* env, float* actions, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; int t = env->timestep; - if (t < 0 || t >= agent->array_size) { - agent->x = INVALID_POSITION; - agent->y = INVALID_POSITION; - agent->z = 0.0f; - agent->heading = 0.0f; - agent->heading_x = 1.0f; - agent->heading_y = 0.0f; + + // Timestep out of bounds + if (t < 0 || t >= agent->trajectory_length) { + agent->sim_x = INVALID_POSITION; + agent->sim_y = INVALID_POSITION; + agent->sim_z = 0.0f; + agent->sim_heading = 0.0f; + agent->sim_vx = 0.0f; + agent->sim_vy = 0.0f; + agent->sim_valid = 0; + // agent->sim_length = 0.0f; + // agent->sim_width = 0.0f; + // agent->sim_height = 0.0f; return; } - if (agent->traj_valid && agent->traj_valid[t] == 0) { - agent->x = INVALID_POSITION; - agent->y = INVALID_POSITION; - agent->z = 0.0f; - agent->heading = 0.0f; - agent->heading_x = 1.0f; - agent->heading_y = 0.0f; + // Invalid log entry + if (agent->log_valid && agent->log_valid[t] == 0) { + agent->sim_x = INVALID_POSITION; + agent->sim_y = INVALID_POSITION; + agent->sim_z = 0.0f; + agent->sim_heading = 0.0f; + agent->sim_vx = 0.0f; + agent->sim_vy = 0.0f; + agent->sim_valid = 0; + // agent->sim_length = 0.0f; + // agent->sim_width = 0.0f; + // agent->sim_height = 0.0f; return; } - agent->x = agent->traj_x[t]; - agent->y = agent->traj_y[t]; - agent->z = agent->traj_z[t]; - agent->heading = agent->traj_heading[t]; - agent->heading_x = cosf(agent->heading); - agent->heading_y = sinf(agent->heading); + + // Copy from logged trajectory to simulated state + agent->sim_x = agent->log_trajectory_x[t]; + agent->sim_y = agent->log_trajectory_y[t]; + agent->sim_z = agent->log_trajectory_z[t]; + agent->sim_heading = agent->log_heading[t]; + agent->sim_vx = agent->log_velocity_x[t]; + agent->sim_vy = agent->log_velocity_y[t]; + // agent->sim_length = agent->log_length[t]; + // agent->sim_width = agent->log_width[t]; + // agent->sim_height = agent->log_height[t]; + agent->sim_valid = agent->log_valid[t]; } bool check_line_intersection(float p1[2], float p2[2], float q1[2], float q2[2]) { @@ -988,41 +1055,46 @@ int checkNeighbors(Drive* env, float x, float y, GridMapEntity* entity_list, int for (int j = 0; j < count && entity_list_count < max_size; j++) { int entityId = env->grid_map->cells[neighborIndex][j].entity_idx; int geometry_idx = env->grid_map->cells[neighborIndex][j].geometry_idx; + int entity_type = env->grid_map->cells[neighborIndex][j].entity_type; entity_list[entity_list_count].entity_idx = entityId; entity_list[entity_list_count].geometry_idx = geometry_idx; + entity_list[entity_list_count].entity_type = entity_type; entity_list_count += 1; } } return entity_list_count; } -int check_aabb_collision(Entity* car1, Entity* car2) { +int check_aabb_collision(Agent* car1, Agent* car2) { // Get car corners in world space - float cos1 = car1->heading_x; - float sin1 = car1->heading_y; - float cos2 = car2->heading_x; - float sin2 = car2->heading_y; + float heading1 = car1->sim_heading; + float cos1 = cosf(heading1); + float sin1 = sinf(heading1); + + float heading2 = car2->sim_heading; + float cos2 = cosf(heading2); + float sin2 = sinf(heading2); // Calculate half dimensions - float half_len1 = car1->length * 0.5f; - float half_width1 = car1->width * 0.5f; - float half_len2 = car2->length * 0.5f; - float half_width2 = car2->width * 0.5f; + float half_len1 = car1->sim_length * 0.5f; + float half_width1 = car1->sim_width * 0.5f; + float half_len2 = car2->sim_length * 0.5f; + float half_width2 = car2->sim_width * 0.5f; // Calculate car1's corners in world space float car1_corners[4][2] = { - {car1->x + (half_len1 * cos1 - half_width1 * sin1), car1->y + (half_len1 * sin1 + half_width1 * cos1)}, - {car1->x + (half_len1 * cos1 + half_width1 * sin1), car1->y + (half_len1 * sin1 - half_width1 * cos1)}, - {car1->x + (-half_len1 * cos1 - half_width1 * sin1), car1->y + (-half_len1 * sin1 + half_width1 * cos1)}, - {car1->x + (-half_len1 * cos1 + half_width1 * sin1), car1->y + (-half_len1 * sin1 - half_width1 * cos1)} + {car1->sim_x + (half_len1 * cos1 - half_width1 * sin1), car1->sim_y + (half_len1 * sin1 + half_width1 * cos1)}, + {car1->sim_x + (half_len1 * cos1 + half_width1 * sin1), car1->sim_y + (half_len1 * sin1 - half_width1 * cos1)}, + {car1->sim_x + (-half_len1 * cos1 - half_width1 * sin1), car1->sim_y + (-half_len1 * sin1 + half_width1 * cos1)}, + {car1->sim_x + (-half_len1 * cos1 + half_width1 * sin1), car1->sim_y + (-half_len1 * sin1 - half_width1 * cos1)} }; // Calculate car2's corners in world space float car2_corners[4][2] = { - {car2->x + (half_len2 * cos2 - half_width2 * sin2), car2->y + (half_len2 * sin2 + half_width2 * cos2)}, - {car2->x + (half_len2 * cos2 + half_width2 * sin2), car2->y + (half_len2 * sin2 - half_width2 * cos2)}, - {car2->x + (-half_len2 * cos2 - half_width2 * sin2), car2->y + (-half_len2 * sin2 + half_width2 * cos2)}, - {car2->x + (-half_len2 * cos2 + half_width2 * sin2), car2->y + (-half_len2 * sin2 - half_width2 * cos2)} + {car2->sim_x + (half_len2 * cos2 - half_width2 * sin2), car2->sim_y + (half_len2 * sin2 + half_width2 * cos2)}, + {car2->sim_x + (half_len2 * cos2 + half_width2 * sin2), car2->sim_y + (half_len2 * sin2 - half_width2 * cos2)}, + {car2->sim_x + (-half_len2 * cos2 - half_width2 * sin2), car2->sim_y + (-half_len2 * sin2 + half_width2 * cos2)}, + {car2->sim_x + (-half_len2 * cos2 + half_width2 * sin2), car2->sim_y + (-half_len2 * sin2 - half_width2 * cos2)} }; // Get the axes to check (normalized vectors perpendicular to each edge) @@ -1063,9 +1135,9 @@ int check_aabb_collision(Entity* car1, Entity* car2) { } int collision_check(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; - if(agent->x == INVALID_POSITION ) return -1; + if(agent->sim_x == INVALID_POSITION) return -1; int car_collided_with_index = -1; @@ -1080,13 +1152,13 @@ int collision_check(Drive* env, int agent_idx) { } if(index == -1) continue; if(index == agent_idx) continue; - Entity* entity = &env->entities[index]; - if (entity->respawn_timestep != -1) continue; // Skip respawning entities - float x1 = entity->x; - float y1 = entity->y; - float dist = ((x1 - agent->x)*(x1 - agent->x) + (y1 - agent->y)*(y1 - agent->y)); + + Agent* other_agent = &env->agents[index]; + if (other_agent->respawn_timestep != -1) continue; // Skip respawning entities + + float dist = ((other_agent->sim_x - agent->sim_x)*(other_agent->sim_x - agent->sim_x) + (other_agent->sim_y - agent->sim_y)*(other_agent->sim_y - agent->sim_y)); if(dist > 225.0f) continue; - if(check_aabb_collision(agent, entity)) { + if(check_aabb_collision(agent, other_agent)) { car_collided_with_index = index; break; } @@ -1095,27 +1167,27 @@ int collision_check(Drive* env, int agent_idx) { return car_collided_with_index; } -int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { +int check_lane_aligned(Agent* car, RoadMapElement* lane, int geometry_idx, int timestep) { // Validate lane geometry length - if (!lane || lane->array_size < 2) return 0; + if (!lane || lane->segment_length < 2) return 0; - // Clamp geometry index to valid segment range [0, array_size-2] + // Clamp geometry index to valid segment range [0, segment_length-2] if (geometry_idx < 0) geometry_idx = 0; - if (geometry_idx >= lane->array_size - 1) geometry_idx = lane->array_size - 2; + if (geometry_idx >= lane->segment_length - 1) geometry_idx = lane->segment_length - 2; // Compute local lane segment heading float heading_x1, heading_y1; if (geometry_idx > 0) { - heading_x1 = lane->traj_x[geometry_idx] - lane->traj_x[geometry_idx - 1]; - heading_y1 = lane->traj_y[geometry_idx] - lane->traj_y[geometry_idx - 1]; + heading_x1 = lane->x[geometry_idx] - lane->x[geometry_idx - 1]; + heading_y1 = lane->y[geometry_idx] - lane->y[geometry_idx - 1]; } else { // For first segment, just use the forward direction - heading_x1 = lane->traj_x[geometry_idx + 1] - lane->traj_x[geometry_idx]; - heading_y1 = lane->traj_y[geometry_idx + 1] - lane->traj_y[geometry_idx]; + heading_x1 = lane->x[geometry_idx + 1] - lane->x[geometry_idx]; + heading_y1 = lane->y[geometry_idx + 1] - lane->y[geometry_idx]; } - float heading_x2 = lane->traj_x[geometry_idx + 1] - lane->traj_x[geometry_idx]; - float heading_y2 = lane->traj_y[geometry_idx + 1] - lane->traj_y[geometry_idx]; + float heading_x2 = lane->x[geometry_idx + 1] - lane->x[geometry_idx]; + float heading_y2 = lane->y[geometry_idx + 1] - lane->y[geometry_idx]; float heading_1 = atan2f(heading_y1, heading_x1); float heading_2 = atan2f(heading_y2, heading_x2); @@ -1126,7 +1198,7 @@ int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { if (heading < -M_PI) heading += 2.0f * M_PI; // Compute heading difference - float car_heading = car->heading; // radians + float car_heading = car->sim_heading; // radians float heading_diff = fabsf(car_heading - heading); if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff; @@ -1136,7 +1208,7 @@ int check_lane_aligned(Entity* car, Entity* lane, int geometry_idx) { } void reset_agent_metrics(Drive* env, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; agent->metrics_array[COLLISION_IDX] = 0.0f; // vehicle collision agent->metrics_array[OFFROAD_IDX] = 0.0f; // offroad agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; // lane aligned @@ -1169,11 +1241,11 @@ float point_to_segment_distance_2d(float px, float py, float x1, float y1, float } void compute_agent_metrics(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; reset_agent_metrics(env, agent_idx); - if(agent->x == INVALID_POSITION ) return; // invalid agent position + if(agent->sim_x == INVALID_POSITION) return; // invalid agent position // Compute displacement error float displacement_error = compute_displacement_error(agent, env->timestep); @@ -1188,10 +1260,10 @@ void compute_agent_metrics(Drive* env, int agent_idx) { } int collided = 0; - float half_length = agent->length/2.0f; - float half_width = agent->width/2.0f; - float cos_heading = cosf(agent->heading); - float sin_heading = sinf(agent->heading); + float half_length = agent->sim_length/2.0f; + float half_width = agent->sim_width/2.0f; + float cos_heading = cosf(agent->sim_heading); + float sin_heading = sinf(agent->sim_heading); float min_distance = (float)INT16_MAX; int closest_lane_entity_idx = -1; @@ -1199,23 +1271,25 @@ void compute_agent_metrics(Drive* env, int agent_idx) { float corners[4][2]; for (int i = 0; i < 4; i++) { - corners[i][0] = agent->x + (offsets[i][0]*half_length*cos_heading - offsets[i][1]*half_width*sin_heading); - corners[i][1] = agent->y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading); + corners[i][0] = agent->sim_x + (offsets[i][0]*half_length*cos_heading - offsets[i][1]*half_width*sin_heading); + corners[i][1] = agent->sim_y + (offsets[i][0]*half_length*sin_heading + offsets[i][1]*half_width*cos_heading); } GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; // Array big enough for all neighboring cells - int list_size = checkNeighbors(env, agent->x, agent->y, entity_list, MAX_ENTITIES_PER_CELL*25, collision_offsets, 25); - for (int i = 0; i < list_size ; i++) { + int list_size = checkNeighbors(env, agent->sim_x, agent->sim_y, entity_list, MAX_ENTITIES_PER_CELL*25, collision_offsets, 25); + for (int i = 0; i < list_size; i++) { if(entity_list[i].entity_idx == -1) continue; - if(entity_list[i].entity_idx == agent_idx) continue; - Entity* entity; - entity = &env->entities[entity_list[i].entity_idx]; + + // Get the road element (only road elements are in grid) + if(entity_list[i].entity_type != ENTITY_TYPE_ROAD_ELEMENT) continue; + + RoadMapElement* element = &env->road_elements[entity_list[i].entity_idx]; + int geometry_idx = entity_list[i].geometry_idx; // Check for offroad collision with road edges - if(entity->type == ROAD_EDGE) { - int geometry_idx = entity_list[i].geometry_idx; - float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]}; - float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; + if(is_road_edge(element->type)) { + float start[2] = {element->x[geometry_idx], element->y[geometry_idx]}; + float end[2] = {element->x[geometry_idx + 1], element->y[geometry_idx + 1]}; for (int k = 0; k < 4; k++) { // Check each edge of the bounding box int next = (k + 1) % 4; if (check_line_intersection(corners[k], corners[next], start, end)) { @@ -1228,15 +1302,14 @@ void compute_agent_metrics(Drive* env, int agent_idx) { if (collided == OFFROAD) break; // Find closest point on the road centerline to the agent - if(entity->type == ROAD_LANE) { - int entity_idx = entity_list[i].entity_idx; - int geometry_idx = entity_list[i].geometry_idx; + if(is_drivable_road_lane(element->type)) { + int elem_idx = entity_list[i].entity_idx; - float start[2] = {entity->traj_x[geometry_idx], entity->traj_y[geometry_idx]}; - float end[2] = {entity->traj_x[geometry_idx + 1], entity->traj_y[geometry_idx + 1]}; + float start[2] = {element->x[geometry_idx], element->y[geometry_idx]}; + float end[2] = {element->x[geometry_idx + 1], element->y[geometry_idx + 1]}; - float dist = point_to_segment_distance_2d(agent->x, agent->y, start[0], start[1], end[0], end[1]); - float heading_diff = fabsf(atan2f(end[1]-start[1], end[0]-start[0]) - agent->heading); + float dist = point_to_segment_distance_2d(agent->sim_x, agent->sim_y, start[0], start[1], end[0], end[1]); + float heading_diff = fabsf(atan2f(end[1]-start[1], end[0]-start[0]) - agent->sim_heading); // Normalize heading difference to [0, pi] if (heading_diff > M_PI) heading_diff = 2.0f * M_PI - heading_diff; @@ -1246,7 +1319,7 @@ void compute_agent_metrics(Drive* env, int agent_idx) { if (dist < min_distance) { min_distance = dist; - closest_lane_entity_idx = entity_idx; + closest_lane_entity_idx = elem_idx; closest_lane_geometry_idx = geometry_idx; } } @@ -1260,7 +1333,7 @@ void compute_agent_metrics(Drive* env, int agent_idx) { } else { agent->current_lane_idx = closest_lane_entity_idx; - int lane_aligned = check_lane_aligned(agent, &env->entities[closest_lane_entity_idx], closest_lane_geometry_idx); + int lane_aligned = check_lane_aligned(agent, &env->road_elements[closest_lane_entity_idx], closest_lane_geometry_idx, env->timestep); agent->metrics_array[LANE_ALIGNED_IDX] = lane_aligned; } @@ -1274,26 +1347,21 @@ void compute_agent_metrics(Drive* env, int agent_idx) { } bool should_control_agent(Drive* env, int agent_idx){ - // Check if we have room for more agents or are already at capacity - if (env->active_agent_count >= env->num_agents) { + if (env->active_agent_count >= env->num_max_agents) { return false; } - Entity* entity = &env->entities[agent_idx]; - - // Shrink agent size for collision checking - entity->width *= 0.7f; // TODO: Move this somewhere else - entity->length *= 0.7f; + Agent* agent = &env->agents[agent_idx]; if (env->control_mode == CONTROL_SDC_ONLY) { - return (agent_idx == env->sdc_track_index); + return (agent_idx == env->sdc_index); } // Special mode: control only agents in prediction track list if (env->control_mode == CONTROL_TRACKS_TO_PREDICT) { for (int j = 0; j < env->num_tracks_to_predict; j++) { - if (env->tracks_to_predict_indices[j] == agent_idx) { + if (env->tracks_to_predict[j] == agent_idx) { return true; } } @@ -1303,20 +1371,20 @@ bool should_control_agent(Drive* env, int agent_idx){ // Standard mode: check type, distance to goal, and expert status bool type_is_controllable = false; if (env->control_mode == CONTROL_VEHICLES) { - type_is_controllable = (entity->type == VEHICLE); + type_is_controllable = (agent->type == VEHICLE); } else { // CONTROL_AGENTS mode - type_is_controllable = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); + type_is_controllable = (agent->type == VEHICLE || agent->type == PEDESTRIAN || agent->type == CYCLIST); } - if (!type_is_controllable || entity->mark_as_expert) { + if (!type_is_controllable || agent->mark_as_expert) { return false; } // Check distance to goal in agent's local frame - float cos_heading = cosf(entity->traj_heading[0]); - float sin_heading = sinf(entity->traj_heading[0]); - float goal_dx = entity->goal_position_x - entity->traj_x[0]; - float goal_dy = entity->goal_position_y - entity->traj_y[0]; + float cos_heading = cosf(agent->log_heading[0]); + float sin_heading = sinf(agent->log_heading[0]); + float goal_dx = agent->goal_position_x - agent->log_trajectory_x[0]; + float goal_dy = agent->goal_position_y - agent->log_trajectory_y[0]; // Transform to agent's local frame float local_goal_x = goal_dx * cos_heading + goal_dy * sin_heading; @@ -1327,7 +1395,6 @@ bool should_control_agent(Drive* env, int agent_idx){ } void set_active_agents(Drive* env){ - // Initialize env->active_agent_count = 0; // Policy-controlled agents env->static_agent_count = 0; // Non-moving background agents @@ -1338,17 +1405,17 @@ void set_active_agents(Drive* env){ int static_agent_indices[MAX_AGENTS]; int expert_static_agent_indices[MAX_AGENTS]; - if(env->num_agents == 0){ - env->num_agents = MAX_AGENTS; + if(env->num_max_agents == 0){ + env->num_max_agents = MAX_AGENTS; } // Iterate through entities to find agents to create and/or control - for(int i = 0; i < env->num_objects && env->num_actors < MAX_AGENTS; i++){ + for(int i = 0; i < env->num_total_agents && env->num_actors < MAX_AGENTS; i++){ - Entity* entity = &env->entities[i]; + Agent* agent = &env->agents[i]; // Skip if not valid at initialization - if (entity->traj_valid[env->init_steps] != 1) { + if (agent->log_valid[env->init_steps] != 1) { continue; } @@ -1357,9 +1424,9 @@ void set_active_agents(Drive* env){ if (env->init_mode == INIT_ALL_VALID) { should_create = true; // All valid entities } else if (env->control_mode == CONTROL_VEHICLES) { - should_create = (entity->type == VEHICLE); + should_create = (agent->type == VEHICLE); } else { // Control all agents - should_create = (entity->type == VEHICLE || entity->type == PEDESTRIAN || entity->type == CYCLIST); + should_create = (is_controllable_agent(agent->type)); } if (!should_create) continue; @@ -1374,15 +1441,15 @@ void set_active_agents(Drive* env){ if(is_controlled){ active_agent_indices[env->active_agent_count] = i; env->active_agent_count++; - env->entities[i].active_agent = 1; + env->agents[i].active_agent = 1; } else if (env->init_mode != INIT_ONLY_CONTROLLABLE_AGENTS) { static_agent_indices[env->static_agent_count] = i; env->static_agent_count++; - env->entities[i].active_agent = 0; - if(env->entities[i].mark_as_expert == 1 || env->active_agent_count == env->num_agents) { + env->agents[i].active_agent = 0; + if(env->agents[i].mark_as_expert == 1 || env->active_agent_count == env->num_max_agents) { expert_static_agent_indices[env->expert_static_agent_count] = i; env->expert_static_agent_count++; - env->entities[i].mark_as_expert = 1; + env->agents[i].mark_as_expert = 1; } } } @@ -1396,7 +1463,6 @@ void set_active_agents(Drive* env){ }; for(int i=0;istatic_agent_count;i++){ env->static_agent_indices[i] = static_agent_indices[i]; - } for(int i=0;iexpert_static_agent_count;i++){ env->expert_static_agent_indices[i] = expert_static_agent_indices[i]; @@ -1426,13 +1492,13 @@ void remove_bad_trajectories(Drive* env){ } for(int i = 0; i < env->expert_static_agent_count; i++){ int expert_idx = env->expert_static_agent_indices[i]; - if(env->entities[expert_idx].x == INVALID_POSITION) continue; + if(env->agents[expert_idx].sim_x == INVALID_POSITION) continue; move_expert(env, env->actions, expert_idx); } // check collisions for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; - env->entities[agent_idx].collision_state = 0; + env->agents[agent_idx].collision_state = 0; int collided_with_index = collision_check(env, agent_idx); if((collided_with_index >= 0) && collided_agents[i] == 0){ collided_agents[i] = 1; @@ -1447,8 +1513,8 @@ void remove_bad_trajectories(Drive* env){ for(int j = 0; j < env->static_agent_count; j++){ int static_agent_idx = env->static_agent_indices[j]; if(static_agent_idx != collided_with_indices[i]) continue; - env->entities[static_agent_idx].traj_x[0] = INVALID_POSITION; - env->entities[static_agent_idx].traj_y[0] = INVALID_POSITION; + env->agents[static_agent_idx].log_trajectory_x[0] = INVALID_POSITION; + env->agents[static_agent_idx].log_trajectory_y[0] = INVALID_POSITION; } } env->timestep = 0; @@ -1457,15 +1523,15 @@ void remove_bad_trajectories(Drive* env){ void init_goal_positions(Drive* env){ for(int x = 0;xactive_agent_count; x++){ int agent_idx = env->active_agent_indices[x]; - env->entities[agent_idx].init_goal_x = env->entities[agent_idx].goal_position_x; - env->entities[agent_idx].init_goal_y = env->entities[agent_idx].goal_position_y; + env->agents[agent_idx].init_goal_x = env->agents[agent_idx].goal_position_x; + env->agents[agent_idx].init_goal_y = env->agents[agent_idx].goal_position_y; } } void init(Drive* env){ env->human_agent_idx = 0; env->timestep = 0; - env->entities = load_map_binary(env->map_name, env); + load_map_binary(env->map_name, env); set_means(env); init_grid_map(env); if (env->goal_behavior==GOAL_GENERATE_NEW) init_topology_graph(env); @@ -1482,10 +1548,12 @@ void init(Drive* env){ } void c_close(Drive* env){ - for(int i = 0; i < env->num_entities; i++){ - free_entity(&env->entities[i]); - } - free(env->entities); + for(int i = 0; i < env->num_total_agents; i++) free_agent(&env->agents[i]); + for(int i = 0; i < env->num_road_elements; i++) free_road_element(&env->road_elements[i]); + for(int i = 0; i < env->num_traffic_elements; i++) free_traffic_element(&env->traffic_elements[i]); + free(env->agents); + free(env->road_elements); + free(env->traffic_elements); free(env->active_agent_indices); free(env->logs); // GridMap cleanup @@ -1505,6 +1573,8 @@ void c_close(Drive* env){ free(env->grid_map); free(env->static_agent_indices); free(env->expert_static_agent_indices); + free(env->objects_of_interest); + free(env->tracks_to_predict); freeTopologyGraph(env->topology_graph); // free(env->map_name); free(env->ini_file); @@ -1546,12 +1616,12 @@ float normalize_value(float value, float min, float max){ } void move_dynamics(Drive* env, int action_idx, int agent_idx){ - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; if (agent->removed) return; if (agent->stopped) { - agent->vx = 0.0f; - agent->vy = 0.0f; + agent->sim_vx = 0.0f; + agent->sim_vy = 0.0f; return; } @@ -1580,24 +1650,25 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ } // Current state - float x = agent->x; - float y = agent->y; - float heading = agent->heading; - float vx = agent->vx; - float vy = agent->vy; + float x = agent->sim_x; + float y = agent->sim_y; + float heading = agent->sim_heading; + float vx = agent->sim_vx; + float vy = agent->sim_vy; + float length = agent->sim_length; // Calculate current speed float speed = sqrtf(vx*vx + vy*vy); // Update speed with acceleration - speed = speed + acceleration*env->dt; + speed = speed + acceleration * env->dt; speed = clipSpeed(speed); // Compute yaw rate float beta = tanh(.5*tanf(steering)); // New heading - float yaw_rate = (speed*cosf(beta)*tanf(steering)) / agent->length; + float yaw_rate = (speed*cosf(beta)*tanf(steering)) / length; // New velocity float new_vx = speed*cosf(heading + beta); @@ -1609,13 +1680,11 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ heading = heading + yaw_rate*env->dt; // Apply updates to the agent's state - agent->x = x; - agent->y = y; - agent->heading = heading; - agent->heading_x = cosf(heading); - agent->heading_y = sinf(heading); - agent->vx = new_vx; - agent->vy = new_vy; + agent->sim_x = x; + agent->sim_y = y; + agent->sim_heading = heading; + agent->sim_vx = new_vx; + agent->sim_vy = new_vy; } else { // JERK dynamics model // Extract action components @@ -1659,9 +1728,12 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ a_lat_new = clip(a_lat_new, -4.0f, 4.0f); } + float heading_x = cosf(agent->sim_heading); + float heading_y = sinf(agent->sim_heading); + // Calculate new velocity - float v_dot_heading = agent->vx * agent->heading_x + agent->vy * agent->heading_y; - float signed_v = copysignf(sqrtf(agent->vx*agent->vx + agent->vy*agent->vy), v_dot_heading); + float v_dot_heading = agent->sim_vx * heading_x + agent->sim_vy * heading_y; + float signed_v = copysignf(sqrtf(agent->sim_vx*agent->sim_vx + agent->sim_vy*agent->sim_vy), v_dot_heading); float v_new = signed_v + 0.5f * (a_long_new + agent->a_long) * env->dt; // Make it easy to stop with 0 vel @@ -1695,21 +1767,19 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ dy_local = (1.0f - cosf(theta)) / signed_curvature; } - float dx = dx_local * agent->heading_x - dy_local * agent->heading_y; - float dy = dx_local * agent->heading_y + dy_local * agent->heading_x; + float dx = dx_local * heading_x - dy_local * heading_y; + float dy = dx_local * heading_y + dy_local * heading_x; - // Update everything - agent->x += dx; - agent->y += dy; + // Update agent state + agent->sim_x += dx; + agent->sim_y += dy; agent->jerk_long = (a_long_new - agent->a_long) / env->dt; agent->jerk_lat = (a_lat_new - agent->a_lat) / env->dt; agent->a_long = a_long_new; agent->a_lat = a_lat_new; - agent->heading = normalize_heading(agent->heading + theta); - agent->heading_x = cosf(agent->heading); - agent->heading_y = sinf(agent->heading); - agent->vx = v_new * agent->heading_x; - agent->vy = v_new * agent->heading_y; + agent->sim_heading = normalize_heading(agent->sim_heading + theta); + agent->sim_vx = v_new * cosf(agent->sim_heading); + agent->sim_vy = v_new * sinf(agent->sim_heading); agent->steering_angle = new_steering_angle; } @@ -1719,32 +1789,32 @@ void move_dynamics(Drive* env, int action_idx, int agent_idx){ void c_get_global_agent_state(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* id_out) { for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; // For WOSAC, we need the original world coordinates, so we add the world means back - x_out[i] = agent->x + env->world_mean_x; - y_out[i] = agent->y + env->world_mean_y; - z_out[i] = agent->z; - heading_out[i] = agent->heading; - id_out[i] = env->tracks_to_predict_indices[i]; + x_out[i] = agent->sim_x + env->world_mean_x; + y_out[i] = agent->sim_y + env->world_mean_y; + z_out[i] = agent->sim_z; + heading_out[i] = agent->sim_heading; + id_out[i] = env->tracks_to_predict[i]; } } void c_get_global_ground_truth_trajectories(Drive* env, float* x_out, float* y_out, float* z_out, float* heading_out, int* valid_out, int* id_out, int* scenario_id_out) { for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; - Entity* agent = &env->entities[agent_idx]; - id_out[i] = env->tracks_to_predict_indices[i]; - scenario_id_out[i] = agent->scenario_id; + Agent* agent = &env->agents[agent_idx]; + id_out[i] = env->tracks_to_predict[i]; + scenario_id_out[i] = env->map_index; - for(int t = env->init_steps; t < agent->array_size; t++){ - int out_idx = i * (agent->array_size - env->init_steps) + (t - env->init_steps); + for(int t = env->init_steps; t < agent->trajectory_length; t++){ + int out_idx = i * (agent->trajectory_length - env->init_steps) + (t - env->init_steps); // Add world means back to get original world coordinates - x_out[out_idx] = agent->traj_x[t] + env->world_mean_x; - y_out[out_idx] = agent->traj_y[t] + env->world_mean_y; - z_out[out_idx] = agent->traj_z[t]; - heading_out[out_idx] = agent->traj_heading[t]; - valid_out[out_idx] = agent->traj_valid[t]; + x_out[out_idx] = agent->log_trajectory_x[t] + env->world_mean_x; + y_out[out_idx] = agent->log_trajectory_y[t] + env->world_mean_y; + z_out[out_idx] = agent->log_trajectory_z[t]; + heading_out[out_idx] = agent->log_heading[t]; + valid_out[out_idx] = agent->log_valid[t]; } } } @@ -1756,16 +1826,16 @@ void compute_observations(Drive* env) { float (*observations)[max_obs] = (float(*)[max_obs])env->observations; for(int i = 0; i < env->active_agent_count; i++) { float* obs = &observations[i][0]; - Entity* ego_entity = &env->entities[env->active_agent_indices[i]]; - if(ego_entity->type > 3) break; + Agent* ego_entity = &env->agents[env->active_agent_indices[i]]; - float cos_heading = ego_entity->heading_x; - float sin_heading = ego_entity->heading_y; - float ego_speed = sqrtf(ego_entity->vx*ego_entity->vx + ego_entity->vy*ego_entity->vy); + float ego_heading = ego_entity->sim_heading; + float cos_heading = cosf(ego_heading); + float sin_heading = sinf(ego_heading); + float ego_speed = sqrtf(ego_entity->sim_vx*ego_entity->sim_vx + ego_entity->sim_vy*ego_entity->sim_vy); // Set goal distances - float goal_x = ego_entity->goal_position_x - ego_entity->x; - float goal_y = ego_entity->goal_position_y - ego_entity->y; + float goal_x = ego_entity->goal_position_x - ego_entity->sim_x; + float goal_y = ego_entity->goal_position_y - ego_entity->sim_y; // Rotate to ego vehicle's frame float rel_goal_x = goal_x*cos_heading + goal_y*sin_heading; @@ -1774,8 +1844,8 @@ void compute_observations(Drive* env) { obs[0] = rel_goal_x* 0.005f; obs[1] = rel_goal_y* 0.005f; obs[2] = ego_speed / MAX_SPEED; - obs[3] = ego_entity->width / MAX_VEH_WIDTH; - obs[4] = ego_entity->length / MAX_VEH_LEN; + obs[3] = ego_entity->sim_width / MAX_VEH_WIDTH; + obs[4] = ego_entity->sim_length / MAX_VEH_LEN; obs[5] = (ego_entity->collision_state > 0) ? 1.0f : 0.0f; if (env->dynamics_model == JERK) { @@ -1799,14 +1869,14 @@ void compute_observations(Drive* env) { index = env->static_agent_indices[j - env->active_agent_count]; } if(index == -1) continue; - if(env->entities[index].type > 3) break; + if(env->agents[index].type > 3) break; if(index == env->active_agent_indices[i]) continue; // Skip self, but don't increment obs_idx - Entity* other_entity = &env->entities[index]; + Agent* other_entity = &env->agents[index]; if(ego_entity->respawn_timestep != -1) continue; if(other_entity->respawn_timestep != -1) continue; // Store original relative positions - float dx = other_entity->x - ego_entity->x; - float dy = other_entity->y - ego_entity->y; + float dx = other_entity->sim_x - ego_entity->sim_x; + float dy = other_entity->sim_y - ego_entity->sim_y; float dist = (dx*dx + dy*dy); if(dist > 2500.0f) continue; // Rotate to ego vehicle's frame @@ -1815,20 +1885,23 @@ void compute_observations(Drive* env) { // Store observations with correct indexing obs[obs_idx] = rel_x * 0.02f; obs[obs_idx + 1] = rel_y * 0.02f; - obs[obs_idx + 2] = other_entity->width / MAX_VEH_WIDTH; - obs[obs_idx + 3] = other_entity->length / MAX_VEH_LEN; + obs[obs_idx + 2] = other_entity->sim_width / MAX_VEH_WIDTH; + obs[obs_idx + 3] = other_entity->sim_length / MAX_VEH_LEN; // relative heading - float rel_heading_x = other_entity->heading_x * ego_entity->heading_x + - other_entity->heading_y * ego_entity->heading_y; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b) - float rel_heading_y = other_entity->heading_y * ego_entity->heading_x - - other_entity->heading_x * ego_entity->heading_y; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b) + float other_heading = other_entity->sim_heading; + float other_cos = cosf(other_heading); + float other_sin = sinf(other_heading); + float rel_heading_x = other_cos * cos_heading + + other_sin * sin_heading; // cos(a-b) = cos(a)cos(b) + sin(a)sin(b) + float rel_heading_y = other_sin * cos_heading - + other_cos * sin_heading; // sin(a-b) = sin(a)cos(b) - cos(a)sin(b) obs[obs_idx + 4] = rel_heading_x; obs[obs_idx + 5] = rel_heading_y; // obs[obs_idx + 4] = cosf(rel_heading) / MAX_ORIENTATION_RAD; // obs[obs_idx + 5] = sinf(rel_heading) / MAX_ORIENTATION_RAD; // // relative speed - float other_speed = sqrtf(other_entity->vx*other_entity->vx + other_entity->vy*other_entity->vy); + float other_speed = sqrtf(other_entity->sim_vx*other_entity->sim_vx + other_entity->sim_vy*other_entity->sim_vy); obs[obs_idx + 6] = other_speed / MAX_SPEED; cars_seen++; obs_idx += 7; // Move to next observation slot @@ -1838,36 +1911,40 @@ void compute_observations(Drive* env) { obs_idx += remaining_partner_obs; // map observations GridMapEntity entity_list[MAX_ENTITIES_PER_CELL*25]; - int grid_idx = getGridIndex(env, ego_entity->x, ego_entity->y); + int grid_idx = getGridIndex(env, ego_entity->sim_x, ego_entity->sim_y); int list_size = get_neighbor_cache_entities(env, grid_idx, entity_list, MAX_ROAD_SEGMENT_OBSERVATIONS); for(int k = 0; k < list_size; k++) { + int entity_type = entity_list[k].entity_type; int entity_idx = entity_list[k].entity_idx; int geometry_idx = entity_list[k].geometry_idx; + // Only process road elements in observations + if(entity_type != ENTITY_TYPE_ROAD_ELEMENT) continue; + // Validate entity_idx before accessing - if(entity_idx < 0 || entity_idx >= env->num_entities) { - printf("ERROR: Invalid entity_idx %d (max: %d)\n", entity_idx, env->num_entities-1); + if(entity_idx < 0 || entity_idx >= env->num_road_elements) { + printf("ERROR: Invalid road element idx %d (max: %d)\n", entity_idx, env->num_road_elements-1); continue; } - Entity* entity = &env->entities[entity_idx]; + RoadMapElement* element = &env->road_elements[entity_idx]; // Validate geometry_idx before accessing - if(geometry_idx < 0 || geometry_idx >= entity->array_size) { - printf("ERROR: Invalid geometry_idx %d for entity %d (max: %d)\n", - geometry_idx, entity_idx, entity->array_size-1); + if(geometry_idx < 0 || geometry_idx > element->segment_length) { + printf("ERROR: Invalid geometry_idx %d for road element %d (max: %d)\n", + geometry_idx, entity_idx, element->segment_length-1); continue; } - float start_x = entity->traj_x[geometry_idx]; - float start_y = entity->traj_y[geometry_idx]; - float end_x = entity->traj_x[geometry_idx+1]; - float end_y = entity->traj_y[geometry_idx+1]; + float start_x = element->x[geometry_idx]; + float start_y = element->y[geometry_idx]; + float end_x = element->x[geometry_idx+1]; + float end_y = element->y[geometry_idx+1]; float mid_x = (start_x + end_x) / 2.0f; float mid_y = (start_y + end_y) / 2.0f; - float rel_x = mid_x - ego_entity->x; - float rel_y = mid_y - ego_entity->y; + float rel_x = mid_x - ego_entity->sim_x; + float rel_y = mid_y - ego_entity->sim_y; float x_obs = rel_x*cos_heading + rel_y*sin_heading; float y_obs = -rel_x*sin_heading + rel_y*cos_heading; float length = relative_distance_2d(mid_x, mid_y, end_x, end_y); @@ -1891,7 +1968,7 @@ void compute_observations(Drive* env) { obs[obs_idx + 3] = width / MAX_ROAD_SCALE; obs[obs_idx + 4] = cos_angle; obs[obs_idx + 5] = sin_angle; - obs[obs_idx + 6] = entity->type - 4.0f; + obs[obs_idx + 6] = normalize_road_type(element->type); obs_idx += 7; } int remaining_obs = (MAX_ROAD_SEGMENT_OBSERVATIONS - list_size) * 7; @@ -1900,22 +1977,28 @@ void compute_observations(Drive* env) { } } -static int find_forward_projection_on_lane(Entity* lane, Entity* agent, int* out_segment_idx, float* out_fraction) { +static int find_forward_projection_on_lane(RoadMapElement* lane, Agent* agent, int timestep, int* out_segment_idx, float* out_fraction) { int best_idx = -1; float best_dist_sq = 1e30f; - for (int i = 1; i < lane->array_size; i++) { - float x0 = lane->traj_x[i - 1]; - float y0 = lane->traj_y[i - 1]; - float x1 = lane->traj_x[i]; - float y1 = lane->traj_y[i]; + float agent_x = agent->sim_x; + float agent_y = agent->sim_y; + float agent_heading = agent->sim_heading; + float agent_heading_x = cosf(agent_heading); + float agent_heading_y = sinf(agent_heading); + + for (int i = 1; i < lane->segment_length; i++) { + float x0 = lane->x[i - 1]; + float y0 = lane->y[i - 1]; + float x1 = lane->x[i]; + float y1 = lane->y[i]; float dx = x1 - x0; float dy = y1 - y0; float seg_len_sq = dx * dx + dy * dy; if (seg_len_sq < 1e-6f) continue; - float to_agent_x = agent->x - x0; - float to_agent_y = agent->y - y0; + float to_agent_x = agent_x - x0; + float to_agent_y = agent_y - y0; float t = (to_agent_x * dx + to_agent_y * dy) / seg_len_sq; if (t < 0.0f) t = 0.0f; else if (t > 1.0f) t = 1.0f; @@ -1923,9 +2006,9 @@ static int find_forward_projection_on_lane(Entity* lane, Entity* agent, int* out float proj_x = x0 + t * dx; float proj_y = y0 + t * dy; - float rel_x = proj_x - agent->x; - float rel_y = proj_y - agent->y; - float forward = rel_x * agent->heading_x + rel_y * agent->heading_y; + float rel_x = proj_x - agent_x; + float rel_y = proj_y - agent_y; + float forward = rel_x * agent_heading_x + rel_y * agent_heading_y; if (forward < 0.0f) continue; float dist_sq = rel_x * rel_x + rel_y * rel_y; @@ -1945,7 +2028,8 @@ static int find_forward_projection_on_lane(Entity* lane, Entity* agent, int* out } void compute_new_goal(Drive* env, int agent_idx) { - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; + int t = env->timestep; int current_lane = agent->current_lane_idx; if (current_lane == -1) return; // No current lane @@ -1953,16 +2037,22 @@ void compute_new_goal(Drive* env, int agent_idx) { // Target distance: 40m ahead along the lane topology from agent's current position float target_distance = 40.0f; int current_entity = current_lane; - Entity* lane = &env->entities[current_entity]; + RoadMapElement* lane = &env->road_elements[current_entity]; + + float agent_x = agent->sim_x; + float agent_y = agent->sim_y; + float agent_heading = agent->sim_heading; + float agent_heading_x = cosf(agent_heading); + float agent_heading_y = sinf(agent_heading); int initial_segment_idx = 1; float initial_fraction = 0.0f; - if (!find_forward_projection_on_lane(lane, agent, &initial_segment_idx, &initial_fraction)) { + if (!find_forward_projection_on_lane(lane, agent, t, &initial_segment_idx, &initial_fraction)) { int forward_idx = -1; - for (int i = 0; i < lane->array_size; i++) { - float to_point_x = lane->traj_x[i] - agent->x; - float to_point_y = lane->traj_y[i] - agent->y; - float dot = to_point_x * agent->heading_x + to_point_y * agent->heading_y; + for (int i = 0; i < lane->segment_length; i++) { + float to_point_x = lane->x[i] - agent_x; + float to_point_y = lane->y[i] - agent_y; + float dot = to_point_x * agent_heading_x + to_point_y * agent_heading_y; if (dot > 0.0f) { forward_idx = i; break; @@ -1970,8 +2060,8 @@ void compute_new_goal(Drive* env, int agent_idx) { } if (forward_idx == -1) { - agent->goal_position_x = lane->traj_x[lane->array_size - 1]; - agent->goal_position_y = lane->traj_y[lane->array_size - 1]; + agent->goal_position_x = lane->x[lane->segment_length - 1]; + agent->goal_position_y = lane->y[lane->segment_length - 1]; agent->sampled_new_goal = 0; return; } @@ -1986,18 +2076,18 @@ void compute_new_goal(Drive* env, int agent_idx) { // Traverse the topology graph starting from the vehicle's position forward while (current_entity != -1) { - lane = &env->entities[current_entity]; + lane = &env->road_elements[current_entity]; int start_idx = first_lane ? initial_segment_idx : 1; - // Ensure start_idx is at least 1 to avoid accessing traj_x[i-1] with i=0 + // Ensure start_idx is at least 1 to avoid accessing x[i-1] with i=0 if (start_idx < 1) start_idx = 1; first_lane = 0; - for (int i = start_idx; i < lane->array_size; i++) { - float prev_x = lane->traj_x[i - 1]; - float prev_y = lane->traj_y[i - 1]; - float next_x = lane->traj_x[i]; - float next_y = lane->traj_y[i]; + for (int i = start_idx; i < lane->segment_length; i++) { + float prev_x = lane->x[i - 1]; + float prev_y = lane->y[i - 1]; + float next_x = lane->x[i]; + float next_y = lane->y[i]; float seg_dx = next_x - prev_x; float seg_dy = next_y - prev_y; float segment_length = relative_distance_2d(prev_x, prev_y, next_x, next_y); @@ -2016,8 +2106,8 @@ void compute_new_goal(Drive* env, int agent_idx) { int num_connected = getNextLanes(env->topology_graph, current_entity, connected_lanes, 5); if (num_connected == 0) { - agent->goal_position_x = lane->traj_x[lane->array_size - 1]; - agent->goal_position_y = lane->traj_y[lane->array_size - 1]; + agent->goal_position_x = lane->x[lane->segment_length - 1]; + agent->goal_position_y = lane->y[lane->segment_length - 1]; agent->sampled_new_goal = 0; return; // No further lanes to traverse } @@ -2033,24 +2123,24 @@ void c_reset(Drive* env){ for(int x = 0;xactive_agent_count; x++){ env->logs[x] = (Log){0}; int agent_idx = env->active_agent_indices[x]; - env->entities[agent_idx].respawn_timestep = -1; - env->entities[agent_idx].respawn_count = 0; - env->entities[agent_idx].collided_before_goal = 0; - env->entities[agent_idx].reached_goal_this_episode = 0; - env->entities[agent_idx].metrics_array[COLLISION_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; - env->entities[agent_idx].cumulative_displacement = 0.0f; - env->entities[agent_idx].displacement_sample_count = 0; - env->entities[agent_idx].stopped = 0; - env->entities[agent_idx].removed = 0; + env->agents[agent_idx].respawn_timestep = -1; + env->agents[agent_idx].respawn_count = 0; + env->agents[agent_idx].collided_before_goal = 0; + env->agents[agent_idx].reached_goal_this_episode = 0; + env->agents[agent_idx].metrics_array[COLLISION_IDX] = 0.0f; + env->agents[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; + env->agents[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; + env->agents[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; + env->agents[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; + env->agents[agent_idx].cumulative_displacement = 0.0f; + env->agents[agent_idx].displacement_sample_count = 0; + env->agents[agent_idx].stopped = 0; + env->agents[agent_idx].removed = 0; if (env->goal_behavior==GOAL_GENERATE_NEW) { - env->entities[agent_idx].goal_position_x = env->entities[agent_idx].init_goal_x; - env->entities[agent_idx].goal_position_y = env->entities[agent_idx].init_goal_y; - env->entities[agent_idx].sampled_new_goal = 0; + env->agents[agent_idx].goal_position_x = env->agents[agent_idx].init_goal_x; + env->agents[agent_idx].goal_position_y = env->agents[agent_idx].init_goal_y; + env->agents[agent_idx].sampled_new_goal = 0; } compute_agent_metrics(env, agent_idx); @@ -2059,28 +2149,27 @@ void c_reset(Drive* env){ } void respawn_agent(Drive* env, int agent_idx){ - env->entities[agent_idx].x = env->entities[agent_idx].traj_x[0]; - env->entities[agent_idx].y = env->entities[agent_idx].traj_y[0]; - env->entities[agent_idx].heading = env->entities[agent_idx].traj_heading[0]; - env->entities[agent_idx].heading_x = cosf(env->entities[agent_idx].heading); - env->entities[agent_idx].heading_y = sinf(env->entities[agent_idx].heading); - env->entities[agent_idx].vx = env->entities[agent_idx].traj_vx[0]; - env->entities[agent_idx].vy = env->entities[agent_idx].traj_vy[0]; - env->entities[agent_idx].metrics_array[COLLISION_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[OFFROAD_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX] = 0.0f; - env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; - env->entities[agent_idx].cumulative_displacement = 0.0f; - env->entities[agent_idx].displacement_sample_count = 0; - env->entities[agent_idx].respawn_timestep = env->timestep; - env->entities[agent_idx].stopped = 0; - env->entities[agent_idx].removed = 0; - env->entities[agent_idx].a_long = 0.0f; - env->entities[agent_idx].a_lat = 0.0f; - env->entities[agent_idx].jerk_long = 0.0f; - env->entities[agent_idx].jerk_lat = 0.0f; - env->entities[agent_idx].steering_angle = 0.0f; + Agent* agent = &env->agents[agent_idx]; + agent->sim_x = agent->log_trajectory_x[0]; + agent->sim_y = agent->log_trajectory_y[0]; + agent->sim_heading = agent->log_heading[0]; + agent->sim_vx = agent->log_velocity_x[0]; + agent->sim_vy = agent->log_velocity_y[0]; + agent->metrics_array[COLLISION_IDX] = 0.0f; + agent->metrics_array[OFFROAD_IDX] = 0.0f; + agent->metrics_array[REACHED_GOAL_IDX] = 0.0f; + agent->metrics_array[LANE_ALIGNED_IDX] = 0.0f; + agent->metrics_array[AVG_DISPLACEMENT_ERROR_IDX] = 0.0f; + agent->cumulative_displacement = 0.0f; + agent->displacement_sample_count = 0; + agent->respawn_timestep = env->timestep; + agent->stopped = 0; + agent->removed = 0; + agent->a_long = 0.0f; + agent->a_lat = 0.0f; + agent->jerk_long = 0.0f; + agent->jerk_lat = 0.0f; + agent->steering_angle = 0.0f; } void c_step(Drive* env){ @@ -2096,7 +2185,7 @@ void c_step(Drive* env){ // Move static experts for (int i = 0; i < env->expert_static_agent_count; i++) { int expert_idx = env->expert_static_agent_indices[i]; - if(env->entities[expert_idx].x == INVALID_POSITION) continue; + if(env->agents[expert_idx].sim_x == INVALID_POSITION) continue; move_expert(env, env->actions, expert_idx); } // Process actions for all active agents @@ -2104,15 +2193,15 @@ void c_step(Drive* env){ env->logs[i].score = 0.0f; env->logs[i].episode_length += 1; int agent_idx = env->active_agent_indices[i]; - env->entities[agent_idx].collision_state = 0; + env->agents[agent_idx].collision_state = 0; move_dynamics(env, i, agent_idx); } for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; - env->entities[agent_idx].collision_state = 0; + env->agents[agent_idx].collision_state = 0; compute_agent_metrics(env, agent_idx); - int collision_state = env->entities[agent_idx].collision_state; + int collision_state = env->agents[agent_idx].collision_state; if(collision_state > 0){ if(collision_state == VEHICLE_COLLISION){ @@ -2127,50 +2216,49 @@ void c_step(Drive* env){ env->logs[i].episode_return += env->reward_offroad_collision; env->logs[i].avg_offroad_per_agent += 1.0f; } - if(!env->entities[agent_idx].reached_goal_this_episode){ - env->entities[agent_idx].collided_before_goal = 1; + if(!env->agents[agent_idx].reached_goal_this_episode){ + env->agents[agent_idx].collided_before_goal = 1; } } float distance_to_goal = relative_distance_2d( - env->entities[agent_idx].x, - env->entities[agent_idx].y, - env->entities[agent_idx].goal_position_x, - env->entities[agent_idx].goal_position_y + env->agents[agent_idx].sim_x, + env->agents[agent_idx].sim_y, + env->agents[agent_idx].goal_position_x, + env->agents[agent_idx].goal_position_y ); // Reward agent if it is within X meters of goal if (distance_to_goal < env->goal_radius){ - - if (env->goal_behavior == GOAL_RESPAWN && env->entities[agent_idx].respawn_timestep != -1){ + if (env->goal_behavior == GOAL_RESPAWN && env->agents[agent_idx].respawn_timestep != -1){ env->rewards[i] += env->reward_goal_post_respawn; env->logs[i].episode_return += env->reward_goal_post_respawn; } else if (env->goal_behavior == GOAL_GENERATE_NEW) { env->rewards[i] += env->reward_goal; env->logs[i].episode_return += env->reward_goal; - env->entities[agent_idx].sampled_new_goal = 1; + env->agents[agent_idx].sampled_new_goal = 1; env->logs[i].num_goals_reached += 1; } else { // Zero out the velocity so that the agent stops at the goal env->rewards[i] = env->reward_goal; env->logs[i].episode_return = env->reward_goal; env->logs[i].num_goals_reached = 1; - env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f; + env->agents[agent_idx].stopped = 1; + env->agents[agent_idx].sim_vx=env->agents[agent_idx].sim_vy = 0.0f; } - env->entities[agent_idx].reached_goal_this_episode = 1; - env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX] = 1.0f; + env->agents[agent_idx].reached_goal_this_episode = 1; + env->agents[agent_idx].metrics_array[REACHED_GOAL_IDX] = 1.0f; } - if(env->entities[agent_idx].sampled_new_goal && env->goal_behavior == GOAL_GENERATE_NEW){ - compute_new_goal(env, agent_idx); + if(env->agents[agent_idx].sampled_new_goal && env->goal_behavior == GOAL_GENERATE_NEW){ + compute_new_goal(env, agent_idx); } - int lane_aligned = env->entities[agent_idx].metrics_array[LANE_ALIGNED_IDX]; + int lane_aligned = env->agents[agent_idx].metrics_array[LANE_ALIGNED_IDX]; env->logs[i].lane_alignment_rate = lane_aligned; // Apply ADE reward - float current_ade = env->entities[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX]; + float current_ade = env->agents[agent_idx].metrics_array[AVG_DISPLACEMENT_ERROR_IDX]; if(current_ade > 0.0f && env->reward_ade != 0.0f) { float ade_reward = env->reward_ade * current_ade; env->rewards[i] += ade_reward; @@ -2182,20 +2270,20 @@ void c_step(Drive* env){ if (env->goal_behavior==GOAL_RESPAWN) { for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; - int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; + int reached_goal = env->agents[agent_idx].metrics_array[REACHED_GOAL_IDX]; if(reached_goal){ respawn_agent(env, agent_idx); - env->entities[agent_idx].respawn_count++; + env->agents[agent_idx].respawn_count++; } } } else if (env->goal_behavior==GOAL_STOP) { for(int i = 0; i < env->active_agent_count; i++){ int agent_idx = env->active_agent_indices[i]; - int reached_goal = env->entities[agent_idx].metrics_array[REACHED_GOAL_IDX]; + int reached_goal = env->agents[agent_idx].metrics_array[REACHED_GOAL_IDX]; if(reached_goal){ - env->entities[agent_idx].stopped = 1; - env->entities[agent_idx].vx=env->entities[agent_idx].vy = 0.0f; + env->agents[agent_idx].stopped = 1; + env->agents[agent_idx].sim_vx = env->agents[agent_idx].sim_vy = 0.0f; } } } @@ -2361,10 +2449,11 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las float* agent_obs = &observations[agent_index][0]; // self int active_idx = env->active_agent_indices[agent_index]; - float heading_self_x = env->entities[active_idx].heading_x; - float heading_self_y = env->entities[active_idx].heading_y; - float px = env->entities[active_idx].x; - float py = env->entities[active_idx].y; + float heading_self = env->agents[active_idx].sim_heading; + float heading_self_x = cosf(heading_self); + float heading_self_y = sinf(heading_self); + float px = env->agents[active_idx].sim_x; + float py = env->agents[active_idx].sim_y; // draw goal float goal_x = agent_obs[0] * 200; float goal_y = agent_obs[1] * 200; @@ -2537,9 +2626,9 @@ void draw_agent_obs(Drive* env, int agent_index, int mode, int obs_only, int las Color lineColor = BLUE; // Default color int entity_type = (int)agent_obs[entity_idx + 6]; // Choose color based on entity type - if(entity_type+4 != ROAD_EDGE){ - continue; - } + int unnormalized_type = unnormalize_road_type(entity_type); + if(!is_road_edge(unnormalized_type)) continue; + lineColor = PUFF_CYAN; // For road segments, draw line between start and end points float x_middle = agent_obs[entity_idx] * 50; @@ -2664,264 +2753,263 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->top_left_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); DrawLine3D((Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->top_left_y, 0}, PUFF_CYAN); DrawLine3D((Vector3){env->grid_map->top_left_x, env->grid_map->bottom_right_y, 0}, (Vector3){env->grid_map->bottom_right_x, env->grid_map->bottom_right_y, 0}, PUFF_CYAN); - for(int i = 0; i < env->num_entities; i++) { + + for(int i = 0; i < env->num_total_agents; i++) { + Agent* agent = &env->agents[i]; // Draw objects - if(env->entities[i].type == VEHICLE || env->entities[i].type == PEDESTRIAN || env->entities[i].type == CYCLIST) { - // Check if this vehicle is an active agent - bool is_active_agent = false; - bool is_static_agent = false; - int agent_index = -1; - for(int j = 0; j < env->active_agent_count; j++) { - if(env->active_agent_indices[j] == i) { - is_active_agent = true; - agent_index = j; - break; - } - } - for(int j = 0; j < env->static_agent_count; j++) { - if(env->static_agent_indices[j] == i) { - is_static_agent = true; - break; - } + // Check if this vehicle is an active agent + bool is_active_agent = false; + bool is_static_agent = false; + int agent_index = -1; + for(int j = 0; j < env->active_agent_count; j++) { + if(env->active_agent_indices[j] == i) { + is_active_agent = true; + agent_index = j; + break; } - // HIDE CARS ON RESPAWN - IMPORTANT TO KNOW VISUAL SETTING - if((!is_active_agent && !is_static_agent) || env->entities[i].respawn_timestep != -1){ - continue; + } + for(int j = 0; j < env->static_agent_count; j++) { + if(env->static_agent_indices[j] == i) { + is_static_agent = true; + break; } - Vector3 position; - float heading; - position = (Vector3){ - env->entities[i].x, - env->entities[i].y, - 1 - }; - heading = env->entities[i].heading; - // Create size vector - Vector3 size = { - env->entities[i].length, - env->entities[i].width, - env->entities[i].height - }; + } + // HIDE CARS ON RESPAWN - IMPORTANT TO KNOW VISUAL SETTING + if((!is_active_agent && !is_static_agent) || agent->respawn_timestep != -1){ + continue; + } + Vector3 position; + float heading; + position = (Vector3){ + agent->sim_x, + agent->sim_y, + 1 + }; + heading = agent->sim_heading; + // Create size vector + Vector3 size = { + agent->sim_length, + agent->sim_width, + agent->sim_height + }; - bool is_expert = (!is_active_agent) && (env->entities[i].mark_as_expert == 1); + bool is_expert = (!is_active_agent) && (agent->mark_as_expert == 1); - // Save current transform - if(mode==1){ - float cos_heading = env->entities[i].heading_x; - float sin_heading = env->entities[i].heading_y; + // Save current transform + if(mode==1){ + float cos_heading = cosf(heading); + float sin_heading = sinf(heading); - // Calculate half dimensions - float half_len = env->entities[i].length * 0.5f; - float half_width = env->entities[i].width * 0.5f; + // Calculate half dimensions + float half_len = agent->sim_length * 0.5f; + float half_width = agent->sim_width * 0.5f; - // Calculate the four corners of the collision box - Vector3 corners[4] = { - (Vector3){ - position.x + (half_len * cos_heading - half_width * sin_heading), - position.y + (half_len * sin_heading + half_width * cos_heading), - position.z - }, - - - (Vector3){ - position.x + (half_len * cos_heading + half_width * sin_heading), - position.y + (half_len * sin_heading - half_width * cos_heading), - position.z - }, - (Vector3){ - position.x + (-half_len * cos_heading + half_width * sin_heading), - position.y + (-half_len * sin_heading - half_width * cos_heading), - position.z - }, - (Vector3){ - position.x + (-half_len * cos_heading - half_width * sin_heading), - position.y + (-half_len * sin_heading + half_width * cos_heading), - position.z - }, + // Calculate the four corners of the collision box + Vector3 corners[4] = { + (Vector3){ + position.x + (half_len * cos_heading - half_width * sin_heading), + position.y + (half_len * sin_heading + half_width * cos_heading), + position.z + }, - }; + (Vector3){ + position.x + (half_len * cos_heading + half_width * sin_heading), + position.y + (half_len * sin_heading - half_width * cos_heading), + position.z + }, + (Vector3){ + position.x + (-half_len * cos_heading + half_width * sin_heading), + position.y + (-half_len * sin_heading - half_width * cos_heading), + position.z + }, + (Vector3){ + position.x + (-half_len * cos_heading - half_width * sin_heading), + position.y + (-half_len * sin_heading + half_width * cos_heading), + position.z + }, - if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { - draw_agent_obs(env, agent_index, mode, obs_only, lasers); - } - if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ - continue; - } - // --- Draw the car --- + }; - Vector3 carPos = { position.x, position.y, position.z }; - Color car_color = GRAY; // default for static - if (is_expert) car_color = GOLD; // expert replay - if (is_active_agent) car_color = BLUE; // policy-controlled - if (is_active_agent && env->entities[i].collision_state > 0) car_color = RED; - rlSetLineWidth(3.0f); - for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], car_color); - } - // --- Draw a heading arrow pointing forward --- - Vector3 arrowStart = position; - Vector3 arrowEnd = { - position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car - position.y + sin_heading * half_len * 1.5f, - position.z - }; + if(agent_index == env->human_agent_idx && !agent->metrics_array[REACHED_GOAL_IDX]) { + draw_agent_obs(env, agent_index, mode, obs_only, lasers); + } + if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ + continue; + } - DrawLine3D(arrowStart, arrowEnd, car_color); - DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip + // --- Draw the car --- + Vector3 carPos = { position.x, position.y, position.z }; + Color car_color = GRAY; // default for static + if (is_expert) car_color = GOLD; // expert replay + if (is_active_agent) car_color = BLUE; // policy-controlled + if (is_active_agent && agent->collision_state > 0) car_color = RED; + rlSetLineWidth(3.0f); + for (int j = 0; j < 4; j++) { + DrawLine3D(corners[j], corners[(j+1)%4], car_color); } - else { - rlPushMatrix(); - // Translate to position, rotate around Y axis, then draw - rlTranslatef(position.x, position.y, position.z); - rlRotatef(heading*RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees - // Determine color based on status - Color object_color = PUFF_BACKGROUND2; // fill color unused for model tint - Color outline_color = PUFF_CYAN; // not used for model tint - Model car_model = client->cars[5]; - if(is_active_agent){ - car_model = client->cars[client->car_assignments[i %64]]; - } - if(agent_index == env->human_agent_idx){ - object_color = PUFF_CYAN; - outline_color = PUFF_WHITE; - } - if(is_active_agent && env->entities[i].collision_state > 0) { - car_model = client->cars[0]; // Collided agent - } - // Draw obs for human selected agent - if(agent_index == env->human_agent_idx && !env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]) { - draw_agent_obs(env, agent_index, mode, obs_only, lasers); - } - // Draw cube for cars static and active - // Calculate scale factors based on desired size and model dimensions - - BoundingBox bounds = GetModelBoundingBox(car_model); - Vector3 model_size = { - bounds.max.x - bounds.min.x, - bounds.max.y - bounds.min.y, - bounds.max.z - bounds.min.z - }; - Vector3 scale = { - size.x / model_size.x, - size.y / model_size.y, - size.z / model_size.z - }; - if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ - rlPopMatrix(); - continue; - } + // --- Draw a heading arrow pointing forward --- + Vector3 arrowStart = position; + Vector3 arrowEnd = { + position.x + cos_heading * half_len * 1.5f, // extend arrow beyond car + position.y + sin_heading * half_len * 1.5f, + position.z + }; - DrawModelEx(car_model, (Vector3){0, 0, 0}, (Vector3){1, 0, 0}, 90.0f, scale, WHITE); - { - float cos_heading = env->entities[i].heading_x; - float sin_heading = env->entities[i].heading_y; - float half_len = env->entities[i].length * 0.5f; - float half_width = env->entities[i].width * 0.5f; - Vector3 corners[4] = { - (Vector3){ 0 + ( half_len * cos_heading - half_width * sin_heading), 0 + ( half_len * sin_heading + half_width * cos_heading), 0 }, - (Vector3){ 0 + ( half_len * cos_heading + half_width * sin_heading), 0 + ( half_len * sin_heading - half_width * cos_heading), 0 }, - (Vector3){ 0 + (-half_len * cos_heading + half_width * sin_heading), 0 + (-half_len * sin_heading - half_width * cos_heading), 0 }, - (Vector3){ 0 + (-half_len * cos_heading - half_width * sin_heading), 0 + (-half_len * sin_heading + half_width * cos_heading), 0 }, - }; - Color wire_color = GRAY; // static - if (!is_active_agent && env->entities[i].mark_as_expert == 1) wire_color = GOLD; // expert replay - if (is_active_agent) wire_color = BLUE; // policy - if (is_active_agent && env->entities[i].collision_state > 0) wire_color = RED; - rlSetLineWidth(2.0f); - for (int j = 0; j < 4; j++) { - DrawLine3D(corners[j], corners[(j+1)%4], wire_color); - } - } + DrawLine3D(arrowStart, arrowEnd, car_color); + DrawSphere(arrowEnd, 0.2f, car_color); // arrow tip + + } + else { + rlPushMatrix(); + // Translate to position, rotate around Y axis, then draw + rlTranslatef(position.x, position.y, position.z); + rlRotatef(heading*RAD2DEG, 0.0f, 0.0f, 1.0f); // Convert radians to degrees + // Determine color based on status + Color object_color = PUFF_BACKGROUND2; // fill color unused for model tint + Color outline_color = PUFF_CYAN; // not used for model tint + Model car_model = client->cars[5]; + if(is_active_agent){ + car_model = client->cars[client->car_assignments[i %64]]; + } + if(agent_index == env->human_agent_idx){ + object_color = PUFF_CYAN; + outline_color = PUFF_WHITE; + } + if(is_active_agent && agent->collision_state > 0) { + car_model = client->cars[0]; // Collided agent + } + // Draw obs for human selected agent + if(agent_index == env->human_agent_idx && !agent->metrics_array[REACHED_GOAL_IDX]) { + draw_agent_obs(env, agent_index, mode, obs_only, lasers); + } + // Draw cube for cars static and active + // Calculate scale factors based on desired size and model dimensions + + BoundingBox bounds = GetModelBoundingBox(car_model); + Vector3 model_size = { + bounds.max.x - bounds.min.x, + bounds.max.y - bounds.min.y, + bounds.max.z - bounds.min.z + }; + Vector3 scale = { + size.x / model_size.x, + size.y / model_size.y, + size.z / model_size.z + }; + if((obs_only || IsKeyDown(KEY_LEFT_CONTROL)) && agent_index != env->human_agent_idx){ rlPopMatrix(); + continue; } - // FPV Camera Control - if(IsKeyDown(KEY_SPACE) && env->human_agent_idx== agent_index){ - if(env->entities[agent_index].metrics_array[REACHED_GOAL_IDX]){ - env->human_agent_idx = rand() % env->active_agent_count; - } - Vector3 camera_position = (Vector3){ - position.x - (25.0f * cosf(heading)), - position.y - (25.0f * sinf(heading)), - position.z + 15 - }; - - Vector3 camera_target = (Vector3){ - position.x + 40.0f * cosf(heading), - position.y + 40.0f * sinf(heading), - position.z - 5.0f + DrawModelEx(car_model, (Vector3){0, 0, 0}, (Vector3){1, 0, 0}, 90.0f, scale, WHITE); + { + float cos_heading = cosf(heading); + float sin_heading = sinf(heading); + float half_len = agent->sim_length * 0.5f; + float half_width = agent->sim_width * 0.5f; + Vector3 corners[4] = { + (Vector3){ 0 + ( half_len * cos_heading - half_width * sin_heading), 0 + ( half_len * sin_heading + half_width * cos_heading), 0 }, + (Vector3){ 0 + ( half_len * cos_heading + half_width * sin_heading), 0 + ( half_len * sin_heading - half_width * cos_heading), 0 }, + (Vector3){ 0 + (-half_len * cos_heading + half_width * sin_heading), 0 + (-half_len * sin_heading - half_width * cos_heading), 0 }, + (Vector3){ 0 + (-half_len * cos_heading - half_width * sin_heading), 0 + (-half_len * sin_heading + half_width * cos_heading), 0 }, }; - client->camera.position = camera_position; - client->camera.target = camera_target; - client->camera.up = (Vector3){0, 0, 1}; - } - if(IsKeyReleased(KEY_SPACE)){ - client->camera.position = client->default_camera_position; - client->camera.target = client->default_camera_target; - client->camera.up = (Vector3){0, 0, 1}; + Color wire_color = GRAY; // static + if (!is_active_agent && agent->mark_as_expert == 1) wire_color = GOLD; // expert replay + if (is_active_agent) wire_color = BLUE; // policy + if (is_active_agent && agent->collision_state > 0) wire_color = RED; + rlSetLineWidth(2.0f); + for (int j = 0; j < 4; j++) { + DrawLine3D(corners[j], corners[(j+1)%4], wire_color); + } } - // Draw goal position for active agents + rlPopMatrix(); + } - if(!is_active_agent || env->entities[i].valid == 0) { - continue; - } - if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ - DrawSphere((Vector3){ - env->entities[i].goal_position_x, - env->entities[i].goal_position_y, - 1 - }, 0.5f, DARKGREEN); - - DrawCircle3D((Vector3){ - env->entities[i].goal_position_x, - env->entities[i].goal_position_y, - 0.1f - }, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + // FPV Camera Control + if(IsKeyDown(KEY_SPACE) && env->human_agent_idx== agent_index){ + if(agent->metrics_array[REACHED_GOAL_IDX]){ + env->human_agent_idx = rand() % env->active_agent_count; } + Vector3 camera_position = (Vector3){ + position.x - (25.0f * cosf(heading)), + position.y - (25.0f * sinf(heading)), + position.z + 15 + }; + + Vector3 camera_target = (Vector3){ + position.x + 40.0f * cosf(heading), + position.y + 40.0f * sinf(heading), + position.z - 5.0f + }; + client->camera.position = camera_position; + client->camera.target = camera_target; + client->camera.up = (Vector3){0, 0, 1}; } - // Draw road elements - if(env->entities[i].type <=3 && env->entities[i].type >= 7){ + if(IsKeyReleased(KEY_SPACE)){ + client->camera.position = client->default_camera_position; + client->camera.target = client->default_camera_target; + client->camera.up = (Vector3){0, 0, 1}; + } + // Draw goal position for active agents + + if(!is_active_agent || agent->sim_valid == 0) { continue; } - for(int j = 0; j < env->entities[i].array_size - 1; j++) { + if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ + DrawSphere((Vector3){ + agent->goal_position_x, + agent->goal_position_y, + 1 + }, 0.5f, DARKGREEN); + + DrawCircle3D((Vector3){ + agent->goal_position_x, + agent->goal_position_y, + 0.1f + }, env->goal_radius, (Vector3){0, 0, 1}, 90.0f, Fade(LIGHTGREEN, 0.3f)); + } + + } + for (int i = 0; i < env->num_road_elements; i++) { + RoadMapElement* element = &env->road_elements[i]; + + for(int j = 0; j < element->segment_length - 1; j++) { Vector3 start = { - env->entities[i].traj_x[j], - env->entities[i].traj_y[j], + element->x[j], + element->y[j], 1 }; Vector3 end = { - env->entities[i].traj_x[j + 1], - env->entities[i].traj_y[j + 1], + element->x[j + 1], + element->y[j + 1], 1 }; Color lineColor = GRAY; - if (env->entities[i].type == ROAD_LANE) lineColor = GRAY; - else if (env->entities[i].type == ROAD_LINE) lineColor = BLUE; - else if (env->entities[i].type == ROAD_EDGE) lineColor = WHITE; - else if (env->entities[i].type == DRIVEWAY) lineColor = RED; - if(env->entities[i].type != ROAD_EDGE){ - continue; - } + + if (is_road_lane(element->type)) lineColor = GRAY; + else if (is_road_line(element->type)) lineColor = BLUE; + else if (is_road_edge(element->type)) lineColor = WHITE; + else if (element->type == DRIVEWAY) lineColor = RED; if(!IsKeyDown(KEY_LEFT_CONTROL) && obs_only==0){ draw_road_edge(env, start.x, start.y, end.x, end.y); } } } if(show_grid) { - // Draw grid cells using the stored bounds - float grid_start_x = env->grid_map->top_left_x; - float grid_start_y = env->grid_map->bottom_right_y; - for(int i = 0; i < env->grid_map->grid_cols; i++) { - for(int j = 0; j < env->grid_map->grid_rows; j++) { - float x = grid_start_x + i*GRID_CELL_SIZE; - float y = grid_start_y + j*GRID_CELL_SIZE; - DrawCubeWires( - (Vector3){x + GRID_CELL_SIZE/2, y + GRID_CELL_SIZE/2, 1}, - GRID_CELL_SIZE, GRID_CELL_SIZE, 0.1f, PUFF_BACKGROUND2); - } + // Draw grid cells using the stored bounds + float grid_start_x = env->grid_map->top_left_x; + float grid_start_y = env->grid_map->bottom_right_y; + for(int i = 0; i < env->grid_map->grid_cols; i++) { + for(int j = 0; j < env->grid_map->grid_rows; j++) { + float x = grid_start_x + i*GRID_CELL_SIZE; + float y = grid_start_y + j*GRID_CELL_SIZE; + DrawCubeWires( + (Vector3){x + GRID_CELL_SIZE/2, y + GRID_CELL_SIZE/2, 1}, + GRID_CELL_SIZE, GRID_CELL_SIZE, 0.1f, PUFF_BACKGROUND2); + } } } @@ -2935,14 +3023,14 @@ void draw_scene(Drive* env, Client* client, int mode, int obs_only, int lasers, for (int i = 0; i < env->active_agent_count; i++) { // Ignore respawned agents - if (env->entities[i].respawn_timestep != -1) { + if (env->agents[i].respawn_timestep != -1) { continue; } int agent_idx = env->active_agent_indices[i]; - int womd_track_idx = env->tracks_to_predict_indices[i]; + int womd_track_idx = env->tracks_to_predict[i]; - float raw_x = -env->entities[agent_idx].x * pixels_per_world_unit; - float raw_y = env->entities[agent_idx].y * pixels_per_world_unit; + float raw_x = -env->agents[agent_idx].sim_x * pixels_per_world_unit; + float raw_y = env->agents[agent_idx].sim_y * pixels_per_world_unit; int screen_x = (int)raw_x + client->width/2 + 20; int screen_y = (int)raw_y + client->height/2 - 25; @@ -2975,12 +3063,13 @@ void saveTopDownImage(Drive* env, Client* client, const char *filename, RenderTe // Draw log trajectories FIRST (in background at lower Z-level) if(log_trajectories){ - for(int i=0; iactive_agent_count;i++){ + for(int i = 0; i < env->num_total_agents; i++) { + Agent* agent = &env->agents[i]; int idx = env->active_agent_indices[i]; - for(int j=0; jentities[idx].array_size;j++){ - float x = env->entities[idx].traj_x[j]; - float y = env->entities[idx].traj_y[j]; - float valid = env->entities[idx].traj_valid[j]; + for(int j=0; jtrajectory_length; j++){ + float x = agent->log_trajectory_x[j]; + float y = agent->log_trajectory_y[j]; + float valid = agent->log_valid[j]; if(!valid) continue; DrawSphere((Vector3){x,y,0.5f}, 0.3f, Fade(LIGHTGREEN, 0.6f)); } @@ -3010,18 +3099,18 @@ void saveTopDownImage(Drive* env, Client* client, const char *filename, RenderTe void saveAgentViewImage(Drive* env, Client* client, const char *filename, RenderTexture2D target, int map_height, int obs_only, int lasers, int show_grid) { // Agent perspective camera following the human agent int agent_idx = env->active_agent_indices[env->human_agent_idx]; - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; Camera3D camera = {0}; // Position camera behind and above the agent camera.position = (Vector3){ - agent->x - (25.0f * cosf(agent->heading)), - agent->y - (25.0f * sinf(agent->heading)), + agent->sim_x - (25.0f * cosf(agent->sim_heading)), + agent->sim_y - (25.0f * sinf(agent->sim_heading)), 15.0f }; camera.target = (Vector3){ - agent->x + 40.0f * cosf(agent->heading), - agent->y + 40.0f * sinf(agent->heading), + agent->sim_x + 40.0f * cosf(agent->sim_heading), + agent->sim_y + 40.0f * sinf(agent->sim_heading), 1.0f }; camera.up = (Vector3){ 0.0f, 0.0f, 1.0f }; diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index c7eb29beb..9642e2cbd 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -104,10 +104,10 @@ void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int Vector3 prev_point = {0}; bool has_prev = false; - for(int j = 0; j < env->entities[idx].array_size; j++){ - float x = env->entities[idx].traj_x[j]; - float y = env->entities[idx].traj_y[j]; - float valid = env->entities[idx].traj_valid[j]; + for(int j = 0; j < env->agents[idx].trajectory_length; j++){ + float x = env->agents[idx].log_trajectory_x[j]; + float y = env->agents[idx].log_trajectory_y[j]; + float valid = env->agents[idx].log_valid[j]; if(!valid) { has_prev = false; @@ -142,20 +142,20 @@ void renderTopDownView(Drive* env, Client* client, int map_height, int obs, int void renderAgentView(Drive* env, Client* client, int map_height, int obs_only, int lasers, int show_grid) { // Agent perspective camera following the selected agent int agent_idx = env->active_agent_indices[env->human_agent_idx]; - Entity* agent = &env->entities[agent_idx]; + Agent* agent = &env->agents[agent_idx]; BeginDrawing(); Camera3D camera = {0}; // Position camera behind and above the agent camera.position = (Vector3){ - agent->x - (25.0f * cosf(agent->heading)), - agent->y - (25.0f * sinf(agent->heading)), + agent->sim_x - (25.0f * cosf(agent->sim_heading)), + agent->sim_y - (25.0f * sinf(agent->sim_heading)), 15.0f }; camera.target = (Vector3){ - agent->x + 40.0f * cosf(agent->heading), - agent->y + 40.0f * sinf(agent->heading), + agent->sim_x + 40.0f * cosf(agent->sim_heading), + agent->sim_y + 40.0f * sinf(agent->sim_heading), 1.0f }; camera.up = (Vector3){ 0.0f, 0.0f, 1.0f }; diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 6cbf9b152..449d396b4 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -26,7 +26,7 @@ def __init__(self, env, input_size=128, hidden_size=128, **kwargs): # nn.ReLU(), pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), ) - max_road_objects = 13 + max_road_objects = 10 self.road_encoder = nn.Sequential( pufferlib.pytorch.layer_init(nn.Linear(max_road_objects, input_size)), nn.LayerNorm(input_size), @@ -75,7 +75,7 @@ def encode_observations(self, observations, state=None): road_objects = road_obs.view(-1, 200, 7) road_continuous = road_objects[:, :, :6] # First 6 features road_categorical = road_objects[:, :, 6] - road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, 200, 7] + road_onehot = F.one_hot((road_categorical + 1).long(), num_classes=4) # Shape: [batch, 200, 4] road_objects = torch.cat([road_continuous, road_onehot], dim=2) ego_features = self.ego_encoder(ego_obs) partner_features, _ = self.partner_encoder(partner_objects).max(dim=1)