Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 0 additions & 135 deletions gpudrive/datatypes/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,138 +354,3 @@ def one_hot_encode_bev_map(self):
self.bev_segmentation_map.long(),
num_classes=constants.NUM_MADRONA_ENTITY_TYPES, # From size of Madrona EntityType
)


@dataclass
class TrafficLightObs:
"""
A dataclass that represents traffic light information in the scenario.
This data struct contains the time series of traffic light information.
It contains the state (unknown: 0, stop: 1, caution: 2, go: 3), position, and lane_ids.
Initialized from tl_states_tensor (from Manager.trafficLightTensor()).
For details, see `TrafficLightState` in src/types.hpp.
Shape: (num_worlds, max_traffic_lights, num_timesteps * features).
Attributes:
state: The state of each traffic light (0=unknown, 1=stop, 2=caution, 3=go)
pos_x: X-coordinate of the traffic light
pos_y: Y-coordinate of the traffic light
pos_z: Z-coordinate of the traffic light
time_index: Time index of the traffic light state
lane_id: Lane ID associated with the traffic light
valid_mask: Boolean mask indicating valid traffic lights
"""

state: torch.Tensor
pos_x: torch.Tensor
pos_y: torch.Tensor
pos_z: torch.Tensor
time_index: torch.Tensor
lane_id: torch.Tensor
valid_mask: torch.Tensor

def __init__(
self,
tl_states_tensor: torch.Tensor,

):
"""Initializes the traffic light observation from a tensor."""
traj_length = constants.LOG_TRAJECTORY_LENGTH

# Calculate indices based on C++ struct layout:
# laneId (1) + state[traj_length] + x[traj_length] + y[traj_length] + z[traj_length] + timeIndex[traj_length] + numStates (1)

lane_id_end_idx = 1
state_end_idx = lane_id_end_idx + traj_length
pos_x_end_idx = state_end_idx + traj_length
pos_y_end_idx = pos_x_end_idx + traj_length
pos_z_end_idx = pos_y_end_idx + traj_length
time_index_end_idx = pos_z_end_idx + traj_length

# Extract fields according to C++ struct layout
# See `TrafficLightState` in src/types.hpp for details
self.lane_id = tl_states_tensor[:, :, 0] # Single lane ID value
self.state = tl_states_tensor[
:, :, lane_id_end_idx:state_end_idx
] # state[traj_length]
self.pos_x = tl_states_tensor[
:, :, state_end_idx:pos_x_end_idx
].float() # x[traj_length]
self.pos_y = tl_states_tensor[
:, :, pos_x_end_idx:pos_y_end_idx
].float() # y[traj_length]
self.pos_z = tl_states_tensor[
:, :, pos_y_end_idx:pos_z_end_idx
].float() # z[traj_length]
self.time_index = tl_states_tensor[
:, :, pos_z_end_idx:time_index_end_idx
] # timeIndex[traj_length]
self.num_states = tl_states_tensor[
:, :, time_index_end_idx
] # Single numStates value

# Create a valid mask based on numStates
# Traffic lights are valid if they have numStates > 0
self.valid_mask = self.num_states > 0

@classmethod
def from_tensor(
cls,
tl_states_tensor: madrona_gpudrive.madrona.Tensor,
backend="torch",
device="cuda",
):
"""Creates a TrafficLightObs from a tensor.
Args:
tl_states_tensor: The traffic light state tensor from the simulation
backend: Which backend to use ("torch" or "jax")
device: The device to place tensors on
Returns:
A TrafficLightObs instance
"""
if backend == "torch":
tensor = tl_states_tensor.to_torch().clone().to(device)
obj = cls(tensor)
return obj
elif backend == "jax":
raise NotImplementedError("JAX backend not implemented yet.")

def normalize(self):
"""Normalizes the traffic light observation coordinates."""

# Normalize position coordinates
self.pos_x = normalize_min_max(
tensor=self.pos_x,
min_val=constants.MIN_REL_COORD,
max_val=constants.MAX_REL_COORD,
)
self.pos_y = normalize_min_max(
tensor=self.pos_y,
min_val=constants.MIN_REL_COORD,
max_val=constants.MAX_REL_COORD,
)
self.pos_z = normalize_min_max(
tensor=self.pos_z,
min_val=constants.MIN_Z_COORD,
max_val=constants.MAX_Z_COORD,
)
def one_hot_encode_states(self):
"""One-hot encodes the traffic light states.
Converts the state values to one-hot encoded vectors with 4 classes:
0: Unknown
1: Stop
2: Caution
3: Go
"""
# Make sure values are in range 0-3
state_clamped = torch.clamp(self.state, 0, 3)
# One-hot encode
self.state_onehot = torch.nn.functional.one_hot(
state_clamped, num_classes=4
) * self.valid_mask.unsqueeze(-1)

return self.state_onehot

@property
def shape(self) -> tuple[int, ...]:
"""Shape: (num_worlds, max_traffic_lights, num_timesteps)."""
return self.state.shape
100 changes: 0 additions & 100 deletions gpudrive/visualize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
LocalEgoState,
GlobalEgoState,
PartnerObs,
TrafficLightObs,
)
from gpudrive.datatypes.trajectory import LogTrajectory
from gpudrive.datatypes.control import ResponseType
Expand Down Expand Up @@ -93,12 +92,6 @@ def initialize_static_scenario_data(self, controlled_agent_mask):
backend=self.backend,
)

self.tl_obs = TrafficLightObs.from_tensor(
tl_states_tensor=self.sim_object.tl_state_tensor(),
backend=self.backend,
device=self.device,
)

def plot_simulator_state(
self,
env_indices: List[int],
Expand Down Expand Up @@ -321,14 +314,6 @@ def plot_simulator_state(
world_based_policy_mask=world_based_policy_mask,
)

self._plot_traffic_lights(
ax=ax,
env_idx=env_idx,
tl_obs=self.tl_obs,
time_step=time_step if time_step is not None else 0,
marker_size_scale=marker_scale,
)

if agent_positions is not None:
# First calculate the maximum valid trajectory length across all agents for this env_idx
max_valid_length = 0
Expand Down Expand Up @@ -564,91 +549,6 @@ def _plot_log_replay_trajectory(
zorder=0,
)

def _plot_traffic_lights(
self,
ax: matplotlib.axes.Axes,
env_idx: int,
tl_obs: "TrafficLightObs",
time_step: int = 0,
marker_size_scale: float = 1.0,
):
"""Plot traffic light states as colored dots.
Args:
ax: Matplotlib axis to plot on
env_idx: Environment index
tl_obs: Traffic light observation object
time_step: Current time step
marker_size_scale: Scale factor for marker size
"""

# Traffic light state colors
TL_STATE_COLORS = {
0: "#C5C5C5", # Unknown - gray
1: "r", # Stop - red
2: "tab:orange", # Caution - orange
3: "g", # Go - green
}

# Get valid traffic lights for this environment
valid_mask = tl_obs.valid_mask[env_idx, :]
if not valid_mask.any():
return

# Clamp time_step to available data
max_time_idx = tl_obs.state.shape[2] - 1
time_step = min(time_step, max_time_idx)

# Get traffic light data for valid lights at current time step
valid_indices = torch.where(valid_mask)[0]

for tl_idx in valid_indices:
# Get position (use first valid position if time series)
pos_x = tl_obs.pos_x[env_idx, tl_idx, time_step].item()
pos_y = tl_obs.pos_y[env_idx, tl_idx, time_step].item()

# Skip if position is invalid (0,0 or out of bounds)
if (
(pos_x == 0 and pos_y == 0)
or abs(pos_x) > 1000
or abs(pos_y) > 1000
):
continue

# Get current state
state = int(tl_obs.state[env_idx, tl_idx, time_step].item())
state = max(0, min(3, state)) # Clamp to valid range

color = TL_STATE_COLORS[state]

if self.render_3d:
# Plot as elevated marker in 3D
height = 0.2 # Height above ground for visibility
ax.scatter3D(
[pos_x],
[pos_y],
[height],
color=color,
s=60 * marker_size_scale,
marker="o",
edgecolors="black",
linewidth=0.5,
alpha=0.8,
zorder=10,
)
else:
# Plot as 2D marker
ax.scatter(
pos_x,
pos_y,
color=color,
s=30 * marker_size_scale,
marker="o",
edgecolors="black",
linewidth=0.5,
alpha=0.9,
zorder=10,
)

def _get_endpoints(self, x, y, length, yaw):
"""Compute the start and end points of a road segment."""
center = np.array([x, y])
Expand Down
3 changes: 1 addition & 2 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace madrona_gpudrive
m.attr("kMaxAgentCount") = consts::kMaxAgentCount;
m.attr("kMaxRoadEntityCount") = consts::kMaxRoadEntityCount;
m.attr("kMaxAgentMapObservationsCount") = consts::kMaxAgentMapObservationsCount;
m.attr("episodeLen") = consts::episodeLen;
m.attr("episodeLen") = consts::episodeLen;
m.attr("numLidarSamples") = consts::numLidarSamples;
m.attr("vehicleScale") = consts::vehicleLengthScale;

Expand Down Expand Up @@ -144,7 +144,6 @@ namespace madrona_gpudrive

self.deleteAgents(agents_to_delete);
})
.def("tl_state_tensor", &Manager::trafficLightTensor)
.def("deleted_agents_tensor", &Manager::deletedAgentsTensor)
.def("map_name_tensor", &Manager::mapNameTensor)
.def("scenario_id_tensor", &Manager::scenarioIdTensor);
Expand Down
1 change: 0 additions & 1 deletion src/consts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace consts {
inline constexpr madrona::CountT kMaxAgentCount = 64;
inline constexpr madrona::CountT kMaxRoadEntityCount = 10000;
inline constexpr madrona::CountT kMaxAgentMapObservationsCount = 200;
inline constexpr madrona::CountT kMaxTrafficLightCount = 16;

inline constexpr bool useEstimatedYaw = true;

Expand Down
3 changes: 0 additions & 3 deletions src/init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,10 @@ namespace madrona_gpudrive
{
MapObject objects[MAX_OBJECTS];
MapRoad roads[MAX_ROADS];
TrafficLightState trafficLightStates[consts::kMaxTrafficLightCount];

uint32_t numObjects;
uint32_t numRoads;
uint32_t numRoadSegments;
uint32_t numTrafficLights;
bool hasTrafficLights;
MapVector2 mean;

char mapName[32];
Expand Down
12 changes: 0 additions & 12 deletions src/level_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,6 @@ void createPersistentEntities(Engine &ctx) {
other_agents.e[out_idx++] = other_agent;
}
}

for(CountT i = 0; i < map.numTrafficLights; i++)
{
auto &trafficLight = ctx.singleton<TrafficLights>().trafficLights[i];
trafficLight = map.trafficLightStates[i];

for (size_t t = 0; t < consts::kTrajectoryLength-1; t++) {
trafficLight.x[t] -= ctx.singleton<WorldMeans>().mean.x;
trafficLight.y[t] -= ctx.singleton<WorldMeans>().mean.y;
trafficLight.z[t] -= ctx.singleton<WorldMeans>().mean.z;
}
}
}

static void resetPersistentEntities(Engine &ctx)
Expand Down
15 changes: 1 addition & 14 deletions src/mgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,24 +896,11 @@ Tensor Manager::scenarioIdTensor() const {

Tensor Manager::metadataTensor() const {
return impl_->exportTensor(
ExportID::MetaData, TensorElementType::Float32,
ExportID::MetaData, TensorElementType::Int32,
{impl_->numWorlds, consts::kMaxAgentCount, MetaDataExportSize}
);
}

Tensor Manager::trafficLightTensor() const
{
return impl_->exportTensor(
ExportID::TrafficLights,
TensorElementType::Float32,
{
impl_->numWorlds,
consts::kMaxTrafficLightCount,
TrafficLightsStateExportSize
}
);
}

void Manager::triggerReset(int32_t world_idx)
{
WorldReset reset {
Expand Down
1 change: 0 additions & 1 deletion src/mgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ class Manager {
MGR_EXPORT madrona::py::Tensor expertTrajectoryTensor() const;
MGR_EXPORT madrona::py::Tensor worldMeansTensor() const;
MGR_EXPORT madrona::py::Tensor metadataTensor() const;
MGR_EXPORT madrona::py::Tensor trafficLightTensor() const;
MGR_EXPORT madrona::py::Tensor deletedAgentsTensor() const;
MGR_EXPORT madrona::py::Tensor mapNameTensor() const;
MGR_EXPORT madrona::py::Tensor scenarioIdTensor() const;
Expand Down
3 changes: 1 addition & 2 deletions src/sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
registry.registerComponent<RoadMapId>();
registry.registerComponent<MapType>();
registry.registerComponent<MetaData>();
registry.registerSingleton<TrafficLights>();

registry.registerSingleton<WorldReset>();
registry.registerSingleton<Shape>();
Expand All @@ -84,7 +83,6 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
registry.exportSingleton<DeletedAgents>((uint32_t)ExportID::DeletedAgents);
registry.exportSingleton<MapName>((uint32_t)ExportID::MapName);
registry.exportSingleton<ScenarioId>((uint32_t)ExportID::ScenarioId);
registry.exportSingleton<TrafficLights>((uint32_t)ExportID::TrafficLights);

registry.exportColumn<AgentInterface, Action>(
(uint32_t)ExportID::Action);
Expand All @@ -94,6 +92,7 @@ void Sim::registerTypes(ECSRegistry &registry, const Config &cfg)
(uint32_t)ExportID::AgentMapObservations);
registry.exportColumn<RoadInterface, MapObservation>(
(uint32_t)ExportID::MapObservation);

registry.exportColumn<AgentInterface, PartnerObservations>(
(uint32_t)ExportID::PartnerObservations);
registry.exportColumn<AgentInterface, Lidar>(
Expand Down
1 change: 0 additions & 1 deletion src/sim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ enum class ExportID : uint32_t {
DeletedAgents,
MapName,
ScenarioId,
TrafficLights,
NumExports
};

Expand Down
Loading
Loading