Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "MBTAclient"
version = "1.1.29"
version = "1.1.30"
description = "A Python client for interacting with the MBTA API"
readme = "README.md"
requires-python = ">=3.12"
Expand Down
51 changes: 23 additions & 28 deletions src/mbtaclient/handlers/trains_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,28 @@ class TrainsHandler(MBTABaseHandler):
"""Handler for managing Trips."""

DEFAULT_MAX_TRIPS = 1

@classmethod
async def create(
cls,
mbta_client: MBTAClient,
departure_stop_name: str ,
departure_stop_name: str,
arrival_stop_name: str,
trip_name: str,
max_trips: Optional[int] = DEFAULT_MAX_TRIPS,
logger: Optional[logging.Logger] = None)-> "TrainsHandler":

logger: Optional[logging.Logger] = None
) -> "TrainsHandler":
"""Asynchronous factory method to initialize TripsHandler."""
instance = await super()._create(
mbta_client=mbta_client,
departure_stop_name=departure_stop_name,
arrival_stop_name=arrival_stop_name,
max_trips=max_trips,
logger=logger)

instance._logger = logger or logging.getLogger(__name__) # Logger instance
logger=logger
)

instance._logger = logger or logging.getLogger(__name__)
instance._mbta_trips_id: list[str] = [] # Initialize trip ID list

await instance.__update_mbta_trips_by_trip_name(trip_name)

Expand All @@ -46,57 +48,53 @@ async def __update_mbta_trips_by_trip_name(self, trip_name: str) -> None:
for mbta_trip in mbta_trips:
if not MBTATripObjStore.get_by_id(mbta_trip.id):
MBTATripObjStore.store(mbta_trip)
self._mbta_trips_id = mbta_trip.id
if mbta_trip.id not in self._mbta_trips_id:
self._mbta_trips_id.append(mbta_trip.id)
else:
self._logger.error(f"Invalid MBTA trip name {trip_name}")
raise MBTATripError(f"Invalid MBTA trip name {trip_name}")

except Exception as e:
self._logger.error(f"Error updating MBTA trips: {e}")
raise

async def __fetch_trips_by_name(self, train_name: str) -> Tuple[list[MBTATrip],float]:

async def __fetch_trips_by_name(self, train_name: str) -> Tuple[list[MBTATrip], float]:
params = {
'filter[revenue]': 'REVENUE',
'filter[name]': train_name
}
}

mbta_trips, timestamp = await self._mbta_client.fetch_trips(params)
return mbta_trips, timestamp


async def update(self) -> list[Trip]:
self._logger.debug("Updating trips scheduling and info")
try:

now = datetime.now().astimezone()

# Initialize trips
weekly_trips: list[dict[str, Trip]] = []

for i in range(8):
daily_trip: dict[str, Trip] = {}
date_to_try = (now + timedelta(days=i)).strftime('%Y-%m-%d')

params = {
'filter[trip]': self._mbta_trips_id,
'filter[trip]': ','.join(self._mbta_trips_id),
'filter[date]': date_to_try
}

daily_updated_trip = await super()._update_scheduling(trips=daily_trip,params=params)
daily_updated_trip = await super()._update_scheduling(trips=daily_trip, params=params)

# Filter out departed trips
daily_filtered_trip = super()._filter_and_sort_trips(
trips=daily_updated_trip,
remove_departed=False)

remove_departed=False
)

if len(daily_filtered_trip) > 0:
weekly_trips.append(daily_filtered_trip)

if len(weekly_trips) == self._max_trips:
break

if len(weekly_trips) == 0:
if i == 7:
self._logger.error(f"No trips between the provided stops till {date_to_try}")
Expand All @@ -105,13 +103,10 @@ async def update(self) -> list[Trip]:

trains: list[Trip] = []
for trips in weekly_trips:

# Update stops for the trip
task_stops = asyncio.create_task(super()._update_mbta_stops_for_trips(trips=trips.values()))
# Update trip details
tasks_trips_details = asyncio.create_task(super()._update_details(trips=trips))

await task_stops
await task_stops
detailed_trip = await tasks_trips_details

trains.append(list(detailed_trip.values())[0])
Expand All @@ -121,6 +116,6 @@ async def update(self) -> list[Trip]:
except Exception as e:
self._logger.error(f"Error updating trips scheduling and info: {e}")
raise

class MBTATripError(Exception):
pass
3 changes: 2 additions & 1 deletion tests/test_train_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
[
("South Station", "Back Bay", "509"),
("Worcester", "South Station", "518"),
("Swampscott", "North Station", "12"),
#("West Natick", "South Station", "520"),
]
)
Expand Down Expand Up @@ -90,4 +91,4 @@ async def test_handler(departure_stop_name, arrival_stop_name, train):
print(f"seconds_arrival: {seconds_arrival}")
print(f"seconds_departure: {seconds_departure}")

print("##############")
print("##############")
Loading