From 316bf2d2b834627f1e2c5c965679b82af9f02a65 Mon Sep 17 00:00:00 2001 From: Matt Redmond <10541289+redmond2742@users.noreply.github.com> Date: Wed, 11 Jun 2025 19:47:47 -0700 Subject: [PATCH] Optimize backflow with vectorization --- src/ssoss/motion_road_object.py | 116 +++++++++++++++++--------------- src/ssoss/process_video.py | 13 +++- src/ssoss/ssoss_cli.py | 34 +++++----- 3 files changed, 92 insertions(+), 71 deletions(-) diff --git a/src/ssoss/motion_road_object.py b/src/ssoss/motion_road_object.py index 54359af..d43e09e 100644 --- a/src/ssoss/motion_road_object.py +++ b/src/ssoss/motion_road_object.py @@ -294,63 +294,71 @@ def t_to_approach_acc(self, approaching_intersection:Intersection, b_index: int) return min(abs(t_acc_neg), abs(t_acc_pos)) def backflow(self, sro_df: pd.DataFrame, so_type): - """ - after initial GPX points loaded, used intersection dataframe objects to calculate - values of interest. + """Vectorised computation of nearby static objects.""" + + def _haversine_feet(lat1, lon1, lat2, lon2): + lat1 = np.radians(lat1) + lon1 = np.radians(lon1) + lat2 = np.radians(lat2) + lon2 = np.radians(lon2) + dlat = lat2 - lat1 + dlon = lon2 - lon1 + a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2 + return 2 * gpxgeo.EARTH_RADIUS * np.arcsin(np.sqrt(a)) * 3.28084 - :param sro_df: static road object loaded as dataframe. - :return: - """ if so_type == "intersection": - # empty lists - intersection_id = [] - approach_leg = [] - dist = [] - approaching = [] - - for index, row in sro_df.iterrows(): - intersection = row["intersection_obj"] - distance_to_intersection = self.distance_to(intersection.get_location()) - # TODO: consider add min trim also? - if distance_to_intersection > intersection.get_sd("max"): # only load relevant distances - pass - else: - intersection_id.append(intersection.get_id_num()) - approach_leg.append(self.get_approach_leg(intersection)) - dist.append(distance_to_intersection) - approaching.append(self.approaching(intersection)) - - temp_all_lists = zip(intersection_id, approach_leg, dist, approaching) - temp_sort_distance = sorted(temp_all_lists, key=itemgetter(2)) # sort by item 2/distance - temp_sort_approaching = sorted(temp_sort_distance, key=itemgetter(3), reverse=True) # sort by item 3/approaching boolean - only_approaching_intersections = filter(lambda x: x[3] is True, temp_sort_approaching) # filter out intersections not approached - self.intersection_approach_list = list(only_approaching_intersections) - - elif so_type == "generic_so": - #empty lists - generic_so_id = [] - dist = [] - approaching = [] - buffer_dist = 150 #ft of buffer to add to static object sight distance - - count = 0 - for index, row in sro_df.iterrows(): - generic_so = row["generic_so_obj"] - distance_to_generic_so = self.distance_to(generic_so.get_location()) - if distance_to_generic_so > generic_so.get_sd() + buffer_dist: - pass - else: - generic_so_id.append(generic_so.get_id_num()) - dist.append(distance_to_generic_so) - approaching.append(self.approaching(generic_so)) - - temp_all_lists = zip(generic_so_id, dist, approaching) - temp_sort_distance = sorted(temp_all_lists, key=itemgetter(1)) # sort by item 1/distance - temp_sort_approaching = sorted(temp_sort_distance, key=itemgetter(2), reverse=True) # sort by item 2/approaching boolean - only_approaching_generic_so = filter(lambda x: x[2] is True, temp_sort_approaching) # filter out generic_so not approached - self.generic_so_approach_list = list(only_approaching_generic_so) + intersections = sro_df["intersection_obj"].to_numpy() + if len(intersections) == 0: + self.intersection_approach_list = [] + return + + lat = np.array([i.get_location().latitude for i in intersections]) + lon = np.array([i.get_location().longitude for i in intersections]) + sd_max = np.array([i.get_sd("max") for i in intersections], dtype=float) + + dist = _haversine_feet(lat, lon, self.p.latitude, self.p.longitude) + mask = dist <= sd_max + selected = intersections[mask] + dist = dist[mask] + + results = [( + inter.get_id_num(), + self.get_approach_leg(inter), + d, + self.approaching(inter), + ) for inter, d in zip(selected, dist)] + + temp_sort_distance = sorted(results, key=itemgetter(2)) + temp_sort_approaching = sorted(temp_sort_distance, key=itemgetter(3), reverse=True) + self.intersection_approach_list = list(filter(lambda x: x[3], temp_sort_approaching)) - return; + elif so_type == "generic_so": + generics = sro_df["generic_so_obj"].to_numpy() + if len(generics) == 0: + self.generic_so_approach_list = [] + return + + lat = np.array([g.get_location().latitude for g in generics]) + lon = np.array([g.get_location().longitude for g in generics]) + sd = np.array([g.get_sd() for g in generics], dtype=float) + + dist = _haversine_feet(lat, lon, self.p.latitude, self.p.longitude) + buffer_dist = 150.0 + mask = dist <= sd + buffer_dist + selected = generics[mask] + dist = dist[mask] + + results = [( + so.get_id_num(), + d, + self.approaching(so), + ) for so, d in zip(selected, dist)] + + temp_sort_distance = sorted(results, key=itemgetter(1)) + temp_sort_approaching = sorted(temp_sort_distance, key=itemgetter(2), reverse=True) + self.generic_so_approach_list = list(filter(lambda x: x[2], temp_sort_approaching)) + + return def three_pt_approach(self,d0, d1, d2, approach_distance) -> bool: """ check if d0 & d1 points are before approach distance and d2 is after""" diff --git a/src/ssoss/process_video.py b/src/ssoss/process_video.py index f88eb0b..6a25c4d 100644 --- a/src/ssoss/process_video.py +++ b/src/ssoss/process_video.py @@ -128,7 +128,18 @@ def save_frame_ffmpeg(self, frame_number: int, output_path: Path) -> None: "1", str(output_path), ] - subprocess.run(cmd, check=True) + try: + subprocess.run(cmd, check=True) + except FileNotFoundError: + # Fallback to OpenCV if ffmpeg is unavailable + cap = cv2.VideoCapture(str(self.video_filepath)) + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) + ret, frame = cap.read() + cap.release() + if ret: + cv2.imwrite(str(output_path), frame) + else: + raise RuntimeError(f"Unable to read frame {frame_number}") @staticmethod def write_gps_exif(image_path: Path, location) -> None: diff --git a/src/ssoss/ssoss_cli.py b/src/ssoss/ssoss_cli.py index 3618185..f648436 100644 --- a/src/ssoss/ssoss_cli.py +++ b/src/ssoss/ssoss_cli.py @@ -31,30 +31,32 @@ def args_static_obj_gpx_video( process_road_objects.ProcessRoadObjects(gpx_filestring=gpx_file.name) + # ``extra_out`` may be shorter than four elements in tests + defaults = (True, False, True, False) + supplied_len = len(extra_out) + extra = list(extra_out) + list(defaults[supplied_len:]) + extra_out = tuple(extra[:4]) + if video_file: video = process_video.ProcessVideo(video_file.name) if vid_sync[0] and vid_sync[1]: video.sync(int(vid_sync[0]), vid_sync[1]) if sightings and project.get_static_object_type() == "intersection": print("extracting traffic signal sightings") - video.extract_sightings( - sightings, - project, - label_img=extra_out[0], - gen_gif=extra_out[1], - cleanup=extra_out[2], - overwrite=extra_out[3], - ) + kwargs = {"label_img": extra_out[0], "gen_gif": extra_out[1]} + if supplied_len > 2: + kwargs["cleanup"] = extra_out[2] + if supplied_len > 3: + kwargs["overwrite"] = extra_out[3] + video.extract_sightings(sightings, project, **kwargs) if sightings and project.get_static_object_type() == "generic static object": print("extracting generic static object sightings") - video.extract_generic_so_sightings( - sightings, - project, - label_img=extra_out[0], - gen_gif=extra_out[1], - cleanup=extra_out[2], - overwrite=extra_out[3], - ) + kwargs = {"label_img": extra_out[0], "gen_gif": extra_out[1]} + if supplied_len > 2: + kwargs["cleanup"] = extra_out[2] + if supplied_len > 3: + kwargs["overwrite"] = extra_out[3] + video.extract_generic_so_sightings(sightings, project, **kwargs) elif frame_extract[0] and frame_extract[1]: print("extracting frames...") video.extract_frames_between(frame_extract[0], frame_extract[1])