diff --git a/src/mbtaclient/handlers/trains_handler.py b/src/mbtaclient/handlers/trains_handler.py index a1429e3..566bbb5 100644 --- a/src/mbtaclient/handlers/trains_handler.py +++ b/src/mbtaclient/handlers/trains_handler.py @@ -13,26 +13,28 @@ class TrainsHandler(MBTABaseHandler): """Handler for managing Trips.""" DEFAULT_MAX_TRIPS = 1 - + @classmethod async def create( cls, mbta_client: MBTAClient, - departure_stop_name: str , + departure_stop_name: str, arrival_stop_name: str, trip_name: str, max_trips: Optional[int] = DEFAULT_MAX_TRIPS, - logger: Optional[logging.Logger] = None)-> "TrainsHandler": - + logger: Optional[logging.Logger] = None + ) -> "TrainsHandler": """Asynchronous factory method to initialize TripsHandler.""" instance = await super()._create( mbta_client=mbta_client, departure_stop_name=departure_stop_name, arrival_stop_name=arrival_stop_name, max_trips=max_trips, - logger=logger) - - instance._logger = logger or logging.getLogger(__name__) # Logger instance + logger=logger + ) + + instance._logger = logger or logging.getLogger(__name__) + instance._mbta_trips_id: list[str] = [] # Initialize trip ID list await instance.__update_mbta_trips_by_trip_name(trip_name) @@ -46,33 +48,29 @@ async def __update_mbta_trips_by_trip_name(self, trip_name: str) -> None: for mbta_trip in mbta_trips: if not MBTATripObjStore.get_by_id(mbta_trip.id): MBTATripObjStore.store(mbta_trip) - self._mbta_trips_id = mbta_trip.id + if mbta_trip.id not in self._mbta_trips_id: + self._mbta_trips_id.append(mbta_trip.id) else: self._logger.error(f"Invalid MBTA trip name {trip_name}") raise MBTATripError(f"Invalid MBTA trip name {trip_name}") - + except Exception as e: self._logger.error(f"Error updating MBTA trips: {e}") raise - async def __fetch_trips_by_name(self, train_name: str) -> Tuple[list[MBTATrip],float]: - + async def __fetch_trips_by_name(self, train_name: str) -> Tuple[list[MBTATrip], float]: params = { 'filter[revenue]': 'REVENUE', 'filter[name]': train_name - } + } mbta_trips, timestamp = await self._mbta_client.fetch_trips(params) return mbta_trips, timestamp - async def update(self) -> list[Trip]: self._logger.debug("Updating trips scheduling and info") try: - now = datetime.now().astimezone() - - # Initialize trips weekly_trips: list[dict[str, Trip]] = [] for i in range(8): @@ -80,23 +78,23 @@ async def update(self) -> list[Trip]: date_to_try = (now + timedelta(days=i)).strftime('%Y-%m-%d') params = { - 'filter[trip]': self._mbta_trips_id, + 'filter[trip]': ','.join(self._mbta_trips_id), 'filter[date]': date_to_try } - daily_updated_trip = await super()._update_scheduling(trips=daily_trip,params=params) + daily_updated_trip = await super()._update_scheduling(trips=daily_trip, params=params) - # Filter out departed trips daily_filtered_trip = super()._filter_and_sort_trips( trips=daily_updated_trip, - remove_departed=False) - + remove_departed=False + ) + if len(daily_filtered_trip) > 0: weekly_trips.append(daily_filtered_trip) - + if len(weekly_trips) == self._max_trips: break - + if len(weekly_trips) == 0: if i == 7: self._logger.error(f"No trips between the provided stops till {date_to_try}") @@ -105,13 +103,10 @@ async def update(self) -> list[Trip]: trains: list[Trip] = [] for trips in weekly_trips: - - # Update stops for the trip task_stops = asyncio.create_task(super()._update_mbta_stops_for_trips(trips=trips.values())) - # Update trip details tasks_trips_details = asyncio.create_task(super()._update_details(trips=trips)) - await task_stops + await task_stops detailed_trip = await tasks_trips_details trains.append(list(detailed_trip.values())[0]) @@ -121,6 +116,6 @@ async def update(self) -> list[Trip]: except Exception as e: self._logger.error(f"Error updating trips scheduling and info: {e}") raise - + class MBTATripError(Exception): pass diff --git a/tests/test_train_handler.py b/tests/test_train_handler.py index 7840645..114e807 100644 --- a/tests/test_train_handler.py +++ b/tests/test_train_handler.py @@ -16,6 +16,7 @@ [ ("South Station", "Back Bay", "509"), ("Worcester", "South Station", "518"), + ("Swampscott", "North Station", "12"), #("West Natick", "South Station", "520"), ] ) @@ -90,4 +91,4 @@ async def test_handler(departure_stop_name, arrival_stop_name, train): print(f"seconds_arrival: {seconds_arrival}") print(f"seconds_departure: {seconds_departure}") - print("##############") \ No newline at end of file + print("##############")