From 50514dd529b059d34bdc3138f18860a750e40496 Mon Sep 17 00:00:00 2001 From: Maximilian Naumann Date: Mon, 7 Jul 2025 20:07:46 +0200 Subject: [PATCH] Revert "[FEATURE] cherry-pick TL states commit (#485)" This reverts commit 5279c9ec44ed7429421cb58bab67180f89431d97. --- gpudrive/datatypes/observation.py | 135 ------------------------------ gpudrive/visualize/core.py | 100 ---------------------- src/bindings.cpp | 3 +- src/consts.hpp | 1 - src/init.hpp | 3 - src/level_gen.cpp | 12 --- src/mgr.cpp | 15 +--- src/mgr.hpp | 1 - src/sim.cpp | 3 +- src/sim.hpp | 1 - src/types.hpp | 54 ++---------- 11 files changed, 12 insertions(+), 316 deletions(-) diff --git a/gpudrive/datatypes/observation.py b/gpudrive/datatypes/observation.py index a799ea0a9..6aaaab864 100644 --- a/gpudrive/datatypes/observation.py +++ b/gpudrive/datatypes/observation.py @@ -354,138 +354,3 @@ def one_hot_encode_bev_map(self): self.bev_segmentation_map.long(), num_classes=constants.NUM_MADRONA_ENTITY_TYPES, # From size of Madrona EntityType ) - - -@dataclass -class TrafficLightObs: - """ - A dataclass that represents traffic light information in the scenario. - This data struct contains the time series of traffic light information. - It contains the state (unknown: 0, stop: 1, caution: 2, go: 3), position, and lane_ids. - Initialized from tl_states_tensor (from Manager.trafficLightTensor()). - For details, see `TrafficLightState` in src/types.hpp. - Shape: (num_worlds, max_traffic_lights, num_timesteps * features). - Attributes: - state: The state of each traffic light (0=unknown, 1=stop, 2=caution, 3=go) - pos_x: X-coordinate of the traffic light - pos_y: Y-coordinate of the traffic light - pos_z: Z-coordinate of the traffic light - time_index: Time index of the traffic light state - lane_id: Lane ID associated with the traffic light - valid_mask: Boolean mask indicating valid traffic lights - """ - - state: torch.Tensor - pos_x: torch.Tensor - pos_y: torch.Tensor - pos_z: torch.Tensor - time_index: torch.Tensor - lane_id: torch.Tensor - valid_mask: torch.Tensor - - def __init__( - self, - tl_states_tensor: torch.Tensor, - - ): - """Initializes the traffic light observation from a tensor.""" - traj_length = constants.LOG_TRAJECTORY_LENGTH - - # Calculate indices based on C++ struct layout: - # laneId (1) + state[traj_length] + x[traj_length] + y[traj_length] + z[traj_length] + timeIndex[traj_length] + numStates (1) - - lane_id_end_idx = 1 - state_end_idx = lane_id_end_idx + traj_length - pos_x_end_idx = state_end_idx + traj_length - pos_y_end_idx = pos_x_end_idx + traj_length - pos_z_end_idx = pos_y_end_idx + traj_length - time_index_end_idx = pos_z_end_idx + traj_length - - # Extract fields according to C++ struct layout - # See `TrafficLightState` in src/types.hpp for details - self.lane_id = tl_states_tensor[:, :, 0] # Single lane ID value - self.state = tl_states_tensor[ - :, :, lane_id_end_idx:state_end_idx - ] # state[traj_length] - self.pos_x = tl_states_tensor[ - :, :, state_end_idx:pos_x_end_idx - ].float() # x[traj_length] - self.pos_y = tl_states_tensor[ - :, :, pos_x_end_idx:pos_y_end_idx - ].float() # y[traj_length] - self.pos_z = tl_states_tensor[ - :, :, pos_y_end_idx:pos_z_end_idx - ].float() # z[traj_length] - self.time_index = tl_states_tensor[ - :, :, pos_z_end_idx:time_index_end_idx - ] # timeIndex[traj_length] - self.num_states = tl_states_tensor[ - :, :, time_index_end_idx - ] # Single numStates value - - # Create a valid mask based on numStates - # Traffic lights are valid if they have numStates > 0 - self.valid_mask = self.num_states > 0 - - @classmethod - def from_tensor( - cls, - tl_states_tensor: madrona_gpudrive.madrona.Tensor, - backend="torch", - device="cuda", - ): - """Creates a TrafficLightObs from a tensor. - Args: - tl_states_tensor: The traffic light state tensor from the simulation - backend: Which backend to use ("torch" or "jax") - device: The device to place tensors on - Returns: - A TrafficLightObs instance - """ - if backend == "torch": - tensor = tl_states_tensor.to_torch().clone().to(device) - obj = cls(tensor) - return obj - elif backend == "jax": - raise NotImplementedError("JAX backend not implemented yet.") - - def normalize(self): - """Normalizes the traffic light observation coordinates.""" - - # Normalize position coordinates - self.pos_x = normalize_min_max( - tensor=self.pos_x, - min_val=constants.MIN_REL_COORD, - max_val=constants.MAX_REL_COORD, - ) - self.pos_y = normalize_min_max( - tensor=self.pos_y, - min_val=constants.MIN_REL_COORD, - max_val=constants.MAX_REL_COORD, - ) - self.pos_z = normalize_min_max( - tensor=self.pos_z, - min_val=constants.MIN_Z_COORD, - max_val=constants.MAX_Z_COORD, - ) - def one_hot_encode_states(self): - """One-hot encodes the traffic light states. - Converts the state values to one-hot encoded vectors with 4 classes: - 0: Unknown - 1: Stop - 2: Caution - 3: Go - """ - # Make sure values are in range 0-3 - state_clamped = torch.clamp(self.state, 0, 3) - # One-hot encode - self.state_onehot = torch.nn.functional.one_hot( - state_clamped, num_classes=4 - ) * self.valid_mask.unsqueeze(-1) - - return self.state_onehot - - @property - def shape(self) -> tuple[int, ...]: - """Shape: (num_worlds, max_traffic_lights, num_timesteps).""" - return self.state.shape \ No newline at end of file diff --git a/gpudrive/visualize/core.py b/gpudrive/visualize/core.py index 4cecf6804..f961f82a5 100644 --- a/gpudrive/visualize/core.py +++ b/gpudrive/visualize/core.py @@ -22,7 +22,6 @@ LocalEgoState, GlobalEgoState, PartnerObs, - TrafficLightObs, ) from gpudrive.datatypes.trajectory import LogTrajectory from gpudrive.datatypes.control import ResponseType @@ -93,12 +92,6 @@ def initialize_static_scenario_data(self, controlled_agent_mask): backend=self.backend, ) - self.tl_obs = TrafficLightObs.from_tensor( - tl_states_tensor=self.sim_object.tl_state_tensor(), - backend=self.backend, - device=self.device, - ) - def plot_simulator_state( self, env_indices: List[int], @@ -321,14 +314,6 @@ def plot_simulator_state( world_based_policy_mask=world_based_policy_mask, ) - self._plot_traffic_lights( - ax=ax, - env_idx=env_idx, - tl_obs=self.tl_obs, - time_step=time_step if time_step is not None else 0, - marker_size_scale=marker_scale, - ) - if agent_positions is not None: # First calculate the maximum valid trajectory length across all agents for this env_idx max_valid_length = 0 @@ -564,91 +549,6 @@ def _plot_log_replay_trajectory( zorder=0, ) - def _plot_traffic_lights( - self, - ax: matplotlib.axes.Axes, - env_idx: int, - tl_obs: "TrafficLightObs", - time_step: int = 0, - marker_size_scale: float = 1.0, - ): - """Plot traffic light states as colored dots. - Args: - ax: Matplotlib axis to plot on - env_idx: Environment index - tl_obs: Traffic light observation object - time_step: Current time step - marker_size_scale: Scale factor for marker size - """ - - # Traffic light state colors - TL_STATE_COLORS = { - 0: "#C5C5C5", # Unknown - gray - 1: "r", # Stop - red - 2: "tab:orange", # Caution - orange - 3: "g", # Go - green - } - - # Get valid traffic lights for this environment - valid_mask = tl_obs.valid_mask[env_idx, :] - if not valid_mask.any(): - return - - # Clamp time_step to available data - max_time_idx = tl_obs.state.shape[2] - 1 - time_step = min(time_step, max_time_idx) - - # Get traffic light data for valid lights at current time step - valid_indices = torch.where(valid_mask)[0] - - for tl_idx in valid_indices: - # Get position (use first valid position if time series) - pos_x = tl_obs.pos_x[env_idx, tl_idx, time_step].item() - pos_y = tl_obs.pos_y[env_idx, tl_idx, time_step].item() - - # Skip if position is invalid (0,0 or out of bounds) - if ( - (pos_x == 0 and pos_y == 0) - or abs(pos_x) > 1000 - or abs(pos_y) > 1000 - ): - continue - - # Get current state - state = int(tl_obs.state[env_idx, tl_idx, time_step].item()) - state = max(0, min(3, state)) # Clamp to valid range - - color = TL_STATE_COLORS[state] - - if self.render_3d: - # Plot as elevated marker in 3D - height = 0.2 # Height above ground for visibility - ax.scatter3D( - [pos_x], - [pos_y], - [height], - color=color, - s=60 * marker_size_scale, - marker="o", - edgecolors="black", - linewidth=0.5, - alpha=0.8, - zorder=10, - ) - else: - # Plot as 2D marker - ax.scatter( - pos_x, - pos_y, - color=color, - s=30 * marker_size_scale, - marker="o", - edgecolors="black", - linewidth=0.5, - alpha=0.9, - zorder=10, - ) - def _get_endpoints(self, x, y, length, yaw): """Compute the start and end points of a road segment.""" center = np.array([x, y]) diff --git a/src/bindings.cpp b/src/bindings.cpp index 70434589b..6b7414ddb 100755 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -23,7 +23,7 @@ namespace madrona_gpudrive m.attr("kMaxAgentCount") = consts::kMaxAgentCount; m.attr("kMaxRoadEntityCount") = consts::kMaxRoadEntityCount; m.attr("kMaxAgentMapObservationsCount") = consts::kMaxAgentMapObservationsCount; - m.attr("episodeLen") = consts::episodeLen; + m.attr("episodeLen") = consts::episodeLen; m.attr("numLidarSamples") = consts::numLidarSamples; m.attr("vehicleScale") = consts::vehicleLengthScale; @@ -144,7 +144,6 @@ namespace madrona_gpudrive self.deleteAgents(agents_to_delete); }) - .def("tl_state_tensor", &Manager::trafficLightTensor) .def("deleted_agents_tensor", &Manager::deletedAgentsTensor) .def("map_name_tensor", &Manager::mapNameTensor) .def("scenario_id_tensor", &Manager::scenarioIdTensor); diff --git a/src/consts.hpp b/src/consts.hpp index 806bd7976..41d7a5282 100644 --- a/src/consts.hpp +++ b/src/consts.hpp @@ -11,7 +11,6 @@ namespace consts { inline constexpr madrona::CountT kMaxAgentCount = 64; inline constexpr madrona::CountT kMaxRoadEntityCount = 10000; inline constexpr madrona::CountT kMaxAgentMapObservationsCount = 200; -inline constexpr madrona::CountT kMaxTrafficLightCount = 16; inline constexpr bool useEstimatedYaw = true; diff --git a/src/init.hpp b/src/init.hpp index 9c93586f0..253e1d06e 100755 --- a/src/init.hpp +++ b/src/init.hpp @@ -54,13 +54,10 @@ namespace madrona_gpudrive { MapObject objects[MAX_OBJECTS]; MapRoad roads[MAX_ROADS]; - TrafficLightState trafficLightStates[consts::kMaxTrafficLightCount]; uint32_t numObjects; uint32_t numRoads; uint32_t numRoadSegments; - uint32_t numTrafficLights; - bool hasTrafficLights; MapVector2 mean; char mapName[32]; diff --git a/src/level_gen.cpp b/src/level_gen.cpp index 7ac9af55e..e550efff1 100755 --- a/src/level_gen.cpp +++ b/src/level_gen.cpp @@ -462,18 +462,6 @@ void createPersistentEntities(Engine &ctx) { other_agents.e[out_idx++] = other_agent; } } - - for(CountT i = 0; i < map.numTrafficLights; i++) - { - auto &trafficLight = ctx.singleton().trafficLights[i]; - trafficLight = map.trafficLightStates[i]; - - for (size_t t = 0; t < consts::kTrajectoryLength-1; t++) { - trafficLight.x[t] -= ctx.singleton().mean.x; - trafficLight.y[t] -= ctx.singleton().mean.y; - trafficLight.z[t] -= ctx.singleton().mean.z; - } - } } static void resetPersistentEntities(Engine &ctx) diff --git a/src/mgr.cpp b/src/mgr.cpp index 98907db37..e314735a4 100755 --- a/src/mgr.cpp +++ b/src/mgr.cpp @@ -896,24 +896,11 @@ Tensor Manager::scenarioIdTensor() const { Tensor Manager::metadataTensor() const { return impl_->exportTensor( - ExportID::MetaData, TensorElementType::Float32, + ExportID::MetaData, TensorElementType::Int32, {impl_->numWorlds, consts::kMaxAgentCount, MetaDataExportSize} ); } -Tensor Manager::trafficLightTensor() const -{ - return impl_->exportTensor( - ExportID::TrafficLights, - TensorElementType::Float32, - { - impl_->numWorlds, - consts::kMaxTrafficLightCount, - TrafficLightsStateExportSize - } - ); -} - void Manager::triggerReset(int32_t world_idx) { WorldReset reset { diff --git a/src/mgr.hpp b/src/mgr.hpp index f2ef8d93d..d9613b61b 100755 --- a/src/mgr.hpp +++ b/src/mgr.hpp @@ -70,7 +70,6 @@ class Manager { MGR_EXPORT madrona::py::Tensor expertTrajectoryTensor() const; MGR_EXPORT madrona::py::Tensor worldMeansTensor() const; MGR_EXPORT madrona::py::Tensor metadataTensor() const; - MGR_EXPORT madrona::py::Tensor trafficLightTensor() const; MGR_EXPORT madrona::py::Tensor deletedAgentsTensor() const; MGR_EXPORT madrona::py::Tensor mapNameTensor() const; MGR_EXPORT madrona::py::Tensor scenarioIdTensor() const; diff --git a/src/sim.cpp b/src/sim.cpp index ea911817b..bcc0fe51e 100755 --- a/src/sim.cpp +++ b/src/sim.cpp @@ -59,7 +59,6 @@ void Sim::registerTypes(ECSRegistry ®istry, const Config &cfg) registry.registerComponent(); registry.registerComponent(); registry.registerComponent(); - registry.registerSingleton(); registry.registerSingleton(); registry.registerSingleton(); @@ -84,7 +83,6 @@ void Sim::registerTypes(ECSRegistry ®istry, const Config &cfg) registry.exportSingleton((uint32_t)ExportID::DeletedAgents); registry.exportSingleton((uint32_t)ExportID::MapName); registry.exportSingleton((uint32_t)ExportID::ScenarioId); - registry.exportSingleton((uint32_t)ExportID::TrafficLights); registry.exportColumn( (uint32_t)ExportID::Action); @@ -94,6 +92,7 @@ void Sim::registerTypes(ECSRegistry ®istry, const Config &cfg) (uint32_t)ExportID::AgentMapObservations); registry.exportColumn( (uint32_t)ExportID::MapObservation); + registry.exportColumn( (uint32_t)ExportID::PartnerObservations); registry.exportColumn( diff --git a/src/sim.hpp b/src/sim.hpp index 71059e09d..16ddc2e81 100755 --- a/src/sim.hpp +++ b/src/sim.hpp @@ -41,7 +41,6 @@ enum class ExportID : uint32_t { DeletedAgents, MapName, ScenarioId, - TrafficLights, NumExports }; diff --git a/src/types.hpp b/src/types.hpp index 5c8245115..0ac27d418 100755 --- a/src/types.hpp +++ b/src/types.hpp @@ -411,60 +411,24 @@ namespace madrona_gpudrive const size_t ScenarioIdExportSize = 32; static_assert(sizeof(ScenarioId) == sizeof(char32_t) * ScenarioIdExportSize); - - enum class TLState : int32_t - { - Unknown = 0, - Stop = 1, - Caution = 2, - Go = 3 - }; - - struct TrafficLightState - { - // Lane ID for this traffic light - float laneId = 0; - - // Arrays of state data for each timestep - float state[consts::kTrajectoryLength] = {}; - float x [consts::kTrajectoryLength] = {}; - float y [consts::kTrajectoryLength] = {}; - float z [consts::kTrajectoryLength] = {}; - float timeIndex[consts::kTrajectoryLength] = {}; - // Number of valid states - float numStates = 0; - }; - - // 1 (lane_id) + 5 (state, x, y, z, timeIndex) + 1 (numStates) = 6 - const size_t TrafficLightsStateExportSize = 1 + (consts::kTrajectoryLength) * 5 + 1; - static_assert(sizeof(TrafficLightState) == sizeof(float) * TrafficLightsStateExportSize); - - struct TrafficLights - { - TrafficLightState trafficLights[consts::kMaxTrafficLightCount]; - }; - - const size_t TrafficLightsExportSize = consts::kMaxTrafficLightCount * TrafficLightsStateExportSize; - static_assert(sizeof(TrafficLights) == sizeof(float) * TrafficLightsExportSize); - //Metadata struct : using agent IDs. struct MetaData { - float isSdc; - float isObjectOfInterest; - float isTrackToPredict; - float difficulty; + int32_t isSdc; + int32_t isObjectOfInterest; + int32_t isTrackToPredict; + int32_t difficulty; static inline void zero(MetaData& metadata) { - metadata.isSdc = -1.0f; - metadata.isObjectOfInterest = -1.0f; - metadata.isTrackToPredict = -1.0f; - metadata.difficulty = -1.0f; + metadata.isSdc = -1; + metadata.isObjectOfInterest = -1; + metadata.isTrackToPredict = -1; + metadata.difficulty = -1; } }; const size_t MetaDataExportSize = 4; - static_assert(sizeof(MetaData) == sizeof(int32_t) * (MetaDataExportSize - 1) + sizeof(float)); + static_assert(sizeof(MetaData) == sizeof(int32_t) * MetaDataExportSize); struct AgentInterface : public madrona::Archetype< Action,