Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions src/track_linearization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ def _calculate_linear_position(
position : np.ndarray, shape (n_time, n_space)
Spatial positions.
track_segment_id : np.ndarray, shape (n_time,)
Integer 'edge_id' for each time point. NaNs should be pre-handled or
will lead to errors/defaulting to edge_id 0.
Integer segment indices (0..E-1) for each time point. NaNs should be pre-handled or
will lead to errors/defaulting to index 0.
edge_order : list of 2-tuples
Ordered list of edge tuples (node1, node2) defining the linearization path.
These tuples are keys in `track_graph.edges`.
Expand Down Expand Up @@ -697,27 +697,18 @@ def _calculate_linear_position(

start_node_linear_position = np.asarray(start_node_linear_position)

track_segment_id_to_start_node_linear_position = {
track_graph.edges[e]["edge_id"]: snlp
for e, snlp in zip(edge_order, start_node_linear_position)
}
# Use segment indices directly to look up start node linear positions
start_node_linear_position_by_idx = start_node_linear_position[track_segment_id]

start_node_linear_position = np.asarray(
[
track_segment_id_to_start_node_linear_position[edge_id]
for edge_id in track_segment_id
]
)

track_segment_id_to_edge = {track_graph.edges[e]["edge_id"]: e for e in edge_order}
# Use segment indices to look up the corresponding edge and get start node
start_node_id = np.asarray(
[track_segment_id_to_edge[edge_id][0] for edge_id in track_segment_id]
[edge_order[seg_idx][0] for seg_idx in track_segment_id]
)
start_node_2D_position = np.asarray(
[track_graph.nodes[node]["pos"] for node in start_node_id]
)

linear_position = start_node_linear_position + (
linear_position = start_node_linear_position_by_idx + (
np.linalg.norm(start_node_2D_position - projected_track_positions, axis=1)
)
linear_position[is_nan] = np.nan
Expand Down Expand Up @@ -791,10 +782,14 @@ def get_linearized_position(
if edge_order is None:
edge_order = list(track_graph.edges)

# Create mapping between edge IDs and indices
edge_id_by_index = np.array([track_graph.edges[e]["edge_id"] for e in edge_order])
index_by_edge_id = {eid: i for i, eid in enumerate(edge_id_by_index)}

# Figure out the most probable track segement that correponds to
# 2D position
# 2D position (returns segment indices 0..E-1)
if use_HMM:
track_segment_id = classify_track_segments(
seg_idx = classify_track_segments(
track_graph,
position,
route_euclidean_distance_scaling=route_euclidean_distance_scaling,
Expand All @@ -803,25 +798,36 @@ def get_linearized_position(
)
else:
track_segments = get_track_segments_from_graph(track_graph)
track_segment_id = find_nearest_segment(track_segments, position)
seg_idx = find_nearest_segment(track_segments, position)

# Convert segment indices to edge labels
edge_ids = edge_id_by_index[seg_idx]

# Allow resassignment of edges
# Apply edge_map to labels and validate
if edge_map is not None:
for cur_edge, new_edge in edge_map.items():
track_segment_id[track_segment_id == cur_edge] = new_edge
invalid_keys = [k for k in edge_map.keys() if k not in index_by_edge_id]
if invalid_keys:
raise ValueError(f"edge_map contains invalid source edge_ids: {invalid_keys}. "
f"Valid edge_ids are: {list(index_by_edge_id.keys())}")

# Apply mapping, keeping original dtype flexible for strings/mixed types
mapped_edge_ids = np.array([edge_map.get(eid, eid) for eid in edge_ids])
# Keep using original seg_idx for internal operations - only use mapped_edge_ids for output
else:
mapped_edge_ids = edge_ids

(
linear_position,
projected_x_position,
projected_y_position,
) = _calculate_linear_position(
track_graph, position, track_segment_id, edge_order, edge_spacing
track_graph, position, seg_idx, edge_order, edge_spacing
)

return pd.DataFrame(
{
"linear_position": linear_position,
"track_segment_id": track_segment_id,
"track_segment_id": mapped_edge_ids,
"projected_x_position": projected_x_position,
"projected_y_position": projected_y_position,
}
Expand Down
16 changes: 7 additions & 9 deletions src/track_linearization/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,17 +670,15 @@ def test_edge_map_validation_invalid_keys(self, simple_rectangular_track):
"""Test edge_map validation with invalid keys."""
position = np.array([[15, 0]])

# Map non-existent edge ID - should be ignored
# Map non-existent edge ID - should raise ValueError
invalid_edge_map = {999: 1} # 999 doesn't exist

pos_df = get_linearized_position(
position=position,
track_graph=simple_rectangular_track,
edge_map=invalid_edge_map
)

# Should still work (invalid keys ignored)
assert hasattr(pos_df, 'linear_position'), "Invalid edge_map keys should be ignored"
with pytest.raises(ValueError, match="edge_map contains invalid source edge_ids"):
get_linearized_position(
position=position,
track_graph=simple_rectangular_track,
edge_map=invalid_edge_map
)

def test_edge_map_none_vs_no_parameter(self, simple_rectangular_track):
"""Test that edge_map=None is equivalent to not providing edge_map."""
Expand Down
38 changes: 38 additions & 0 deletions src/track_linearization/tests/test_edge_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest

import track_linearization.core as core
from track_linearization import make_track_graph

def _mk_line_graph():
pos = np.array([[0.0,0.0],[1.0,0.0],[2.0,0.0]], dtype=float)
edges = [(0,1),(1,2)]
g = make_track_graph(pos, edges)
# Explicit, non-index edge IDs to highlight id/index mismatch
g.edges[(0,1)]["edge_id"] = 10
g.edges[(1,2)]["edge_id"] = 20
# Ensure edge distances exist
g.edges[(0,1)]["distance"] = 1.0
g.edges[(1,2)]["distance"] = 1.0
return g

def test_edge_map_label_passthrough_no_change():
g = _mk_line_graph()
pts = np.array([[0.2,0.0],[1.7,0.0]])
df_nomap = core.get_linearized_position(pts, g, use_HMM=False)
df_map = core.get_linearized_position(pts, g, edge_map={10:10, 20:20}, use_HMM=False)
# Same geometry -> identical linear positions
assert np.allclose(df_nomap["linear_position"], df_map["linear_position"])

def test_edge_map_merge_two_edges_to_one_label():
g = _mk_line_graph()
pts = np.array([[0.2,0.0],[1.7,0.0]])
df = core.get_linearized_position(pts, g, edge_map={10:99, 20:99}, use_HMM=False)
assert set(df["track_segment_id"].unique()) == {99}

def test_edge_map_invalid_source_raises():
g = _mk_line_graph()
pts = np.array([[0.2,0.0]])
with pytest.raises(ValueError, match="edge_map contains invalid source edge_ids"):
# 999 is not a real edge_id in the graph; should raise ValueError
core.get_linearized_position(pts, g, edge_map={999:42, 10:50}, use_HMM=False)
Loading