diff --git a/deutsche_bahn_api/api_authentication.py b/deutsche_bahn_api/api_authentication.py index ff6ff51..61682c9 100644 --- a/deutsche_bahn_api/api_authentication.py +++ b/deutsche_bahn_api/api_authentication.py @@ -12,7 +12,7 @@ def test_credentials(self) -> bool: headers={ "DB-Api-Key": self.client_secret, "DB-Client-Id": self.client_id, - } + }, ) return response.status_code == 200 diff --git a/deutsche_bahn_api/station_helper.py b/deutsche_bahn_api/station_helper.py index d5965cd..a18a030 100644 --- a/deutsche_bahn_api/station_helper.py +++ b/deutsche_bahn_api/station_helper.py @@ -1,6 +1,6 @@ import json import pkgutil -import mpu +from haversine import haversine from deutsche_bahn_api.station import Station @@ -37,7 +37,7 @@ def find_stations_by_lat_long(self, target_lat: float, target_long: float, radiu for station in self.stations_list: lat_long: dict[str, float] = normalize_lat_or_long_from_station(station) - distance = mpu.haversine_distance( + distance = haversine( (lat_long['lat'], lat_long['long']), (target_lat, target_long)) if distance < radius: @@ -53,3 +53,23 @@ def find_stations_by_name(self, query: str) -> list[Station]: results.append(station) return results + + def find_stations_by_id(self, query: int) -> list[Station]: + results: list[Station] = [] + + for station in self.stations_list: + if query == station.EVA_NR: + results.append(station) + + return results + + def find_station_by_id(self, query: int) -> list[Station]: + results: list[Station] = [] + for station in self.stations_list: + if query == station.EVA_NR: + results.append(station) + + if len(results)>1: + raise Exception("More than one station with id '%s' found" % query) + else: + return results[0] diff --git a/deutsche_bahn_api/timetable_helper.py b/deutsche_bahn_api/timetable_helper.py index ab10f42..243c301 100644 --- a/deutsche_bahn_api/timetable_helper.py +++ b/deutsche_bahn_api/timetable_helper.py @@ -32,7 +32,7 @@ def get_timetable_xml(self, hour: Optional[int] = None, date: Optional[datetime] response = requests.get( f"https://apis.deutschebahn.com/db-api-marketplace/apis/timetables/v1" f"/plan/{self.station.EVA_NR}/{date_string}/{hour}", - headers=self.api_authentication.get_headers() + headers=self.api_authentication.get_headers(), ) if response.status_code == 410: return self.get_timetable_xml(int(hour), datetime.now() + timedelta(days=1)) @@ -59,23 +59,27 @@ def get_timetable(self, hour: Optional[int] = None, date: Optional[datetime] = N if train_details.tag == "ar": arrival_object = train_details.attrib - if not departure_object: - """ Arrival without department """ - continue - train_object: Train = Train() train_object.stop_id = train.attrib["id"] train_object.train_type = trip_label_object["c"] train_object.train_number = trip_label_object["n"] - train_object.platform = departure_object['pp'] - train_object.stations = departure_object['ppth'] - train_object.departure = departure_object['pt'] + # If Stop has departure_object, get some informations from it + if departure_object: + train_object.platform = departure_object['pp'] + train_object.stations = departure_object['ppth'] + train_object.departure = departure_object['pt'] + if "l" in departure_object: + train_object.train_line = departure_object['l'] + # If not, get them from arrival_object + else: + train_object.platform = arrival_object['pp'] + if "l" in arrival_object: + train_object.train_line = arrival_object['l'] if "f" in trip_label_object: train_object.trip_type = trip_label_object["f"] - if "l" in departure_object: - train_object.train_line = departure_object['l'] + if arrival_object: train_object.passed_stations = arrival_object['ppth'] @@ -114,12 +118,18 @@ def get_timetable_changes(self, trains: list) -> list[Train]: train_changes.stations = changes.attrib["cpth"] if "cp" in changes.attrib: train_changes.platform = changes.attrib["cp"] + if "cs" in changes.attrib: + train_changes.departure_cancelled = changes.attrib["cs"] if changes.tag == "ar": if "ct" in changes.attrib: train_changes.arrival = changes.attrib["ct"] if "cpth" in changes.attrib: train_changes.passed_stations = changes.attrib["cpth"] + if "cp" in changes.attrib: + train_changes.platform = changes.attrib["cp"] + if "cs" in changes.attrib: + train_changes.arrival_cancelled = changes.attrib["cs"] for message in changes: new_message = Message() diff --git a/deutsche_bahn_api/train_changes.py b/deutsche_bahn_api/train_changes.py index 32c0573..cc53a7f 100644 --- a/deutsche_bahn_api/train_changes.py +++ b/deutsche_bahn_api/train_changes.py @@ -9,3 +9,5 @@ class TrainChanges: stations: str platform: str messages: list[Message] + arrival_cancelled: str + departure_cancelled: str diff --git a/setup.py b/setup.py index f93f511..2cd4f40 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ long_description_content_type="text/markdown", url="https://github.com/Tutorialwork/deutsche_bahn_api", packages=find_packages(), - install_requires=["mpu", "requests"], + install_requires=["haversine", "requests"], package_data={"deutsche_bahn_api": ["static/*"]}, classifiers=[ "Programming Language :: Python :: 3",