From 163b6e1123da13b892fea5bc845a9d254f569562 Mon Sep 17 00:00:00 2001 From: tzuk Date: Wed, 1 Nov 2023 11:14:55 +0200 Subject: [PATCH 01/13] checkpoint for computer switching --- simba/SimBA.py | 14 ++- .../directing_animal_to_bodypart.py | 13 +-- simba/mixins/pop_up_mixin.py | 88 ++++++++++++------- .../pose_configurations/bp_names/bp_names.csv | 2 + .../configuration_names/pose_config_names.csv | 2 + .../no_animals/no_animals.csv | 2 + simba/pose_importers/dlc_importer_csv.py | 13 +-- ...bodypart_directionality_features_pop_up.py | 59 +++++++++++++ 8 files changed, 147 insertions(+), 46 deletions(-) create mode 100644 simba/ui/pop_ups/append_bodypart_directionality_features_pop_up.py diff --git a/simba/SimBA.py b/simba/SimBA.py index 9964deb68..9da296c95 100644 --- a/simba/SimBA.py +++ b/simba/SimBA.py @@ -3,6 +3,8 @@ import os.path import warnings +from simba.ui.pop_ups.append_bodypart_directionality_features_pop_up import AppendBodyPartDirectionalityFeaturesPopUp + warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) from simba.ui.pop_ups.direction_animal_to_bodypart_settings_pop_up import DirectionAnimalToBodyPartSettingsPopUp @@ -140,7 +142,7 @@ # from simba.unsupervised.unsupervised_ui import UnsupervisedGUI -sys.setrecursionlimit(10**6) +sys.setrecursionlimit(10 ** 6) currentPlatform = platform.system() @@ -597,6 +599,14 @@ def activate(box, *args): config_path=self.config_path ), ) + append_body_part_directionality_features = Button( + roi_feature_frm, + text="APPEND BODY PART DIRECTIONALITY DATA TO FEATURES (CAUTION)", + fg="green", + command=lambda: AppendBodyPartDirectionalityFeaturesPopUp( + config_path=self.config_path + ), + ) # remove_roi_features_from_feature_set = Button( # roi_feature_frm, # text="REMOVE ROI FEATURES FROM FEATURE SET", @@ -1145,6 +1155,7 @@ def activate(box, *args): roi_feature_frm.grid(row=1, column=0, sticky=NW) append_roi_features_by_animal.grid(row=0, column=0, sticky=NW) append_roi_features_by_body_part.grid(row=1, column=0, sticky=NW) + append_body_part_directionality_features.grid(row=2, column=0, sticky=NW) # remove_roi_features_from_feature_set.grid(row=2, column=0, sticky=NW) feature_tools_frm.grid(row=2, column=0, sticky=NW) @@ -1255,7 +1266,6 @@ def directing_other_animals_analysis(self): def directing_animal_to_bp_analysis(self): _ = DirectionAnimalToBodyPartSettingsPopUp(config_path=self.config_path) - def directing_other_animals_visualizer(self): _ = DirectingOtherAnimalsVisualizerPopUp(config_path=self.config_path) diff --git a/simba/data_processors/directing_animal_to_bodypart.py b/simba/data_processors/directing_animal_to_bodypart.py index 43f9e0e87..2426a8cbb 100644 --- a/simba/data_processors/directing_animal_to_bodypart.py +++ b/simba/data_processors/directing_animal_to_bodypart.py @@ -146,7 +146,7 @@ def create_directionality_dfs(self): for animal_permutation, permutation_data in video_data.items(): for bp_name, bp_data in permutation_data.items(): directing_df = ( - bp_data#[bp_data["Directing_BOOL"] == 1] + bp_data # [bp_data["Directing_BOOL"] == 1] .reset_index() .rename( columns={ @@ -194,14 +194,15 @@ def save_directionality_dfs(self): ------- None """ - if not os.path.exists(self.body_part_directionality_df_dir): - os.makedirs(self.body_part_directionality_df_dir) + output_dir = os.path.join(self.body_part_directionality_df_dir, self.bodypart_direction) + if not os.path.exists(output_dir): + os.makedirs(output_dir) for video_name, video_data in self.directionality_df_dict.items(): - save_name = os.path.join(self.body_part_directionality_df_dir, video_name + ".csv") + save_name = os.path.join(output_dir, video_name + ".csv") video_data.to_csv(save_name) print(f"Detailed directional data saved for video {video_name}...") stdout_success( - msg=f"All detailed directional data saved in the {self.body_part_directionality_df_dir} directory" + msg=f"All detailed directional data saved in the {output_dir} directory" ) @@ -238,7 +239,7 @@ def summary_statistics(self): .set_index("Video") ) self.save_path = os.path.join( - self.logs_path, "Body_part_directions_data_{}.csv".format(str(self.datetime)) + self.logs_path, "Body_part_directions_data_{}_{}.csv".format(str(self.bodypart_direction, self.datetime)) ) self.summary_df.to_csv(self.save_path) self.timer.stop_timer() diff --git a/simba/mixins/pop_up_mixin.py b/simba/mixins/pop_up_mixin.py index 8921312c5..e15ef946a 100644 --- a/simba/mixins/pop_up_mixin.py +++ b/simba/mixins/pop_up_mixin.py @@ -43,10 +43,10 @@ class PopUpMixin(object): """ def __init__( - self, - title: str, - config_path: Optional[str] = None, - size: Tuple[int, int] = (960,720), + self, + title: str, + config_path: Optional[str] = None, + size: Tuple[int, int] = (960, 720), ): self.root = Toplevel() self.root.minsize(size[0], size[1]) @@ -92,7 +92,7 @@ def create_clf_checkboxes(self, main_frm: Frame, clfs: List[str]): self.choose_clf_frm.grid(row=self.children_cnt_main(), column=0, sticky=NW) def create_cb_frame( - self, main_frm: Frame, cb_titles: List[str], frm_title: str + self, main_frm: Frame, cb_titles: List[str], frm_title: str ) -> Dict[str, BooleanVar]: cb_frm = LabelFrame( main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value @@ -106,11 +106,11 @@ def create_cb_frame( return cb_dict def create_dropdown_frame( - self, - main_frm: Frame, - drop_down_titles: List[str], - drop_down_options: List[str], - frm_title: str, + self, + main_frm: Frame, + drop_down_titles: List[str], + drop_down_options: List[str], + frm_title: str, ): dropdown_frm = LabelFrame( main_frm, text=frm_title, font=Formats.LABELFRAME_HEADER_FORMAT.value @@ -209,8 +209,33 @@ def create_run_frm(self, run_function: object, title: str = "RUN"): self.run_frm.grid(row=self.children_cnt_main() + 1, column=0, sticky=NW) self.run_btn.grid(row=0, column=0, sticky=NW) + def create_choose_number_of_body_parts_directionality_frm( + self, path_to_directionality_dir: str, run_function: object + ): + self.bp_cnt_frm = LabelFrame( + self.main_frm, + text="SELECT DIR", + font=Formats.LABELFRAME_HEADER_FORMAT.value, + ) + root, types_of_directionality, files = list(os.walk(path_to_directionality_dir))[0] + self.bp_cnt_dropdown = DropDownMenu( + self.bp_cnt_frm, + "# of body-parts directionality", + list(types_of_directionality), + "20", + ) + self.bp_cnt_dropdown.setChoices(types_of_directionality[0]) + self.bp_cnt_confirm_btn = Button( + self.bp_cnt_frm, + text="Confirm", + command=lambda: run_function, + ) + self.bp_cnt_frm.grid(row=0, sticky=NW) + self.bp_cnt_dropdown.grid(row=0, column=0, sticky=NW) + self.bp_cnt_confirm_btn.grid(row=0, column=1, sticky=NW) + def create_choose_number_of_body_parts_frm( - self, project_body_parts: List[str], run_function: object + self, project_body_parts: List[str], run_function: object ): self.bp_cnt_frm = LabelFrame( self.main_frm, @@ -244,7 +269,7 @@ def add_value_to_listbox(self, list_box: Listbox, value: float): list_box.insert(0, value) def add_values_to_several_listboxes( - self, list_boxes: List[Listbox], values: List[float] + self, list_boxes: List[Listbox], values: List[float] ): if len(list_boxes) != len(values): raise CountError(msg="Value count and listboxes count are not equal") @@ -270,7 +295,7 @@ def create_choose_bp_frm(self, project_body_parts, run_function): for bp_cnt in range(int(self.bp_cnt_dropdown.getChoices())): self.body_parts_dropdowns[bp_cnt] = DropDownMenu( self.body_part_frm, - f"Body-part {str(bp_cnt+1)}:", + f"Body-part {str(bp_cnt + 1)}:", project_body_parts, "25", ) @@ -326,7 +351,7 @@ def choose_bp_threshold_frm(self, parent: LabelFrame): self.probability_entry.grid(row=0, column=0, sticky=NW) def enable_dropdown_from_checkbox( - self, check_box_var: BooleanVar, dropdown_menus: List[DropDownMenu] + self, check_box_var: BooleanVar, dropdown_menus: List[DropDownMenu] ): if check_box_var.get(): for menu in dropdown_menus: @@ -336,13 +361,13 @@ def enable_dropdown_from_checkbox( menu.disable() def create_entry_boxes_from_entrybox( - self, count: int, parent: Frame, current_entries: list + self, count: int, parent: Frame, current_entries: list ): check_int(name="CLASSIFIER COUNT", value=count, min_value=1) for entry in current_entries: entry.destroy() for clf_cnt in range(int(count)): - entry = Entry_Box(parent, f"Classifier {str(clf_cnt+1)}:", labelwidth=15) + entry = Entry_Box(parent, f"Classifier {str(clf_cnt + 1)}:", labelwidth=15) current_entries.append(entry) entry.grid(row=clf_cnt + 2, column=0, sticky=NW) @@ -353,12 +378,12 @@ def create_animal_names_entry_boxes(self, animal_cnt: str): if not hasattr(self, "multi_animal_id_list"): self.multi_animal_id_list = [] for i in range(int(animal_cnt)): - self.multi_animal_id_list.append(f"Animal {i+1}") + self.multi_animal_id_list.append(f"Animal {i + 1}") self.animal_names_frm = Frame(self.animal_settings_frm, pady=5, padx=5) self.animal_name_entry_boxes = {} for i in range(int(animal_cnt)): self.animal_name_entry_boxes[i + 1] = Entry_Box( - self.animal_names_frm, f"Animal {str(i+1)} name: ", "25" + self.animal_names_frm, f"Animal {str(i + 1)} name: ", "25" ) if i <= len(self.multi_animal_id_list) - 1: self.animal_name_entry_boxes[i + 1].entry_set( @@ -369,10 +394,10 @@ def create_animal_names_entry_boxes(self, animal_cnt: str): self.animal_names_frm.grid(row=1, column=0, sticky=NW) def enable_entrybox_from_checkbox( - self, - check_box_var: BooleanVar, - entry_boxes: List[Entry_Box], - reverse: bool = False, + self, + check_box_var: BooleanVar, + entry_boxes: List[Entry_Box], + reverse: bool = False, ): if reverse: if check_box_var.get(): @@ -390,16 +415,16 @@ def enable_entrybox_from_checkbox( box.set_state("disable") def create_import_pose_menu( - self, parent_frm: Frame, idx_row: int = 0, idx_column: int = 0 + self, parent_frm: Frame, idx_row: int = 0, idx_column: int = 0 ): def run_call( - data_type: str, - interpolation: str, - smoothing: str, - smoothing_window: str, - animal_names: dict, - data_path: str, - tracking_data_type: str or None = None, + data_type: str, + interpolation: str, + smoothing: str, + smoothing_window: str, + animal_names: dict, + data_path: str, + tracking_data_type: str or None = None, ): smooth_settings = {} smooth_settings["Method"] = smoothing @@ -880,7 +905,7 @@ def import_menu(data_type_choice: str): self.data_type_dropdown.grid(row=0, column=0, sticky=NW) def create_import_videos_menu( - self, parent_frm: Frame, idx_row: int = 0, idx_column: int = 0 + self, parent_frm: Frame, idx_row: int = 0, idx_column: int = 0 ): def run_import(multiple_videos: bool): if multiple_videos: @@ -991,7 +1016,6 @@ def run_import(multiple_videos: bool): # #self.main_frm.config(width=e.x_root, height=e.y_root) # #self.main_frm.update() - # test = PopUpMixin(config_path='/Users/simon/Desktop/envs/troubleshooting/two_animals_16bp_032023/project_folder/project_config.ini', # title='ss') # test.create_import_pose_menu(parent_frm=test.main_frm) diff --git a/simba/pose_configurations/bp_names/bp_names.csv b/simba/pose_configurations/bp_names/bp_names.csv index 379594315..4f204cbd7 100644 --- a/simba/pose_configurations/bp_names/bp_names.csv +++ b/simba/pose_configurations/bp_names/bp_names.csv @@ -10,3 +10,5 @@ Ear_left_1,Ear_right_1,Nose_1,Tail_base_1,Ear_left_2,Right_ear_2,Nose_2,Tail_bas Ear_left_1,Ear_right_1,Nose_1,Center_1,Lat_left_1,Lat_right_1,Tail_base_1,Ear_left_2,Ear_right_2,Nose_2,Center_2,Lat_left_2,Lat_right_2,Tail_base_2,, Ear_left_1,Ear_right_1,Nose_1,Center_1,Lat_left_1,Lat_right_1,Tail_base_1,Tail_end_1,Ear_left_2,Ear_right_2,Nose_2,Center_2,Lat_left_2,Lat_right_2,Tail_base_2,Tail_end_2 3D,,,,,,,,,,,,,,, +nose,left_ear,right_ear,tail_base,bug_center +nose,right_hand,left_hand,right_leg,left_leg,tail_base diff --git a/simba/pose_configurations/configuration_names/pose_config_names.csv b/simba/pose_configurations/configuration_names/pose_config_names.csv index 8a28ee459..5ddb28e5b 100644 --- a/simba/pose_configurations/configuration_names/pose_config_names.csv +++ b/simba/pose_configurations/configuration_names/pose_config_names.csv @@ -10,3 +10,5 @@ Multi-animals; 4 body-parts Multi-animals; 7 body-parts Multi-animals; 8 body-parts 3D tracking +mouse_and_bug +from_below diff --git a/simba/pose_configurations/no_animals/no_animals.csv b/simba/pose_configurations/no_animals/no_animals.csv index dc113873d..66cb1e9a7 100644 --- a/simba/pose_configurations/no_animals/no_animals.csv +++ b/simba/pose_configurations/no_animals/no_animals.csv @@ -10,3 +10,5 @@ 2 2 1 +1 +1 diff --git a/simba/pose_importers/dlc_importer_csv.py b/simba/pose_importers/dlc_importer_csv.py index 0e040d810..92c58ad11 100644 --- a/simba/pose_importers/dlc_importer_csv.py +++ b/simba/pose_importers/dlc_importer_csv.py @@ -70,12 +70,13 @@ def import_dlc_csv(config_path: Union[str, os.PathLike], source: str) -> List[st new_file_name_wo_ext = new_file_name.split(".")[0] video_basename = os.path.basename(file_path) print(f"Importing {video_name} to SimBA project...") - if new_file_name_wo_ext in imported_file_names: - raise FileExistError( - "SIMBA IMPORT ERROR: {} already exist in project. Remove file from project or rename imported video file name before importing.".format( - new_file_name - ) - ) + #if new_file_name_wo_ext in imported_file_names: + # raise FileExistError( + # "SIMBA IMPORT ERROR: {} already exist in project. Remove file from project or rename imported video file name before importing.".format( + # new_file_name + # ) + # ) + shutil.copy(file_path, input_csv_dir) shutil.copy(file_path, original_file_name_dir) os.rename( diff --git a/simba/ui/pop_ups/append_bodypart_directionality_features_pop_up.py b/simba/ui/pop_ups/append_bodypart_directionality_features_pop_up.py new file mode 100644 index 000000000..430a57ede --- /dev/null +++ b/simba/ui/pop_ups/append_bodypart_directionality_features_pop_up.py @@ -0,0 +1,59 @@ +__author__ = "Tzuk Polinsky" + +import glob +import os + +import pandas + +from simba.mixins.config_reader import ConfigReader +from simba.mixins.pop_up_mixin import PopUpMixin +from simba.roi_tools.ROI_feature_analyzer import ROIFeatureCreator +from simba.utils.errors import NoROIDataError + + +class AppendBodyPartDirectionalityFeaturesPopUp(PopUpMixin, ConfigReader): + def __init__(self, config_path: str): + PopUpMixin.__init__(self, config_path=config_path, title="APPEND BODY PART DIRECTIONALITY FEATURES") + ConfigReader.__init__(self, config_path=config_path) + if not os.path.isfile(self.roi_coordinates_path): + raise NoROIDataError( + msg="SIMBA ERROR: No ROIs have been defined. Please define ROIs before appending ROI-based features" + ) + self.create_choose_number_of_body_parts_directionality_frm( + path_to_directionality_dir=self.body_part_directionality_df_dir, run_function=self.run + ) + # self.main_frm.mainloop() + + def run(self): + settings = {} + settings["body_parts_directionality"] = {} + for bp_cnt, bp_dropdown in self.bp_cnt_dropdown.items(): + settings["body_parts_directionality"] = bp_dropdown.getChoices() + directionality_data_path = os.path.join(self.body_part_directionality_df_dir, + settings["body_parts_directionality"]) + data_dic = {} + for root, dirs, files in os.walk(directionality_data_path): + for file in files: + data = pandas.read_csv(os.path.join(root, file))["Directing_BOOL"] + data_dic[file] = data + files_found = glob.glob( + self.outlier_corrected_dir + "/*." + self.file_type + ) + concatenate_data = {} + for file in files_found: + data = pandas.read_csv(os.path.join(root, file)) + c_data = pandas.concat([data, data_dic[file]]) + concatenate_data[file] = c_data + for file_name, data in concatenate_data.items(): + save_path = os.path.join( + self.features_dir, file_name + "." + self.file_type + ) + data.to_csv(save_path) + + # roi_feature_creator = ROIFeatureCreator( + # config_path=self.config_path, settings=settings + # ) + # roi_feature_creator.run() + # roi_feature_creator.save() + +# _ = AppendROIFeaturesByBodyPartPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/two_animals_16bp_032023/project_folder/project_config.ini') From 914aef226e1c1a33812d99eaeb37b27b34849bd6 Mon Sep 17 00:00:00 2001 From: tzuk Date: Wed, 8 Nov 2023 09:58:25 +0200 Subject: [PATCH 02/13] 1. added the directionality by body part to the config creation 2. changed the value in the pop-up to the enum value for generalization --- .../pop_ups/direction_animal_to_bodypart_settings_pop_up.py | 4 ++-- simba/utils/config_creator.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/simba/ui/pop_ups/direction_animal_to_bodypart_settings_pop_up.py b/simba/ui/pop_ups/direction_animal_to_bodypart_settings_pop_up.py index 0c272466f..1ef907e4c 100644 --- a/simba/ui/pop_ups/direction_animal_to_bodypart_settings_pop_up.py +++ b/simba/ui/pop_ups/direction_animal_to_bodypart_settings_pop_up.py @@ -58,13 +58,13 @@ def run(self): for animal_cnt, animal_name in enumerate(self.animal_bp_dict.keys()): try: self.config.set(ConfigKey.DIRECTIONALITY_SETTINGS.value, - "bodypart_direction", + ConfigKey.BODYPART_DIRECTION_VALUE.value, self.criterion_dropdowns[animal_name]["bp"].getChoices(), ) except configparser.NoSectionError as e: self.config.add_section(ConfigKey.DIRECTIONALITY_SETTINGS.value) self.config.set(ConfigKey.DIRECTIONALITY_SETTINGS.value, - "bodypart_direction", + ConfigKey.BODYPART_DIRECTION_VALUE.value, self.criterion_dropdowns[animal_name]["bp"].getChoices(), ) with open(self.config_path, "w") as f: diff --git a/simba/utils/config_creator.py b/simba/utils/config_creator.py index 7980a84fc..b07171739 100644 --- a/simba/utils/config_creator.py +++ b/simba/utils/config_creator.py @@ -269,7 +269,8 @@ def __create_configparser_config(self): self.config[ConfigKey.OUTLIER_SETTINGS.value][ ConfigKey.LOCATION_CRITERION.value ] = Dtypes.NONE.value - + self.config.add_section(ConfigKey.DIRECTIONALITY_SETTINGS.value) + self.config[ConfigKey.DIRECTIONALITY_SETTINGS.value][ConfigKey.BODYPART_DIRECTION_VALUE.value] = Dtypes.NONE.value self.config_path = os.path.join(self.project_folder, "project_config.ini") with open(self.config_path, "w") as file: self.config.write(file) From 4b98366e46513dbfe9a0dfd7f041e8d6214456d9 Mon Sep 17 00:00:00 2001 From: tzuk Date: Sun, 12 Nov 2023 12:21:03 +0200 Subject: [PATCH 03/13] preventing error in config creation --- simba/utils/config_creator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/simba/utils/config_creator.py b/simba/utils/config_creator.py index b07171739..98f338891 100644 --- a/simba/utils/config_creator.py +++ b/simba/utils/config_creator.py @@ -269,7 +269,6 @@ def __create_configparser_config(self): self.config[ConfigKey.OUTLIER_SETTINGS.value][ ConfigKey.LOCATION_CRITERION.value ] = Dtypes.NONE.value - self.config.add_section(ConfigKey.DIRECTIONALITY_SETTINGS.value) self.config[ConfigKey.DIRECTIONALITY_SETTINGS.value][ConfigKey.BODYPART_DIRECTION_VALUE.value] = Dtypes.NONE.value self.config_path = os.path.join(self.project_folder, "project_config.ini") with open(self.config_path, "w") as file: From 10949836cef89533d6b9af275b85b52140eef308 Mon Sep 17 00:00:00 2001 From: tzuk Date: Sun, 12 Nov 2023 12:36:39 +0200 Subject: [PATCH 04/13] preventing error in seaborn when creating plots, because in seabrom 0.9.0 it has np.float which was removed in later versions of numpy --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8a1dd9a19..5531d89a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ pandas==0.25.3;python_version=="3.6" pandas;python_version>="3.9" scikit-image scipy -seaborn == 0.9.0 +seaborn scikit-learn tabulate == 0.8.3 tqdm == 4.30.0 @@ -32,7 +32,7 @@ plotly == 4.9.0 statsmodels cefpython3 == 66.0 pyarrow == 6.0.1 -shap == 0.35.0 +shap tables>=3.6.1 xlrd==1.2.0 trafaret==2.1.1 From c4ffdac73e9afe13a54b74ed9d7453c97d2fc601 Mon Sep 17 00:00:00 2001 From: tzuk Date: Mon, 13 Nov 2023 12:51:40 +0200 Subject: [PATCH 05/13] I create my body part directionality in a sub folder inside the body_part_directionality_df_dir so this is an implementation of how to read the directory tree --- simba/mixins/config_reader.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/simba/mixins/config_reader.py b/simba/mixins/config_reader.py index 7ba217849..e1a8cfc76 100644 --- a/simba/mixins/config_reader.py +++ b/simba/mixins/config_reader.py @@ -121,22 +121,25 @@ def __init__(self, config_path: str, read_video_info: bool = True): Dtypes.INT.value, ) self.clf_names = get_all_clf_names(config=self.config, target_cnt=self.clf_cnt) - self.feature_file_paths = glob.glob(self.features_dir + "/*." + self.file_type) - self.target_file_paths = glob.glob(self.targets_folder + "/*." + self.file_type) - self.input_csv_paths = glob.glob(self.input_csv_dir + "/*." + self.file_type) - self.body_part_directionality_paths = glob.glob( - self.body_part_directionality_df_dir + "/*." + self.file_type - ) - self.outlier_corrected_paths = glob.glob( - self.outlier_corrected_dir + "/*." + self.file_type - ) - self.outlier_corrected_movement_paths = glob.glob( - self.outlier_corrected_movement_dir + "/*." + self.file_type - ) + self.feature_file_paths = glob.glob(os.path.join(self.features_dir , "*." + self.file_type)) + self.target_file_paths = glob.glob(os.path.join(self.targets_folder ,"*." + self.file_type)) + self.input_csv_paths = glob.glob(os.path.join(self.input_csv_dir , "*." + self.file_type)) + self.body_part_directionality_paths = [] + for root,dirs,files in os.walk(self.body_part_directionality_df_dir): + for d in dirs: + for root2,dirs2,files2 in os.walk(os.path.join(root,d)): + for file in glob.glob(os.path.join(root2, "*." + self.file_type)): + self.body_part_directionality_paths.append(file) + self.outlier_corrected_paths = glob.glob(os.path.join( + self.outlier_corrected_dir , "*." + self.file_type + )) + self.outlier_corrected_movement_paths = glob.glob(os.path.join( + self.outlier_corrected_movement_dir , "*." + self.file_type + )) self.cpu_cnt, self.cpu_to_use = find_core_cnt() - self.machine_results_paths = glob.glob( - self.machine_results_dir + "/*." + self.file_type - ) + self.machine_results_paths = glob.glob(os.path.join( + self.machine_results_dir , "*." + self.file_type + )) self.logs_path = os.path.join(self.project_path, "logs") self.body_parts_path = os.path.join(self.project_path, Paths.BP_NAMES.value) check_file_exist_and_readable(file_path=self.body_parts_path) From 219f7cff1da5db6cc5fe653a881f59ab1a3cb41a Mon Sep 17 00:00:00 2001 From: tzuk Date: Mon, 13 Nov 2023 12:52:04 +0200 Subject: [PATCH 06/13] I create my body part directionality in a sub folder inside the body_part_directionality_df_dir so this is an implementation of how to read the directory tree --- .../directing_animal_to_bodypart.py | 14 ++++-- ...irecting_animals_to_bodypart_visualizer.py | 49 ++++++++++--------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/simba/data_processors/directing_animal_to_bodypart.py b/simba/data_processors/directing_animal_to_bodypart.py index 2426a8cbb..d893d10ba 100644 --- a/simba/data_processors/directing_animal_to_bodypart.py +++ b/simba/data_processors/directing_animal_to_bodypart.py @@ -167,14 +167,18 @@ def create_directionality_dfs(self): def read_directionality_dfs(self): results = {} + body_parts_directionality = [] for file_cnt, file_path in enumerate(self.body_part_directionality_paths): video_timer = SimbaTimer(start=True) - _, file_name, _ = get_fn_ext(file_path) - results[file_name] = read_df(file_path, self.file_type) + dir_name, file_name, _ = get_fn_ext(file_path) + bp_name = os.path.basename(dir_name) + body_parts_directionality.append(bp_name) + key = file_name+"_"+bp_name + results[key] = read_df(file_path, self.file_type) video_timer.stop_timer() print( "read body part directionality data completed for video {} ({}/{}, elapsed time: {}s)...".format( - file_name, + key, str(file_cnt + 1), str(len(self.outlier_corrected_paths)), video_timer.elapsed_time_str, @@ -183,7 +187,7 @@ def read_directionality_dfs(self): stdout_success( msg='reading body part directionality data completed' ) - return results + return results,body_parts_directionality def save_directionality_dfs(self): """ @@ -239,7 +243,7 @@ def summary_statistics(self): .set_index("Video") ) self.save_path = os.path.join( - self.logs_path, "Body_part_directions_data_{}_{}.csv".format(str(self.bodypart_direction, self.datetime)) + self.logs_path, "Body_part_directions_data_{}_{}.csv".format(self.bodypart_direction,str( self.datetime)) ) self.summary_df.to_csv(self.save_path) self.timer.stop_timer() diff --git a/simba/plotting/directing_animals_to_bodypart_visualizer.py b/simba/plotting/directing_animals_to_bodypart_visualizer.py index b5d5737ce..272c9a1f2 100644 --- a/simba/plotting/directing_animals_to_bodypart_visualizer.py +++ b/simba/plotting/directing_animals_to_bodypart_visualizer.py @@ -59,7 +59,7 @@ def __init__(self, config_path: str, data_path: str, style_attr: dict): self.direction_analyzer = DirectingAnimalsToBodyPartAnalyzer( config_path=config_path, ) - self.data_dict = self.direction_analyzer.read_directionality_dfs() + self.data_dict, self.body_parts_directionality_names = self.direction_analyzer.read_directionality_dfs() self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) self.style_attr, self.pose_colors = style_attr, [] self.colors = get_color_dict() @@ -75,7 +75,6 @@ def __init__(self, config_path: str, data_path: str, style_attr: dict): video_dir=self.video_dir, filename=self.video_name) if not os.path.exists(self.directing_body_part_animal_video_output_path): os.makedirs(self.directing_body_part_animal_video_output_path) - print(f"Processing video {self.video_name}...") def run(self): """ @@ -88,25 +87,29 @@ def run(self): """ self.data_df = read_df(self.data_path, file_type=self.file_type) - self.video_save_path = os.path.join( - self.directing_body_part_animal_video_output_path, self.video_name + ".mp4" - ) + self.cap = cv2.VideoCapture(self.video_path) self.video_meta_data = get_video_meta_data(self.video_path) - self.video_data = self.data_dict[self.video_name] - self.writer = cv2.VideoWriter( - self.video_save_path, - self.fourcc, - self.video_meta_data["fps"], - (self.video_meta_data["width"], self.video_meta_data["height"]), - ) - self.__create_video() - self.timer.stop_timer() + for bp in self.body_parts_directionality_names: + key = self.video_name + "_" + bp + print(f"Processing video {key}...") + + self.video_save_path = os.path.join( + self.directing_body_part_animal_video_output_path, self.video_name + "_" + bp + ".mp4" + ) + self.video_data = self.data_dict[key] + self.writer = cv2.VideoWriter( + self.video_save_path, + self.fourcc, + self.video_meta_data["fps"], + (self.video_meta_data["width"], self.video_meta_data["height"]), + ) + self.__create_video() def __draw_individual_lines(self, animal_img_data: pd.DataFrame, img: np.array): color = self.direction_colors[0] - bp_x_name = "Animal_"+self.bodypart_direction+"_x" - bp_y_name = "Animal_"+self.bodypart_direction+"_y" + bp_x_name = "Animal_" + self.bodypart_direction + "_x" + bp_y_name = "Animal_" + self.bodypart_direction + "_y" for cnt, (i, r) in enumerate(animal_img_data.iterrows()): if self.style_attr["Direction_color"] == "Random": color = random.sample(self.direction_colors[0], 1)[0] @@ -162,13 +165,13 @@ def __create_video(self): img_cnt += 1 self.writer.write(np.uint8(img)) - # print( - # "Frame: {} / {}. Video: {}".format( - # str(img_cnt), - # str(self.video_meta_data["frame_count"]), - # self.video_name, - # ) - # ) + print( + "Frame: {} / {}. Video: {}".format( + str(img_cnt), + str(self.video_meta_data["frame_count"]), + self.video_name, + ) + ) else: self.cap.release() self.writer.release() From ec253298c4b8d5ff27cbe81cbbe10bf38fb71908 Mon Sep 17 00:00:00 2001 From: tzuk Date: Mon, 13 Nov 2023 12:52:13 +0200 Subject: [PATCH 07/13] I create my body part directionality in a sub folder inside the body_part_directionality_df_dir so this is an implementation of how to read the directory tree --- simba/ui/pop_ups/movement_analysis_pop_up.py | 1 + simba/ui/pop_ups/movement_analysis_time_bins_pop_up.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/simba/ui/pop_ups/movement_analysis_pop_up.py b/simba/ui/pop_ups/movement_analysis_pop_up.py index 71219ecef..e0a1a3d15 100644 --- a/simba/ui/pop_ups/movement_analysis_pop_up.py +++ b/simba/ui/pop_ups/movement_analysis_pop_up.py @@ -80,6 +80,7 @@ def run(self): ) movement_processor.run() movement_processor.save() + self.root.destroy() pass diff --git a/simba/ui/pop_ups/movement_analysis_time_bins_pop_up.py b/simba/ui/pop_ups/movement_analysis_time_bins_pop_up.py index fd3ac5917..27dd187a4 100644 --- a/simba/ui/pop_ups/movement_analysis_time_bins_pop_up.py +++ b/simba/ui/pop_ups/movement_analysis_time_bins_pop_up.py @@ -87,6 +87,6 @@ def run(self): body_parts=body_parts, ) time_bin_movement_analyzer.run() - + self.root.destroy() # MovementAnalysisTimeBinsPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/locomotion/project_folder/project_config.ini') From b3744334e4f2b95e390ce4a9c4555d07089c4e4f Mon Sep 17 00:00:00 2001 From: tzuk polinsky Date: Sun, 3 Dec 2023 19:05:55 +0200 Subject: [PATCH 08/13] changes: 1. added freezing feature extractor 2. fixed bug in data.py, that didnt run the user defined class 3. changed the arg name in read_all_files_in_folder_mp_futures func to be more specific 4. fixed error in "apply all" button in ROI menu. 5. move the button of extract features below the "apply" button, UX reason, but not important 6. added a new function that take the annotation file and concatenate it with the feature file. 7. --- simba/SimBA.py | 4 +- .../feature_extractor_freezing.py | 145 ++++++ simba/feature_extractors/feature_subsets.py | 4 +- simba/mixins/config_reader.py | 1 + simba/mixins/train_model_mixin.py | 422 ++++++++++-------- simba/model/inference_validation.py | 2 + simba/model/train_rf.py | 31 +- simba/roi_tools/ROI_multiply.py | 6 +- simba/ui/pop_ups/validation_plot_pop_up.py | 1 + simba/utils/data.py | 2 +- 10 files changed, 418 insertions(+), 200 deletions(-) create mode 100644 simba/feature_extractors/feature_extractor_freezing.py diff --git a/simba/SimBA.py b/simba/SimBA.py index 9da296c95..bf6cd2fa1 100644 --- a/simba/SimBA.py +++ b/simba/SimBA.py @@ -1147,8 +1147,8 @@ def activate(box, *args): button_skipOC.grid(row=2, sticky=W, pady=5) label_extractfeatures.grid(row=0, column=0, sticky=NW) - button_extractfeatures.grid(row=0, column=0, sticky=NW) - labelframe_usrdef.grid(row=1, column=0, sticky=NW, pady=5) + button_extractfeatures.grid(row=1, column=0, sticky=NW) + labelframe_usrdef.grid(row=0, column=0, sticky=NW, pady=15) userscript.grid(row=1, column=0, sticky=NW) self.scriptfile.grid(row=2, column=0, sticky=NW) diff --git a/simba/feature_extractors/feature_extractor_freezing.py b/simba/feature_extractors/feature_extractor_freezing.py new file mode 100644 index 000000000..9ea8d2b2c --- /dev/null +++ b/simba/feature_extractors/feature_extractor_freezing.py @@ -0,0 +1,145 @@ +__author__ = "Tzuk Polinsky" + +import os +from itertools import product + +import numpy as np +import pandas as pd + +from simba.mixins.config_reader import ConfigReader +from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin +from simba.utils.checks import check_str +from simba.utils.enums import Paths +from simba.utils.printing import SimbaTimer, stdout_success +from simba.utils.read_write import get_fn_ext, read_df, write_df + + +class MiceFreezingFeatureExtractor(ConfigReader, FeatureExtractionMixin): + """ + Generic featurizer of data within SimBA project using user-defined body-parts in the pose-estimation data. + Results are stored in the `project_folder/csv/features_extracted` directory of the SimBA project. + + :parameter str config_path: path to SimBA project config file in Configparser format + + .. note:: + `Feature extraction tutorial `__. + + Examples + ---------- + >>> feature_extractor = MiceFreezingFeatureExtractor(config_path='MyProjectConfig') + >>> feature_extractor.run() + """ + + def __init__(self, config_path: str): + FeatureExtractionMixin.__init__(self, config_path=config_path) + ConfigReader.__init__(self, config_path=config_path) + print( + "Extracting features from {} file(s)...".format(str(len(self.files_found))) + ) + + # Function to calculate the direction vector + def angle_between_vectors(self, v1, v2): + unit_vector_1 = v1 / np.linalg.norm(v1) + unit_vector_2 = v2 / np.linalg.norm(v2) + dot_product = unit_vector_2.dot(unit_vector_1.T) + angle = np.arccos(dot_product) + return np.degrees(angle) + + def calculate_direction_vector(self, from_point, to_point): + return np.array(to_point) - np.array(from_point) + + def extract_features(self, input_file_path: str, window_size: int, video_center: [int, int], pixel_mm: float,directionality_data:pd.DataFrame): + print("Calculating freezing features ...") + + input_data = pd.read_csv(input_file_path) + output_data = pd.DataFrame( + columns=["activity"]) + columns_to_drop = [col for col in input_data.columns if ('bug' in col) or ("_p" in col)] + columns_to_drop.append("Unnamed: 0") + without_bug = input_data.drop(columns_to_drop, axis=1) + + body_parts_diffs = without_bug.diff(axis=0) + body_parts_diffs["nose_x"]*=5 + body_parts_diffs["nose_y"]*=5 + + time_point_diff = body_parts_diffs.sum(axis=1) + rolling_windows = time_point_diff.rolling(window=window_size, min_periods=1).sum() + output_data["activity"] = rolling_windows.abs().fillna(0) + center_cols = [colName for colName in without_bug.columns if ("center" in colName) and ("_p") not in colName] + nose_cols = [colName for colName in without_bug.columns if ("nose" in colName) and ("_p") not in colName] + centers = without_bug[center_cols].to_numpy() + noses = without_bug[nose_cols].to_numpy() + angles = [] + for i, center in enumerate(centers): + nose = noses[i] + vector_fixed_to_center = self.calculate_direction_vector(video_center, center) + vector_center_to_nose = self.calculate_direction_vector(center, nose) + angles.append(self.angle_between_vectors(vector_center_to_nose, vector_fixed_to_center)) + # output_data["nose_direction"] = angles + angles_df = pd.DataFrame(angles) + # angles_diff = angles_df.diff() + # angles_diff_sum = angles_diff.rolling(window=window_size, min_periods=1).sum() + # output_data["nose_direction_sum_of_diffs"] = angles_diff_sum.abs().fillna(0) + output_data["nose_direction_avg"] = angles_df.rolling(window=window_size, min_periods=1).mean().fillna(0) + # directionality_rolling = directionality_data.rolling(window=window_size, min_periods=1) + # output_data["amount_of_looking_at_bug"] = directionality_rolling.sum() + # onsets = [-1] * len(output_data["amount_of_looking_at_bug"]) + # for j, rol in enumerate(directionality_rolling): + # for i, r in enumerate(rol): + # if r: + # onsets[j] = i + # break + # output_data["looking_at_bug_onset"] = onsets + return output_data + + def run(self): + """ + Method to compute and save features to disk. Results are saved in the `project_folder/csv/features_extracted` + directory of the SimBA project. + + Returns + ------- + None + """ + self.roi_coordinates_path = os.path.join( + self.logs_path, Paths.ROI_DEFINITIONS.value + ) + polygons = pd.read_hdf(self.roi_coordinates_path, key="polygons") + directionality_dir_path = os.path.join(self.body_part_directionality_df_dir, "bug") + for file_cnt, file_path in enumerate(self.files_found): + video_timer = SimbaTimer(start=True) + print( + "Extracting features for video {}/{}...".format( + str(file_cnt + 1), str(len(self.files_found)) + ) + ) + _, file_name, _ = get_fn_ext(file_path) + current_polygon = polygons[polygons["Video"] == file_name] + directionality_data_path = os.path.join(directionality_dir_path, file_name + ".csv") + directionality_data = pd.read_csv(directionality_data_path)["Directing_BOOL"] + check_str("file name", file_name) + video_settings, self.px_per_mm, fps = self.read_video_info( + video_name=file_name + ) + self.data_df = self.extract_features(file_path, 50, ( + current_polygon["Center_X"].values[0], current_polygon["Center_Y"].values[0]), + video_settings["pixels/mm"].values[0],directionality_data) + save_path = os.path.join(self.save_dir, file_name + "." + self.file_type) + self.data_df = self.data_df.reset_index(drop=True).fillna(0) + write_df(df=self.data_df, file_type=self.file_type, save_path=save_path) + video_timer.stop_timer() + print( + f"Feature extraction complete for video {file_name} (elapsed time: {video_timer.elapsed_time_str}s)" + ) + print( + f"Feature extraction file for video {file_name} saved to {save_path})" + ) + + self.timer.stop_timer() + stdout_success( + f"Feature extraction complete for {str(len(self.files_found))} video(s). Results are saved inside the project_folder/csv/features_extracted directory", + elapsed_time=self.timer.elapsed_time_str, + ) + +# test = UserDefinedFeatureExtractor(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') +# test.run() diff --git a/simba/feature_extractors/feature_subsets.py b/simba/feature_extractors/feature_subsets.py index bd6e78e21..16da24b8b 100644 --- a/simba/feature_extractors/feature_subsets.py +++ b/simba/feature_extractors/feature_subsets.py @@ -456,7 +456,7 @@ def append_to_data(self): self.features_extracted_temp_path + f"/*.{self.file_type}" ) self.data_df = self.read_all_files_in_folder_mp_futures( - file_paths=file_paths, file_type=self.file_type + annotations_file_paths=file_paths, file_type=self.file_type ) self.check_raw_dataset_integrity( df=self.data_df, logs_path=self.logs_path @@ -469,7 +469,7 @@ def append_to_data(self): self.targets_inserted_temp_path + f"/*.{self.file_type}" ) self.data_df = self.read_all_files_in_folder_mp_futures( - file_paths=file_paths, file_type=self.file_type + annotations_file_paths=file_paths, file_type=self.file_type ) self.check_raw_dataset_integrity( df=self.data_df, logs_path=self.logs_path diff --git a/simba/mixins/config_reader.py b/simba/mixins/config_reader.py index e1a8cfc76..8f648668f 100644 --- a/simba/mixins/config_reader.py +++ b/simba/mixins/config_reader.py @@ -559,6 +559,7 @@ def drop_bp_cords(self, df: pd.DataFrame) -> pd.DataFrame: BodypartColumnNotFoundWarning( msg=f"SimBA could not drop body-part coordinates, some body-part names are missing in dataframe. SimBA expected the following body-parts, that could not be found inside the file: {missing_body_part_fields}" ) + return df else: return df.drop(self.bp_col_names, axis=1) diff --git a/simba/mixins/train_model_mixin.py b/simba/mixins/train_model_mixin.py index 8c48f8bad..96d0cefa7 100644 --- a/simba/mixins/train_model_mixin.py +++ b/simba/mixins/train_model_mixin.py @@ -1,6 +1,5 @@ __author__ = "Simon Nilsson" - import warnings warnings.simplefilter(action="ignore", category=FutureWarning) @@ -79,10 +78,10 @@ def __init__(self): pass def read_all_files_in_folder( - self, - file_paths: List[str], - file_type: str, - classifier_names: Optional[List[str]] = None, + self, + file_paths: List[str], + file_type: str, + classifier_names: Optional[List[str]] = None, ) -> pd.DataFrame: """ Read in all data files in a folder to a single pd.DataFrame for downstream ML algo. @@ -139,8 +138,8 @@ def read_all_files_in_folder( source=self.__class__.__name__, ) df_concat = df_concat.loc[ - :, ~df_concat.columns.str.contains("^Unnamed") - ].fillna(0) + :, ~df_concat.columns.str.contains("^Unnamed") + ].fillna(0) timer.stop_timer() memory_size = get_memory_usage_of_df(df=df_concat) print( @@ -154,7 +153,7 @@ def read_all_files_in_folder( return df_concat.astype(np.float32) def read_in_all_model_names_to_remove( - self, config: configparser.ConfigParser, model_cnt: int, clf_name: str + self, config: configparser.ConfigParser, model_cnt: int, clf_name: str ) -> List[str]: """ Helper to find all field names that are annotations but are not the target. @@ -178,7 +177,7 @@ def read_in_all_model_names_to_remove( return annotation_cols_to_remove def delete_other_annotation_columns( - self, df: pd.DataFrame, annotations_lst: List[str] + self, df: pd.DataFrame, annotations_lst: List[str] ) -> pd.DataFrame: """ Helper to drop fields that contain annotations which are not the target. @@ -196,7 +195,7 @@ def delete_other_annotation_columns( return df def split_df_to_x_y( - self, df: pd.DataFrame, clf_name: str + self, df: pd.DataFrame, clf_name: str ) -> (pd.DataFrame, pd.DataFrame): """ Helper to split dataframe into features and target. @@ -216,7 +215,7 @@ def split_df_to_x_y( return df, y def random_undersampler( - self, x_train: np.ndarray, y_train: np.ndarray, sample_ratio: float + self, x_train: np.ndarray, y_train: np.ndarray, sample_ratio: float ) -> (pd.DataFrame, pd.DataFrame): """ Helper to perform random under-sampling of behavior-absent frames in a dataframe. @@ -252,7 +251,7 @@ def random_undersampler( return self.split_df_to_x_y(data_df, y_train.name) def smoteen_oversampler( - self, x_train: pd.DataFrame, y_train: pd.DataFrame, sample_ratio: float + self, x_train: pd.DataFrame, y_train: pd.DataFrame, sample_ratio: float ) -> (np.ndarray, np.ndarray): """ Helper to perform SMOTEEN oversampling of behavior-present annotations. @@ -272,10 +271,10 @@ def smoteen_oversampler( return smt.fit_sample(x_train, y_train) def smote_oversampler( - self, - x_train: pd.DataFrame or np.array, - y_train: pd.DataFrame or np.array, - sample_ratio: float, + self, + x_train: pd.DataFrame or np.array, + y_train: pd.DataFrame or np.array, + sample_ratio: float, ) -> (np.ndarray, np.ndarray): """ Helper to perform SMOTE oversampling of behavior-present annotations. @@ -294,14 +293,14 @@ def smote_oversampler( return smt.fit_sample(x_train, y_train) def calc_permutation_importance( - self, - x_test: np.ndarray, - y_test: np.ndarray, - clf: RandomForestClassifier, - feature_names: List[str], - clf_name: str, - save_dir: Union[str, os.PathLike], - save_file_no: Optional[int] = None, + self, + x_test: np.ndarray, + y_test: np.ndarray, + clf: RandomForestClassifier, + feature_names: List[str], + clf_name: str, + save_dir: Union[str, os.PathLike], + save_file_no: Optional[int] = None, ) -> None: """ Helper to calculate feature permutation importance scores. @@ -356,15 +355,15 @@ def calc_permutation_importance( ) def calc_learning_curve( - self, - x_y_df: pd.DataFrame, - clf_name: str, - shuffle_splits: int, - dataset_splits: int, - tt_size: float, - rf_clf: RandomForestClassifier, - save_dir: str, - save_file_no: Optional[int] = None, + self, + x_y_df: pd.DataFrame, + clf_name: str, + shuffle_splits: int, + dataset_splits: int, + tt_size: float, + rf_clf: RandomForestClassifier, + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to compute random forest learning curves with cross-validation. @@ -437,13 +436,13 @@ def calc_learning_curve( ) def calc_pr_curve( - self, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.DataFrame, - clf_name: str, - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.DataFrame, + clf_name: str, + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to compute random forest precision-recall curve. @@ -469,10 +468,10 @@ def calc_pr_curve( pr_df["PRECISION"] = precision pr_df["RECALL"] = recall pr_df["F1"] = ( - 2 - * pr_df["RECALL"] - * pr_df["PRECISION"] - / (pr_df["RECALL"] + pr_df["PRECISION"]) + 2 + * pr_df["RECALL"] + * pr_df["PRECISION"] + / (pr_df["RECALL"] + pr_df["PRECISION"]) ) thresholds = list(thresholds) thresholds.insert(0, 0.00) @@ -492,13 +491,13 @@ def calc_pr_curve( ) def create_example_dt( - self, - rf_clf: RandomForestClassifier, - clf_name: str, - feature_names: List[str], - class_names: List[str], - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + clf_name: str, + feature_names: List[str], + class_names: List[str], + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to produce visualization of random forest decision tree using graphviz. @@ -538,13 +537,13 @@ def create_example_dt( call(command, shell=True) def create_clf_report( - self, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.DataFrame, - class_names: List[str], - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.DataFrame, + class_names: List[str], + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to create classifier truth table report. @@ -588,12 +587,12 @@ def create_clf_report( ) def create_x_importance_log( - self, - rf_clf: RandomForestClassifier, - x_names: List[str], - clf_name: str, - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_names: List[str], + clf_name: str, + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to save gini or entropy based feature importance scores. @@ -627,13 +626,13 @@ def create_x_importance_log( df.to_csv(self.f_importance_save_path, index=False) def create_x_importance_bar_chart( - self, - rf_clf: RandomForestClassifier, - x_names: list, - clf_name: str, - save_dir: str, - n_bars: int, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_names: list, + clf_name: str, + save_dir: str, + n_bars: int, + save_file_no: Optional[int] = None, ) -> None: """ Helper to create a bar chart displaying the top N gini or entropy feature importance scores. @@ -689,12 +688,12 @@ def create_x_importance_bar_chart( plt.close("all") def dviz_classification_visualization( - self, - x_train: np.ndarray, - y_train: np.ndarray, - clf_name: str, - class_names: List[str], - save_dir: str, + self, + x_train: np.ndarray, + y_train: np.ndarray, + clf_name: str, + class_names: List[str], + save_dir: str, ) -> None: """ Helper to create visualization of example decision tree using dtreeviz. @@ -735,7 +734,7 @@ def dviz_classification_visualization( @staticmethod def split_and_group_df( - df: pd.DataFrame, splits: int, include_split_order: bool = True + df: pd.DataFrame, splits: int, include_split_order: bool = True ) -> (List[pd.DataFrame], int): """ Helper to split a dataframe for multiprocessing. If include_split_order, then include the group number @@ -749,18 +748,18 @@ def split_and_group_df( return data_arr, obs_per_split def create_shap_log( - self, - ini_file_path: str, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.Series, - x_names: List[str], - clf_name: str, - cnt_present: int, - cnt_absent: int, - save_path: str, - save_it: int = 100, - save_file_no: Optional[int] = None, + self, + ini_file_path: str, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.Series, + x_names: List[str], + clf_name: str, + cnt_present: int, + cnt_absent: int, + save_path: str, + save_it: int = 100, + save_file_no: Optional[int] = None, ) -> None: """ Compute SHAP values for a random forest classifier. @@ -912,7 +911,7 @@ def print_machine_model_information(self, model_dict: dict) -> None: print(f"{table} {Defaults.STR_SPLIT_DELIMITER.value}TABLE") def create_meta_data_csv_training_one_model( - self, meta_data_lst: list, clf_name: str, save_dir: Union[str, os.PathLike] + self, meta_data_lst: list, clf_name: str, save_dir: Union[str, os.PathLike] ) -> None: """ Helper to save single model meta data (hyperparameters, sampling settings etc.) from list format into SimBA @@ -930,7 +929,7 @@ def create_meta_data_csv_training_one_model( out_df.to_csv(save_path) def create_meta_data_csv_training_multiple_models( - self, meta_data, clf_name, save_dir, save_file_no: Optional[int] = None + self, meta_data, clf_name, save_dir, save_file_no: Optional[int] = None ) -> None: print("Saving model meta data file...") save_path = os.path.join(save_dir, f"{clf_name}_{str(save_file_no)}_meta.csv") @@ -938,11 +937,11 @@ def create_meta_data_csv_training_multiple_models( out_df.to_csv(save_path) def save_rf_model( - self, - rf_clf: RandomForestClassifier, - clf_name: str, - save_dir: Union[str, os.PathLike], - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + clf_name: str, + save_dir: Union[str, os.PathLike], + save_file_no: Optional[int] = None, ) -> None: """ Helper to save pickled classifier object to disk. @@ -962,7 +961,7 @@ def save_rf_model( pickle.dump(rf_clf, open(save_path, "wb")) def get_model_info( - self, config: configparser.ConfigParser, model_cnt: int + self, config: configparser.ConfigParser, model_cnt: int ) -> Dict[int, Any]: """ Helper to read in N SimBA random forest config meta files to python dict memory. @@ -983,8 +982,8 @@ def get_model_info( ) continue if ( - config.get("SML settings", "model_path_" + str(n + 1)) - == "No file selected" + config.get("SML settings", "model_path_" + str(n + 1)) + == "No file selected" ): MissingUserInputWarning( msg=f'Skipping {str(config.get("SML settings", "target_name_" + str(n + 1)))} classifier analysis: The classifier path is set to "No file selected', @@ -1026,7 +1025,7 @@ def get_model_info( return model_dict def get_all_clf_names( - self, config: configparser.ConfigParser, target_cnt: int + self, config: configparser.ConfigParser, target_cnt: int ) -> List[str]: """ Helper to get all classifier names in a SimBA project. @@ -1054,10 +1053,10 @@ def get_all_clf_names( return model_names def insert_column_headers_for_outlier_correction( - self, - data_df: pd.DataFrame, - new_headers: List[str], - filepath: Union[str, os.PathLike], + self, + data_df: pd.DataFrame, + new_headers: List[str], + filepath: Union[str, os.PathLike], ) -> pd.DataFrame: """ Helper to insert new column headers onto a dataframe following outlier correction. @@ -1103,7 +1102,7 @@ def read_pickle(self, file_path: Union[str, os.PathLike]) -> object: return clf def bout_train_test_splitter( - self, x_df: pd.DataFrame, y_df: pd.Series, test_size: float + self, x_df: pd.DataFrame, y_df: pd.Series, test_size: float ) -> (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series): """ Helper to split train and test based on annotated `bouts`. @@ -1170,7 +1169,7 @@ def find_bouts(s: pd.Series, type: str): return x_train, x_test, y_train, y_test def check_sampled_dataset_integrity( - self, x_df: pd.DataFrame, y_df: pd.DataFrame + self, x_df: pd.DataFrame, y_df: pd.DataFrame ) -> None: """ Helper to check for non-numerical entries post data sampling @@ -1189,15 +1188,15 @@ def check_sampled_dataset_integrity( if len(x_nan_cnt) < 10: raise FaultyTrainingSetError( msg=f"{str(len(x_nan_cnt))} feature column(s) exist in some files within the project_folder/csv/targets_inserted directory, but missing in others. " - f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the " - f"column names with mismatches are: {list(x_nan_cnt.index)}", + f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the " + f"column names with mismatches are: {list(x_nan_cnt.index)}", source=self.__class__.__name__, ) else: raise FaultyTrainingSetError( msg=f"{str(len(x_nan_cnt))} feature columns exist in some files, but missing in others. The feature files are found in the project_folder/csv/targets_inserted directory. " - f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " - f"column names with mismatches are: {list(x_nan_cnt.index)[0:9]}", + f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " + f"column names with mismatches are: {list(x_nan_cnt.index)[0:9]}", source=self.__class__.__name__, ) @@ -1214,12 +1213,12 @@ def check_sampled_dataset_integrity( ) def partial_dependence_calculator( - self, - clf: RandomForestClassifier, - x_df: pd.DataFrame, - clf_name: str, - save_dir: Union[str, os.PathLike], - clf_cnt: Optional[int] = None, + self, + clf: RandomForestClassifier, + x_df: pd.DataFrame, + clf_name: str, + save_dir: Union[str, os.PathLike], + clf_cnt: Optional[int] = None, ) -> None: """ Compute feature partial dependencies for every feature in training set. @@ -1256,11 +1255,11 @@ def partial_dependence_calculator( print(f"Partial dependencies for {feature_name} complete...") def clf_predict_proba( - self, - clf: RandomForestClassifier, - x_df: pd.DataFrame, - model_name: Optional[str] = None, - data_path: Optional[Union[str, os.PathLike]] = None, + self, + clf: RandomForestClassifier, + x_df: pd.DataFrame, + model_name: Optional[str] = None, + data_path: Optional[Union[str, os.PathLike]] = None, ) -> np.ndarray: """ @@ -1301,7 +1300,7 @@ def clf_predict_proba( return p_vals[:, 1] def clf_fit( - self, clf: RandomForestClassifier, x_df: pd.DataFrame, y_df: pd.DataFrame + self, clf: RandomForestClassifier, x_df: pd.DataFrame, y_df: pd.DataFrame ) -> RandomForestClassifier: """ Helper to fit clf model @@ -1326,7 +1325,7 @@ def clf_fit( return clf.fit(x_df, y_df) def _read_data_file_helper( - self, file_path: str, file_type: str, clf_names: Optional[List[str]] = None + self, file_path: str, file_type: str, clf_names: Optional[List[str]] = None ): """ Private function called by :meth:`simba.train_model_functions.read_all_files_in_folder_mp` @@ -1356,10 +1355,10 @@ def _read_data_file_helper( return df def read_all_files_in_folder_mp( - self, - file_paths: List[str], - file_type: Literal["csv", "parquet", "pickle"], - classifier_names: Optional[List[str]] = None, + self, + file_paths: List[str], + file_type: Literal["csv", "parquet", "pickle"], + classifier_names: Optional[List[str]] = None, ) -> pd.DataFrame: """ @@ -1384,10 +1383,10 @@ def read_all_files_in_folder_mp( try: with ProcessPoolExecutor(int(np.ceil(cpu_cnt / 2))) as pool: for res in pool.map( - self._read_data_file_helper, - file_paths, - repeat(file_type), - repeat(classifier_names), + self._read_data_file_helper, + file_paths, + repeat(file_type), + repeat(classifier_names), ): df_lst.append(res) df_concat = pd.concat(df_lst, axis=0).round(4) @@ -1399,8 +1398,8 @@ def read_all_files_in_folder_mp( source=self.read_all_files_in_folder_mp.__name__, ) df_concat = df_concat.loc[ - :, ~df_concat.columns.str.contains("^Unnamed") - ].astype(np.float32) + :, ~df_concat.columns.str.contains("^Unnamed") + ].astype(np.float32) memory_size = get_memory_usage_of_df(df=df_concat) print( f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB' @@ -1420,32 +1419,101 @@ def read_all_files_in_folder_mp( @staticmethod def _read_data_file_helper_futures( - file_path: str, file_type: str, clf_names: Optional[List[str]] = None + annotation_file_path: str, features_dir: str, file_type: str, clf_names: Optional[List[str]] = None ): """ Private function called by :meth:`simba.train_model_functions.read_all_files_in_folder_mp_futures` """ timer = SimbaTimer(start=True) - _, vid_name, _ = get_fn_ext(file_path) - df = read_df(file_path, file_type).dropna(axis=0, how="all").fillna(0) - df.index = [vid_name] * len(df) + _, vid_name, _ = get_fn_ext(annotation_file_path) + annotation_df = read_df(annotation_file_path, file_type).dropna(axis=0, how="all").fillna(0) + if features_dir != "": + features_file_path = os.path.join(features_dir, vid_name + ".csv") + features_df = read_df(features_file_path, file_type).dropna(axis=0, how="all").fillna(0) + features_df = pd.concat([features_df, annotation_df[clf_names]], axis=1) + else: + features_df = annotation_df + features_df.index = [vid_name] * len(features_df) if clf_names != None: for clf_name in clf_names: - if not clf_name in df.columns: - raise ColumnNotFoundError(column_name=clf_name, file_name=file_path) - elif len(set(df[clf_name].unique()) - {0, 1}) > 0: + if not clf_name in annotation_df.columns: + raise ColumnNotFoundError(column_name=clf_name, file_name=annotation_file_path) + elif len(set(annotation_df[clf_name].unique()) - {0, 1}) > 0: raise InvalidInputError( - msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}." + msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {annotation_file_path} the {clf_name} field contains additional value(s): {list(set(annotation_df[clf_name].unique()) - {0, 1})}." ) timer.stop_timer() - return df, vid_name, timer.elapsed_time_str + return features_df, vid_name, timer.elapsed_time_str + + def read_and_concatenate_all_files_in_folder_mp_futures( + self, + annotations_file_paths: List[str], + features_dir: str, + file_type: Literal["csv", "parquet", "pickle"], + classifier_names: Optional[List[str]] = None, + ) -> pd.DataFrame: + """ + Multiprocessing helper function to read in all data files in a folder to a single + pd.DataFrame for downstream ML through concurrent.Futures. Asserts that all classifiers + have annotation fields present in each dataframe. + + .. note:: + A ``concurrent.Futures`` alternative to :meth:`simba.mixins.train_model_mixin.read_all_files_in_folder_mp` which + has uses ``multiprocessing.ProcessPoolExecutor`` and reported unstable on Linux machines. + + If multiprocess failure, reverts to :meth:`simba.mixins.train_model_mixin.read_all_files_in_folder` + + :parameter List[str] file_paths: List of file-paths + :parameter List[str] file_paths: The filetype of ``file_paths`` OPTIONS: csv or parquet. + :parameter Optional[List[str]] classifier_names: List of classifier names representing fields of human annotations. If not None, then assert that classifier names + are present in each data file. + :return pd.DataFrame: Concatenated dataframe of all data in ``file_paths``. + + """ + try: + if platform.system() == "Darwin": + multiprocessing.set_start_method("spawn") + cpu_cnt, _ = find_core_cnt() + df_lst = [] + with concurrent.futures.ProcessPoolExecutor( + max_workers=cpu_cnt + ) as executor: + results = [ + executor.submit( + self._read_data_file_helper_futures, + annotation_file_path, + features_dir, + file_type, + classifier_names, + ) + for annotation_file_path in annotations_file_paths + ] + for result in concurrent.futures.as_completed(results): + df_lst.append(result.result()[0]) + print( + f"Reading complete {result.result()[1]} (elapsed time: {result.result()[2]}s)..." + ) + df_concat = pd.concat(df_lst, axis=0).round(4) + if "scorer" in df_concat.columns: + df_concat = df_concat.drop(["scorer"], axis=1) + return df_concat + + except: + MultiProcessingFailedWarning( + msg="Multi-processing file read failed, reverting to single core (increased run-time on large datasets)." + ) + return self.read_all_files_in_folder( + file_paths=annotations_file_paths, + file_type=file_type, + classifier_names=classifier_names, + ) def read_all_files_in_folder_mp_futures( - self, - file_paths: List[str], - file_type: Literal["csv", "parquet", "pickle"], - classifier_names: Optional[List[str]] = None, + self, + annotations_file_paths: List[str], + file_type: Literal["csv", "parquet", "pickle"], + classifier_names: Optional[List[str]] = None, ) -> pd.DataFrame: """ Multiprocessing helper function to read in all data files in a folder to a single @@ -1471,16 +1539,17 @@ def read_all_files_in_folder_mp_futures( cpu_cnt, _ = find_core_cnt() df_lst = [] with concurrent.futures.ProcessPoolExecutor( - max_workers=cpu_cnt + max_workers=cpu_cnt ) as executor: results = [ executor.submit( self._read_data_file_helper_futures, - data, + annotation_file_path, + "", file_type, classifier_names, ) - for data in file_paths + for annotation_file_path in annotations_file_paths ] for result in concurrent.futures.as_completed(results): df_lst.append(result.result()[0]) @@ -1497,13 +1566,13 @@ def read_all_files_in_folder_mp_futures( msg="Multi-processing file read failed, reverting to single core (increased run-time on large datasets)." ) return self.read_all_files_in_folder( - file_paths=file_paths, + file_paths=annotations_file_paths, file_type=file_type, classifier_names=classifier_names, ) def check_raw_dataset_integrity( - self, df: pd.DataFrame, logs_path: Optional[Union[str, os.PathLike]] + self, df: pd.DataFrame, logs_path: Optional[Union[str, os.PathLike]] ) -> None: """ Helper to check column-wise NaNs in raw input data for fitting model. @@ -1540,18 +1609,18 @@ def check_raw_dataset_integrity( results.to_csv(save_log_path) raise FaultyTrainingSetError( msg=f"{len(nan_cols)} feature columns exist in some files, but missing in others. The feature files are found in the project_folder/csv/targets_inserted directory. " - f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " - f"column names with mismatches are: {nan_cols[0:9]}. For a log of the files that contain, and not contain, the mis-matched columns, see {save_log_path}", + f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " + f"column names with mismatches are: {nan_cols[0:9]}. For a log of the files that contain, and not contain, the mis-matched columns, see {save_log_path}", source=self.__class__.__name__, ) @staticmethod def _create_shap_mp_helper( - data: pd.DataFrame, - explainer: shap.TreeExplainer, - clf_name: str, - rf_clf: RandomForestClassifier, - expected_value: float, + data: pd.DataFrame, + explainer: shap.TreeExplainer, + clf_name: str, + rf_clf: RandomForestClassifier, + expected_value: float, ): target = data.pop(clf_name).values.reshape(-1, 1) frame_batch_shap = explainer.shap_values(data.values, check_additivity=False)[1] @@ -1570,7 +1639,7 @@ def _create_shap_mp_helper( @staticmethod def _create_shap_mp_helper( - data: pd.DataFrame, explainer: shap.TreeExplainer, clf_name: str + data: pd.DataFrame, explainer: shap.TreeExplainer, clf_name: str ): target = data.pop(clf_name).values.reshape(-1, 1) group_cnt = data.pop("group").values[0] @@ -1583,18 +1652,18 @@ def _create_shap_mp_helper( return shap_vals, data.values, target def create_shap_log_mp( - self, - ini_file_path: str, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.DataFrame, - x_names: List[str], - clf_name: str, - cnt_present: int, - cnt_absent: int, - save_path: str, - batch_size: int = 10, - save_file_no: Optional[int] = None, + self, + ini_file_path: str, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.DataFrame, + x_names: List[str], + clf_name: str, + cnt_present: int, + cnt_absent: int, + save_path: str, + batch_size: int = 10, + save_file_no: Optional[int] = None, ) -> None: """ Helper to compute SHAP values using multiprocessing. @@ -1680,10 +1749,10 @@ def create_shap_log_mp( self._create_shap_mp_helper, explainer=explainer, clf_name=clf_name ) for cnt, result in enumerate( - pool.imap_unordered(constants, shap_data, chunksize=1) + pool.imap_unordered(constants, shap_data, chunksize=1) ): print( - f"Concatenating multi-processed SHAP data (batch {cnt+1}/{len(shap_data)})" + f"Concatenating multi-processed SHAP data (batch {cnt + 1}/{len(shap_data)})" ) proba = rf_clf.predict_proba(result[1])[:, 1].reshape(-1, 1) shap_sum = np.sum(result[0], axis=1).reshape(-1, 1) @@ -1706,7 +1775,7 @@ def create_shap_log_mp( shap_save_df = pd.DataFrame( data=np.row_stack(shap_results), columns=list(x_names) - + ["Expected_value", "Sum", "Prediction_probability", clf_name], + + ["Expected_value", "Sum", "Prediction_probability", clf_name], ) raw_save_df = pd.DataFrame( data=np.row_stack(shap_raw), columns=list(x_names) @@ -1746,7 +1815,7 @@ def create_shap_log_mp( ) def check_df_dataset_integrity( - self, df: pd.DataFrame, file_name: str, logs_path: Union[str, os.PathLike] + self, df: pd.DataFrame, file_name: str, logs_path: Union[str, os.PathLike] ) -> None: """ Helper to check for non-numerical np.inf, -np.inf, NaN, None in a single dataframe. @@ -1768,7 +1837,6 @@ def check_df_dataset_integrity( else: pass - # test = TrainModelMixin() # test.read_all_files_in_folder(file_paths=['/Users/simon/Desktop/envs/troubleshooting/jake/project_folder/csv/targets_inserted/22-437C_c3_2022-11-01_13-16-23_color.csv', '/Users/simon/Desktop/envs/troubleshooting/jake/project_folder/csv/targets_inserted/22-437D_c4_2022-11-01_13-16-39_color.csv'], # file_type='csv', classifier_names=['attack', 'non-agresive parallel swimming']) diff --git a/simba/model/inference_validation.py b/simba/model/inference_validation.py index e4a21ddfb..f898e2467 100644 --- a/simba/model/inference_validation.py +++ b/simba/model/inference_validation.py @@ -51,6 +51,8 @@ def __init__( data_df = read_df(input_file_path, self.file_type) output_df = deepcopy(data_df) data_df = self.drop_bp_cords(df=data_df) + if data_df is None: + data_df = deepcopy(output_df) clf = self.read_pickle(file_path=clf_path) probability_col_name = f"Probability_{classifier_name}" output_df[probability_col_name] = self.clf_predict_proba( diff --git a/simba/model/train_rf.py b/simba/model/train_rf.py index 877801dc5..9b378cd3c 100644 --- a/simba/model/train_rf.py +++ b/simba/model/train_rf.py @@ -110,7 +110,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): else: self.under_sample_ratio = Dtypes.NAN.value if (self.over_sample_setting == Methods.SMOTEENN.value.lower()) or ( - self.over_sample_setting == Methods.SMOTE.value.lower() + self.over_sample_setting == Methods.SMOTE.value.lower() ): self.over_sample_ratio = read_config_entry( self.config, @@ -132,21 +132,23 @@ def __init__(self, config_path: Union[str, os.PathLike]): print( "Reading in {} annotated files...".format(str(len(self.target_file_paths))) ) - self.data_df = self.read_all_files_in_folder_mp_futures( - self.target_file_paths, self.file_type, [self.clf_name] - ) - self.data_df = self.check_raw_dataset_integrity( - df=self.data_df, logs_path=self.logs_path + self.data_df = self.read_and_concatenate_all_files_in_folder_mp_futures( + self.target_file_paths, self.features_dir, self.file_type, [self.clf_name] ) + # self.data_df = self.check_raw_dataset_integrity( + # df=self.data_df, logs_path=self.logs_path + # ) self.data_df_wo_cords = self.drop_bp_cords(df=self.data_df) - annotation_cols_to_remove = self.read_in_all_model_names_to_remove( - self.config, self.clf_cnt, self.clf_name - ) - self.x_y_df = self.delete_other_annotation_columns( - self.data_df_wo_cords, list(annotation_cols_to_remove) - ) + if self.data_df_wo_cords is None: + self.data_df_wo_cords = self.data_df + # annotation_cols_to_remove = self.read_in_all_model_names_to_remove( + # self.config, self.clf_cnt, self.clf_name + # ) + # self.x_y_df = self.delete_other_annotation_columns( + # self.data_df_wo_cords, list(annotation_cols_to_remove) + # ) self.class_names = ["Not_" + self.clf_name, self.clf_name] - self.x_df, self.y_df = self.split_df_to_x_y(self.x_y_df, self.clf_name) + self.x_df, self.y_df = self.split_df_to_x_y(self.data_df_wo_cords, self.clf_name) self.feature_names = self.x_df.columns self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df) print("Number of features in dataset: " + str(len(self.x_df.columns))) @@ -305,7 +307,7 @@ def train_model(self): ) if self.config.has_option( - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value ): class_weights = read_config_entry( self.config, @@ -555,7 +557,6 @@ def save_model(self) -> None: msg=f"Evaluation files are in models/generated_models/model_evaluations folders" ) - # test = TrainRandomForestClassifier(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') # test.perform_sampling() # test.train_model() diff --git a/simba/roi_tools/ROI_multiply.py b/simba/roi_tools/ROI_multiply.py index 7fc69fe67..4fff6132b 100644 --- a/simba/roi_tools/ROI_multiply.py +++ b/simba/roi_tools/ROI_multiply.py @@ -113,9 +113,9 @@ def multiply_ROIs(config_path, filename): vid_name, vid_name, ) - r_df = r_df.append(duplicatedRec, ignore_index=True) - c_df = c_df.append(duplicatedCirc, ignore_index=True) - p_df = p_df.append(duplicatedPoly, ignore_index=True) + r_df =pd.concat([r_df,duplicatedRec],axis=0) #r_df.append(duplicatedRec, ignore_index=True) + c_df = pd.concat([c_df,duplicatedCirc],axis=0)#c_df.append(duplicatedCirc, ignore_index=True) + p_df = pd.concat([p_df,duplicatedPoly],axis=0)#p_df.append(duplicatedPoly, ignore_index=True) r_df = r_df.drop_duplicates(subset=["Video", "Name"], keep="first") c_df = c_df.drop_duplicates(subset=["Video", "Name"], keep="first") p_df = p_df.drop_duplicates(subset=["Video", "Name"], keep="first") diff --git a/simba/ui/pop_ups/validation_plot_pop_up.py b/simba/ui/pop_ups/validation_plot_pop_up.py index 2ae3b81bf..ab39068d2 100644 --- a/simba/ui/pop_ups/validation_plot_pop_up.py +++ b/simba/ui/pop_ups/validation_plot_pop_up.py @@ -172,3 +172,4 @@ def run(self): create_gantt=self.gantt_dropdown.getChoices(), ) validation_video_creator.run() + self.root.destroy() diff --git a/simba/utils/data.py b/simba/utils/data.py index d843b0b4f..4051a2bc8 100644 --- a/simba/utils/data.py +++ b/simba/utils/data.py @@ -528,7 +528,7 @@ def run_user_defined_feature_extraction_class( spec.loader.exec_module(user_module) user_class = getattr(user_module, class_name) print(f"Running user-defined {class_name} feature extraction file...") - user_class(config_path=config_path) + user_class(config_path=config_path).run() def slp_to_df_convert( From 29c3f109e8c1b90680e46921a1e534a37210a73a Mon Sep 17 00:00:00 2001 From: tzuk polinsky Date: Thu, 7 Dec 2023 15:30:36 +0200 Subject: [PATCH 09/13] changes: 1. added imbalance random forest and the infrastructure for more machine learning algorithms that has the methods "fit" and "predict" 2. added 2 buttons that enables running a model on all the project data. 3. added method to detect and return graphviz dot if it exists on the computer, have tested on linux. --- simba/SimBA.py | 43 +- .../feature_extractor_freezing.py | 50 +- simba/mixins/train_model_mixin.py | 67 +- simba/model/grid_search_rf.py | 2 +- simba/model/inference_validation.py | 11 +- simba/model/train_rf.py | 623 +++++++++--------- simba/ui/pop_ups/validation_plot_pop_up.py | 37 +- simba/utils/enums.py | 2 +- 8 files changed, 464 insertions(+), 371 deletions(-) diff --git a/simba/SimBA.py b/simba/SimBA.py index bf6cd2fa1..d2f1bc774 100644 --- a/simba/SimBA.py +++ b/simba/SimBA.py @@ -841,7 +841,26 @@ def activate(box, *args): ) ).start(), ) - + label_run_model_on_all = label_model_validation = CreateLabelFrameWithIcon( + parent=tab9, + header="Run model on all the data", + icon_name=Keys.DOCUMENTATION.value, + icon_link=Links.OUT_OF_SAMPLE_VALIDATION.value, + ) + button_run_model_on_all = Button( + label_run_model_on_all, + text="RUN", + fg="red", + command=lambda: self.validate_model_first_step(), + ) + button_create_video_for_all = Button( + label_run_model_on_all, + text="CREATE VALIDATION VIDEOS", + fg="blue", + command=lambda: ValidationVideoPopUp( + config_path=config_path, simba_main_frm=self, run_on_all=True + ), + ) label_model_validation = CreateLabelFrameWithIcon( parent=tab9, header="VALIDATE MODEL ON SINGLE VIDEO", @@ -885,7 +904,7 @@ def activate(box, *args): text="CREATE VALIDATION VIDEO", fg="blue", command=lambda: ValidationVideoPopUp( - config_path=config_path, simba_main_frm=self + config_path=config_path, simba_main_frm=self,run_on_all=False ), ) @@ -1196,15 +1215,17 @@ def activate(box, *args): button_trainmachinesettings.grid(row=0, column=0, sticky=NW, padx=5) button_trainmachinemodel.grid(row=1, column=0, sticky=NW, padx=5) button_train_multimodel.grid(row=2, column=0, sticky=NW, padx=5) - - label_model_validation.grid(row=7, sticky=W, pady=5) - self.csvfile.grid(row=0, sticky=W) - self.modelfile.grid(row=1, sticky=W) - button_runvalidmodel.grid(row=2, sticky=W) - button_generateplot.grid(row=3, sticky=W) - self.dis_threshold.grid(row=4, sticky=W) - self.min_behaviorbout.grid(row=5, sticky=W) - button_validate_model.grid(row=6, sticky=W) + label_run_model_on_all.grid(row=0,sticky=W) + button_run_model_on_all.grid(row=0,column=0,sticky=W) + button_create_video_for_all.grid(row=0,column=1,sticky=W) + label_model_validation.grid(row=1, sticky=W, pady=5) + self.csvfile.grid(row=1, sticky=W) + self.modelfile.grid(row=2, sticky=W) + button_runvalidmodel.grid(row=3, sticky=W) + button_generateplot.grid(row=4, sticky=W) + self.dis_threshold.grid(row=5, sticky=W) + self.min_behaviorbout.grid(row=6, sticky=W) + button_validate_model.grid(row=7, sticky=W) label_runmachinemodel.grid(row=8, sticky=NW) button_run_rfmodelsettings.grid(row=0, sticky=NW) diff --git a/simba/feature_extractors/feature_extractor_freezing.py b/simba/feature_extractors/feature_extractor_freezing.py index 9ea8d2b2c..b33b7ee56 100644 --- a/simba/feature_extractors/feature_extractor_freezing.py +++ b/simba/feature_extractors/feature_extractor_freezing.py @@ -48,7 +48,8 @@ def angle_between_vectors(self, v1, v2): def calculate_direction_vector(self, from_point, to_point): return np.array(to_point) - np.array(from_point) - def extract_features(self, input_file_path: str, window_size: int, video_center: [int, int], pixel_mm: float,directionality_data:pd.DataFrame): + def extract_features(self, input_file_path: str, window_size: int, video_center: [int, int], pixel_mm: float, + directionality_data: pd.DataFrame): print("Calculating freezing features ...") input_data = pd.read_csv(input_file_path) @@ -59,16 +60,25 @@ def extract_features(self, input_file_path: str, window_size: int, video_center: without_bug = input_data.drop(columns_to_drop, axis=1) body_parts_diffs = without_bug.diff(axis=0) - body_parts_diffs["nose_x"]*=5 - body_parts_diffs["nose_y"]*=5 - time_point_diff = body_parts_diffs.sum(axis=1) + #second_time_point_diff = time_point_diff.diff() rolling_windows = time_point_diff.rolling(window=window_size, min_periods=1).sum() - output_data["activity"] = rolling_windows.abs().fillna(0) + output_data["activity"] = rolling_windows.abs().fillna(500) + bug_cols = [colName for colName in input_data.columns if ("bug" in colName) and ("_p") not in colName] center_cols = [colName for colName in without_bug.columns if ("center" in colName) and ("_p") not in colName] + #tails_cols = [colName for colName in without_bug.columns if ("tail" in colName) and ("_p") not in colName] nose_cols = [colName for colName in without_bug.columns if ("nose" in colName) and ("_p") not in colName] centers = without_bug[center_cols].to_numpy() + #tails = without_bug[tails_cols].to_numpy() noses = without_bug[nose_cols].to_numpy() + bug = input_data[bug_cols].to_numpy() + distances_from_bug = np.linalg.norm(bug - noses,axis=1) + video_centers = np.array([video_center]*len(centers)) + distances_from_center = np.linalg.norm(video_centers - noses,axis=1) + #body_size = np.insert(np.diff(np.linalg.norm(tails - noses,axis=1), axis=0),0,0) + output_data["distances_from_bug"] = pd.DataFrame(distances_from_bug).rolling(window=window_size, min_periods=1).mean().fillna(100).to_numpy() + output_data["distances_from_center"] = pd.DataFrame(distances_from_center).rolling(window=window_size, min_periods=1).mean().fillna(100).to_numpy() + #output_data["body_size"] = pd.DataFrame(body_size).rolling(window=window_size, min_periods=1).sum().abs().fillna(100).to_numpy() angles = [] for i, center in enumerate(centers): nose = noses[i] @@ -77,19 +87,19 @@ def extract_features(self, input_file_path: str, window_size: int, video_center: angles.append(self.angle_between_vectors(vector_center_to_nose, vector_fixed_to_center)) # output_data["nose_direction"] = angles angles_df = pd.DataFrame(angles) - # angles_diff = angles_df.diff() - # angles_diff_sum = angles_diff.rolling(window=window_size, min_periods=1).sum() - # output_data["nose_direction_sum_of_diffs"] = angles_diff_sum.abs().fillna(0) - output_data["nose_direction_avg"] = angles_df.rolling(window=window_size, min_periods=1).mean().fillna(0) - # directionality_rolling = directionality_data.rolling(window=window_size, min_periods=1) - # output_data["amount_of_looking_at_bug"] = directionality_rolling.sum() - # onsets = [-1] * len(output_data["amount_of_looking_at_bug"]) - # for j, rol in enumerate(directionality_rolling): - # for i, r in enumerate(rol): - # if r: - # onsets[j] = i - # break - # output_data["looking_at_bug_onset"] = onsets + angles_diff = angles_df.diff() + angles_diff_sum = angles_diff.rolling(window=window_size, min_periods=1).sum() + output_data["nose_direction_sum_of_diffs"] = angles_diff_sum.abs().fillna(0) + # output_data["nose_direction_avg"] = angles_df.rolling(window=window_size, min_periods=1).mean().fillna(0) + directionality_rolling = directionality_data.rolling(window=window_size, min_periods=1) + output_data["amount_of_looking_at_bug"] = directionality_rolling.sum().fillna(0) + onsets = [-1] * len(output_data["amount_of_looking_at_bug"]) + for j, rol in enumerate(directionality_rolling): + for i, r in enumerate(rol): + if r: + onsets[j] = i + break + output_data["looking_at_bug_onset"] = onsets return output_data def run(self): @@ -121,9 +131,9 @@ def run(self): video_settings, self.px_per_mm, fps = self.read_video_info( video_name=file_name ) - self.data_df = self.extract_features(file_path, 50, ( + self.data_df = self.extract_features(file_path, 25, ( current_polygon["Center_X"].values[0], current_polygon["Center_Y"].values[0]), - video_settings["pixels/mm"].values[0],directionality_data) + video_settings["pixels/mm"].values[0], directionality_data) save_path = os.path.join(self.save_dir, file_name + "." + self.file_type) self.data_df = self.data_df.reset_index(drop=True).fillna(0) write_df(df=self.data_df, file_type=self.file_type, save_path=save_path) diff --git a/simba/mixins/train_model_mixin.py b/simba/mixins/train_model_mixin.py index 96d0cefa7..35288733c 100644 --- a/simba/mixins/train_model_mixin.py +++ b/simba/mixins/train_model_mixin.py @@ -1,5 +1,7 @@ __author__ = "Simon Nilsson" +import shutil +import subprocess import warnings warnings.simplefilter(action="ignore", category=FutureWarning) @@ -195,7 +197,7 @@ def delete_other_annotation_columns( return df def split_df_to_x_y( - self, df: pd.DataFrame, clf_name: str + self, df: pd.DataFrame, col_names: [str] ) -> (pd.DataFrame, pd.DataFrame): """ Helper to split dataframe into features and target. @@ -207,12 +209,16 @@ def split_df_to_x_y( :return pd.DataFrame: target :examples: - >>> self.split_df_to_x_y(df=df, clf_name='Attack') + >>> self.split_df_to_x_y(df=df, col_names='Attack') """ df = deepcopy(df) - y = df.pop(clf_name) - return df, y + ys = np.array([0]*len(df.index)) + for i,col_name in enumerate(col_names): + y = df.pop(col_name) + ys[y == 1] = i+1 + + return df, pd.DataFrame(ys) def random_undersampler( self, x_train: np.ndarray, y_train: np.ndarray, sample_ratio: float @@ -248,7 +254,7 @@ def random_undersampler( data_df = pd.concat( [present_df, absent_df.sample(n=ratio_n, replace=False)], axis=0 ) - return self.split_df_to_x_y(data_df, y_train.name) + return self.split_df_to_x_y(data_df, [y_train.name]) def smoteen_oversampler( self, x_train: pd.DataFrame, y_train: pd.DataFrame, sample_ratio: float @@ -385,7 +391,7 @@ def calc_learning_curve( print("Calculating learning curves...") timer = SimbaTimer(start=True) - x_df, y_df = self.split_df_to_x_y(x_y_df, clf_name) + x_df, y_df = self.split_df_to_x_y(x_y_df, [clf_name]) cv = ShuffleSplit(n_splits=shuffle_splits, test_size=tt_size) if platform.system() == "Darwin": with parallel_backend("threading", n_jobs=-2): @@ -510,8 +516,11 @@ def create_example_dt( :parameter Optional[int] save_file_no: If integer, represents the count of the classifier within a grid search. If none, the classifier is not part of a grid search. """ - - print("Visualizing example decision tree using graphviz...") + dot_path = self.find_graphviz_dot() + if not dot_path: + print("please install graphviz using the following link: https://graphviz.org/download/") + return + print(f"Visualizing example decision tree using graphviz using {dot_path}") estimator = rf_clf.estimators_[3] if save_file_no != None: dot_name = os.path.join( @@ -522,7 +531,7 @@ def create_example_dt( ) else: dot_name = os.path.join(save_dir, str(clf_name) + "_tree.dot") - file_name = os.path.join(save_dir, str(clf_name) + "_tree.pdf") + file_name = os.path.join(save_dir, str(clf_name) + "_tree.png") export_graphviz( estimator, out_file=dot_name, @@ -533,8 +542,24 @@ def create_example_dt( class_names=class_names, feature_names=feature_names, ) - command = "dot " + str(dot_name) + " -T pdf -o " + str(file_name) + " -Gdpi=600" - call(command, shell=True) + subprocess.run([dot_path, '-Tpng', dot_name, '-o', file_name]) + + def find_graphviz_dot(self): + # Check if dot is in PATH + dot_path = shutil.which("dot") + if dot_path: + return dot_path + + common_paths = [ + r"C:\Program Files\Graphviz\bin\dot.exe", + r"C:\Program Files (x86)\Graphviz\bin\dot.exe" + ] + + # Check common paths + for path in common_paths: + if os.path.isfile(path): + return path + return None def create_clf_report( self, @@ -1199,14 +1224,14 @@ def check_sampled_dataset_integrity( f"column names with mismatches are: {list(x_nan_cnt.index)[0:9]}", source=self.__class__.__name__, ) - - if len(y_df.unique()) == 1: - if y_df.unique()[0] == 0: + labels = np.unique(y_df) + if len(labels)== 1: + if labels[0] == 0: raise FaultyTrainingSetError( msg=f"All training annotations for classifier {str(y_df.name)} is labelled as ABSENT. A classifier has be be trained with both behavior PRESENT and ABSENT ANNOTATIONS.", source=self.__class__.__name__, ) - if y_df.unique()[0] == 1: + if labels[0] == 1: raise FaultyTrainingSetError( msg=f"All training annotations for classifier {str(y_df.name)} is labelled as PRESENT. A classifier has be be trained with both behavior PRESENT and ABSENT ANNOTATIONS.", source=self.__class__.__name__, @@ -1292,11 +1317,11 @@ def clf_predict_proba( source=self.__class__.__name__, ) p_vals = clf.predict_proba(x_df) - if p_vals.shape[1] != 2: - raise ClassifierInferenceError( - msg=f"The classifier {model_name} (data path {data_path}) has not been created properly. See The SimBA GitHub FAQ page or Gitter for more information and suggested fixes. The classifier is not a binary classifier and does not predict two targets (absence and presence of behavior)", - source=self.__class__.__name__, - ) + # if p_vals.shape[1] != 2: + # raise ClassifierInferenceError( + # msg=f"The classifier {model_name} (data path {data_path}) has not been created properly. See The SimBA GitHub FAQ page or Gitter for more information and suggested fixes. The classifier is not a binary classifier and does not predict two targets (absence and presence of behavior)", + # source=self.__class__.__name__, + # ) return p_vals[:, 1] def clf_fit( @@ -1311,7 +1336,7 @@ def clf_fit( :return RandomForestClassifier: Fitted random forest classifier object """ nan_features = x_df[~x_df.applymap(np.isreal).all(1)] - nan_target = y_df.loc[pd.to_numeric(y_df).isna()] + nan_target = y_df[y_df.isna().to_numpy()] if len(nan_features) > 0: raise FaultyTrainingSetError( msg=f"{len(nan_features)} frame(s) in your project_folder/csv/targets_inserted directory contains FEATURES with non-numerical values", diff --git a/simba/model/grid_search_rf.py b/simba/model/grid_search_rf.py index af86f6d25..eb0b071ac 100644 --- a/simba/model/grid_search_rf.py +++ b/simba/model/grid_search_rf.py @@ -416,7 +416,7 @@ def run(self): self.data_df, annotation_cols_to_remove ) self.x_df, self.y_df = self.split_df_to_x_y( - self.x_y_df, meta_dict[MetaKeys.CLF_NAME.value] + self.x_y_df, [meta_dict[MetaKeys.CLF_NAME.value]] ) self.feature_names = self.x_df.columns self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df) diff --git a/simba/model/inference_validation.py b/simba/model/inference_validation.py index f898e2467..1ee717e3c 100644 --- a/simba/model/inference_validation.py +++ b/simba/model/inference_validation.py @@ -33,10 +33,10 @@ class InferenceValidation(ConfigReader, TrainModelMixin): """ def __init__( - self, - config_path: Union[str, os.PathLike], - input_file_path: Union[str, os.PathLike], - clf_path: Union[str, os.PathLike], + self, + config_path: Union[str, os.PathLike], + input_file_path: Union[str, os.PathLike], + clf_path: Union[str, os.PathLike], ): ConfigReader.__init__(self, config_path=config_path) TrainModelMixin.__init__(self) @@ -63,14 +63,13 @@ def __init__( self.timer.stop_timer() stdout_success( - msg=f'Validation predictions generated for "{file_name}" within the project_folder/csv/validation directory', + msg=f'Validation predictions generated for "{file_name}" saved in {save_filename}', elapsed_time=self.timer.elapsed_time_str, ) print( 'Click on "Interactive probability plot" to inspect classifier probability thresholds. If satisfactory proceed to specify threshold and minimum bout length and click on "Validate" to create video.' ) - # # ValidateModelRunClf(config_path=r"Z:\DeepLabCut\DLC_extract\Troubleshooting\DLC_two_mice\project_folder\project_config.ini", # input_file_path=r"Z:\DeepLabCut\DLC_extract\Troubleshooting\DLC_2_black_060320\project_folder\csv\features_extracted\Together_1.csv", diff --git a/simba/model/train_rf.py b/simba/model/train_rf.py index 9b378cd3c..a58e4817c 100644 --- a/simba/model/train_rf.py +++ b/simba/model/train_rf.py @@ -14,6 +14,7 @@ from simba.utils.enums import ConfigKey, Dtypes, Methods, Options from simba.utils.printing import SimbaTimer, stdout_success from simba.utils.read_write import read_config_entry +from imblearn.ensemble import BalancedRandomForestClassifier class TrainRandomForestClassifier(ConfigReader, TrainModelMixin): @@ -132,8 +133,12 @@ def __init__(self, config_path: Union[str, os.PathLike]): print( "Reading in {} annotated files...".format(str(len(self.target_file_paths))) ) + annotation_cols = self.read_in_all_model_names_to_remove( + self.config, self.clf_cnt, self.clf_name + ) + cls = [self.clf_name] + annotation_cols self.data_df = self.read_and_concatenate_all_files_in_folder_mp_futures( - self.target_file_paths, self.features_dir, self.file_type, [self.clf_name] + self.target_file_paths, self.features_dir, self.file_type, cls ) # self.data_df = self.check_raw_dataset_integrity( # df=self.data_df, logs_path=self.logs_path @@ -141,22 +146,18 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.data_df_wo_cords = self.drop_bp_cords(df=self.data_df) if self.data_df_wo_cords is None: self.data_df_wo_cords = self.data_df - # annotation_cols_to_remove = self.read_in_all_model_names_to_remove( - # self.config, self.clf_cnt, self.clf_name - # ) - # self.x_y_df = self.delete_other_annotation_columns( - # self.data_df_wo_cords, list(annotation_cols_to_remove) - # ) - self.class_names = ["Not_" + self.clf_name, self.clf_name] - self.x_df, self.y_df = self.split_df_to_x_y(self.data_df_wo_cords, self.clf_name) + + self.class_names = ["Not_" + self.clf_name] + cls + + self.x_df, self.y_df = self.split_df_to_x_y(self.data_df_wo_cords, cls) self.feature_names = self.x_df.columns self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df) print("Number of features in dataset: " + str(len(self.x_df.columns))) print( "Number of {} frames in dataset: {} ({}%)".format( self.clf_name, - str(self.y_df.sum()), - str(round(self.y_df.sum() / len(self.y_df), 4) * 100), + str(self.y_df[self.y_df == (cls.index(self.clf_name)+1)].sum()), + str(round(self.y_df[self.y_df == (cls.index(self.clf_name)+1)].sum() / len(self.y_df[self.y_df == (cls.index(self.clf_name)+1)]), 4) * 100), ) ) print("Training and evaluating model...") @@ -199,215 +200,214 @@ def train_model(self): """ Method for training single random forest model. """ + n_estimators = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.RF_ESTIMATORS.value, + data_type=Dtypes.INT.value, + ) + max_features = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.RF_MAX_FEATURES.value, + data_type=Dtypes.STR.value, + ) + if max_features == "None": + max_features = None + criterion = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.RF_CRITERION.value, + data_type=Dtypes.STR.value, + options=Options.CLF_CRITERION.value, + ) + min_sample_leaf = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.MIN_LEAF.value, + data_type=Dtypes.INT.value, + ) + compute_permutation_importance = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.PERMUTATION_IMPORTANCE.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_learning_curve = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.LEARNING_CURVE.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_precision_recall_curve = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.PRECISION_RECALL.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_example_decision_tree = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.EX_DECISION_TREE.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_classification_report = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.CLF_REPORT.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_features_importance_log = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.IMPORTANCE_LOG.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_features_importance_bar_graph = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.IMPORTANCE_LOG.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_example_decision_tree_fancy = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.EX_DECISION_TREE_FANCY.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_shap_scores = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.SHAP_SCORES.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + save_meta_data = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.RF_METADATA.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + compute_partial_dependency = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.PARTIAL_DEPENDENCY.value, + data_type=Dtypes.STR.value, + default_value=False, + ) - if self.algo == "RF": - n_estimators = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_ESTIMATORS.value, - data_type=Dtypes.INT.value, - ) - max_features = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_MAX_FEATURES.value, - data_type=Dtypes.STR.value, - ) - if max_features == "None": - max_features = None - criterion = read_config_entry( + if self.config.has_option( + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value + ): + class_weights = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_CRITERION.value, + ConfigKey.CLASS_WEIGHTS.value, data_type=Dtypes.STR.value, - options=Options.CLF_CRITERION.value, + default_value=Dtypes.NONE.value, ) - min_sample_leaf = read_config_entry( + if class_weights == "custom": + class_weights = ast.literal_eval( + read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.CUSTOM_WEIGHTS.value, + data_type=Dtypes.STR.value, + ) + ) + for k, v in class_weights.items(): + class_weights[k] = int(v) + if class_weights == Dtypes.NONE.value: + class_weights = None + else: + class_weights = None + + if generate_learning_curve in Options.PERFORM_FLAGS.value: + shuffle_splits = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MIN_LEAF.value, + ConfigKey.LEARNING_CURVE_K_SPLITS.value, data_type=Dtypes.INT.value, + default_value=Dtypes.NAN.value, ) - compute_permutation_importance = read_config_entry( + dataset_splits = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PERMUTATION_IMPORTANCE.value, - data_type=Dtypes.STR.value, - default_value=False, + ConfigKey.LEARNING_DATA_SPLITS.value, + data_type=Dtypes.INT.value, + default_value=Dtypes.NAN.value, ) - generate_learning_curve = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=ConfigKey.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits ) - generate_precision_recall_curve = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PRECISION_RECALL.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=ConfigKey.LEARNING_DATA_SPLITS.value, value=dataset_splits ) - generate_example_decision_tree = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE.value, - data_type=Dtypes.STR.value, - default_value=False, - ) - generate_classification_report = read_config_entry( + else: + shuffle_splits, dataset_splits = Dtypes.NAN.value, Dtypes.NAN.value + if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: + feature_importance_bars = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLF_REPORT.value, - data_type=Dtypes.STR.value, - default_value=False, + ConfigKey.IMPORTANCE_BARS_N.value, + Dtypes.INT.value, + Dtypes.NAN.value, ) - generate_features_importance_log = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=ConfigKey.IMPORTANCE_BARS_N.value, + value=feature_importance_bars, + min_value=1, ) - generate_features_importance_bar_graph = read_config_entry( + else: + feature_importance_bars = Dtypes.NAN.value + shap_target_present_cnt, shap_target_absent_cnt, shap_save_n = ( + None, + None, + None, + ) + if generate_shap_scores in Options.PERFORM_FLAGS.value: + shap_target_present_cnt = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, - data_type=Dtypes.STR.value, - default_value=False, + ConfigKey.SHAP_PRESENT.value, + data_type=Dtypes.INT.value, + default_value=0, ) - generate_example_decision_tree_fancy = read_config_entry( + shap_target_absent_cnt = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE_FANCY.value, - data_type=Dtypes.STR.value, - default_value=False, + ConfigKey.SHAP_ABSENT.value, + data_type=Dtypes.INT.value, + default_value=0, ) - generate_shap_scores = read_config_entry( + shap_save_n = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SCORES.value, + ConfigKey.SHAP_SAVE_ITERATION.value, data_type=Dtypes.STR.value, - default_value=False, + default_value=Dtypes.NONE.value, ) - save_meta_data = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_METADATA.value, - data_type=Dtypes.STR.value, - default_value=False, + try: + shap_save_n = int(shap_save_n) + except ValueError: + shap_save_n = shap_target_present_cnt + shap_target_absent_cnt + check_int( + name=ConfigKey.SHAP_PRESENT.value, value=shap_target_present_cnt ) - compute_partial_dependency = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PARTIAL_DEPENDENCY.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=ConfigKey.SHAP_ABSENT.value, value=shap_target_absent_cnt ) - - if self.config.has_option( - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value - ): - class_weights = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASS_WEIGHTS.value, - data_type=Dtypes.STR.value, - default_value=Dtypes.NONE.value, - ) - if class_weights == "custom": - class_weights = ast.literal_eval( - read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CUSTOM_WEIGHTS.value, - data_type=Dtypes.STR.value, - ) - ) - for k, v in class_weights.items(): - class_weights[k] = int(v) - if class_weights == Dtypes.NONE.value: - class_weights = None - else: - class_weights = None - - if generate_learning_curve in Options.PERFORM_FLAGS.value: - shuffle_splits = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE_K_SPLITS.value, - data_type=Dtypes.INT.value, - default_value=Dtypes.NAN.value, - ) - dataset_splits = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_DATA_SPLITS.value, - data_type=Dtypes.INT.value, - default_value=Dtypes.NAN.value, - ) - check_int( - name=ConfigKey.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits - ) - check_int( - name=ConfigKey.LEARNING_DATA_SPLITS.value, value=dataset_splits - ) - else: - shuffle_splits, dataset_splits = Dtypes.NAN.value, Dtypes.NAN.value - if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: - feature_importance_bars = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_BARS_N.value, - Dtypes.INT.value, - Dtypes.NAN.value, - ) - check_int( - name=ConfigKey.IMPORTANCE_BARS_N.value, - value=feature_importance_bars, - min_value=1, - ) - else: - feature_importance_bars = Dtypes.NAN.value - shap_target_present_cnt, shap_target_absent_cnt, shap_save_n = ( - None, - None, - None, - ) - if generate_shap_scores in Options.PERFORM_FLAGS.value: - shap_target_present_cnt = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_PRESENT.value, - data_type=Dtypes.INT.value, - default_value=0, - ) - shap_target_absent_cnt = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_ABSENT.value, - data_type=Dtypes.INT.value, - default_value=0, - ) - shap_save_n = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SAVE_ITERATION.value, - data_type=Dtypes.STR.value, - default_value=Dtypes.NONE.value, - ) - try: - shap_save_n = int(shap_save_n) - except ValueError: - shap_save_n = shap_target_present_cnt + shap_target_absent_cnt - check_int( - name=ConfigKey.SHAP_PRESENT.value, value=shap_target_present_cnt - ) - check_int( - name=ConfigKey.SHAP_ABSENT.value, value=shap_target_absent_cnt - ) - + print(f"Fitting {self.clf_name} model...") + if self.algo == "RF": self.rf_clf = RandomForestClassifier( n_estimators=n_estimators, max_features=max_features, @@ -419,127 +419,140 @@ def train_model(self): class_weight=class_weights, ) - print(f"Fitting {self.clf_name} model...") self.rf_clf = self.clf_fit( clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train ) + elif self.algo == "imbalanced_rf": + self.rf_clf = BalancedRandomForestClassifier( + n_estimators=n_estimators, + max_features=max_features, + max_depth=7, + n_jobs=-1, + criterion=criterion, + min_samples_leaf=min_sample_leaf, + bootstrap=True, + verbose=1, + class_weight=class_weights, + ) + self.rf_clf = self.clf_fit( + clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train + ) + if compute_permutation_importance in Options.PERFORM_FLAGS.value: + self.calc_permutation_importance( + self.x_test, + self.y_test, + self.rf_clf, + self.feature_names, + self.clf_name, + self.eval_out_path, + ) + if generate_learning_curve in Options.PERFORM_FLAGS.value: + self.calc_learning_curve( + x_y_df=self.x_y_df, + clf_name=self.clf_name, + shuffle_splits=shuffle_splits, + dataset_splits=dataset_splits, + tt_size=self.tt_size, + rf_clf=self.rf_clf, + save_dir=self.eval_out_path, + ) - if compute_permutation_importance in Options.PERFORM_FLAGS.value: - self.calc_permutation_importance( - self.x_test, - self.y_test, - self.rf_clf, - self.feature_names, - self.clf_name, - self.eval_out_path, - ) - if generate_learning_curve in Options.PERFORM_FLAGS.value: - self.calc_learning_curve( - x_y_df=self.x_y_df, - clf_name=self.clf_name, - shuffle_splits=shuffle_splits, - dataset_splits=dataset_splits, - tt_size=self.tt_size, - rf_clf=self.rf_clf, - save_dir=self.eval_out_path, - ) - - if generate_precision_recall_curve in Options.PERFORM_FLAGS.value: - self.calc_pr_curve( - self.rf_clf, - self.x_test, - self.y_test, - self.clf_name, - self.eval_out_path, - ) - if generate_example_decision_tree in Options.PERFORM_FLAGS.value: - self.create_example_dt( - self.rf_clf, - self.clf_name, - self.feature_names, - self.class_names, - self.eval_out_path, - ) - if generate_classification_report in Options.PERFORM_FLAGS.value: - self.create_clf_report( - self.rf_clf, - self.x_test, - self.y_test, - self.class_names, - self.eval_out_path, - ) - if generate_features_importance_log in Options.PERFORM_FLAGS.value: - self.create_x_importance_log( - self.rf_clf, self.feature_names, self.clf_name, self.eval_out_path - ) - if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: - self.create_x_importance_bar_chart( - self.rf_clf, - self.feature_names, - self.clf_name, - self.eval_out_path, - feature_importance_bars, - ) - if generate_example_decision_tree_fancy in Options.PERFORM_FLAGS.value: - self.dviz_classification_visualization( - self.x_train, - self.y_train, - self.clf_name, - self.class_names, - self.eval_out_path, - ) - if generate_shap_scores in Options.PERFORM_FLAGS.value: - self.create_shap_log_mp( - ini_file_path=self.config_path, - rf_clf=self.rf_clf, - x_df=self.x_train, - y_df=self.y_train, - x_names=self.feature_names, - clf_name=self.clf_name, - cnt_present=shap_target_present_cnt, - cnt_absent=shap_target_absent_cnt, - save_it=shap_save_n, - save_path=self.eval_out_path, - ) - - if compute_partial_dependency in Options.PERFORM_FLAGS.value: - self.partial_dependence_calculator( - clf=self.rf_clf, - x_df=self.x_train, - clf_name=self.clf_name, - save_dir=self.eval_out_path, - ) + if generate_precision_recall_curve in Options.PERFORM_FLAGS.value: + self.calc_pr_curve( + self.rf_clf, + self.x_test, + self.y_test, + self.clf_name, + self.eval_out_path, + ) + if generate_example_decision_tree in Options.PERFORM_FLAGS.value: + self.create_example_dt( + self.rf_clf, + self.clf_name, + self.feature_names, + self.class_names, + self.eval_out_path, + ) + if generate_classification_report in Options.PERFORM_FLAGS.value: + self.create_clf_report( + self.rf_clf, + self.x_test, + self.y_test, + self.class_names, + self.eval_out_path, + ) + if generate_features_importance_log in Options.PERFORM_FLAGS.value: + self.create_x_importance_log( + self.rf_clf, self.feature_names, self.clf_name, self.eval_out_path + ) + if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: + self.create_x_importance_bar_chart( + self.rf_clf, + self.feature_names, + self.clf_name, + self.eval_out_path, + feature_importance_bars, + ) + if generate_example_decision_tree_fancy in Options.PERFORM_FLAGS.value: + self.dviz_classification_visualization( + self.x_train, + self.y_train, + self.clf_name, + self.class_names, + self.eval_out_path, + ) + if generate_shap_scores in Options.PERFORM_FLAGS.value: + self.create_shap_log_mp( + ini_file_path=self.config_path, + rf_clf=self.rf_clf, + x_df=self.x_train, + y_df=self.y_train, + x_names=self.feature_names, + clf_name=self.clf_name, + cnt_present=shap_target_present_cnt, + cnt_absent=shap_target_absent_cnt, + save_it=shap_save_n, + save_path=self.eval_out_path, + ) - if save_meta_data in Options.PERFORM_FLAGS.value: - meta_data_lst = [ - self.clf_name, - criterion, - max_features, - min_sample_leaf, - n_estimators, - compute_permutation_importance, - generate_classification_report, - generate_example_decision_tree, - generate_features_importance_bar_graph, - generate_features_importance_log, - generate_precision_recall_curve, - save_meta_data, - generate_learning_curve, - dataset_splits, - shuffle_splits, - feature_importance_bars, - self.over_sample_ratio, - self.over_sample_setting, - self.tt_size, - self.split_type, - self.under_sample_ratio, - self.under_sample_setting, - str(class_weights), - ] + if compute_partial_dependency in Options.PERFORM_FLAGS.value: + self.partial_dependence_calculator( + clf=self.rf_clf, + x_df=self.x_train, + clf_name=self.clf_name, + save_dir=self.eval_out_path, + ) - self.create_meta_data_csv_training_one_model( - meta_data_lst, self.clf_name, self.eval_out_path - ) + if save_meta_data in Options.PERFORM_FLAGS.value: + meta_data_lst = [ + self.clf_name, + criterion, + max_features, + min_sample_leaf, + n_estimators, + compute_permutation_importance, + generate_classification_report, + generate_example_decision_tree, + generate_features_importance_bar_graph, + generate_features_importance_log, + generate_precision_recall_curve, + save_meta_data, + generate_learning_curve, + dataset_splits, + shuffle_splits, + feature_importance_bars, + self.over_sample_ratio, + self.over_sample_setting, + self.tt_size, + self.split_type, + self.under_sample_ratio, + self.under_sample_setting, + str(class_weights), + ] + + self.create_meta_data_csv_training_one_model( + meta_data_lst, self.clf_name, self.eval_out_path + ) def save_model(self) -> None: """ diff --git a/simba/ui/pop_ups/validation_plot_pop_up.py b/simba/ui/pop_ups/validation_plot_pop_up.py index ab39068d2..0c9b73ba7 100644 --- a/simba/ui/pop_ups/validation_plot_pop_up.py +++ b/simba/ui/pop_ups/validation_plot_pop_up.py @@ -1,5 +1,6 @@ __author__ = "Simon Nilsson" +import os from tkinter import * from simba.mixins.config_reader import ConfigReader @@ -16,7 +17,7 @@ class ValidationVideoPopUp(PopUpMixin, ConfigReader): - def __init__(self, config_path: str, simba_main_frm: object): + def __init__(self, config_path: str, simba_main_frm: object, run_on_all=False): PopUpMixin.__init__(self, title="CREATE VALIDATION VIDEO") ConfigReader.__init__(self, config_path=config_path) self.feature_file_path = simba_main_frm.csvfile.file_path @@ -119,11 +120,14 @@ def __init__(self, config_path: str, simba_main_frm: object): self.multiprocess_dropdown.grid(row=0, column=1, sticky=NW) gantt_frm.grid(row=3, column=0, sticky=NW) self.gantt_dropdown.grid(row=0, column=0, sticky=NW) + if run_on_all: + self.create_run_frm(run_function=self.run_on_all) - self.create_run_frm(run_function=self.run) + else: + self.create_run_frm(run_function=self.run) self.main_frm.mainloop() - def run(self): + def create_style_dict(self) -> dict: settings = { "pose": str_2_bool(self.show_pose_dropdown.getChoices()), "animal_names": str_2_bool(self.show_animal_names_dropdown.getChoices()), @@ -136,23 +140,33 @@ def run(self): settings["styles"]["circle size"] = int(self.circle_size.entry_get) settings["styles"]["font size"] = self.font_size_eb.entry_get settings["styles"]["space_scale"] = int(self.spacing_eb.entry_get) + return settings + + def get_short_bout(self): try: self.shortest_bout = int(self.shortest_bout) except ValueError as e: print(e.args[1]) return + def run_on_all(self): + settings = self.create_style_dict() + self.get_short_bout() # check_float(name="MINIMUM BOUT LENGTH", value=self.shortest_bout) check_float( name="DISCRIMINATION THRESHOLD", value=self.discrimination_threshold ) - check_file_exist_and_readable(file_path=self.feature_file_path) check_file_exist_and_readable(file_path=self.model_path) + for root, dir, files in os.walk(self.features_dir): + for file in files: + self.create_video(os.path.join(root, file), settings) + self.root.destroy() + def create_video(self, feature_file: str, settings: dict): if not self.multiprocess_var.get(): validation_video_creator = ValidateModelOneVideo( config_path=self.config_path, - feature_file_path=self.feature_file_path, + feature_file_path=feature_file, model_path=self.model_path, discrimination_threshold=float(self.discrimination_threshold), shortest_bout=int(self.shortest_bout), @@ -163,7 +177,7 @@ def run(self): else: validation_video_creator = ValidateModelOneVideoMultiprocess( config_path=self.config_path, - feature_file_path=self.feature_file_path, + feature_file_path=feature_file, model_path=self.model_path, discrimination_threshold=float(self.discrimination_threshold), shortest_bout=int(self.shortest_bout), @@ -172,4 +186,15 @@ def run(self): create_gantt=self.gantt_dropdown.getChoices(), ) validation_video_creator.run() + + def run(self): + settings = self.create_style_dict() + self.get_short_bout() + # check_float(name="MINIMUM BOUT LENGTH", value=self.shortest_bout) + check_float( + name="DISCRIMINATION THRESHOLD", value=self.discrimination_threshold + ) + check_file_exist_and_readable(file_path=self.feature_file_path) + check_file_exist_and_readable(file_path=self.model_path) + self.create_video(self.feature_file_path, settings) self.root.destroy() diff --git a/simba/utils/enums.py b/simba/utils/enums.py index 2a925e9db..487d2cb66 100644 --- a/simba/utils/enums.py +++ b/simba/utils/enums.py @@ -170,7 +170,7 @@ class Formats(Enum): class Options(Enum): ROLLING_WINDOW_DIVISORS = [2, 5, 6, 7.5, 15] - CLF_MODELS = ["RF", "GBC", "XGBoost"] + CLF_MODELS = ["RF", "imbalanced_rf", "GBC", "XGBoost"] CLF_MAX_FEATURES = ["sqrt", "log", "None"] CLF_CRITERION = ["gini", "entropy"] UNDERSAMPLE_OPTIONS = ["None", "random undersample"] From 8d4888c8492bd6a88c8020f01b042e953951d200 Mon Sep 17 00:00:00 2001 From: tzuk polinsky Date: Thu, 21 Dec 2023 14:59:28 +0200 Subject: [PATCH 10/13] added max depth and cleaned duplicated values in enums and moved all the machine learning vals to a single enum --- simba/mixins/train_model_mixin.py | 20 +-- simba/model/grid_search_rf.py | 246 +++++++++++++------------- simba/model/train_rf.py | 92 +++++----- simba/ui/machine_model_settings_ui.py | 166 ++++++++++------- simba/utils/config_creator.py | 67 +++---- simba/utils/enums.py | 52 ++---- 6 files changed, 338 insertions(+), 305 deletions(-) diff --git a/simba/mixins/train_model_mixin.py b/simba/mixins/train_model_mixin.py index 35288733c..0f4034a6d 100644 --- a/simba/mixins/train_model_mixin.py +++ b/simba/mixins/train_model_mixin.py @@ -53,7 +53,7 @@ from simba.utils.checks import (check_float, check_if_dir_exists, check_int, check_str) from simba.utils.data import create_color_palette, detect_bouts -from simba.utils.enums import ConfigKey, Defaults, Dtypes, MetaKeys +from simba.utils.enums import ConfigKey, Defaults, Dtypes, MachineLearningMetaKeys from simba.utils.errors import (ClassifierInferenceError, ColumnNotFoundError, CorruptedFileError, DataHeaderError, FaultyTrainingSetError, @@ -921,16 +921,16 @@ def print_machine_model_information(self, model_dict: dict) -> None: """ table_view = [ - ["Model name", model_dict[MetaKeys.CLF_NAME.value]], + ["Model name", model_dict[MachineLearningMetaKeys.CLASSIFIER.value]], ["Ensemble method", "RF"], - ["Estimators (trees)", model_dict[MetaKeys.RF_ESTIMATORS.value]], - ["Max features", model_dict[MetaKeys.RF_MAX_FEATURES.value]], - ["Under sampling setting", model_dict[ConfigKey.UNDERSAMPLE_SETTING.value]], - ["Under sampling ratio", model_dict[ConfigKey.UNDERSAMPLE_RATIO.value]], - ["Over sampling setting", model_dict[ConfigKey.OVERSAMPLE_SETTING.value]], - ["Over sampling ratio", model_dict[ConfigKey.OVERSAMPLE_RATIO.value]], - ["criterion", model_dict[MetaKeys.CRITERION.value]], - ["Min sample leaf", model_dict[MetaKeys.MIN_LEAF.value]], + ["Estimators (trees)", model_dict[MachineLearningMetaKeys.RF_ESTIMATORS.value]], + ["Max features", model_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value]], + ["Under sampling setting", model_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value]], + ["Under sampling ratio", model_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value]], + ["Over sampling setting", model_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value]], + ["Over sampling ratio", model_dict[MachineLearningMetaKeys.OVERSAMPLE_RATIO.value]], + ["criterion", model_dict[MachineLearningMetaKeys.RF_CRITERION.value]], + ["Min sample leaf", model_dict[MachineLearningMetaKeys.MIN_LEAF.value]], ] table = tabulate(table_view, ["Setting", "value"], tablefmt="grid") print(f"{table} {Defaults.STR_SPLIT_DELIMITER.value}TABLE") diff --git a/simba/model/grid_search_rf.py b/simba/model/grid_search_rf.py index eb0b071ac..1c7318427 100644 --- a/simba/model/grid_search_rf.py +++ b/simba/model/grid_search_rf.py @@ -11,7 +11,7 @@ from simba.ui.tkinter_functions import TwoOptionQuestionPopUp from simba.utils.checks import (check_float, check_if_filepath_list_is_empty, check_if_valid_input, check_int, check_str) -from simba.utils.enums import ConfigKey, Dtypes, MetaKeys, Methods, Options +from simba.utils.enums import ConfigKey, Dtypes, MachineLearningMetaKeys, Methods, Options from simba.utils.errors import InvalidInputError, NoDataError from simba.utils.printing import stdout_success from simba.utils.read_write import (get_fn_ext, read_config_entry, @@ -68,14 +68,14 @@ def __init__(self, config_path: str): def perform_sampling(self, meta_dict: dict): if ( - meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value] + meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value] == Methods.SPLIT_TYPE_FRAMES.value ): self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( self.x_df, self.y_df, test_size=meta_dict["train_test_size"] ) elif ( - meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value] + meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value] == Methods.SPLIT_TYPE_BOUTS.value ): ( @@ -88,27 +88,27 @@ def perform_sampling(self, meta_dict: dict): ) else: raise InvalidInputError( - msg=f"{meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value]} is not recognized as a valid SPLIT TYPE (OPTIONS: FRAMES, BOUTS" + msg=f"{meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value]} is not recognized as a valid SPLIT TYPE (OPTIONS: FRAMES, BOUTS" ) if ( - meta_dict[ConfigKey.UNDERSAMPLE_SETTING.value].lower() + meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value].lower() == Methods.RANDOM_UNDERSAMPLE.value ): self.x_train, self.y_train = self.random_undersampler( - self.x_train, self.y_train, meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value] + self.x_train, self.y_train, meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value] ) if ( - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower() + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower() == Methods.SMOTEENN.value ): self.x_train, self.y_train = self.smoteen_oversampler( - self.x_train, self.y_train, meta_dict[ConfigKey.OVERSAMPLE_RATIO.value] + self.x_train, self.y_train, meta_dict[MachineLearningMetaKeys.OVERSAMPLE_RATIO.value] ) elif ( - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower() == Methods.SMOTE.value + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower() == Methods.SMOTE.value ): self.x_train, self.y_train = self.smote_oversampler( - self.x_train, self.y_train, meta_dict[ConfigKey.OVERSAMPLE_RATIO.value] + self.x_train, self.y_train, meta_dict[MachineLearningMetaKeys.OVERSAMPLE_RATIO.value] ) def __check_validity_of_meta_files(self, meta_file_paths: list): @@ -120,48 +120,48 @@ def __check_validity_of_meta_files(self, meta_file_paths: list): meta_dict = {k.lower(): v for k, v in meta_dict.items()} errors.append( check_str( - name=meta_dict[MetaKeys.CLF_NAME.value], - value=meta_dict[MetaKeys.CLF_NAME.value], + name=meta_dict[MachineLearningMetaKeys.CLASSIFIER.value], + value=meta_dict[MachineLearningMetaKeys.CLASSIFIER.value], raise_error=False, )[1] ) errors.append( check_str( - name=MetaKeys.CRITERION.value, - value=meta_dict[MetaKeys.CRITERION.value], + name=MachineLearningMetaKeys.RF_CRITERION.value, + value=meta_dict[MachineLearningMetaKeys.RF_CRITERION.value], options=Options.CLF_CRITERION.value, raise_error=False, )[1] ) errors.append( check_str( - name=MetaKeys.RF_MAX_FEATURES.value, - value=meta_dict[MetaKeys.RF_MAX_FEATURES.value], + name=MachineLearningMetaKeys.RF_MAX_FEATURES.value, + value=meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value], options=Options.CLF_MAX_FEATURES.value, raise_error=False, )[1] ) errors.append( check_str( - ConfigKey.UNDERSAMPLE_SETTING.value, - meta_dict[ConfigKey.UNDERSAMPLE_SETTING.value].lower(), + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value, + meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value].lower(), options=[x.lower() for x in Options.UNDERSAMPLE_OPTIONS.value], raise_error=False, )[1] ) errors.append( check_str( - ConfigKey.OVERSAMPLE_SETTING.value, - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower(), + MachineLearningMetaKeys.OVERSAMPLE_SETTING.value, + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower(), options=[x.lower() for x in Options.OVERSAMPLE_OPTIONS.value], raise_error=False, )[1] ) - if MetaKeys.TRAIN_TEST_SPLIT_TYPE.value in meta_dict.keys(): + if MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value in meta_dict.keys(): errors.append( check_str( - name=meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value], - value=meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value], + name=meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value], + value=meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value], options=Options.TRAIN_TEST_SPLIT.value, raise_error=False, )[1] @@ -169,212 +169,212 @@ def __check_validity_of_meta_files(self, meta_file_paths: list): errors.append( check_int( - name=MetaKeys.RF_ESTIMATORS.value, - value=meta_dict[MetaKeys.RF_ESTIMATORS.value], + name=MachineLearningMetaKeys.RF_ESTIMATORS.value, + value=meta_dict[MachineLearningMetaKeys.RF_ESTIMATORS.value], min_value=1, raise_error=False, )[1] ) errors.append( check_int( - name=MetaKeys.MIN_LEAF.value, - value=meta_dict[MetaKeys.MIN_LEAF.value], + name=MachineLearningMetaKeys.MIN_LEAF.value, + value=meta_dict[MachineLearningMetaKeys.MIN_LEAF.value], raise_error=False, )[1] ) - if meta_dict[MetaKeys.LEARNING_CURVE.value] in Options.PERFORM_FLAGS.value: + if meta_dict[MachineLearningMetaKeys.LEARNING_CURVE.value] in Options.PERFORM_FLAGS.value: errors.append( check_int( - name=MetaKeys.LEARNING_CURVE_K_SPLITS.value, - value=meta_dict[MetaKeys.LEARNING_CURVE_K_SPLITS.value], + name=MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, + value=meta_dict[MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value], raise_error=False, )[1] ) errors.append( check_int( - name=MetaKeys.LEARNING_CURVE_DATA_SPLITS.value, - value=meta_dict[MetaKeys.LEARNING_CURVE_DATA_SPLITS.value], + name=MachineLearningMetaKeys.LEARNING_CURVE_DATA_SPLITS.value, + value=meta_dict[MachineLearningMetaKeys.LEARNING_CURVE_DATA_SPLITS.value], raise_error=False, )[1] ) if ( - meta_dict[MetaKeys.IMPORTANCE_BAR_CHART.value] + meta_dict[MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value] in Options.PERFORM_FLAGS.value ): errors.append( check_int( - name=MetaKeys.N_FEATURE_IMPORTANCE_BARS.value, - value=meta_dict[MetaKeys.N_FEATURE_IMPORTANCE_BARS.value], + name=MachineLearningMetaKeys.N_FEATURE_IMPORTANCE_BARS.value, + value=meta_dict[MachineLearningMetaKeys.N_FEATURE_IMPORTANCE_BARS.value], raise_error=False, )[1] ) - if MetaKeys.SHAP_SCORES.value in meta_dict.keys(): - if meta_dict[MetaKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value: + if MachineLearningMetaKeys.SHAP_SCORES.value in meta_dict.keys(): + if meta_dict[MachineLearningMetaKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value: errors.append( check_int( - name=MetaKeys.SHAP_PRESENT.value, - value=meta_dict[MetaKeys.SHAP_PRESENT.value], + name=MachineLearningMetaKeys.SHAP_PRESENT.value, + value=meta_dict[MachineLearningMetaKeys.SHAP_PRESENT.value], raise_error=False, )[1] ) errors.append( check_int( - name=MetaKeys.SHAP_ABSENT.value, - value=meta_dict[MetaKeys.SHAP_ABSENT.value], + name=MachineLearningMetaKeys.SHAP_ABSENT.value, + value=meta_dict[MachineLearningMetaKeys.SHAP_ABSENT.value], raise_error=False, )[1] ) errors.append( check_float( - name=MetaKeys.TT_SIZE.value, - value=meta_dict[MetaKeys.TT_SIZE.value], + name=MachineLearningMetaKeys.TT_SIZE.value, + value=meta_dict[MachineLearningMetaKeys.TT_SIZE.value], raise_error=False, )[1] ) if ( - meta_dict[ConfigKey.UNDERSAMPLE_SETTING.value].lower() + meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value].lower() == Methods.RANDOM_UNDERSAMPLE.value ): errors.append( check_float( - name=ConfigKey.UNDERSAMPLE_RATIO.value, - value=meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value], + name=MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, + value=meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value], raise_error=False, )[1] ) try: present_len, absent_len = len( self.data_df[ - self.data_df[meta_dict[MetaKeys.CLF_NAME.value]] == 1 - ] + self.data_df[meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]] == 1 + ] ), len( self.data_df[ - self.data_df[meta_dict[MetaKeys.CLF_NAME.value]] == 0 - ] + self.data_df[meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]] == 0 + ] ) ratio_n = int( - present_len * meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value] + present_len * meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value] ) if absent_len < ratio_n: errors.append( - f"The under-sample ratio of {meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value]} in \n classifier {meta_dict[MetaKeys.CLF_NAME.value]} demands {ratio_n} behavior-absent annotations." + f"The under-sample ratio of {meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value]} in \n classifier {meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]} demands {ratio_n} behavior-absent annotations." ) except: pass if ( - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower() + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower() == Methods.SMOTEENN.value.lower() ) or ( - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower() + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower() == Methods.SMOTE.value.lower() ): errors.append( check_float( - name=ConfigKey.OVERSAMPLE_RATIO.value, - value=meta_dict[ConfigKey.OVERSAMPLE_RATIO.value], + name=MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, + value=meta_dict[MachineLearningMetaKeys.OVERSAMPLE_RATIO.value], raise_error=False, )[1] ) errors.append( check_if_valid_input( - name=MetaKeys.META_FILE.value, - input=meta_dict[MetaKeys.META_FILE.value], + name=MachineLearningMetaKeys.RF_METADATA.value, + input=meta_dict[MachineLearningMetaKeys.RF_METADATA.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.EX_DECISION_TREE.value, - input=meta_dict[MetaKeys.EX_DECISION_TREE.value], + MachineLearningMetaKeys.EX_DECISION_TREE.value, + input=meta_dict[MachineLearningMetaKeys.EX_DECISION_TREE.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.CLF_REPORT.value, - input=meta_dict[MetaKeys.CLF_REPORT.value], + MachineLearningMetaKeys.CLF_REPORT.value, + input=meta_dict[MachineLearningMetaKeys.CLF_REPORT.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.IMPORTANCE_LOG.value, - input=meta_dict[MetaKeys.IMPORTANCE_LOG.value], + MachineLearningMetaKeys.IMPORTANCE_LOG.value, + input=meta_dict[MachineLearningMetaKeys.IMPORTANCE_LOG.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.IMPORTANCE_BAR_CHART.value, - input=meta_dict[MetaKeys.IMPORTANCE_BAR_CHART.value], + MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value, + input=meta_dict[MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.PERMUTATION_IMPORTANCE.value, - input=meta_dict[MetaKeys.PERMUTATION_IMPORTANCE.value], + MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value, + input=meta_dict[MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.LEARNING_CURVE.value, - input=meta_dict[MetaKeys.LEARNING_CURVE.value], + MachineLearningMetaKeys.LEARNING_CURVE.value, + input=meta_dict[MachineLearningMetaKeys.LEARNING_CURVE.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.PRECISION_RECALL.value, - input=meta_dict[MetaKeys.PRECISION_RECALL.value], + MachineLearningMetaKeys.PRECISION_RECALL.value, + input=meta_dict[MachineLearningMetaKeys.PRECISION_RECALL.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) - if MetaKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys(): + if MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys(): errors.append( check_if_valid_input( - MetaKeys.PARTIAL_DEPENDENCY.value, - input=meta_dict[MetaKeys.PARTIAL_DEPENDENCY.value], + MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value, + input=meta_dict[MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) - if meta_dict[MetaKeys.RF_MAX_FEATURES.value] == Dtypes.NONE.value: - meta_dict[MetaKeys.RF_MAX_FEATURES.value] = None - if MetaKeys.TRAIN_TEST_SPLIT_TYPE.value not in meta_dict.keys(): + if meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value] == Dtypes.NONE.value: + meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value] = None + if MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value not in meta_dict.keys(): meta_dict[ - MetaKeys.TRAIN_TEST_SPLIT_TYPE.value + MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value ] = Methods.SPLIT_TYPE_FRAMES.value - if ConfigKey.CLASS_WEIGHTS.value in meta_dict.keys(): + if MachineLearningMetaKeys.CLASS_WEIGHTS.value in meta_dict.keys(): if ( - meta_dict[ConfigKey.CLASS_WEIGHTS.value] + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] not in Options.CLASS_WEIGHT_OPTIONS.value ): - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = None - if meta_dict[ConfigKey.CLASS_WEIGHTS.value] == "custom": - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = literal_eval( + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = None + if meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] == "custom": + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = literal_eval( meta_dict["class_custom_weights"] ) - for k, v in meta_dict[ConfigKey.CLASS_WEIGHTS.value].items(): - meta_dict[ConfigKey.CLASS_WEIGHTS.value][k] = int(v) - if meta_dict[ConfigKey.CLASS_WEIGHTS.value] == Dtypes.NONE.value: - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = None + for k, v in meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value].items(): + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value][k] = int(v) + if meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] == Dtypes.NONE.value: + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = None else: - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = None + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = None errors = [x for x in errors if x != ""] if errors: @@ -401,22 +401,22 @@ def run(self): if len(self.meta_dicts.keys()) == 0: raise NoDataError(msg="No valid hyper-parameter config files") for config_cnt, meta_dict in self.meta_dicts.items(): - self.clf_name = meta_dict[MetaKeys.CLF_NAME.value] + self.clf_name = meta_dict[MachineLearningMetaKeys.CLASSIFIER.value] print( - f"Training model {config_cnt+1}/{len(self.meta_dicts.keys())} ({meta_dict[MetaKeys.CLF_NAME.value]})..." + f"Training model {config_cnt+1}/{len(self.meta_dicts.keys())} ({meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]})..." ) self.class_names = [ - f"Not_{meta_dict[MetaKeys.CLF_NAME.value]}", - meta_dict[MetaKeys.CLF_NAME.value], + f"Not_{meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]}", + meta_dict[MachineLearningMetaKeys.CLASSIFIER.value], ] annotation_cols_to_remove = self.read_in_all_model_names_to_remove( - self.config, self.clf_cnt, meta_dict[MetaKeys.CLF_NAME.value] + self.config, self.clf_cnt, meta_dict[MachineLearningMetaKeys.CLASSIFIER.value] ) self.x_y_df = self.delete_other_annotation_columns( self.data_df, annotation_cols_to_remove ) self.x_df, self.y_df = self.split_df_to_x_y( - self.x_y_df, [meta_dict[MetaKeys.CLF_NAME.value]] + self.x_y_df, [meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]] ) self.feature_names = self.x_df.columns self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df) @@ -425,14 +425,14 @@ def run(self): self.print_machine_model_information(meta_dict) print("# {} features.".format(len(self.feature_names))) self.rf_clf = RandomForestClassifier( - n_estimators=meta_dict[MetaKeys.RF_ESTIMATORS.value], - max_features=meta_dict[MetaKeys.RF_MAX_FEATURES.value], + n_estimators=meta_dict[MachineLearningMetaKeys.RF_ESTIMATORS.value], + max_features=meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value], n_jobs=-1, - criterion=meta_dict[MetaKeys.CRITERION.value], - min_samples_leaf=meta_dict[MetaKeys.MIN_LEAF.value], + criterion=meta_dict[MachineLearningMetaKeys.RF_CRITERION.value], + min_samples_leaf=meta_dict[MachineLearningMetaKeys.MIN_LEAF.value], bootstrap=True, verbose=1, - class_weight=meta_dict[ConfigKey.CLASS_WEIGHTS.value], + class_weight=meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value], ) print(f"Fitting {self.clf_name} model...") @@ -440,7 +440,7 @@ def run(self): clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train ) if ( - meta_dict[MetaKeys.PERMUTATION_IMPORTANCE.value] + meta_dict[MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value] in Options.PERFORM_FLAGS.value ): self.calc_permutation_importance( @@ -452,19 +452,19 @@ def run(self): self.model_dir_out, save_file_no=config_cnt, ) - if meta_dict[MetaKeys.LEARNING_CURVE.value] in Options.PERFORM_FLAGS.value: + if meta_dict[MachineLearningMetaKeys.LEARNING_CURVE.value] in Options.PERFORM_FLAGS.value: self.calc_learning_curve( self.x_y_df, self.clf_name, - meta_dict[MetaKeys.LEARNING_CURVE_K_SPLITS.value], - meta_dict[MetaKeys.LEARNING_CURVE_DATA_SPLITS.value], - meta_dict[MetaKeys.TT_SIZE.value], + meta_dict[MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value], + meta_dict[MachineLearningMetaKeys.LEARNING_CURVE_DATA_SPLITS.value], + meta_dict[MachineLearningMetaKeys.TT_SIZE.value], self.rf_clf, self.model_dir_out, save_file_no=config_cnt, ) if ( - meta_dict[MetaKeys.PRECISION_RECALL.value] + meta_dict[MachineLearningMetaKeys.PRECISION_RECALL.value] in Options.PERFORM_FLAGS.value ): self.calc_pr_curve( @@ -476,7 +476,7 @@ def run(self): save_file_no=config_cnt, ) if ( - meta_dict[MetaKeys.EX_DECISION_TREE.value] + meta_dict[MachineLearningMetaKeys.EX_DECISION_TREE.value] in Options.PERFORM_FLAGS.value ): self.create_example_dt( @@ -487,7 +487,7 @@ def run(self): self.model_dir_out, save_file_no=config_cnt, ) - if meta_dict[MetaKeys.CLF_REPORT.value] in Options.PERFORM_FLAGS.value: + if meta_dict[MachineLearningMetaKeys.CLF_REPORT.value] in Options.PERFORM_FLAGS.value: self.create_clf_report( self.rf_clf, self.x_test, @@ -496,7 +496,7 @@ def run(self): self.model_dir_out, save_file_no=config_cnt, ) - if meta_dict[MetaKeys.IMPORTANCE_LOG.value] in Options.PERFORM_FLAGS.value: + if meta_dict[MachineLearningMetaKeys.IMPORTANCE_LOG.value] in Options.PERFORM_FLAGS.value: self.create_x_importance_log( self.rf_clf, self.feature_names, @@ -505,7 +505,7 @@ def run(self): save_file_no=config_cnt, ) if ( - meta_dict[MetaKeys.IMPORTANCE_BAR_CHART.value] + meta_dict[MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value] in Options.PERFORM_FLAGS.value ): self.create_x_importance_bar_chart( @@ -513,23 +513,23 @@ def run(self): self.feature_names, self.clf_name, self.model_dir_out, - meta_dict[MetaKeys.N_FEATURE_IMPORTANCE_BARS.value], + meta_dict[MachineLearningMetaKeys.N_FEATURE_IMPORTANCE_BARS.value], save_file_no=config_cnt, ) - if MetaKeys.SHAP_SCORES.value in meta_dict.keys(): + if MachineLearningMetaKeys.SHAP_SCORES.value in meta_dict.keys(): save_n = ( - meta_dict[MetaKeys.SHAP_PRESENT.value] - + meta_dict[MetaKeys.SHAP_ABSENT.value] + meta_dict[MachineLearningMetaKeys.SHAP_PRESENT.value] + + meta_dict[MachineLearningMetaKeys.SHAP_ABSENT.value] ) - if MetaKeys.SHAP_SAVE_ITERATION.value in meta_dict.keys(): + if MachineLearningMetaKeys.SHAP_SAVE_ITERATION.value in meta_dict.keys(): try: - save_n = int(meta_dict[MetaKeys.SHAP_SAVE_ITERATION.value]) + save_n = int(meta_dict[MachineLearningMetaKeys.SHAP_SAVE_ITERATION.value]) except ValueError: save_n = ( - meta_dict[MetaKeys.SHAP_PRESENT.value] - + meta_dict[MetaKeys.SHAP_ABSENT.value] + meta_dict[MachineLearningMetaKeys.SHAP_PRESENT.value] + + meta_dict[MachineLearningMetaKeys.SHAP_ABSENT.value] ) - if meta_dict[MetaKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value: + if meta_dict[MachineLearningMetaKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value: self.create_shap_log_mp( self.config_path, self.rf_clf, @@ -537,15 +537,15 @@ def run(self): self.y_train, self.feature_names, self.clf_name, - meta_dict[MetaKeys.SHAP_PRESENT.value], - meta_dict[MetaKeys.SHAP_ABSENT.value], + meta_dict[MachineLearningMetaKeys.SHAP_PRESENT.value], + meta_dict[MachineLearningMetaKeys.SHAP_ABSENT.value], self.model_dir_out, save_it=save_n, save_file_no=config_cnt, ) - if MetaKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys(): + if MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys(): if ( - meta_dict[MetaKeys.PARTIAL_DEPENDENCY.value] + meta_dict[MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value] in Options.PERFORM_FLAGS.value ): self.partial_dependence_calculator( diff --git a/simba/model/train_rf.py b/simba/model/train_rf.py index a58e4817c..288ab0731 100644 --- a/simba/model/train_rf.py +++ b/simba/model/train_rf.py @@ -11,7 +11,7 @@ from simba.mixins.train_model_mixin import TrainModelMixin from simba.utils.checks import (check_float, check_if_filepath_list_is_empty, check_int) -from simba.utils.enums import ConfigKey, Dtypes, Methods, Options +from simba.utils.enums import ConfigKey, Dtypes, Methods, Options, MachineLearningMetaKeys from simba.utils.printing import SimbaTimer, stdout_success from simba.utils.read_write import read_config_entry from imblearn.ensemble import BalancedRandomForestClassifier @@ -54,25 +54,25 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.clf_name = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASSIFIER.value, + MachineLearningMetaKeys.CLASSIFIER.value, data_type=Dtypes.STR.value, ) self.tt_size = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.TT_SIZE.value, + MachineLearningMetaKeys.TT_SIZE.value, data_type=Dtypes.FLOAT.value, ) self.algo = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MODEL_TO_RUN.value, + MachineLearningMetaKeys.MODEL_TO_RUN.value, data_type=Dtypes.STR.value, ) self.split_type = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SPLIT_TYPE.value, + MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value, data_type=Dtypes.STR.value, options=Options.TRAIN_TEST_SPLIT.value, default_value=Methods.SPLIT_TYPE_FRAMES.value, @@ -81,7 +81,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.UNDERSAMPLE_SETTING.value, + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, ) .lower() @@ -91,7 +91,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.OVERSAMPLE_SETTING.value, + MachineLearningMetaKeys.OVERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, ) .lower() @@ -101,12 +101,12 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.under_sample_ratio = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.UNDERSAMPLE_RATIO.value, + MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, data_type=Dtypes.FLOAT.value, default_value=Dtypes.NAN.value, ) check_float( - name=ConfigKey.UNDERSAMPLE_RATIO.value, value=self.under_sample_ratio + name=MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, value=self.under_sample_ratio ) else: self.under_sample_ratio = Dtypes.NAN.value @@ -116,12 +116,12 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.over_sample_ratio = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.OVERSAMPLE_RATIO.value, + MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, data_type=Dtypes.FLOAT.value, default_value=Dtypes.NAN.value, ) check_float( - name=ConfigKey.OVERSAMPLE_RATIO.value, value=self.over_sample_ratio + name=MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, value=self.over_sample_ratio ) else: self.over_sample_ratio = Dtypes.NAN.value @@ -156,8 +156,9 @@ def __init__(self, config_path: Union[str, os.PathLike]): print( "Number of {} frames in dataset: {} ({}%)".format( self.clf_name, - str(self.y_df[self.y_df == (cls.index(self.clf_name)+1)].sum()), - str(round(self.y_df[self.y_df == (cls.index(self.clf_name)+1)].sum() / len(self.y_df[self.y_df == (cls.index(self.clf_name)+1)]), 4) * 100), + str(self.y_df[self.y_df == (cls.index(self.clf_name) + 1)].sum()), + str(round(self.y_df[self.y_df == (cls.index(self.clf_name) + 1)].sum() / len( + self.y_df[self.y_df == (cls.index(self.clf_name) + 1)]), 4) * 100), ) ) print("Training and evaluating model...") @@ -203,13 +204,19 @@ def train_model(self): n_estimators = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_ESTIMATORS.value, + MachineLearningMetaKeys.RF_ESTIMATORS.value, data_type=Dtypes.INT.value, ) max_features = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_MAX_FEATURES.value, + MachineLearningMetaKeys.RF_MAX_FEATURES.value, + data_type=Dtypes.STR.value, + ) + max_depth = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_MAX_DEPTH.value, data_type=Dtypes.STR.value, ) if max_features == "None": @@ -217,83 +224,83 @@ def train_model(self): criterion = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_CRITERION.value, + MachineLearningMetaKeys.RF_CRITERION.value, data_type=Dtypes.STR.value, options=Options.CLF_CRITERION.value, ) min_sample_leaf = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MIN_LEAF.value, + MachineLearningMetaKeys.MIN_LEAF.value, data_type=Dtypes.INT.value, ) compute_permutation_importance = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PERMUTATION_IMPORTANCE.value, + MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value, data_type=Dtypes.STR.value, default_value=False, ) generate_learning_curve = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE.value, + MachineLearningMetaKeys.LEARNING_CURVE.value, data_type=Dtypes.STR.value, default_value=False, ) generate_precision_recall_curve = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PRECISION_RECALL.value, + MachineLearningMetaKeys.PRECISION_RECALL.value, data_type=Dtypes.STR.value, default_value=False, ) generate_example_decision_tree = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE.value, + MachineLearningMetaKeys.EX_DECISION_TREE.value, data_type=Dtypes.STR.value, default_value=False, ) generate_classification_report = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLF_REPORT.value, + MachineLearningMetaKeys.CLF_REPORT.value, data_type=Dtypes.STR.value, default_value=False, ) generate_features_importance_log = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, + MachineLearningMetaKeys.IMPORTANCE_LOG.value, data_type=Dtypes.STR.value, default_value=False, ) generate_features_importance_bar_graph = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, + MachineLearningMetaKeys.IMPORTANCE_LOG.value, data_type=Dtypes.STR.value, default_value=False, ) generate_example_decision_tree_fancy = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE_FANCY.value, + MachineLearningMetaKeys.EX_DECISION_TREE_FANCY.value, data_type=Dtypes.STR.value, default_value=False, ) generate_shap_scores = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SCORES.value, + MachineLearningMetaKeys.SHAP_SCORES.value, data_type=Dtypes.STR.value, default_value=False, ) save_meta_data = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_METADATA.value, + MachineLearningMetaKeys.RF_METADATA.value, data_type=Dtypes.STR.value, default_value=False, ) @@ -306,12 +313,12 @@ def train_model(self): ) if self.config.has_option( - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MachineLearningMetaKeys.CLASS_WEIGHTS.value ): class_weights = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASS_WEIGHTS.value, + MachineLearningMetaKeys.CLASS_WEIGHTS.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value, ) @@ -320,7 +327,7 @@ def train_model(self): read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CUSTOM_WEIGHTS.value, + MachineLearningMetaKeys.CLASS_CUSTOM_WEIGHTS.value, data_type=Dtypes.STR.value, ) ) @@ -335,22 +342,22 @@ def train_model(self): shuffle_splits = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE_K_SPLITS.value, + MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, data_type=Dtypes.INT.value, default_value=Dtypes.NAN.value, ) dataset_splits = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_DATA_SPLITS.value, + MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value, data_type=Dtypes.INT.value, default_value=Dtypes.NAN.value, ) check_int( - name=ConfigKey.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits + name=MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits ) check_int( - name=ConfigKey.LEARNING_DATA_SPLITS.value, value=dataset_splits + name=MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value, value=dataset_splits ) else: shuffle_splits, dataset_splits = Dtypes.NAN.value, Dtypes.NAN.value @@ -358,12 +365,12 @@ def train_model(self): feature_importance_bars = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_BARS_N.value, + MachineLearningMetaKeys.IMPORTANCE_BARS_N.value, Dtypes.INT.value, Dtypes.NAN.value, ) check_int( - name=ConfigKey.IMPORTANCE_BARS_N.value, + name=MachineLearningMetaKeys.IMPORTANCE_BARS_N.value, value=feature_importance_bars, min_value=1, ) @@ -378,21 +385,21 @@ def train_model(self): shap_target_present_cnt = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_PRESENT.value, + MachineLearningMetaKeys.SHAP_PRESENT.value, data_type=Dtypes.INT.value, default_value=0, ) shap_target_absent_cnt = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_ABSENT.value, + MachineLearningMetaKeys.SHAP_ABSENT.value, data_type=Dtypes.INT.value, default_value=0, ) shap_save_n = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SAVE_ITERATION.value, + MachineLearningMetaKeys.SHAP_SAVE_ITERATION.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value, ) @@ -401,10 +408,10 @@ def train_model(self): except ValueError: shap_save_n = shap_target_present_cnt + shap_target_absent_cnt check_int( - name=ConfigKey.SHAP_PRESENT.value, value=shap_target_present_cnt + name=MachineLearningMetaKeys.SHAP_PRESENT.value, value=shap_target_present_cnt ) check_int( - name=ConfigKey.SHAP_ABSENT.value, value=shap_target_absent_cnt + name=MachineLearningMetaKeys.SHAP_ABSENT.value, value=shap_target_absent_cnt ) print(f"Fitting {self.clf_name} model...") if self.algo == "RF": @@ -412,6 +419,7 @@ def train_model(self): n_estimators=n_estimators, max_features=max_features, n_jobs=-1, + max_depth=max_depth, criterion=criterion, min_samples_leaf=min_sample_leaf, bootstrap=True, @@ -426,7 +434,7 @@ def train_model(self): self.rf_clf = BalancedRandomForestClassifier( n_estimators=n_estimators, max_features=max_features, - max_depth=7, + max_depth=max_depth, n_jobs=-1, criterion=criterion, min_samples_leaf=min_sample_leaf, diff --git a/simba/ui/machine_model_settings_ui.py b/simba/ui/machine_model_settings_ui.py index 3d275cc6d..d9754c5a9 100644 --- a/simba/ui/machine_model_settings_ui.py +++ b/simba/ui/machine_model_settings_ui.py @@ -13,11 +13,11 @@ Entry_Box, FileSelect, hxtScrollbar) from simba.utils.checks import (check_file_exist_and_readable, check_float, check_int) -from simba.utils.enums import Formats, Keys, Links, Options +from simba.utils.enums import Formats, Keys, Links, Options, ConfigKey, Dtypes, MachineLearningMetaKeys from simba.utils.errors import InvalidHyperparametersFileError from simba.utils.printing import stdout_success, stdout_trash, stdout_warning from simba.utils.read_write import (find_files_of_filetypes_in_directory, - get_fn_ext) + get_fn_ext, read_config_entry) class MachineModelSettingsPopUp(PopUpMixin, ConfigReader): @@ -90,30 +90,78 @@ def __init__(self, config_path: str): "25", validation="numeric", ) - self.estimators_entrybox.entry_set(val=2000) + n_estimators = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_ESTIMATORS.value, + data_type=Dtypes.INT.value, + ) + + self.estimators_entrybox.entry_set(val=n_estimators) + max_features = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_MAX_FEATURES.value, + data_type=Dtypes.STR.value, + ) + self.max_features_dropdown = DropDownMenu( self.hyperparameters_frm, "Max features: ", self.max_features_options, "25" ) - self.max_features_dropdown.setChoices(self.max_features_options[0]) + self.max_features_dropdown.setChoices(max_features) self.criterion_dropdown = DropDownMenu( self.hyperparameters_frm, "Criterion: ", self.criterion_options, "25" ) self.criterion_dropdown.setChoices(self.criterion_options[0]) + train_test_size = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.TT_SIZE.value, + data_type=Dtypes.STR.value, + ) self.train_test_size_dropdown = DropDownMenu( self.hyperparameters_frm, "Test Size: ", self.train_test_sizes_options, "25" ) - self.train_test_size_dropdown.setChoices("0.2") + self.train_test_size_dropdown.setChoices(str(train_test_size)) + train_test_split_type = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value, + data_type=Dtypes.STR.value, + ) self.train_test_type_dropdown = DropDownMenu( self.hyperparameters_frm, "Train-test Split Type: ", Options.TRAIN_TEST_SPLIT.value, "25", ) - self.train_test_type_dropdown.setChoices(Options.TRAIN_TEST_SPLIT.value[0]) + self.train_test_type_dropdown.setChoices(str(train_test_split_type)) + max_depth = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_MAX_DEPTH.value, + data_type=Dtypes.STR.value, + ) + self.max_depth_eb = Entry_Box( + self.hyperparameters_frm, "Max Depth: ", "25", validation="numeric" + ) + self.max_depth_eb.entry_set(val=max_depth) + min_sample_leaf = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.MIN_LEAF.value, + data_type=Dtypes.INT.value, + ) self.min_sample_leaf_eb = Entry_Box( self.hyperparameters_frm, "Minimum sample leaf", "25", validation="numeric" ) - self.min_sample_leaf_eb.entry_set(val=1) + self.min_sample_leaf_eb.entry_set(val=min_sample_leaf) + undersample_settings = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value, + data_type=Dtypes.STR.value, + ) self.under_sample_ratio_entrybox = Entry_Box( self.hyperparameters_frm, "UNDER-sample ratio: ", "25", status=DISABLED ) @@ -126,7 +174,7 @@ def __init__(self, config_path: str): self.under_sample_ratio_entrybox, self.undersample_settings_dropdown ), ) - self.undersample_settings_dropdown.setChoices("None") + self.undersample_settings_dropdown.setChoices(undersample_settings) self.over_sample_ratio_entrybox = Entry_Box( self.hyperparameters_frm, "OVER-sample ratio: ", "25", status=DISABLED ) @@ -317,16 +365,17 @@ def __init__(self, config_path: str): self.hyperparameters_frm.grid(row=3, column=0, sticky=NW) self.estimators_entrybox.grid(row=0, column=0, sticky=NW) - self.max_features_dropdown.grid(row=1, column=0, sticky=NW) - self.criterion_dropdown.grid(row=2, column=0, sticky=NW) - self.train_test_size_dropdown.grid(row=3, column=0, sticky=NW) - self.train_test_type_dropdown.grid(row=4, column=0, sticky=NW) - self.min_sample_leaf_eb.grid(row=5, column=0, sticky=NW) - self.undersample_settings_dropdown.grid(row=6, column=0, sticky=NW) - self.under_sample_ratio_entrybox.grid(row=7, column=0, sticky=NW) - self.oversample_settings_dropdown.grid(row=8, column=0, sticky=NW) - self.over_sample_ratio_entrybox.grid(row=9, column=0, sticky=NW) - self.class_weights_dropdown.grid(row=10, column=0, sticky=NW) + self.max_depth_eb.grid(row=1, column=0, sticky=NW) + self.max_features_dropdown.grid(row=2, column=0, sticky=NW) + self.criterion_dropdown.grid(row=3, column=0, sticky=NW) + self.train_test_size_dropdown.grid(row=4, column=0, sticky=NW) + self.train_test_type_dropdown.grid(row=5, column=0, sticky=NW) + self.min_sample_leaf_eb.grid(row=6, column=0, sticky=NW) + self.undersample_settings_dropdown.grid(row=7, column=0, sticky=NW) + self.under_sample_ratio_entrybox.grid(row=8, column=0, sticky=NW) + self.oversample_settings_dropdown.grid(row=9, column=0, sticky=NW) + self.over_sample_ratio_entrybox.grid(row=10, column=0, sticky=NW) + self.class_weights_dropdown.grid(row=11, column=0, sticky=NW) self.evaluations_frm.grid(row=4, column=0, sticky=NW) self.meta_data_file_cb.grid(row=0, column=0, sticky=NW) @@ -669,7 +718,7 @@ def save_config(self): save_path = os.path.join(self.configs_meta_dir, file_name) meta_df.to_csv(save_path, index=FALSE) stdout_success( - msg=f"Hyper-parameter config saved ({str(len(self.total_meta_files)+1)} saved in project_folder/configs folder)." + msg=f"Hyper-parameter config saved ({str(len(self.total_meta_files) + 1)} saved in project_folder/configs folder)." ) def clear_cache(self): @@ -682,7 +731,7 @@ def clear_cache(self): def check_meta_data_integrity(self): self.meta = {k.lower(): v for k, v in self.meta.items()} - for i in self.expected_meta_dict_entries: + for i in MachineLearningMetaKeys: if i not in self.meta.keys(): stdout_warning( msg=f"The file does not contain an expected entry for {i} parameter" @@ -708,88 +757,84 @@ def load_config(self): self.meta = {} for m in meta_df.columns: self.meta[m] = meta_df[m][0] - self.get_expected_meta_dict_entry_keys() self.check_meta_data_integrity() - self.behavior_name_dropdown.setChoices(self.meta["classifier_name"]) - self.estimators_entrybox.entry_set(val=self.meta["rf_n_estimators"]) - self.max_features_dropdown.setChoices(self.meta["rf_max_features"]) - self.criterion_dropdown.setChoices(self.meta["rf_criterion"]) - self.train_test_size_dropdown.setChoices(self.meta["train_test_size"]) - self.min_sample_leaf_eb.entry_set(val=self.meta["rf_min_sample_leaf"]) - self.undersample_settings_dropdown.setChoices(self.meta["under_sample_setting"]) + self.behavior_name_dropdown.setChoices(self.meta[MachineLearningMetaKeys.CLASSIFIER]) + self.estimators_entrybox.entry_set(val=self.meta[MachineLearningMetaKeys.RF_ESTIMATORS]) + self.max_features_dropdown.setChoices(self.meta[MachineLearningMetaKeys.RF_MAX_FEATURES]) + self.criterion_dropdown.setChoices(self.meta[MachineLearningMetaKeys.RF_CRITERION]) + self.train_test_size_dropdown.setChoices(self.meta[MachineLearningMetaKeys.TT_SIZE]) + self.min_sample_leaf_eb.entry_set(val=self.meta[MachineLearningMetaKeys.MIN_LEAF]) + self.max_depth_eb.entry_set(val=self.meta[MachineLearningMetaKeys.RF_MAX_DEPTH]) + self.undersample_settings_dropdown.setChoices(self.meta[MachineLearningMetaKeys.UNDER_SAMPLE_SETTING]) if self.undersample_settings_dropdown.getChoices() != "None": self.under_sample_ratio_entrybox.entry_set( - val=self.meta["under_sample_ratio"] + val=self.meta[MachineLearningMetaKeys.UNDER_SAMPLE_RATIO] ) self.under_sample_ratio_entrybox.set_state(NORMAL) else: self.under_sample_ratio_entrybox.set_state(DISABLED) - self.oversample_settings_dropdown.setChoices(self.meta["over_sample_setting"]) + self.oversample_settings_dropdown.setChoices(self.meta[MachineLearningMetaKeys.OVER_SAMPLE_SETTING]) if self.oversample_settings_dropdown.getChoices() != "None": self.over_sample_ratio_entrybox.entry_set( - val=self.meta["over_sample_ratio"] + val=self.meta[MachineLearningMetaKeys.OVER_SAMPLE_RATIO] ) self.over_sample_ratio_entrybox.set_state(NORMAL) else: self.over_sample_ratio_entrybox.set_state(DISABLED) - if self.meta["generate_rf_model_meta_data_file"]: + if self.meta[MachineLearningMetaKeys.RF_METADATA]: self.create_meta_data_file_var.set(value=True) else: self.create_meta_data_file_var.set(value=False) - if self.meta["generate_example_decision_tree"]: + if self.meta[MachineLearningMetaKeys.EX_DECISION_TREE]: self.create_example_decision_tree_graphviz_var.set(value=True) else: self.create_example_decision_tree_graphviz_var.set(value=False) - if self.meta["generate_example_decision_tree"]: - self.create_example_decision_tree_graphviz_var.set(value=True) - else: - self.create_example_decision_tree_graphviz_var.set(value=False) - if self.meta["generate_classification_report"]: + if self.meta[MachineLearningMetaKeys.CLF_REPORT]: self.create_clf_report_var.set(value=True) else: self.create_clf_report_var.set(value=False) if ( - self.meta["generate_features_importance_log"] - or self.meta["generate_features_importance_bar_graph"] + self.meta[MachineLearningMetaKeys.IMPORTANCE_LOG] + or self.meta[MachineLearningMetaKeys.IMPORTANCE_BAR_CHART] ): self.create_clf_importance_bars_var.set(value=True) self.n_features_bars_entry_box.set_state(NORMAL) self.n_features_bars_entry_box.entry_set( - val=self.meta["n_feature_importance_bars"] + val=self.meta[MachineLearningMetaKeys.N_FEATURE_IMPORTANCE_BARS] ) else: self.create_clf_importance_bars_var.set(value=False) self.n_features_bars_entry_box.set_state(DISABLED) - if self.meta["compute_feature_permutation_importance"]: + if self.meta[MachineLearningMetaKeys.PERMUTATION_IMPORTANCE]: self.feature_permutation_importance_var.set(value=True) - if self.meta["generate_sklearn_learning_curves"]: + if self.meta[MachineLearningMetaKeys.LEARNING_CURVE]: self.learning_curve_var.set(value=True) self.learning_curve_k_splits_entry_box.set_state(NORMAL) self.learning_curve_data_splits_entry_box.set_state(NORMAL) self.learning_curve_k_splits_entry_box.entry_set( - val=self.meta["learning_curve_k_splits"] + val=self.meta[MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS] ) self.learning_curve_data_splits_entry_box.entry_set( - val=self.meta["learning_curve_data_splits"] + val=self.meta[MachineLearningMetaKeys.LEARNING_CURVE_DATA_SPLITS] ) else: self.learning_curve_var.set(value=False) self.learning_curve_k_splits_entry_box.set_state(DISABLED) self.learning_curve_data_splits_entry_box.set_state(DISABLED) - if self.meta["generate_shap_scores"]: + if self.meta[MachineLearningMetaKeys.SHAP_SCORES]: self.calc_shap_scores_var.set(value=True) self.shap_present.set_state(NORMAL) self.shap_absent.set_state(NORMAL) self.shap_absent.set_state(NORMAL) self.shap_save_it_dropdown.enable() - self.shap_present.entry_set(val=self.meta["shap_target_present_no"]) - self.shap_absent.entry_set(val=self.meta["shap_target_absent_no"]) + self.shap_present.entry_set(val=self.meta[MachineLearningMetaKeys.SHAP_PRESENT]) + self.shap_absent.entry_set(val=self.meta[MachineLearningMetaKeys.SHAP_ABSENT]) if "shap_save_iteration" in self.meta.keys(): - self.shap_save_it_dropdown.setChoices(self.meta["shap_save_iteration"]) + self.shap_save_it_dropdown.setChoices(self.meta[MachineLearningMetaKeys.SHAP_SAVE_ITERATION]) else: self.shap_save_it_dropdown.setChoices("ALL FRAMES") else: @@ -798,24 +843,22 @@ def load_config(self): self.shap_absent.set_state(DISABLED) self.shap_save_it_dropdown.enable() - if "train_test_split_type" in self.meta.keys(): - self.train_test_type_dropdown.setChoices(self.meta["train_test_split_type"]) + if MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE in self.meta.keys(): + self.train_test_type_dropdown.setChoices(self.meta[MachineLearningMetaKeys.TT_SIZE]) else: self.train_test_type_dropdown.setChoices(Options.TRAIN_TEST_SPLIT.value[0]) - if "shap_save_iteration" in self.meta.keys(): - self.shap_save_it_dropdown.setChoices(self.meta["shap_save_iteration"]) - if "partial_dependency" in self.meta.keys(): - if self.meta["partial_dependency"] in Options.RUN_OPTIONS_FLAGS.value: + if MachineLearningMetaKeys.PARTIAL_DEPENDENCY in self.meta.keys(): + if self.meta[MachineLearningMetaKeys.PARTIAL_DEPENDENCY] in Options.RUN_OPTIONS_FLAGS.value: self.partial_dependency_var.set(value=True) else: self.shap_save_it_dropdown.setChoices("None") - if "class_weights" in self.meta.keys(): - if self.meta["class_weights"] not in Options.CLASS_WEIGHT_OPTIONS.value: - self.meta["class_weights"] = "None" - self.class_weights_dropdown.setChoices(self.meta["class_weights"]) - if self.meta["class_weights"] == "custom": + if MachineLearningMetaKeys.CLASS_WEIGHTS in self.meta.keys(): + if self.meta[MachineLearningMetaKeys.CLASS_WEIGHTS] not in Options.CLASS_WEIGHT_OPTIONS.value: + self.meta[MachineLearningMetaKeys.CLASS_WEIGHTS] = "None" + self.class_weights_dropdown.setChoices(self.meta[MachineLearningMetaKeys.CLASS_WEIGHTS]) + if self.meta[MachineLearningMetaKeys.CLASS_WEIGHTS] == "custom": self.create_class_weight_table() - weights = ast.literal_eval(self.meta["class_custom_weights"]) + weights = ast.literal_eval(self.meta[MachineLearningMetaKeys.CLASS_CUSTOM_WEIGHTS]) self.weight_present.setChoices(weights[1]) self.weight_absent.setChoices(weights[0]) @@ -858,5 +901,4 @@ def get_expected_meta_dict_entry_keys(self): "class_custom_weights", ] - # _ = MachineModelSettingsPopUp(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') diff --git a/simba/utils/config_creator.py b/simba/utils/config_creator.py index 98f338891..5152b50b2 100644 --- a/simba/utils/config_creator.py +++ b/simba/utils/config_creator.py @@ -7,7 +7,7 @@ from typing import List import simba -from simba.utils.enums import ConfigKey, DirNames, Dtypes, Paths +from simba.utils.enums import ConfigKey, DirNames, Dtypes, Paths, MachineLearningMetaKeys from simba.utils.errors import DirectoryExistError from simba.utils.printing import SimbaTimer, stdout_success @@ -36,14 +36,14 @@ class ProjectConfigCreator(object): """ def __init__( - self, - project_path: str, - project_name: str, - target_list: List[str], - pose_estimation_bp_cnt: str, - body_part_config_idx: int, - animal_cnt: int, - file_type: str = "csv", + self, + project_path: str, + project_name: str, + target_list: List[str], + pose_estimation_bp_cnt: str, + body_part_config_idx: int, + animal_cnt: int, + file_type: str = "csv", ): self.simba_dir = os.path.dirname(simba.__file__) self.animal_cnt = animal_cnt @@ -185,76 +185,78 @@ def __create_configparser_config(self): self.config.add_section(ConfigKey.ROI_SETTINGS.value) self.config.add_section(ConfigKey.DIRECTIONALITY_SETTINGS.value) self.config.add_section(ConfigKey.PROCESS_MOVEMENT_SETTINGS.value) - self.config.add_section(ConfigKey.CREATE_ENSEMBLE_SETTINGS.value) self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ ConfigKey.POSE_SETTING.value ] = str(self.pose_estimation_bp_cnt) self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.CLASSIFIER.value + MachineLearningMetaKeys.CLASSIFIER.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.TT_SIZE.value + MachineLearningMetaKeys.TT_SIZE.value ] = str(0.20) self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.UNDERSAMPLE_SETTING.value + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.UNDERSAMPLE_RATIO.value + MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.OVERSAMPLE_SETTING.value + MachineLearningMetaKeys.OVERSAMPLE_SETTING.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.OVERSAMPLE_RATIO.value + MachineLearningMetaKeys.OVERSAMPLE_RATIO.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.RF_ESTIMATORS.value + MachineLearningMetaKeys.RF_ESTIMATORS.value ] = str(2000) self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.MIN_LEAF.value + MachineLearningMetaKeys.MIN_LEAF.value ] = str(1) self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.RF_MAX_FEATURES.value + MachineLearningMetaKeys.RF_MAX_DEPTH.value + ] = str(7) + self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ + MachineLearningMetaKeys.RF_MAX_FEATURES.value ] = Dtypes.SQRT.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ ConfigKey.RF_JOBS.value ] = str(-1) self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.RF_CRITERION.value + MachineLearningMetaKeys.RF_CRITERION.value ] = Dtypes.ENTROPY.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.RF_METADATA.value + MachineLearningMetaKeys.RF_METADATA.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.EX_DECISION_TREE.value + MachineLearningMetaKeys.EX_DECISION_TREE.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.EX_DECISION_TREE_FANCY.value + MachineLearningMetaKeys.EX_DECISION_TREE_FANCY.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.IMPORTANCE_LOG.value + MachineLearningMetaKeys.IMPORTANCE_LOG.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.IMPORTANCE_BAR_CHART.value + MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.PERMUTATION_IMPORTANCE.value + MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.LEARNING_CURVE.value + MachineLearningMetaKeys.LEARNING_CURVE.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.PRECISION_RECALL.value + MachineLearningMetaKeys.PRECISION_RECALL.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.IMPORTANCE_BARS_N.value + MachineLearningMetaKeys.IMPORTANCE_BARS_N.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.LEARNING_CURVE_K_SPLITS.value + MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value ] = Dtypes.NONE.value self.config[ConfigKey.CREATE_ENSEMBLE_SETTINGS.value][ - ConfigKey.LEARNING_DATA_SPLITS.value + MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value ] = Dtypes.NONE.value self.config.add_section(ConfigKey.MULTI_ANIMAL_ID_SETTING.value) @@ -269,7 +271,8 @@ def __create_configparser_config(self): self.config[ConfigKey.OUTLIER_SETTINGS.value][ ConfigKey.LOCATION_CRITERION.value ] = Dtypes.NONE.value - self.config[ConfigKey.DIRECTIONALITY_SETTINGS.value][ConfigKey.BODYPART_DIRECTION_VALUE.value] = Dtypes.NONE.value + self.config[ConfigKey.DIRECTIONALITY_SETTINGS.value][ + ConfigKey.BODYPART_DIRECTION_VALUE.value] = Dtypes.NONE.value self.config_path = os.path.join(self.project_folder, "project_config.ini") with open(self.config_path, "w") as file: self.config.write(file) diff --git a/simba/utils/enums.py b/simba/utils/enums.py index 487d2cb66..0c88c48de 100644 --- a/simba/utils/enums.py +++ b/simba/utils/enums.py @@ -38,37 +38,6 @@ class ConfigKey(Enum): MULTI_ANIMAL_ID_SETTING = "Multi animal IDs" MULTI_ANIMAL_IDS = "ID_list" OUTLIER_SETTINGS = "Outlier settings" - CLASS_WEIGHTS = "class_weights" - CUSTOM_WEIGHTS = "custom_weights" - CLASSIFIER = "classifier" - TT_SIZE = "train_test_size" - MODEL_TO_RUN = "model_to_run" - UNDERSAMPLE_SETTING = "under_sample_setting" - OVERSAMPLE_SETTING = "over_sample_setting" - UNDERSAMPLE_RATIO = "under_sample_ratio" - OVERSAMPLE_RATIO = "over_sample_ratio" - RF_ESTIMATORS = "RF_n_estimators" - RF_MAX_FEATURES = "RF_max_features" - RF_CRITERION = "RF_criterion" - MIN_LEAF = "RF_min_sample_leaf" - PERMUTATION_IMPORTANCE = "compute_permutation_importance" - LEARNING_CURVE = "generate_learning_curve" - PRECISION_RECALL = "generate_precision_recall_curve" - EX_DECISION_TREE = "generate_example_decision_tree" - EX_DECISION_TREE_FANCY = "generate_example_decision_tree_fancy" - CLF_REPORT = "generate_classification_report" - IMPORTANCE_LOG = "generate_features_importance_log" - PARTIAL_DEPENDENCY = "partial_dependency" - IMPORTANCE_BAR_CHART = "generate_features_importance_bar_graph" - SHAP_SCORES = "generate_shap_scores" - RF_METADATA = "RF_meta_data" - LEARNING_CURVE_K_SPLITS = "LearningCurve_shuffle_k_splits" - LEARNING_DATA_SPLITS = "LearningCurve_shuffle_data_splits" - IMPORTANCE_BARS_N = "N_feature_importance_bars" - SHAP_PRESENT = "shap_target_present_no" - SHAP_ABSENT = "shap_target_absent_no" - SHAP_SAVE_ITERATION = "shap_save_iteration" - SHAP_MULTIPROCESS = "shap_multiprocess" POSE_SETTING = "pose_estimation_body_parts" RF_JOBS = "RF_n_jobs" VALIDATION_VIDEO = "generate_validation_video" @@ -77,7 +46,6 @@ class ConfigKey(Enum): ROI_ANIMAL_CNT = "no_of_animals" DISTANCE_MM = "distance_mm" SKLEARN_BP_PROB_THRESH = "bp_threshold_sklearn" - SPLIT_TYPE = "train_test_split_type" class Paths(Enum): @@ -423,13 +391,13 @@ class Methods(Enum): THIRD_PARTY_ANNOTATION_FILE_NOT_FOUND = "Annotations data file NOT FOUND" -class MetaKeys(Enum): - CLF_NAME = "classifier_name" +class MachineLearningMetaKeys(Enum): + CLASSIFIER = "classifier" RF_ESTIMATORS = "rf_n_estimators" - CRITERION = "rf_criterion" + RF_CRITERION = "rf_criterion" TT_SIZE = "train_test_size" MIN_LEAF = "rf_min_sample_leaf" - META_FILE = "generate_rf_model_meta_data_file" + RF_METADATA = "generate_rf_model_meta_data_file" EX_DECISION_TREE = "generate_example_decision_tree" CLF_REPORT = "generate_classification_report" IMPORTANCE_LOG = "generate_features_importance_log" @@ -438,6 +406,7 @@ class MetaKeys(Enum): LEARNING_CURVE = "generate_sklearn_learning_curves" PRECISION_RECALL = "generate_precision_recall_curves" RF_MAX_FEATURES = "rf_max_features" + RF_MAX_DEPTH = "rf_max_depth" LEARNING_CURVE_K_SPLITS = "learning_curve_k_splits" LEARNING_CURVE_DATA_SPLITS = "learning_curve_data_splits" N_FEATURE_IMPORTANCE_BARS = "n_feature_importance_bars" @@ -447,6 +416,17 @@ class MetaKeys(Enum): SHAP_SAVE_ITERATION = "shap_save_iteration" PARTIAL_DEPENDENCY = "partial_dependency" TRAIN_TEST_SPLIT_TYPE = "train_test_split_type" + UNDERSAMPLE_SETTING = "under_sample_setting" + UNDERSAMPLE_RATIO = "under_sample_ratio" + OVERSAMPLE_SETTING = "over_sample_setting" + OVERSAMPLE_RATIO = "over_sample_ratio" + CLASS_WEIGHTS = "class_weights" + CLASS_CUSTOM_WEIGHTS = "class_custom_weights" + EX_DECISION_TREE_FANCY = "generate_example_decision_tree_fancy" + IMPORTANCE_BARS_N = "N_feature_importance_bars" + LEARNING_DATA_SPLITS = "LearningCurve_shuffle_data_splits" + MODEL_TO_RUN = "model_to_run" + class OS(Enum): From d7c31ff935ee88aa7688a8d4af57617e4e82ef9c Mon Sep 17 00:00:00 2001 From: tzuk polinsky Date: Thu, 21 Dec 2023 15:22:04 +0200 Subject: [PATCH 11/13] trying to merge --- simba/utils/enums.py | 1 - 1 file changed, 1 deletion(-) diff --git a/simba/utils/enums.py b/simba/utils/enums.py index 0c88c48de..6016f41c2 100644 --- a/simba/utils/enums.py +++ b/simba/utils/enums.py @@ -428,7 +428,6 @@ class MachineLearningMetaKeys(Enum): MODEL_TO_RUN = "model_to_run" - class OS(Enum): WINDOWS = "Windows" LINUX = "Linux" From a43553bc958087eead5c8c9a01f9875f995201ad Mon Sep 17 00:00:00 2001 From: tzuk polinsky Date: Sat, 23 Dec 2023 17:27:08 +0200 Subject: [PATCH 12/13] trying to merge --- simba/mixins/train_model_mixin.py | 670 ++++++++++++++-------------- simba/model/train_rf.py | 696 +++++++++++++++--------------- simba/utils/enums.py | 60 +-- 3 files changed, 698 insertions(+), 728 deletions(-) diff --git a/simba/mixins/train_model_mixin.py b/simba/mixins/train_model_mixin.py index bcad20232..4949186d1 100644 --- a/simba/mixins/train_model_mixin.py +++ b/simba/mixins/train_model_mixin.py @@ -1,10 +1,5 @@ __author__ = "Simon Nilsson" - -import warnings - -warnings.simplefilter(action="ignore", category=FutureWarning) - import ast import concurrent import configparser @@ -55,7 +50,7 @@ from simba.utils.checks import (check_float, check_if_dir_exists, check_if_valid_input, check_int, check_str) from simba.utils.data import create_color_palette, detect_bouts -from simba.utils.enums import (ConfigKey, Defaults, Dtypes, MetaKeys, Methods, +from simba.utils.enums import (ConfigKey, Defaults, Dtypes, MachineLearningMetaKeys, Methods, Options) from simba.utils.errors import (ClassifierInferenceError, ColumnNotFoundError, CorruptedFileError, DataHeaderError, @@ -73,6 +68,9 @@ NoModuleWarning, NotEnoughDataWarning, SamplingWarning, ShapWarning) +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) plt.switch_backend("agg") @@ -83,11 +81,11 @@ def __init__(self): pass def read_all_files_in_folder( - self, - file_paths: List[str], - file_type: str, - classifier_names: Optional[List[str]] = None, - raise_bool_clf_error: bool = True, + self, + file_paths: List[str], + file_type: str, + classifier_names: Optional[List[str]] = None, + raise_bool_clf_error: bool = True, ) -> pd.DataFrame: """ Read in all data files in a folder to a single pd.DataFrame for downstream ML algo. @@ -126,8 +124,8 @@ def read_all_files_in_folder( source=self.__class__.__name__, ) elif ( - len(set(df[clf_name].unique()) - {0, 1}) > 0 - and raise_bool_clf_error + len(set(df[clf_name].unique()) - {0, 1}) > 0 + and raise_bool_clf_error ): raise InvalidInputError( msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.", @@ -147,8 +145,8 @@ def read_all_files_in_folder( source=self.__class__.__name__, ) df_concat = df_concat.loc[ - :, ~df_concat.columns.str.contains("^Unnamed") - ].fillna(0) + :, ~df_concat.columns.str.contains("^Unnamed") + ].fillna(0) timer.stop_timer() memory_size = get_memory_usage_of_df(df=df_concat) print( @@ -162,7 +160,7 @@ def read_all_files_in_folder( return df_concat.astype(np.float32) def read_in_all_model_names_to_remove( - self, config: configparser.ConfigParser, model_cnt: int, clf_name: str + self, config: configparser.ConfigParser, model_cnt: int, clf_name: str ) -> List[str]: """ Helper to find all field names that are annotations but are not the target. @@ -186,7 +184,7 @@ def read_in_all_model_names_to_remove( return annotation_cols_to_remove def delete_other_annotation_columns( - self, df: pd.DataFrame, annotations_lst: List[str], raise_error: bool = True + self, df: pd.DataFrame, annotations_lst: List[str], raise_error: bool = True ) -> pd.DataFrame: """ Helper to drop fields that contain annotations which are not the target. @@ -212,7 +210,7 @@ def delete_other_annotation_columns( return df def split_df_to_x_y( - self, df: pd.DataFrame, clf_name: str + self, df: pd.DataFrame, clf_name: str ) -> (pd.DataFrame, pd.DataFrame): """ Helper to split dataframe into features and target. @@ -232,7 +230,7 @@ def split_df_to_x_y( return df, y def random_undersampler( - self, x_train: np.ndarray, y_train: np.ndarray, sample_ratio: float + self, x_train: np.ndarray, y_train: np.ndarray, sample_ratio: float ) -> (pd.DataFrame, pd.DataFrame): """ Helper to perform random under-sampling of behavior-absent frames in a dataframe. @@ -268,7 +266,7 @@ def random_undersampler( return self.split_df_to_x_y(data_df, y_train.name) def smoteen_oversampler( - self, x_train: pd.DataFrame, y_train: pd.DataFrame, sample_ratio: float + self, x_train: pd.DataFrame, y_train: pd.DataFrame, sample_ratio: float ) -> (np.ndarray, np.ndarray): """ Helper to perform SMOTEEN oversampling of behavior-present annotations. @@ -288,10 +286,10 @@ def smoteen_oversampler( return smt.fit_sample(x_train, y_train) def smote_oversampler( - self, - x_train: pd.DataFrame or np.array, - y_train: pd.DataFrame or np.array, - sample_ratio: float, + self, + x_train: pd.DataFrame or np.array, + y_train: pd.DataFrame or np.array, + sample_ratio: float, ) -> (np.ndarray, np.ndarray): """ Helper to perform SMOTE oversampling of behavior-present annotations. @@ -310,14 +308,14 @@ def smote_oversampler( return smt.fit_sample(x_train, y_train) def calc_permutation_importance( - self, - x_test: np.ndarray, - y_test: np.ndarray, - clf: RandomForestClassifier, - feature_names: List[str], - clf_name: str, - save_dir: Union[str, os.PathLike], - save_file_no: Optional[int] = None, + self, + x_test: np.ndarray, + y_test: np.ndarray, + clf: RandomForestClassifier, + feature_names: List[str], + clf_name: str, + save_dir: Union[str, os.PathLike], + save_file_no: Optional[int] = None, ) -> None: """ Helper to calculate feature permutation importance scores. @@ -372,16 +370,16 @@ def calc_permutation_importance( ) def calc_learning_curve( - self, - x_y_df: pd.DataFrame, - clf_name: str, - shuffle_splits: int, - dataset_splits: int, - tt_size: float, - rf_clf: RandomForestClassifier, - save_dir: str, - save_file_no: Optional[int] = None, - multiclass: bool = False, + self, + x_y_df: pd.DataFrame, + clf_name: str, + shuffle_splits: int, + dataset_splits: int, + tt_size: float, + rf_clf: RandomForestClassifier, + save_dir: str, + save_file_no: Optional[int] = None, + multiclass: bool = False, ) -> None: """ Helper to compute random forest learning curves with cross-validation. @@ -457,13 +455,13 @@ def calc_learning_curve( ) def calc_pr_curve( - self, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.DataFrame, - clf_name: str, - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.DataFrame, + clf_name: str, + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to compute random forest precision-recall curve. @@ -488,10 +486,10 @@ def calc_pr_curve( pr_df["PRECISION"] = precision pr_df["RECALL"] = recall pr_df["F1"] = ( - 2 - * pr_df["RECALL"] - * pr_df["PRECISION"] - / (pr_df["RECALL"] + pr_df["PRECISION"]) + 2 + * pr_df["RECALL"] + * pr_df["PRECISION"] + / (pr_df["RECALL"] + pr_df["PRECISION"]) ) thresholds = list(thresholds) thresholds.insert(0, 0.00) @@ -511,13 +509,13 @@ def calc_pr_curve( ) def create_example_dt( - self, - rf_clf: RandomForestClassifier, - clf_name: str, - feature_names: List[str], - class_names: List[str], - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + clf_name: str, + feature_names: List[str], + class_names: List[str], + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to produce visualization of random forest decision tree using graphviz. @@ -557,13 +555,13 @@ def create_example_dt( call(command, shell=True) def create_clf_report( - self, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.DataFrame, - class_names: List[str], - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.DataFrame, + class_names: List[str], + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to create classifier truth table report. @@ -607,12 +605,12 @@ def create_clf_report( ) def create_x_importance_log( - self, - rf_clf: RandomForestClassifier, - x_names: List[str], - clf_name: str, - save_dir: str, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_names: List[str], + clf_name: str, + save_dir: str, + save_file_no: Optional[int] = None, ) -> None: """ Helper to save gini or entropy based feature importance scores. @@ -646,13 +644,13 @@ def create_x_importance_log( df.to_csv(self.f_importance_save_path, index=False) def create_x_importance_bar_chart( - self, - rf_clf: RandomForestClassifier, - x_names: list, - clf_name: str, - save_dir: str, - n_bars: int, - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + x_names: list, + clf_name: str, + save_dir: str, + n_bars: int, + save_file_no: Optional[int] = None, ) -> None: """ Helper to create a bar chart displaying the top N gini or entropy feature importance scores. @@ -709,12 +707,12 @@ def create_x_importance_bar_chart( plt.close("all") def dviz_classification_visualization( - self, - x_train: np.ndarray, - y_train: np.ndarray, - clf_name: str, - class_names: List[str], - save_dir: str, + self, + x_train: np.ndarray, + y_train: np.ndarray, + clf_name: str, + class_names: List[str], + save_dir: str, ) -> None: """ Helper to create visualization of example decision tree using dtreeviz. @@ -756,7 +754,7 @@ def dviz_classification_visualization( @staticmethod def split_and_group_df( - df: pd.DataFrame, splits: int, include_split_order: bool = True + df: pd.DataFrame, splits: int, include_split_order: bool = True ) -> (List[pd.DataFrame], int): """ Helper to split a dataframe for multiprocessing. If include_split_order, then include the group number @@ -770,18 +768,18 @@ def split_and_group_df( return data_arr, obs_per_split def create_shap_log( - self, - ini_file_path: str, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.Series, - x_names: List[str], - clf_name: str, - cnt_present: int, - cnt_absent: int, - save_path: str, - save_it: int = 100, - save_file_no: Optional[int] = None, + self, + ini_file_path: str, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.Series, + x_names: List[str], + clf_name: str, + cnt_present: int, + cnt_absent: int, + save_path: str, + save_it: int = 100, + save_file_no: Optional[int] = None, ) -> None: """ Compute SHAP values for a random forest classifier. @@ -918,22 +916,23 @@ def print_machine_model_information(self, model_dict: dict) -> None: """ table_view = [ - ["Model name", model_dict[MetaKeys.CLF_NAME.value]], + ["Model name", model_dict[MachineLearningMetaKeys.CLASSIFIER.value]], ["Ensemble method", "RF"], - ["Estimators (trees)", model_dict[MetaKeys.RF_ESTIMATORS.value]], - ["Max features", model_dict[MetaKeys.RF_MAX_FEATURES.value]], - ["Under sampling setting", model_dict[ConfigKey.UNDERSAMPLE_SETTING.value]], - ["Under sampling ratio", model_dict[ConfigKey.UNDERSAMPLE_RATIO.value]], - ["Over sampling setting", model_dict[ConfigKey.OVERSAMPLE_SETTING.value]], - ["Over sampling ratio", model_dict[ConfigKey.OVERSAMPLE_RATIO.value]], - ["criterion", model_dict[MetaKeys.CRITERION.value]], - ["Min sample leaf", model_dict[MetaKeys.MIN_LEAF.value]], + ["Estimators (trees)", model_dict[MachineLearningMetaKeys.RF_ESTIMATORS.value]], + ["Max depth", model_dict[MachineLearningMetaKeys.RF_MAX_DEPTH.value]], + ["Max features", model_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value]], + ["Under sampling setting", model_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value]], + ["Under sampling ratio", model_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value]], + ["Over sampling setting", model_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value]], + ["Over sampling ratio", model_dict[MachineLearningMetaKeys.OVERSAMPLE_RATIO.value]], + ["criterion", model_dict[MachineLearningMetaKeys.RF_CRITERION.value]], + ["Min sample leaf", model_dict[MachineLearningMetaKeys.MIN_LEAF.value]], ] table = tabulate(table_view, ["Setting", "value"], tablefmt="grid") print(f"{table} {Defaults.STR_SPLIT_DELIMITER.value}TABLE") def create_meta_data_csv_training_one_model( - self, meta_data_lst: list, clf_name: str, save_dir: Union[str, os.PathLike] + self, meta_data_lst: list, clf_name: str, save_dir: Union[str, os.PathLike] ) -> None: """ Helper to save single model meta data (hyperparameters, sampling settings etc.) from list format into SimBA @@ -951,7 +950,7 @@ def create_meta_data_csv_training_one_model( out_df.to_csv(save_path) def create_meta_data_csv_training_multiple_models( - self, meta_data, clf_name, save_dir, save_file_no: Optional[int] = None + self, meta_data, clf_name, save_dir, save_file_no: Optional[int] = None ) -> None: print("Saving model meta data file...") save_path = os.path.join(save_dir, f"{clf_name}_{str(save_file_no)}_meta.csv") @@ -959,11 +958,11 @@ def create_meta_data_csv_training_multiple_models( out_df.to_csv(save_path) def save_rf_model( - self, - rf_clf: RandomForestClassifier, - clf_name: str, - save_dir: Union[str, os.PathLike], - save_file_no: Optional[int] = None, + self, + rf_clf: RandomForestClassifier, + clf_name: str, + save_dir: Union[str, os.PathLike], + save_file_no: Optional[int] = None, ) -> None: """ Helper to save pickled classifier object to disk. @@ -983,7 +982,7 @@ def save_rf_model( pickle.dump(rf_clf, open(save_path, "wb")) def get_model_info( - self, config: configparser.ConfigParser, model_cnt: int + self, config: configparser.ConfigParser, model_cnt: int ) -> Dict[int, Any]: """ Helper to read in N SimBA random forest config meta files to python dict memory. @@ -1004,8 +1003,8 @@ def get_model_info( ) continue if ( - config.get("SML settings", "model_path_" + str(n + 1)) - == "No file selected" + config.get("SML settings", "model_path_" + str(n + 1)) + == "No file selected" ): MissingUserInputWarning( msg=f'Skipping {str(config.get("SML settings", "target_name_" + str(n + 1)))} classifier analysis: The classifier path is set to "No file selected', @@ -1033,17 +1032,17 @@ def get_model_info( ) check_int("minimum_bout_length", model_dict[n]["minimum_bout_length"]) if config.has_option( - ConfigKey.SML_SETTINGS.value, f"classifier_map_{n+1}" + ConfigKey.SML_SETTINGS.value, f"classifier_map_{n + 1}" ): model_dict[n]["classifier_map"] = config.get( - ConfigKey.SML_SETTINGS.value, f"classifier_map_{n+1}" + ConfigKey.SML_SETTINGS.value, f"classifier_map_{n + 1}" ) model_dict[n]["classifier_map"] = ast.literal_eval( model_dict[n]["classifier_map"] ) if type(model_dict[n]["classifier_map"]) != dict: raise InvalidInputError( - msg=f"SimBA found a classifier map for classifier {n+1} that could not be interpreted as a dictionary", + msg=f"SimBA found a classifier map for classifier {n + 1} that could not be interpreted as a dictionary", source=self.__class__.__name__, ) @@ -1062,7 +1061,7 @@ def get_model_info( return model_dict def get_all_clf_names( - self, config: configparser.ConfigParser, target_cnt: int + self, config: configparser.ConfigParser, target_cnt: int ) -> List[str]: """ Helper to get all classifier names in a SimBA project. @@ -1090,10 +1089,10 @@ def get_all_clf_names( return model_names def insert_column_headers_for_outlier_correction( - self, - data_df: pd.DataFrame, - new_headers: List[str], - filepath: Union[str, os.PathLike], + self, + data_df: pd.DataFrame, + new_headers: List[str], + filepath: Union[str, os.PathLike], ) -> pd.DataFrame: """ Helper to insert new column headers onto a dataframe following outlier correction. @@ -1139,7 +1138,7 @@ def read_pickle(self, file_path: Union[str, os.PathLike]) -> object: return clf def bout_train_test_splitter( - self, x_df: pd.DataFrame, y_df: pd.Series, test_size: float + self, x_df: pd.DataFrame, y_df: pd.Series, test_size: float ) -> (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series): """ Helper to split train and test based on annotated `bouts`. @@ -1209,9 +1208,9 @@ def find_bouts(s: pd.Series, type: str): @staticmethod @njit("(float32[:, :], float64, types.ListType(types.unicode_type))") def find_highly_correlated_fields( - data: np.ndarray, - threshold: float, - field_names: types.ListType(types.unicode_type), + data: np.ndarray, + threshold: float, + field_names: types.ListType(types.unicode_type), ) -> List[str]: """ Find highly correlated fields in a dataset. @@ -1251,7 +1250,7 @@ def find_highly_correlated_fields( return [field_names[x] for x in remove_col_idx] def check_sampled_dataset_integrity( - self, x_df: pd.DataFrame, y_df: pd.DataFrame + self, x_df: pd.DataFrame, y_df: pd.DataFrame ) -> None: """ Helper to check for non-numerical entries post data sampling @@ -1270,15 +1269,15 @@ def check_sampled_dataset_integrity( if len(x_nan_cnt) < 10: raise FaultyTrainingSetError( msg=f"{str(len(x_nan_cnt))} feature column(s) exist in some files within the project_folder/csv/targets_inserted directory, but missing in others. " - f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the " - f"column names with mismatches are: {list(x_nan_cnt.index)}", + f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the " + f"column names with mismatches are: {list(x_nan_cnt.index)}", source=self.__class__.__name__, ) else: raise FaultyTrainingSetError( msg=f"{str(len(x_nan_cnt))} feature columns exist in some files, but missing in others. The feature files are found in the project_folder/csv/targets_inserted directory. " - f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " - f"column names with mismatches are: {list(x_nan_cnt.index)[0:9]}", + f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " + f"column names with mismatches are: {list(x_nan_cnt.index)[0:9]}", source=self.__class__.__name__, ) @@ -1295,12 +1294,12 @@ def check_sampled_dataset_integrity( ) def partial_dependence_calculator( - self, - clf: RandomForestClassifier, - x_df: pd.DataFrame, - clf_name: str, - save_dir: Union[str, os.PathLike], - clf_cnt: Optional[int] = None, + self, + clf: RandomForestClassifier, + x_df: pd.DataFrame, + clf_name: str, + save_dir: Union[str, os.PathLike], + clf_cnt: Optional[int] = None, ) -> None: """ Compute feature partial dependencies for every feature in training set. @@ -1337,12 +1336,12 @@ def partial_dependence_calculator( print(f"Partial dependencies for {feature_name} complete...") def clf_predict_proba( - self, - clf: RandomForestClassifier, - x_df: pd.DataFrame, - multiclass: bool = False, - model_name: Optional[str] = None, - data_path: Optional[Union[str, os.PathLike]] = None, + self, + clf: RandomForestClassifier, + x_df: pd.DataFrame, + multiclass: bool = False, + model_name: Optional[str] = None, + data_path: Optional[Union[str, os.PathLike]] = None, ) -> np.ndarray: """ @@ -1392,7 +1391,7 @@ def clf_predict_proba( return p_vals def clf_fit( - self, clf: RandomForestClassifier, x_df: pd.DataFrame, y_df: pd.DataFrame + self, clf: RandomForestClassifier, x_df: pd.DataFrame, y_df: pd.DataFrame ) -> RandomForestClassifier: """ Helper to fit clf model @@ -1417,11 +1416,11 @@ def clf_fit( return clf.fit(x_df, y_df) def _read_data_file_helper( - self, - file_path: str, - file_type: str, - clf_names: Optional[List[str]] = None, - raise_bool_clf_error: bool = True, + self, + file_path: str, + file_type: str, + clf_names: Optional[List[str]] = None, + raise_bool_clf_error: bool = True, ): """ Private function called by :meth:`simba.train_model_functions.read_all_files_in_folder_mp` @@ -1440,8 +1439,8 @@ def _read_data_file_helper( source=self.__class__.__name__, ) elif ( - len(set(df[clf_name].unique()) - {0, 1}) > 0 - and raise_bool_clf_error + len(set(df[clf_name].unique()) - {0, 1}) > 0 + and raise_bool_clf_error ): raise InvalidInputError( msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}.", @@ -1454,11 +1453,11 @@ def _read_data_file_helper( return df def read_all_files_in_folder_mp( - self, - file_paths: List[str], - file_type: Literal["csv", "parquet", "pickle"], - classifier_names: Optional[List[str]] = None, - raise_bool_clf_error: bool = True, + self, + file_paths: List[str], + file_type: Literal["csv", "parquet", "pickle"], + classifier_names: Optional[List[str]] = None, + raise_bool_clf_error: bool = True, ) -> pd.DataFrame: """ @@ -1483,11 +1482,11 @@ def read_all_files_in_folder_mp( try: with ProcessPoolExecutor(int(np.ceil(cpu_cnt / 2))) as pool: for res in pool.map( - self._read_data_file_helper, - file_paths, - repeat(file_type), - repeat(classifier_names), - repeat(raise_bool_clf_error), + self._read_data_file_helper, + file_paths, + repeat(file_type), + repeat(classifier_names), + repeat(raise_bool_clf_error), ): df_lst.append(res) df_concat = pd.concat(df_lst, axis=0).round(4) @@ -1499,8 +1498,8 @@ def read_all_files_in_folder_mp( source=self.read_all_files_in_folder_mp.__name__, ) df_concat = df_concat.loc[ - :, ~df_concat.columns.str.contains("^Unnamed") - ].astype(np.float32) + :, ~df_concat.columns.str.contains("^Unnamed") + ].astype(np.float32) memory_size = get_memory_usage_of_df(df=df_concat) print( f'Dataset size: {memory_size["megabytes"]}MB / {memory_size["gigabytes"]}GB' @@ -1521,10 +1520,10 @@ def read_all_files_in_folder_mp( @staticmethod def _read_data_file_helper_futures( - file_path: str, - file_type: str, - clf_names: Optional[List[str]] = None, - raise_bool_clf_error: bool = True, + file_path: str, + file_type: str, + clf_names: Optional[List[str]] = None, + raise_bool_clf_error: bool = True, ): """ Private function called by :meth:`simba.train_model_functions.read_all_files_in_folder_mp_futures` @@ -1539,8 +1538,8 @@ def _read_data_file_helper_futures( if not clf_name in df.columns: raise ColumnNotFoundError(column_name=clf_name, file_name=file_path) elif ( - len(set(df[clf_name].unique()) - {0, 1}) > 0 - and raise_bool_clf_error + len(set(df[clf_name].unique()) - {0, 1}) > 0 + and raise_bool_clf_error ): raise InvalidInputError( msg=f"The annotation column for a classifier should contain only 0 or 1 values. However, in file {file_path} the {clf_name} field contains additional value(s): {list(set(df[clf_name].unique()) - {0, 1})}." @@ -1549,11 +1548,11 @@ def _read_data_file_helper_futures( return df, vid_name, timer.elapsed_time_str def read_all_files_in_folder_mp_futures( - self, - file_paths: List[str], - file_type: Literal["csv", "parquet", "pickle"], - classifier_names: Optional[List[str]] = None, - raise_bool_clf_error: bool = True, + self, + file_paths: List[str], + file_type: Literal["csv", "parquet", "pickle"], + classifier_names: Optional[List[str]] = None, + raise_bool_clf_error: bool = True, ) -> pd.DataFrame: """ Multiprocessing helper function to read in all data files in a folder to a single @@ -1579,7 +1578,7 @@ def read_all_files_in_folder_mp_futures( cpu_cnt, _ = find_core_cnt() df_lst = [] with concurrent.futures.ProcessPoolExecutor( - max_workers=cpu_cnt + max_workers=cpu_cnt ) as executor: results = [ executor.submit( @@ -1613,7 +1612,7 @@ def read_all_files_in_folder_mp_futures( ) def check_raw_dataset_integrity( - self, df: pd.DataFrame, logs_path: Optional[Union[str, os.PathLike]] + self, df: pd.DataFrame, logs_path: Optional[Union[str, os.PathLike]] ) -> None: """ Helper to check column-wise NaNs in raw input data for fitting model. @@ -1650,18 +1649,18 @@ def check_raw_dataset_integrity( results.to_csv(save_log_path) raise FaultyTrainingSetError( msg=f"{len(nan_cols)} feature columns exist in some files, but missing in others. The feature files are found in the project_folder/csv/targets_inserted directory. " - f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " - f"column names with mismatches are: {nan_cols[0:9]}. For a log of the files that contain, and not contain, the mis-matched columns, see {save_log_path}", + f"SimBA expects all files within the project_folder/csv/targets_inserted directory to have the same number of features: the first 10 " + f"column names with mismatches are: {nan_cols[0:9]}. For a log of the files that contain, and not contain, the mis-matched columns, see {save_log_path}", source=self.__class__.__name__, ) @staticmethod def _create_shap_mp_helper( - data: pd.DataFrame, - explainer: shap.TreeExplainer, - clf_name: str, - rf_clf: RandomForestClassifier, - expected_value: float, + data: pd.DataFrame, + explainer: shap.TreeExplainer, + clf_name: str, + rf_clf: RandomForestClassifier, + expected_value: float, ): target = data.pop(clf_name).values.reshape(-1, 1) frame_batch_shap = explainer.shap_values(data.values, check_additivity=False)[1] @@ -1680,7 +1679,7 @@ def _create_shap_mp_helper( @staticmethod def _create_shap_mp_helper( - data: pd.DataFrame, explainer: shap.TreeExplainer, clf_name: str + data: pd.DataFrame, explainer: shap.TreeExplainer, clf_name: str ): target = data.pop(clf_name).values.reshape(-1, 1) group_cnt = data.pop("group").values[0] @@ -1693,18 +1692,18 @@ def _create_shap_mp_helper( return shap_vals, data.values, target def create_shap_log_mp( - self, - ini_file_path: str, - rf_clf: RandomForestClassifier, - x_df: pd.DataFrame, - y_df: pd.DataFrame, - x_names: List[str], - clf_name: str, - cnt_present: int, - cnt_absent: int, - save_path: str, - batch_size: int = 10, - save_file_no: Optional[int] = None, + self, + ini_file_path: str, + rf_clf: RandomForestClassifier, + x_df: pd.DataFrame, + y_df: pd.DataFrame, + x_names: List[str], + clf_name: str, + cnt_present: int, + cnt_absent: int, + save_path: str, + batch_size: int = 10, + save_file_no: Optional[int] = None, ) -> None: """ Helper to compute SHAP values using multiprocessing. @@ -1790,10 +1789,10 @@ def create_shap_log_mp( self._create_shap_mp_helper, explainer=explainer, clf_name=clf_name ) for cnt, result in enumerate( - pool.imap_unordered(constants, shap_data, chunksize=1) + pool.imap_unordered(constants, shap_data, chunksize=1) ): print( - f"Concatenating multi-processed SHAP data (batch {cnt+1}/{len(shap_data)})" + f"Concatenating multi-processed SHAP data (batch {cnt + 1}/{len(shap_data)})" ) proba = rf_clf.predict_proba(result[1])[:, 1].reshape(-1, 1) shap_sum = np.sum(result[0], axis=1).reshape(-1, 1) @@ -1816,7 +1815,7 @@ def create_shap_log_mp( shap_save_df = pd.DataFrame( data=np.row_stack(shap_results), columns=list(x_names) - + ["Expected_value", "Sum", "Prediction_probability", clf_name], + + ["Expected_value", "Sum", "Prediction_probability", clf_name], ) raw_save_df = pd.DataFrame( data=np.row_stack(shap_raw), columns=list(x_names) @@ -1856,7 +1855,7 @@ def create_shap_log_mp( ) def check_df_dataset_integrity( - self, df: pd.DataFrame, file_name: str, logs_path: Union[str, os.PathLike] + self, df: pd.DataFrame, file_name: str, logs_path: Union[str, os.PathLike] ) -> None: """ Helper to check for non-numerical np.inf, -np.inf, NaN, None in a single dataframe. @@ -1896,25 +1895,25 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.clf_name = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASSIFIER.value, + MachineLearningMetaKeys.CLASSIFIER.value, data_type=Dtypes.STR.value, ) self.tt_size = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.TT_SIZE.value, + MachineLearningMetaKeys.TT_SIZE.value, data_type=Dtypes.FLOAT.value, ) self.algo = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MODEL_TO_RUN.value, + MachineLearningMetaKeys.MODEL_TO_RUN.value, data_type=Dtypes.STR.value, ) self.split_type = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SPLIT_TYPE.value, + MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value, data_type=Dtypes.STR.value, options=Options.TRAIN_TEST_SPLIT.value, default_value=Methods.SPLIT_TYPE_FRAMES.value, @@ -1923,7 +1922,7 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.UNDERSAMPLE_SETTING.value, + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, ) .lower() @@ -1933,7 +1932,7 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.OVERSAMPLE_SETTING.value, + MachineLearningMetaKeys.OVERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, ) .lower() @@ -1942,102 +1941,102 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.n_estimators = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_ESTIMATORS.value, + MachineLearningMetaKeys.RF_ESTIMATORS.value, data_type=Dtypes.INT.value, ) self.max_features = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_MAX_FEATURES.value, + MachineLearningMetaKeys.RF_MAX_FEATURES.value, data_type=Dtypes.STR.value, ) self.criterion = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_CRITERION.value, + MachineLearningMetaKeys.RF_CRITERION.value, data_type=Dtypes.STR.value, options=Options.CLF_CRITERION.value, ) self.min_sample_leaf = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MIN_LEAF.value, + MachineLearningMetaKeys.MIN_LEAF.value, data_type=Dtypes.INT.value, ) self.compute_permutation_importance = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PERMUTATION_IMPORTANCE.value, + MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_learning_curve = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE.value, + MachineLearningMetaKeys.LEARNING_CURVE.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_precision_recall_curve = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PRECISION_RECALL.value, + MachineLearningMetaKeys.PRECISION_RECALL.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_example_decision_tree = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE.value, + MachineLearningMetaKeys.EX_DECISION_TREE.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_classification_report = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLF_REPORT.value, + MachineLearningMetaKeys.CLF_REPORT.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_features_importance_log = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, + MachineLearningMetaKeys.IMPORTANCE_LOG.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_features_importance_bar_graph = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, + MachineLearningMetaKeys.IMPORTANCE_LOG.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_example_decision_tree_fancy = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE_FANCY.value, + MachineLearningMetaKeys.EX_DECISION_TREE_FANCY.value, data_type=Dtypes.STR.value, default_value=False, ) self.generate_shap_scores = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SCORES.value, + MachineLearningMetaKeys.SHAP_SCORES.value, data_type=Dtypes.STR.value, default_value=False, ) self.save_meta_data = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_METADATA.value, + MachineLearningMetaKeys.RF_METADATA.value, data_type=Dtypes.STR.value, default_value=False, ) self.compute_partial_dependency = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PARTIAL_DEPENDENCY.value, + MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value, data_type=Dtypes.STR.value, default_value=False, ) @@ -2045,38 +2044,38 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.under_sample_ratio = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.UNDERSAMPLE_RATIO.value, + MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, data_type=Dtypes.FLOAT.value, default_value=Dtypes.NAN.value, ) check_float( - name=ConfigKey.UNDERSAMPLE_RATIO.value, value=self.under_sample_ratio + name=MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, value=self.under_sample_ratio ) else: self.under_sample_ratio = Dtypes.NAN.value if (self.over_sample_setting == Methods.SMOTEENN.value.lower()) or ( - self.over_sample_setting == Methods.SMOTE.value.lower() + self.over_sample_setting == Methods.SMOTE.value.lower() ): self.over_sample_ratio = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.OVERSAMPLE_RATIO.value, + MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, data_type=Dtypes.FLOAT.value, default_value=Dtypes.NAN.value, ) check_float( - name=ConfigKey.OVERSAMPLE_RATIO.value, value=self.over_sample_ratio + name=MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, value=self.over_sample_ratio ) else: self.over_sample_ratio = Dtypes.NAN.value if config.has_option( - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MachineLearningMetaKeys.CLASS_WEIGHTS.value ): self.class_weights = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASS_WEIGHTS.value, + MachineLearningMetaKeys.CLASS_WEIGHTS.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value, ) @@ -2085,7 +2084,7 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CUSTOM_WEIGHTS.value, + MachineLearningMetaKeys.CLASS_CUSTOM_WEIGHTS.value, data_type=Dtypes.STR.value, ) ) @@ -2100,22 +2099,22 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.shuffle_splits = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE_K_SPLITS.value, + MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, data_type=Dtypes.INT.value, default_value=Dtypes.NAN.value, ) self.dataset_splits = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_DATA_SPLITS.value, + MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value, data_type=Dtypes.INT.value, default_value=Dtypes.NAN.value, ) check_int( - name=ConfigKey.LEARNING_CURVE_K_SPLITS.value, value=self.shuffle_splits + name=MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, value=self.shuffle_splits ) check_int( - name=ConfigKey.LEARNING_DATA_SPLITS.value, value=self.dataset_splits + name=MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value, value=self.dataset_splits ) else: self.shuffle_splits, self.dataset_splits = ( @@ -2126,12 +2125,12 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.feature_importance_bars = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_BARS_N.value, + MachineLearningMetaKeys.IMPORTANCE_BARS_N.value, Dtypes.INT.value, Dtypes.NAN.value, ) check_int( - name=ConfigKey.IMPORTANCE_BARS_N.value, + name=MachineLearningMetaKeys.IMPORTANCE_BARS_N.value, value=self.feature_importance_bars, min_value=1, ) @@ -2147,21 +2146,21 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.shap_target_present_cnt = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_PRESENT.value, + MachineLearningMetaKeys.SHAP_PRESENT.value, data_type=Dtypes.INT.value, default_value=0, ) self.shap_target_absent_cnt = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_ABSENT.value, + MachineLearningMetaKeys.SHAP_ABSENT.value, data_type=Dtypes.INT.value, default_value=0, ) self.shap_save_n = read_config_entry( config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SAVE_ITERATION.value, + MachineLearningMetaKeys.SHAP_SAVE_ITERATION.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value, ) @@ -2176,17 +2175,17 @@ def read_model_settings_from_config(self, config: configparser.ConfigParser): self.shap_save_n = int(self.shap_save_n) except ValueError or TypeError: self.shap_save_n = ( - self.shap_target_present_cnt + self.shap_target_absent_cnt + self.shap_target_present_cnt + self.shap_target_absent_cnt ) check_int( - name=ConfigKey.SHAP_PRESENT.value, value=self.shap_target_present_cnt + name=MachineLearningMetaKeys.SHAP_PRESENT.value, value=self.shap_target_present_cnt ) check_int( - name=ConfigKey.SHAP_ABSENT.value, value=self.shap_target_absent_cnt + name=MachineLearningMetaKeys.SHAP_ABSENT.value, value=self.shap_target_absent_cnt ) def check_validity_of_meta_files( - self, data_df: pd.DataFrame, meta_file_paths: List[Union[str, os.PathLike]] + self, data_df: pd.DataFrame, meta_file_paths: List[Union[str, os.PathLike]] ): meta_dicts, errors = {}, [] for config_cnt, path in enumerate(meta_file_paths): @@ -2195,48 +2194,48 @@ def check_validity_of_meta_files( meta_dict = {k.lower(): v for k, v in meta_dict.items()} errors.append( check_str( - name=meta_dict[MetaKeys.CLF_NAME.value], - value=meta_dict[MetaKeys.CLF_NAME.value], + name=meta_dict[MachineLearningMetaKeys.CLASSIFIER.value], + value=meta_dict[MachineLearningMetaKeys.CLASSIFIER.value], raise_error=False, )[1] ) errors.append( check_str( - name=MetaKeys.CRITERION.value, - value=meta_dict[MetaKeys.CRITERION.value], + name=MachineLearningMetaKeys.RF_CRITERION.value, + value=meta_dict[MachineLearningMetaKeys.RF_CRITERION.value], options=Options.CLF_CRITERION.value, raise_error=False, )[1] ) errors.append( check_str( - name=MetaKeys.RF_MAX_FEATURES.value, - value=meta_dict[MetaKeys.RF_MAX_FEATURES.value], + name=MachineLearningMetaKeys.RF_MAX_FEATURES.value, + value=meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value], options=Options.CLF_MAX_FEATURES.value, raise_error=False, )[1] ) errors.append( check_str( - ConfigKey.UNDERSAMPLE_SETTING.value, - meta_dict[ConfigKey.UNDERSAMPLE_SETTING.value].lower(), + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value, + meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value].lower(), options=[x.lower() for x in Options.UNDERSAMPLE_OPTIONS.value], raise_error=False, )[1] ) errors.append( check_str( - ConfigKey.OVERSAMPLE_SETTING.value, - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower(), + MachineLearningMetaKeys.OVERSAMPLE_SETTING.value, + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower(), options=[x.lower() for x in Options.OVERSAMPLE_OPTIONS.value], raise_error=False, )[1] ) - if MetaKeys.TRAIN_TEST_SPLIT_TYPE.value in meta_dict.keys(): + if MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value in meta_dict.keys(): errors.append( check_str( - name=meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value], - value=meta_dict[MetaKeys.TRAIN_TEST_SPLIT_TYPE.value], + name=meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value], + value=meta_dict[MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value], options=Options.TRAIN_TEST_SPLIT.value, raise_error=False, )[1] @@ -2244,178 +2243,178 @@ def check_validity_of_meta_files( errors.append( check_int( - name=MetaKeys.RF_ESTIMATORS.value, - value=meta_dict[MetaKeys.RF_ESTIMATORS.value], + name=MachineLearningMetaKeys.RF_ESTIMATORS.value, + value=meta_dict[MachineLearningMetaKeys.RF_ESTIMATORS.value], min_value=1, raise_error=False, )[1] ) errors.append( check_int( - name=MetaKeys.MIN_LEAF.value, - value=meta_dict[MetaKeys.MIN_LEAF.value], + name=MachineLearningMetaKeys.MIN_LEAF.value, + value=meta_dict[MachineLearningMetaKeys.MIN_LEAF.value], raise_error=False, )[1] ) - if meta_dict[MetaKeys.LEARNING_CURVE.value] in Options.PERFORM_FLAGS.value: + if meta_dict[MachineLearningMetaKeys.LEARNING_CURVE.value] in Options.PERFORM_FLAGS.value: errors.append( check_int( - name=MetaKeys.LEARNING_CURVE_K_SPLITS.value, - value=meta_dict[MetaKeys.LEARNING_CURVE_K_SPLITS.value], + name=MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, + value=meta_dict[MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value], raise_error=False, )[1] ) errors.append( check_int( - name=MetaKeys.LEARNING_CURVE_DATA_SPLITS.value, - value=meta_dict[MetaKeys.LEARNING_CURVE_DATA_SPLITS.value], + name=MachineLearningMetaKeys.LEARNING_CURVE_DATA_SPLITS.value, + value=meta_dict[MachineLearningMetaKeys.LEARNING_CURVE_DATA_SPLITS.value], raise_error=False, )[1] ) if ( - meta_dict[MetaKeys.IMPORTANCE_BAR_CHART.value] - in Options.PERFORM_FLAGS.value + meta_dict[MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value] + in Options.PERFORM_FLAGS.value ): errors.append( check_int( - name=MetaKeys.N_FEATURE_IMPORTANCE_BARS.value, - value=meta_dict[MetaKeys.N_FEATURE_IMPORTANCE_BARS.value], + name=MachineLearningMetaKeys.N_FEATURE_IMPORTANCE_BARS.value, + value=meta_dict[MachineLearningMetaKeys.N_FEATURE_IMPORTANCE_BARS.value], raise_error=False, )[1] ) - if MetaKeys.SHAP_SCORES.value in meta_dict.keys(): - if meta_dict[MetaKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value: + if MachineLearningMetaKeys.SHAP_SCORES.value in meta_dict.keys(): + if meta_dict[MachineLearningMetaKeys.SHAP_SCORES.value] in Options.PERFORM_FLAGS.value: errors.append( check_int( - name=MetaKeys.SHAP_PRESENT.value, - value=meta_dict[MetaKeys.SHAP_PRESENT.value], + name=MachineLearningMetaKeys.SHAP_PRESENT.value, + value=meta_dict[MachineLearningMetaKeys.SHAP_PRESENT.value], raise_error=False, )[1] ) errors.append( check_int( - name=MetaKeys.SHAP_ABSENT.value, - value=meta_dict[MetaKeys.SHAP_ABSENT.value], + name=MachineLearningMetaKeys.SHAP_ABSENT.value, + value=meta_dict[MachineLearningMetaKeys.SHAP_ABSENT.value], raise_error=False, )[1] ) errors.append( check_float( - name=MetaKeys.TT_SIZE.value, - value=meta_dict[MetaKeys.TT_SIZE.value], + name=MachineLearningMetaKeys.TT_SIZE.value, + value=meta_dict[MachineLearningMetaKeys.TT_SIZE.value], raise_error=False, )[1] ) if ( - meta_dict[ConfigKey.UNDERSAMPLE_SETTING.value].lower() - == Methods.RANDOM_UNDERSAMPLE.value + meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value].lower() + == Methods.RANDOM_UNDERSAMPLE.value ): errors.append( check_float( - name=ConfigKey.UNDERSAMPLE_RATIO.value, - value=meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value], + name=MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, + value=meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value], raise_error=False, )[1] ) try: present_len, absent_len = len( - data_df[data_df[meta_dict[MetaKeys.CLF_NAME.value]] == 1] - ), len(data_df[data_df[meta_dict[MetaKeys.CLF_NAME.value]] == 0]) + data_df[data_df[meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]] == 1] + ), len(data_df[data_df[meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]] == 0]) ratio_n = int( - present_len * meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value] + present_len * meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value] ) if absent_len < ratio_n: errors.append( - f"The under-sample ratio of {meta_dict[ConfigKey.UNDERSAMPLE_RATIO.value]} in \n classifier {meta_dict[MetaKeys.CLF_NAME.value]} demands {ratio_n} behavior-absent annotations." + f"The under-sample ratio of {meta_dict[MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value]} in \n classifier {meta_dict[MachineLearningMetaKeys.CLASSIFIER.value]} demands {ratio_n} behavior-absent annotations." ) except: pass if ( - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower() - == Methods.SMOTEENN.value.lower() + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower() + == Methods.SMOTEENN.value.lower() ) or ( - meta_dict[ConfigKey.OVERSAMPLE_SETTING.value].lower() - == Methods.SMOTE.value.lower() + meta_dict[MachineLearningMetaKeys.OVERSAMPLE_SETTING.value].lower() + == Methods.SMOTE.value.lower() ): errors.append( check_float( - name=ConfigKey.OVERSAMPLE_RATIO.value, - value=meta_dict[ConfigKey.OVERSAMPLE_RATIO.value], + name=MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, + value=meta_dict[MachineLearningMetaKeys.OVERSAMPLE_RATIO.value], raise_error=False, )[1] ) errors.append( check_if_valid_input( - name=MetaKeys.META_FILE.value, - input=meta_dict[MetaKeys.META_FILE.value], + name=MachineLearningMetaKeys.RF_METADATA.value, + input=meta_dict[MachineLearningMetaKeys.RF_METADATA.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.EX_DECISION_TREE.value, - input=meta_dict[MetaKeys.EX_DECISION_TREE.value], + MachineLearningMetaKeys.EX_DECISION_TREE.value, + input=meta_dict[MachineLearningMetaKeys.EX_DECISION_TREE.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.CLF_REPORT.value, - input=meta_dict[MetaKeys.CLF_REPORT.value], + MachineLearningMetaKeys.CLF_REPORT.value, + input=meta_dict[MachineLearningMetaKeys.CLF_REPORT.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.IMPORTANCE_LOG.value, - input=meta_dict[MetaKeys.IMPORTANCE_LOG.value], + MachineLearningMetaKeys.IMPORTANCE_LOG.value, + input=meta_dict[MachineLearningMetaKeys.IMPORTANCE_LOG.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.IMPORTANCE_BAR_CHART.value, - input=meta_dict[MetaKeys.IMPORTANCE_BAR_CHART.value], + MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value, + input=meta_dict[MachineLearningMetaKeys.IMPORTANCE_BAR_CHART.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.PERMUTATION_IMPORTANCE.value, - input=meta_dict[MetaKeys.PERMUTATION_IMPORTANCE.value], + MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value, + input=meta_dict[MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.LEARNING_CURVE.value, - input=meta_dict[MetaKeys.LEARNING_CURVE.value], + MachineLearningMetaKeys.LEARNING_CURVE.value, + input=meta_dict[MachineLearningMetaKeys.LEARNING_CURVE.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) errors.append( check_if_valid_input( - MetaKeys.PRECISION_RECALL.value, - input=meta_dict[MetaKeys.PRECISION_RECALL.value], + MachineLearningMetaKeys.PRECISION_RECALL.value, + input=meta_dict[MachineLearningMetaKeys.PRECISION_RECALL.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] ) - if MetaKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys(): + if MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value in meta_dict.keys(): errors.append( check_if_valid_input( - MetaKeys.PARTIAL_DEPENDENCY.value, - input=meta_dict[MetaKeys.PARTIAL_DEPENDENCY.value], + MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value, + input=meta_dict[MachineLearningMetaKeys.PARTIAL_DEPENDENCY.value], options=Options.RUN_OPTIONS_FLAGS.value, raise_error=False, )[1] @@ -2429,29 +2428,29 @@ def check_validity_of_meta_files( raise_error=False, )[1] ) - if meta_dict[MetaKeys.RF_MAX_FEATURES.value] == Dtypes.NONE.value: - meta_dict[MetaKeys.RF_MAX_FEATURES.value] = None - if MetaKeys.TRAIN_TEST_SPLIT_TYPE.value not in meta_dict.keys(): + if meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value] == Dtypes.NONE.value: + meta_dict[MachineLearningMetaKeys.RF_MAX_FEATURES.value] = None + if MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value not in meta_dict.keys(): meta_dict[ - MetaKeys.TRAIN_TEST_SPLIT_TYPE.value + MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value ] = Methods.SPLIT_TYPE_FRAMES.value - if ConfigKey.CLASS_WEIGHTS.value in meta_dict.keys(): + if MachineLearningMetaKeys.CLASS_WEIGHTS.value in meta_dict.keys(): if ( - meta_dict[ConfigKey.CLASS_WEIGHTS.value] - not in Options.CLASS_WEIGHT_OPTIONS.value + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] + not in Options.CLASS_WEIGHT_OPTIONS.value ): - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = None - if meta_dict[ConfigKey.CLASS_WEIGHTS.value] == "custom": - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = ast.literal_eval( - meta_dict["class_custom_weights"] + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = None + if meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] == "custom": + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = ast.literal_eval( + meta_dict[MachineLearningMetaKeys.CLASS_CUSTOM_WEIGHTS] ) - for k, v in meta_dict[ConfigKey.CLASS_WEIGHTS.value].items(): - meta_dict[ConfigKey.CLASS_WEIGHTS.value][k] = int(v) - if meta_dict[ConfigKey.CLASS_WEIGHTS.value] == Dtypes.NONE.value: - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = None + for k, v in meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value].items(): + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value][k] = int(v) + if meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] == Dtypes.NONE.value: + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = None else: - meta_dict[ConfigKey.CLASS_WEIGHTS.value] = None + meta_dict[MachineLearningMetaKeys.CLASS_WEIGHTS.value] = None if "classifier_map" in meta_dict.keys(): meta_dict["classifier_map"] = ast.literal_eval( @@ -2491,12 +2490,12 @@ def check_validity_of_meta_files( return meta_dicts def random_multiclass_frm_undersampler( - self, - data_df: pd.DataFrame, - target_field: str, - target_var: int, - sampling_ratio: Union[float, Dict[int, float]], - raise_error: bool = False, + self, + data_df: pd.DataFrame, + target_field: str, + target_var: int, + sampling_ratio: Union[float, Dict[int, float]], + raise_error: bool = False, ): """ Random multiclass undersampler. @@ -2567,7 +2566,6 @@ def random_multiclass_frm_undersampler( return pd.concat(results_df_lst, axis=0) - # test = TrainModelMixin() # test.read_all_files_in_folder(file_paths=['/Users/simon/Desktop/envs/troubleshooting/jake/project_folder/csv/targets_inserted/22-437C_c3_2022-11-01_13-16-23_color.csv', '/Users/simon/Desktop/envs/troubleshooting/jake/project_folder/csv/targets_inserted/22-437D_c4_2022-11-01_13-16-39_color.csv'], # file_type='csv', classifier_names=['attack', 'non-agresive parallel swimming']) diff --git a/simba/model/train_rf.py b/simba/model/train_rf.py index 8a1b9b4a0..288ab0731 100644 --- a/simba/model/train_rf.py +++ b/simba/model/train_rf.py @@ -11,9 +11,10 @@ from simba.mixins.train_model_mixin import TrainModelMixin from simba.utils.checks import (check_float, check_if_filepath_list_is_empty, check_int) -from simba.utils.enums import ConfigKey, Dtypes, Methods, Options, TagNames -from simba.utils.printing import SimbaTimer, log_event, stdout_success +from simba.utils.enums import ConfigKey, Dtypes, Methods, Options, MachineLearningMetaKeys +from simba.utils.printing import SimbaTimer, stdout_success from simba.utils.read_write import read_config_entry +from imblearn.ensemble import BalancedRandomForestClassifier class TrainRandomForestClassifier(ConfigReader, TrainModelMixin): @@ -36,11 +37,6 @@ class TrainRandomForestClassifier(ConfigReader, TrainModelMixin): def __init__(self, config_path: Union[str, os.PathLike]): ConfigReader.__init__(self, config_path=config_path) TrainModelMixin.__init__(self) - log_event( - logger_name=str(self.__class__.__name__), - log_type=TagNames.CLASS_INIT.value, - msg=self.create_log_msg_from_init_args(locals=locals()), - ) self.model_dir_out = os.path.join( read_config_entry( self.config, @@ -58,25 +54,25 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.clf_name = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASSIFIER.value, + MachineLearningMetaKeys.CLASSIFIER.value, data_type=Dtypes.STR.value, ) self.tt_size = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.TT_SIZE.value, + MachineLearningMetaKeys.TT_SIZE.value, data_type=Dtypes.FLOAT.value, ) self.algo = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MODEL_TO_RUN.value, + MachineLearningMetaKeys.MODEL_TO_RUN.value, data_type=Dtypes.STR.value, ) self.split_type = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SPLIT_TYPE.value, + MachineLearningMetaKeys.TRAIN_TEST_SPLIT_TYPE.value, data_type=Dtypes.STR.value, options=Options.TRAIN_TEST_SPLIT.value, default_value=Methods.SPLIT_TYPE_FRAMES.value, @@ -85,7 +81,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.UNDERSAMPLE_SETTING.value, + MachineLearningMetaKeys.UNDERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, ) .lower() @@ -95,7 +91,7 @@ def __init__(self, config_path: Union[str, os.PathLike]): read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.OVERSAMPLE_SETTING.value, + MachineLearningMetaKeys.OVERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, ) .lower() @@ -105,27 +101,27 @@ def __init__(self, config_path: Union[str, os.PathLike]): self.under_sample_ratio = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.UNDERSAMPLE_RATIO.value, + MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, data_type=Dtypes.FLOAT.value, default_value=Dtypes.NAN.value, ) check_float( - name=ConfigKey.UNDERSAMPLE_RATIO.value, value=self.under_sample_ratio + name=MachineLearningMetaKeys.UNDERSAMPLE_RATIO.value, value=self.under_sample_ratio ) else: self.under_sample_ratio = Dtypes.NAN.value if (self.over_sample_setting == Methods.SMOTEENN.value.lower()) or ( - self.over_sample_setting == Methods.SMOTE.value.lower() + self.over_sample_setting == Methods.SMOTE.value.lower() ): self.over_sample_ratio = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.OVERSAMPLE_RATIO.value, + MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, data_type=Dtypes.FLOAT.value, default_value=Dtypes.NAN.value, ) check_float( - name=ConfigKey.OVERSAMPLE_RATIO.value, value=self.over_sample_ratio + name=MachineLearningMetaKeys.OVERSAMPLE_RATIO.value, value=self.over_sample_ratio ) else: self.over_sample_ratio = Dtypes.NAN.value @@ -137,29 +133,32 @@ def __init__(self, config_path: Union[str, os.PathLike]): print( "Reading in {} annotated files...".format(str(len(self.target_file_paths))) ) - self.data_df = self.read_all_files_in_folder_mp_futures( - self.target_file_paths, self.file_type, [self.clf_name] - ) - self.data_df = self.check_raw_dataset_integrity( - df=self.data_df, logs_path=self.logs_path - ) - self.data_df_wo_cords = self.drop_bp_cords(df=self.data_df) - annotation_cols_to_remove = self.read_in_all_model_names_to_remove( + annotation_cols = self.read_in_all_model_names_to_remove( self.config, self.clf_cnt, self.clf_name ) - self.x_y_df = self.delete_other_annotation_columns( - self.data_df_wo_cords, list(annotation_cols_to_remove) + cls = [self.clf_name] + annotation_cols + self.data_df = self.read_and_concatenate_all_files_in_folder_mp_futures( + self.target_file_paths, self.features_dir, self.file_type, cls ) - self.class_names = ["Not_" + self.clf_name, self.clf_name] - self.x_df, self.y_df = self.split_df_to_x_y(self.x_y_df, self.clf_name) + # self.data_df = self.check_raw_dataset_integrity( + # df=self.data_df, logs_path=self.logs_path + # ) + self.data_df_wo_cords = self.drop_bp_cords(df=self.data_df) + if self.data_df_wo_cords is None: + self.data_df_wo_cords = self.data_df + + self.class_names = ["Not_" + self.clf_name] + cls + + self.x_df, self.y_df = self.split_df_to_x_y(self.data_df_wo_cords, cls) self.feature_names = self.x_df.columns self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df) print("Number of features in dataset: " + str(len(self.x_df.columns))) print( "Number of {} frames in dataset: {} ({}%)".format( self.clf_name, - str(self.y_df.sum()), - str(round(self.y_df.sum() / len(self.y_df), 4) * 100), + str(self.y_df[self.y_df == (cls.index(self.clf_name) + 1)].sum()), + str(round(self.y_df[self.y_df == (cls.index(self.clf_name) + 1)].sum() / len( + self.y_df[self.y_df == (cls.index(self.clf_name) + 1)]), 4) * 100), ) ) print("Training and evaluating model...") @@ -202,227 +201,225 @@ def train_model(self): """ Method for training single random forest model. """ + n_estimators = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_ESTIMATORS.value, + data_type=Dtypes.INT.value, + ) + max_features = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_MAX_FEATURES.value, + data_type=Dtypes.STR.value, + ) + max_depth = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_MAX_DEPTH.value, + data_type=Dtypes.STR.value, + ) + if max_features == "None": + max_features = None + criterion = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_CRITERION.value, + data_type=Dtypes.STR.value, + options=Options.CLF_CRITERION.value, + ) + min_sample_leaf = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.MIN_LEAF.value, + data_type=Dtypes.INT.value, + ) + compute_permutation_importance = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.PERMUTATION_IMPORTANCE.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_learning_curve = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.LEARNING_CURVE.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_precision_recall_curve = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.PRECISION_RECALL.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_example_decision_tree = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.EX_DECISION_TREE.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_classification_report = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.CLF_REPORT.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_features_importance_log = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.IMPORTANCE_LOG.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_features_importance_bar_graph = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.IMPORTANCE_LOG.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_example_decision_tree_fancy = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.EX_DECISION_TREE_FANCY.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + generate_shap_scores = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.SHAP_SCORES.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + save_meta_data = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.RF_METADATA.value, + data_type=Dtypes.STR.value, + default_value=False, + ) + compute_partial_dependency = read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + ConfigKey.PARTIAL_DEPENDENCY.value, + data_type=Dtypes.STR.value, + default_value=False, + ) - if self.algo == "RF": - n_estimators = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_ESTIMATORS.value, - data_type=Dtypes.INT.value, - ) - max_features = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_MAX_FEATURES.value, - data_type=Dtypes.STR.value, - ) - if max_features == "None": - max_features = None - criterion = read_config_entry( + if self.config.has_option( + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MachineLearningMetaKeys.CLASS_WEIGHTS.value + ): + class_weights = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_CRITERION.value, + MachineLearningMetaKeys.CLASS_WEIGHTS.value, data_type=Dtypes.STR.value, - options=Options.CLF_CRITERION.value, + default_value=Dtypes.NONE.value, ) - min_sample_leaf = read_config_entry( + if class_weights == "custom": + class_weights = ast.literal_eval( + read_config_entry( + self.config, + ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, + MachineLearningMetaKeys.CLASS_CUSTOM_WEIGHTS.value, + data_type=Dtypes.STR.value, + ) + ) + for k, v in class_weights.items(): + class_weights[k] = int(v) + if class_weights == Dtypes.NONE.value: + class_weights = None + else: + class_weights = None + + if generate_learning_curve in Options.PERFORM_FLAGS.value: + shuffle_splits = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.MIN_LEAF.value, + MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, data_type=Dtypes.INT.value, + default_value=Dtypes.NAN.value, ) - compute_permutation_importance = read_config_entry( + dataset_splits = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PERMUTATION_IMPORTANCE.value, - data_type=Dtypes.STR.value, - default_value=False, + MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value, + data_type=Dtypes.INT.value, + default_value=Dtypes.NAN.value, ) - generate_learning_curve = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=MachineLearningMetaKeys.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits ) - generate_precision_recall_curve = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PRECISION_RECALL.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=MachineLearningMetaKeys.LEARNING_DATA_SPLITS.value, value=dataset_splits ) - generate_example_decision_tree = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE.value, - data_type=Dtypes.STR.value, - default_value=False, - ) - generate_classification_report = read_config_entry( + else: + shuffle_splits, dataset_splits = Dtypes.NAN.value, Dtypes.NAN.value + if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: + feature_importance_bars = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLF_REPORT.value, - data_type=Dtypes.STR.value, - default_value=False, + MachineLearningMetaKeys.IMPORTANCE_BARS_N.value, + Dtypes.INT.value, + Dtypes.NAN.value, ) - generate_features_importance_log = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=MachineLearningMetaKeys.IMPORTANCE_BARS_N.value, + value=feature_importance_bars, + min_value=1, ) - generate_features_importance_bar_graph = read_config_entry( + else: + feature_importance_bars = Dtypes.NAN.value + shap_target_present_cnt, shap_target_absent_cnt, shap_save_n = ( + None, + None, + None, + ) + if generate_shap_scores in Options.PERFORM_FLAGS.value: + shap_target_present_cnt = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_LOG.value, - data_type=Dtypes.STR.value, - default_value=False, + MachineLearningMetaKeys.SHAP_PRESENT.value, + data_type=Dtypes.INT.value, + default_value=0, ) - generate_example_decision_tree_fancy = read_config_entry( + shap_target_absent_cnt = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.EX_DECISION_TREE_FANCY.value, - data_type=Dtypes.STR.value, - default_value=False, + MachineLearningMetaKeys.SHAP_ABSENT.value, + data_type=Dtypes.INT.value, + default_value=0, ) - generate_shap_scores = read_config_entry( + shap_save_n = read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SCORES.value, + MachineLearningMetaKeys.SHAP_SAVE_ITERATION.value, data_type=Dtypes.STR.value, - default_value=False, + default_value=Dtypes.NONE.value, ) - save_meta_data = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.RF_METADATA.value, - data_type=Dtypes.STR.value, - default_value=False, + try: + shap_save_n = int(shap_save_n) + except ValueError: + shap_save_n = shap_target_present_cnt + shap_target_absent_cnt + check_int( + name=MachineLearningMetaKeys.SHAP_PRESENT.value, value=shap_target_present_cnt ) - compute_partial_dependency = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.PARTIAL_DEPENDENCY.value, - data_type=Dtypes.STR.value, - default_value=False, + check_int( + name=MachineLearningMetaKeys.SHAP_ABSENT.value, value=shap_target_absent_cnt ) - - if self.config.has_option( - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, ConfigKey.CLASS_WEIGHTS.value - ): - class_weights = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CLASS_WEIGHTS.value, - data_type=Dtypes.STR.value, - default_value=Dtypes.NONE.value, - ) - if class_weights == "custom": - class_weights = ast.literal_eval( - read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.CUSTOM_WEIGHTS.value, - data_type=Dtypes.STR.value, - ) - ) - for k, v in class_weights.items(): - class_weights[k] = int(v) - if class_weights == Dtypes.NONE.value: - class_weights = None - else: - class_weights = None - - if generate_learning_curve in Options.PERFORM_FLAGS.value: - shuffle_splits = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_CURVE_K_SPLITS.value, - data_type=Dtypes.INT.value, - default_value=Dtypes.NAN.value, - ) - dataset_splits = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.LEARNING_DATA_SPLITS.value, - data_type=Dtypes.INT.value, - default_value=Dtypes.NAN.value, - ) - check_int( - name=ConfigKey.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits - ) - check_int( - name=ConfigKey.LEARNING_DATA_SPLITS.value, value=dataset_splits - ) - else: - shuffle_splits, dataset_splits = Dtypes.NAN.value, Dtypes.NAN.value - if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: - feature_importance_bars = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.IMPORTANCE_BARS_N.value, - Dtypes.INT.value, - Dtypes.NAN.value, - ) - check_int( - name=ConfigKey.IMPORTANCE_BARS_N.value, - value=feature_importance_bars, - min_value=1, - ) - else: - feature_importance_bars = Dtypes.NAN.value - ( - shap_target_present_cnt, - shap_target_absent_cnt, - shap_save_n, - shap_multiprocess, - ) = (None, None, None, None) - if generate_shap_scores in Options.PERFORM_FLAGS.value: - shap_target_present_cnt = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_PRESENT.value, - data_type=Dtypes.INT.value, - default_value=0, - ) - shap_target_absent_cnt = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_ABSENT.value, - data_type=Dtypes.INT.value, - default_value=0, - ) - shap_save_n = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_SAVE_ITERATION.value, - data_type=Dtypes.STR.value, - default_value=Dtypes.NONE.value, - ) - shap_multiprocess = read_config_entry( - self.config, - ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, - ConfigKey.SHAP_MULTIPROCESS.value, - data_type=Dtypes.STR.value, - default_value="False", - ) - try: - shap_save_n = int(shap_save_n) - except ValueError: - shap_save_n = shap_target_present_cnt + shap_target_absent_cnt - check_int( - name=ConfigKey.SHAP_PRESENT.value, value=shap_target_present_cnt - ) - check_int( - name=ConfigKey.SHAP_ABSENT.value, value=shap_target_absent_cnt - ) - + print(f"Fitting {self.clf_name} model...") + if self.algo == "RF": self.rf_clf = RandomForestClassifier( n_estimators=n_estimators, max_features=max_features, n_jobs=-1, + max_depth=max_depth, criterion=criterion, min_samples_leaf=min_sample_leaf, bootstrap=True, @@ -430,140 +427,140 @@ def train_model(self): class_weight=class_weights, ) - print(f"Fitting {self.clf_name} model...") self.rf_clf = self.clf_fit( clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train ) + elif self.algo == "imbalanced_rf": + self.rf_clf = BalancedRandomForestClassifier( + n_estimators=n_estimators, + max_features=max_features, + max_depth=max_depth, + n_jobs=-1, + criterion=criterion, + min_samples_leaf=min_sample_leaf, + bootstrap=True, + verbose=1, + class_weight=class_weights, + ) + self.rf_clf = self.clf_fit( + clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train + ) + if compute_permutation_importance in Options.PERFORM_FLAGS.value: + self.calc_permutation_importance( + self.x_test, + self.y_test, + self.rf_clf, + self.feature_names, + self.clf_name, + self.eval_out_path, + ) + if generate_learning_curve in Options.PERFORM_FLAGS.value: + self.calc_learning_curve( + x_y_df=self.x_y_df, + clf_name=self.clf_name, + shuffle_splits=shuffle_splits, + dataset_splits=dataset_splits, + tt_size=self.tt_size, + rf_clf=self.rf_clf, + save_dir=self.eval_out_path, + ) - if compute_permutation_importance in Options.PERFORM_FLAGS.value: - self.calc_permutation_importance( - self.x_test, - self.y_test, - self.rf_clf, - self.feature_names, - self.clf_name, - self.eval_out_path, - ) - if generate_learning_curve in Options.PERFORM_FLAGS.value: - self.calc_learning_curve( - x_y_df=self.x_y_df, - clf_name=self.clf_name, - shuffle_splits=shuffle_splits, - dataset_splits=dataset_splits, - tt_size=self.tt_size, - rf_clf=self.rf_clf, - save_dir=self.eval_out_path, - ) - - if generate_precision_recall_curve in Options.PERFORM_FLAGS.value: - self.calc_pr_curve( - self.rf_clf, - self.x_test, - self.y_test, - self.clf_name, - self.eval_out_path, - ) - if generate_example_decision_tree in Options.PERFORM_FLAGS.value: - self.create_example_dt( - self.rf_clf, - self.clf_name, - self.feature_names, - self.class_names, - self.eval_out_path, - ) - if generate_classification_report in Options.PERFORM_FLAGS.value: - self.create_clf_report( - self.rf_clf, - self.x_test, - self.y_test, - self.class_names, - self.eval_out_path, - ) - if generate_features_importance_log in Options.PERFORM_FLAGS.value: - self.create_x_importance_log( - self.rf_clf, self.feature_names, self.clf_name, self.eval_out_path - ) - if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: - self.create_x_importance_bar_chart( - self.rf_clf, - self.feature_names, - self.clf_name, - self.eval_out_path, - feature_importance_bars, - ) - if generate_example_decision_tree_fancy in Options.PERFORM_FLAGS.value: - self.dviz_classification_visualization( - self.x_train, - self.y_train, - self.clf_name, - self.class_names, - self.eval_out_path, - ) - if generate_shap_scores in Options.PERFORM_FLAGS.value: - if not shap_multiprocess in Options.PERFORM_FLAGS.value: - self.create_shap_log( - ini_file_path=self.config_path, - rf_clf=self.rf_clf, - x_df=self.x_train, - y_df=self.y_train, - x_names=self.feature_names, - clf_name=self.clf_name, - cnt_present=shap_target_present_cnt, - cnt_absent=shap_target_absent_cnt, - save_it=shap_save_n, - save_path=self.eval_out_path, - ) - else: - self.create_shap_log_mp( - ini_file_path=self.config_path, - rf_clf=self.rf_clf, - x_df=self.x_train, - y_df=self.y_train, - x_names=self.feature_names, - clf_name=self.clf_name, - cnt_present=shap_target_present_cnt, - cnt_absent=shap_target_absent_cnt, - save_path=self.eval_out_path, - ) - - if compute_partial_dependency in Options.PERFORM_FLAGS.value: - self.partial_dependence_calculator( - clf=self.rf_clf, - x_df=self.x_train, - clf_name=self.clf_name, - save_dir=self.eval_out_path, - ) + if generate_precision_recall_curve in Options.PERFORM_FLAGS.value: + self.calc_pr_curve( + self.rf_clf, + self.x_test, + self.y_test, + self.clf_name, + self.eval_out_path, + ) + if generate_example_decision_tree in Options.PERFORM_FLAGS.value: + self.create_example_dt( + self.rf_clf, + self.clf_name, + self.feature_names, + self.class_names, + self.eval_out_path, + ) + if generate_classification_report in Options.PERFORM_FLAGS.value: + self.create_clf_report( + self.rf_clf, + self.x_test, + self.y_test, + self.class_names, + self.eval_out_path, + ) + if generate_features_importance_log in Options.PERFORM_FLAGS.value: + self.create_x_importance_log( + self.rf_clf, self.feature_names, self.clf_name, self.eval_out_path + ) + if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: + self.create_x_importance_bar_chart( + self.rf_clf, + self.feature_names, + self.clf_name, + self.eval_out_path, + feature_importance_bars, + ) + if generate_example_decision_tree_fancy in Options.PERFORM_FLAGS.value: + self.dviz_classification_visualization( + self.x_train, + self.y_train, + self.clf_name, + self.class_names, + self.eval_out_path, + ) + if generate_shap_scores in Options.PERFORM_FLAGS.value: + self.create_shap_log_mp( + ini_file_path=self.config_path, + rf_clf=self.rf_clf, + x_df=self.x_train, + y_df=self.y_train, + x_names=self.feature_names, + clf_name=self.clf_name, + cnt_present=shap_target_present_cnt, + cnt_absent=shap_target_absent_cnt, + save_it=shap_save_n, + save_path=self.eval_out_path, + ) - if save_meta_data in Options.PERFORM_FLAGS.value: - meta_data_lst = [ - self.clf_name, - criterion, - max_features, - min_sample_leaf, - n_estimators, - compute_permutation_importance, - generate_classification_report, - generate_example_decision_tree, - generate_features_importance_bar_graph, - generate_features_importance_log, - generate_precision_recall_curve, - save_meta_data, - generate_learning_curve, - dataset_splits, - shuffle_splits, - feature_importance_bars, - self.over_sample_ratio, - self.over_sample_setting, - self.tt_size, - self.split_type, - self.under_sample_ratio, - self.under_sample_setting, - str(class_weights), - ] + if compute_partial_dependency in Options.PERFORM_FLAGS.value: + self.partial_dependence_calculator( + clf=self.rf_clf, + x_df=self.x_train, + clf_name=self.clf_name, + save_dir=self.eval_out_path, + ) - self.create_meta_data_csv_training_one_model( - meta_data_lst, self.clf_name, self.eval_out_path - ) + if save_meta_data in Options.PERFORM_FLAGS.value: + meta_data_lst = [ + self.clf_name, + criterion, + max_features, + min_sample_leaf, + n_estimators, + compute_permutation_importance, + generate_classification_report, + generate_example_decision_tree, + generate_features_importance_bar_graph, + generate_features_importance_log, + generate_precision_recall_curve, + save_meta_data, + generate_learning_curve, + dataset_splits, + shuffle_splits, + feature_importance_bars, + self.over_sample_ratio, + self.over_sample_setting, + self.tt_size, + self.split_type, + self.under_sample_ratio, + self.under_sample_setting, + str(class_weights), + ] + + self.create_meta_data_csv_training_one_model( + meta_data_lst, self.clf_name, self.eval_out_path + ) def save_model(self) -> None: """ @@ -576,14 +573,11 @@ def save_model(self) -> None: stdout_success( msg=f"Classifier {self.clf_name} saved in models/generated_models directory", elapsed_time=self.timer.elapsed_time_str, - source=self.__class__.__name__, ) stdout_success( - msg=f"Evaluation files are in models/generated_models/model_evaluations folders", - source=self.__class__.__name__, + msg=f"Evaluation files are in models/generated_models/model_evaluations folders" ) - # test = TrainRandomForestClassifier(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') # test.perform_sampling() # test.train_model() diff --git a/simba/utils/enums.py b/simba/utils/enums.py index 3cff87f72..5fbebb64f 100644 --- a/simba/utils/enums.py +++ b/simba/utils/enums.py @@ -41,37 +41,6 @@ class ConfigKey(Enum): MULTI_ANIMAL_ID_SETTING = "Multi animal IDs" MULTI_ANIMAL_IDS = "ID_list" OUTLIER_SETTINGS = "Outlier settings" - CLASS_WEIGHTS = "class_weights" - CUSTOM_WEIGHTS = "custom_weights" - CLASSIFIER = "classifier" - TT_SIZE = "train_test_size" - MODEL_TO_RUN = "model_to_run" - UNDERSAMPLE_SETTING = "under_sample_setting" - OVERSAMPLE_SETTING = "over_sample_setting" - UNDERSAMPLE_RATIO = "under_sample_ratio" - OVERSAMPLE_RATIO = "over_sample_ratio" - RF_ESTIMATORS = "RF_n_estimators" - RF_MAX_FEATURES = "RF_max_features" - RF_CRITERION = "RF_criterion" - MIN_LEAF = "RF_min_sample_leaf" - PERMUTATION_IMPORTANCE = "compute_permutation_importance" - LEARNING_CURVE = "generate_learning_curve" - PRECISION_RECALL = "generate_precision_recall_curve" - EX_DECISION_TREE = "generate_example_decision_tree" - EX_DECISION_TREE_FANCY = "generate_example_decision_tree_fancy" - CLF_REPORT = "generate_classification_report" - IMPORTANCE_LOG = "generate_features_importance_log" - PARTIAL_DEPENDENCY = "partial_dependency" - IMPORTANCE_BAR_CHART = "generate_features_importance_bar_graph" - SHAP_SCORES = "generate_shap_scores" - RF_METADATA = "RF_meta_data" - LEARNING_CURVE_K_SPLITS = "LearningCurve_shuffle_k_splits" - LEARNING_DATA_SPLITS = "LearningCurve_shuffle_data_splits" - IMPORTANCE_BARS_N = "N_feature_importance_bars" - SHAP_PRESENT = "shap_target_present_no" - SHAP_ABSENT = "shap_target_absent_no" - SHAP_SAVE_ITERATION = "shap_save_iteration" - SHAP_MULTIPROCESS = "shap_multiprocess" POSE_SETTING = "pose_estimation_body_parts" RF_JOBS = "RF_n_jobs" VALIDATION_VIDEO = "generate_validation_video" @@ -80,7 +49,7 @@ class ConfigKey(Enum): ROI_ANIMAL_CNT = "no_of_animals" DISTANCE_MM = "distance_mm" SKLEARN_BP_PROB_THRESH = "bp_threshold_sklearn" - SPLIT_TYPE = "train_test_split_type" + SHAP_MULTIPROCESS = "shap_multiprocess" class Paths(Enum): @@ -177,7 +146,7 @@ class Formats(Enum): class Options(Enum): ROLLING_WINDOW_DIVISORS = [2, 5, 6, 7.5, 15] - CLF_MODELS = ["RF", "GBC", "XGBoost"] + CLF_MODELS = ["RF", "imbalanced_rf", "GBC", "XGBoost"] CLF_MAX_FEATURES = ["sqrt", "log", "None"] CLF_CRITERION = ["gini", "entropy"] UNDERSAMPLE_OPTIONS = [ @@ -287,8 +256,6 @@ class Options(Enum): SMOOTHING_OPTIONS_W_NONE = ["None", "Gaussian", "Savitzky Golay"] VIDEO_FORMAT_OPTIONS = ["mp4", "avi"] ALL_VIDEO_FORMAT_OPTIONS = (".avi", ".mp4", ".mov", ".flv", ".m4v") - ALL_IMAGE_FORMAT_OPTIONS = (".bmp", ".png", ".jpeg", ".jpg") - ALL_VIDEO_FORMAT_STR_OPTIONS = ".avi .mp4 .mov .flv .m4v" WORKFLOW_FILE_TYPE_OPTIONS = ["csv", "parquet"] TRACKING_TYPE_OPTIONS = ["Classic tracking", "Multi tracking", "3D tracking"] UNSUPERVISED_FEATURE_OPTIONS = [ @@ -381,7 +348,7 @@ class Defaults(Enum): LARGE_MAX_TASK_PER_CHILD = 1000 CHUNK_SIZE = 1 SPLASH_TIME = 2500 - WELCOME_MSG = f'Welcome fellow scientists! \n SimBA v.{pkg_resources.get_distribution("simba-uw-tf-dev").version} \n ' + WELCOME_MSG = f"Welcome fellow scientists! \n SimBA v.local-dev \n " BROWSE_FOLDER_BTN_TEXT = "Browse Folder" BROWSE_FILE_BTN_TEXT = "Browse File" NO_FILE_SELECTED_TEXT = "No file selected" @@ -468,13 +435,13 @@ class Methods(Enum): THIRD_PARTY_ANNOTATION_FILE_NOT_FOUND = "Annotations data file NOT FOUND" -class MetaKeys(Enum): - CLF_NAME = "classifier_name" +class MachineLearningMetaKeys(Enum): + CLASSIFIER = "classifier" RF_ESTIMATORS = "rf_n_estimators" - CRITERION = "rf_criterion" + RF_CRITERION = "rf_criterion" TT_SIZE = "train_test_size" MIN_LEAF = "rf_min_sample_leaf" - META_FILE = "generate_rf_model_meta_data_file" + RF_METADATA = "generate_rf_model_meta_data_file" EX_DECISION_TREE = "generate_example_decision_tree" CLF_REPORT = "generate_classification_report" IMPORTANCE_LOG = "generate_features_importance_log" @@ -483,6 +450,7 @@ class MetaKeys(Enum): LEARNING_CURVE = "generate_sklearn_learning_curves" PRECISION_RECALL = "generate_precision_recall_curves" RF_MAX_FEATURES = "rf_max_features" + RF_MAX_DEPTH = "rf_max_depth" LEARNING_CURVE_K_SPLITS = "learning_curve_k_splits" LEARNING_CURVE_DATA_SPLITS = "learning_curve_data_splits" N_FEATURE_IMPORTANCE_BARS = "n_feature_importance_bars" @@ -492,7 +460,17 @@ class MetaKeys(Enum): SHAP_SAVE_ITERATION = "shap_save_iteration" PARTIAL_DEPENDENCY = "partial_dependency" TRAIN_TEST_SPLIT_TYPE = "train_test_split_type" - SAVE_TRAIN_TEST_FRM_IDX = "save_train_test_frm_idx" + UNDERSAMPLE_SETTING = "under_sample_setting" + UNDERSAMPLE_RATIO = "under_sample_ratio" + OVERSAMPLE_SETTING = "over_sample_setting" + OVERSAMPLE_RATIO = "over_sample_ratio" + CLASS_WEIGHTS = "class_weights" + CLASS_CUSTOM_WEIGHTS = "class_custom_weights" + EX_DECISION_TREE_FANCY = "generate_example_decision_tree_fancy" + IMPORTANCE_BARS_N = "N_feature_importance_bars" + LEARNING_DATA_SPLITS = "LearningCurve_shuffle_data_splits" + MODEL_TO_RUN = "model_to_run" + class OS(Enum): WINDOWS = "Windows" From 3899965c97506e4fa0d6afc933c26121d055aa48 Mon Sep 17 00:00:00 2001 From: tzuk polinsky Date: Mon, 1 Jan 2024 14:49:01 +0200 Subject: [PATCH 13/13] fixing missing files paths for app visibility --- simba/SimBA.py | 8 +++++++- simba/utils/enums.py | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/simba/SimBA.py b/simba/SimBA.py index bb5cad577..f009f8de7 100644 --- a/simba/SimBA.py +++ b/simba/SimBA.py @@ -150,7 +150,7 @@ # from simba.unsupervised.ui import UnsupervisedGUI -sys.setrecursionlimit(10**6) +sys.setrecursionlimit(10 ** 6) currentPlatform = platform.system() @@ -1505,6 +1505,8 @@ def callback(self, url): class App(object): def __init__(self): bg_path = os.path.join(os.path.dirname(__file__), Paths.BG_IMG_PATH.value) + if not os.path.exists(bg_path): + bg_path = os.path.join(os.path.dirname(__file__), Paths.BG_IMG_PATH_DEFAULT.value) emojis = get_emojis() icon_path_windows = os.path.join( os.path.dirname(__file__), Paths.LOGO_ICON_WINDOWS_PATH.value @@ -2025,6 +2027,10 @@ def __init__(self): splash_path = os.path.join( os.path.dirname(__file__), Paths.SPLASH_PATH_MOVIE.value ) + if not os.path.exists(splash_path): + splash_path = os.path.join( + os.path.dirname(__file__), Paths.SPLASH_PATH_MOVIE_DEFAULT.value + ) self.meta_ = get_video_meta_data(splash_path) self.cap = cv2.VideoCapture(splash_path) width, height = self.meta_["width"], self.meta_["height"] diff --git a/simba/utils/enums.py b/simba/utils/enums.py index bc2170094..d091b873d 100644 --- a/simba/utils/enums.py +++ b/simba/utils/enums.py @@ -1,5 +1,6 @@ __author__ = "Simon Nilsson" +import importlib import os import sys from enum import Enum @@ -108,7 +109,9 @@ class Paths(Enum): SPLASH_PATH_WINDOWS = Path("assets/img/splash.png") SPLASH_PATH_LINUX = Path("assets/img/splash.PNG") SPLASH_PATH_MOVIE = Path("assets/img/splash_2024.mp4") + SPLASH_PATH_MOVIE_DEFAULT = Path("assets/img/splash.mp4") BG_IMG_PATH = Path("assets/img/bg_2024.png") + BG_IMG_PATH_DEFAULT = Path("assets/img/bg.png") LOGO_ICON_WINDOWS_PATH = Path("assets/icons/SimBA_logo.ico") LOGO_ICON_DARWIN_PATH = Path("assets/icons/SimBA_logo.png") UNSUPERVISED_MODEL_NAMES = Path("assets/lookups/model_names.parquet") @@ -351,7 +354,7 @@ class Defaults(Enum): LARGE_MAX_TASK_PER_CHILD = 1000 CHUNK_SIZE = 1 SPLASH_TIME = 2500 - WELCOME_MSG = f'Welcome fellow scientists! \n SimBA v.{pkg_resources.get_distribution("simba-uw-tf-dev").version} \n ' + WELCOME_MSG = f'Welcome fellow scientists! \n SimBA v.{pkg_resources.get_distribution("simba-uw-tf-dev").version if importlib.util.find_spec("simba-uw-tf-dev") is not None else "dev"} \n ' BROWSE_FOLDER_BTN_TEXT = "Browse Folder" BROWSE_FILE_BTN_TEXT = "Browse File" NO_FILE_SELECTED_TEXT = "No file selected"