Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/adapters/repositories/gtfs_repository_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""GTFS repository adapter - SQLite implementation."""

from ...adapters.database.gtfs_connection import get_gtfs_db
from ...core.models.bus import RouteIdentifier
from ...core.models.coordinate import Coordinate
from ...core.models.route_shape import RouteShape, RouteShapePoint
from ...core.ports.gtfs_repository import GTFSRepositoryPort
Expand All @@ -13,26 +14,29 @@ class GTFSRepositoryAdapter(GTFSRepositoryPort):
Implements the GTFS repository port using a SQLite database.
"""

def get_route_shape(self, route_id: str) -> RouteShape | None:
def get_route_shape(self, route: RouteIdentifier) -> RouteShape | None:
"""
Get the geographic shape of a route from GTFS database.

Args:
route_id: Route identifier
route: Route identifier with bus_line and direction

Returns:
RouteShape with ordered coordinates, or None if route not found
"""
with get_gtfs_db() as conn:
# First, get the shape_id for this route
# First, get the shape_id for this route filtering by route_id and direction_id
# In GTFS, direction_id is 0 or 1, while our RouteIdentifier uses 1 or 2
direction_id = route.bus_direction - 1 # Convert: 1->0, 2->1

cursor = conn.execute(
"""
SELECT DISTINCT shape_id
FROM trips
WHERE route_id = ?
WHERE route_id = ? AND direction_id = ?
LIMIT 1
""",
(route_id,),
(route.bus_line, direction_id),
)

row = cursor.fetchone()
Expand Down Expand Up @@ -68,7 +72,7 @@ def get_route_shape(self, route_id: str) -> RouteShape | None:
return None

return RouteShape(
route_id=route_id,
route=route,
shape_id=shape_id,
points=points,
)
5 changes: 3 additions & 2 deletions src/core/models/route_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass

from .bus import RouteIdentifier
from .coordinate import Coordinate


Expand All @@ -27,11 +28,11 @@ class RouteShape:
Complete shape of a route with ordered coordinates.

Attributes:
route_id: Route identifier
route: Route identifier (bus_line and direction)
shape_id: Shape identifier from GTFS
points: List of points defining the route shape, ordered by sequence
"""

route_id: str
route: RouteIdentifier
shape_id: str
points: list[RouteShapePoint]
5 changes: 3 additions & 2 deletions src/core/ports/gtfs_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod

from ..models.bus import RouteIdentifier
from ..models.route_shape import RouteShape


Expand All @@ -14,12 +15,12 @@ class GTFSRepositoryPort(ABC):
"""

@abstractmethod
def get_route_shape(self, route_id: str) -> RouteShape | None:
def get_route_shape(self, route: RouteIdentifier) -> RouteShape | None:
"""
Get the geographic shape of a route.

Args:
route_id: Route identifier
route: Route identifier with bus_line and direction

Returns:
RouteShape with ordered coordinates, or None if route not found
Expand Down
15 changes: 10 additions & 5 deletions src/core/services/route_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,19 @@ async def get_route_details(self, route: RouteIdentifier) -> list[BusRoute]:
await self.bus_provider.authenticate()
return await self.bus_provider.get_route_details(route)

def get_route_shape(self, route_id: str) -> RouteShape | None:
def get_route_shapes(self, routes: list[RouteIdentifier]) -> list[RouteShape]:
"""
Get the geographic shape coordinates of a route from GTFS data.
Get the geographic shape coordinates for multiple routes from GTFS data.

Args:
route_id: Route identifier (e.g., "1012-10")
routes: List of route identifiers with bus_line and direction

Returns:
RouteShape with ordered coordinates, or None if route not found
List of RouteShapes with ordered coordinates (excludes routes not found)
"""
return self.gtfs_repository.get_route_shape(route_id)
shapes: list[RouteShape] = []
for route in routes:
shape = self.gtfs_repository.get_route_shape(route)
if shape is not None:
shapes.append(shape)
return shapes
38 changes: 18 additions & 20 deletions src/web/controllers/route_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..mappers import (
map_bus_position_list_to_schema,
map_route_identifier_schema_to_domain,
map_route_shape_to_response,
map_route_shapes_to_response,
)
from ..schemas import (
BusPositionsRequest,
Expand All @@ -25,7 +25,8 @@
BusRoutesDetailsRequest,
BusRoutesDetailsResponse,
RouteIdentifierSchema,
RouteShapeResponse,
RouteShapesRequest,
RouteShapesResponse,
)

router = APIRouter(prefix="/routes", tags=["routes"])
Expand Down Expand Up @@ -144,40 +145,37 @@ async def get_bus_positions(

# NOTE: Having `current_user: User = Depends(get_current_user)` as a dependency
# makes this endpoint only accessible to authenticated users (requires valid JWT token).
@router.get("/shape/{route_id}", response_model=RouteShapeResponse)
async def get_route_shape(
route_id: str,
@router.post("/shapes", response_model=RouteShapesResponse)
async def get_route_shapes(
request: RouteShapesRequest,
route_service: RouteService = Depends(get_route_service),
current_user: User = Depends(get_current_user),
) -> RouteShapeResponse:
) -> RouteShapesResponse:
"""
Get the geographic shape (coordinates) of a route from GTFS data.
Get the geographic shapes (coordinates) for multiple routes from GTFS data.

Args:
route_id: Route identifier (e.g., "1012-10")
request: Request containing list of route identifiers (bus_line and direction)
route_service: Injected route service

Returns:
Ordered list of coordinates defining the route shape
List of route shapes with ordered coordinates

Raises:
HTTPException: If route not found or database error occurs
HTTPException: If database error occurs
"""
try:
shape = route_service.get_route_shape(route_id)
route_identifiers = [
map_route_identifier_schema_to_domain(route_schema) for route_schema in request.routes
]

if shape is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Route '{route_id}' not found in GTFS database",
)
shapes = route_service.get_route_shapes(route_identifiers)

return map_route_shape_to_response(shape)
shape_responses = map_route_shapes_to_response(shapes)
return RouteShapesResponse(shapes=shape_responses)

except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve route shape: {str(e)}",
detail=f"Failed to retrieve route shapes: {str(e)}",
) from e
15 changes: 14 additions & 1 deletion src/web/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,25 @@ def map_route_shape_to_response(shape: RouteShape) -> RouteShapeResponse:
RouteShapeResponse for API
"""
return RouteShapeResponse(
route_id=shape.route_id,
route=map_route_identifier_domain_to_schema(shape.route),
shape_id=shape.shape_id,
points=[map_coordinate_domain_to_schema(point.coordinate) for point in shape.points],
)


def map_route_shapes_to_response(shapes: list[RouteShape]) -> list[RouteShapeResponse]:
"""
Map a list of RouteShape domain models to RouteShapeResponse list.

Args:
shapes: List of RouteShape domain models

Returns:
List of RouteShapeResponse for API
"""
return [map_route_shape_to_response(shape) for shape in shapes]


# ===== History Mappers =====


Expand Down
16 changes: 15 additions & 1 deletion src/web/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,25 @@ class BusRoutesDetailsResponse(BaseModel):
class RouteShapeResponse(BaseModel):
"""Response schema for route shape coordinates."""

route_id: str = Field(..., description="Route identifier")
route: RouteIdentifierSchema = Field(..., description="Route identifier")
shape_id: str = Field(..., description="GTFS shape identifier")
points: list[CoordinateSchema] = Field(..., description="Ordered list of coordinates")


class RouteShapesRequest(BaseModel):
"""Request schema for querying multiple route shapes."""

routes: list[RouteIdentifierSchema] = Field(
..., description="List of route identifiers to query shapes for"
)


class RouteShapesResponse(BaseModel):
"""Response schema for multiple route shapes."""

shapes: list[RouteShapeResponse] = Field(..., description="List of route shapes")


# ===== Ranking Schemas =====


Expand Down
22 changes: 13 additions & 9 deletions tests/adapters/test_gtfs_repository_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from unittest.mock import MagicMock, patch

from src.adapters.repositories.gtfs_repository_adapter import GTFSRepositoryAdapter
from src.core.models.bus import RouteIdentifier
from src.core.models.route_shape import RouteShape


def test_get_route_shape_found() -> None:
"""Test getting a route shape when the route exists in the database."""
# Arrange
adapter = GTFSRepositoryAdapter()
route = RouteIdentifier(bus_line="test_route_1", bus_direction=1)

# Mock the database connection and cursors
mock_conn = MagicMock()
Expand Down Expand Up @@ -53,12 +54,13 @@ def test_get_route_shape_found() -> None:
mock_get_db.return_value.__enter__.return_value = mock_conn

# Act
result = adapter.get_route_shape("test_route_1")
result = adapter.get_route_shape(route)

# Assert
assert result is not None
assert isinstance(result, RouteShape)
assert result.route_id == "test_route_1"
assert result.route.bus_line == "test_route_1"
assert result.route.bus_direction == 1
assert result.shape_id == "test_shape_123"
assert len(result.points) == 3

Expand All @@ -76,9 +78,9 @@ def test_get_route_shape_found() -> None:


def test_get_route_shape_route_not_found() -> None:
"""Test getting a route shape when the route doesn't exist."""
# Arrange
adapter = GTFSRepositoryAdapter()
route = RouteIdentifier(bus_line="nonexistent_route", bus_direction=1)

mock_conn = MagicMock()
mock_cursor = MagicMock()
Expand All @@ -91,16 +93,16 @@ def test_get_route_shape_route_not_found() -> None:
mock_get_db.return_value.__enter__.return_value = mock_conn

# Act
result = adapter.get_route_shape("nonexistent_route")
result = adapter.get_route_shape(route)

# Assert
assert result is None


def test_get_route_shape_no_shape_points() -> None:
"""Test when route exists but has no shape points."""
# Arrange
adapter = GTFSRepositoryAdapter()
route = RouteIdentifier(bus_line="route_without_points", bus_direction=1)

mock_conn = MagicMock()
mock_cursor1 = MagicMock()
Expand All @@ -118,7 +120,7 @@ def test_get_route_shape_no_shape_points() -> None:
mock_get_db.return_value.__enter__.return_value = mock_conn

# Act
result = adapter.get_route_shape("route_without_points")
result = adapter.get_route_shape(route)

# Assert
assert result is None
Expand All @@ -128,6 +130,7 @@ def test_get_route_shape_single_point() -> None:
"""Test getting a route shape with only one point."""
# Arrange
adapter = GTFSRepositoryAdapter()
route = RouteIdentifier(bus_line="single_point_route", bus_direction=1)

mock_conn = MagicMock()
mock_cursor1 = MagicMock()
Expand All @@ -150,7 +153,7 @@ def test_get_route_shape_single_point() -> None:
mock_get_db.return_value.__enter__.return_value = mock_conn

# Act
result = adapter.get_route_shape("single_point_route")
result = adapter.get_route_shape(route)

# Assert
assert result is not None
Expand All @@ -163,6 +166,7 @@ def test_get_route_shape_null_distance_traveled() -> None:
"""Test getting a route shape with NULL distance_traveled values."""
# Arrange
adapter = GTFSRepositoryAdapter()
route = RouteIdentifier(bus_line="route_no_distance", bus_direction=1)

mock_conn = MagicMock()
mock_cursor1 = MagicMock()
Expand Down Expand Up @@ -191,7 +195,7 @@ def test_get_route_shape_null_distance_traveled() -> None:
mock_get_db.return_value.__enter__.return_value = mock_conn

# Act
result = adapter.get_route_shape("route_no_distance")
result = adapter.get_route_shape(route)

# Assert
assert result is not None
Expand Down
Loading