From d78e83c6975d2d5aca136099c3fdd91b37e4d28a Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Thu, 17 Jul 2025 16:33:23 +0000 Subject: [PATCH 01/23] update readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c419ba0..165490e 100644 --- a/README.md +++ b/README.md @@ -811,6 +811,8 @@ This is a list of TODOs for the repository. If you are interested in contributin - [ ] Curate higher quality instruction tuning and reasoning datasets for ELMs. - [ ] Expand upon current naive distributed training setting to include more efficient and explicit distributed training strategies (i.e., data, tensor, context, pipeline, and expert parallelism as noted in [here](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=5d_parallelism_in_a_nutshell)). - [x] Add option for data mixing. +- [ ] Adjust feature selection for RAG. +- [ ] Apply normalization to RAG database. ## Acknowledgements This work is done in collaboration with the Mario Lemieux Center for Heart Rhythm Care at Allegheny General Hospital. From 868895eea1bd1989795758d91e0340f4a01a7569 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sat, 19 Jul 2025 19:08:45 +0000 Subject: [PATCH 02/23] 1.reduce the #selected features. 2.better feature format. 3.add normalize_rag_feature arg when creating/loading rag index db --- ecg_bench/config.py | 1 + ecg_bench/utils/data_loader_utils.py | 14 +- ecg_bench/utils/preprocess_utils.py | 109 ++++++++++++++ ecg_bench/utils/rag_utils.py | 205 ++++++++++++++++++++------- 4 files changed, 273 insertions(+), 56 deletions(-) diff --git a/ecg_bench/config.py b/ecg_bench/config.py index edb7295..ae11f9a 100644 --- a/ecg_bench/config.py +++ b/ecg_bench/config.py @@ -60,6 +60,7 @@ def get_args(): mode_group.add_argument('--retrieved_information', type=str, default='combined', choices=['feature', 'report', 'combined'], help='Type of information to retrieve in output') mode_group.add_argument('--load_rag_db', type = str, default = None, help = 'Load a RAG database') mode_group.add_argument('--load_rag_db_idx', type = str, default = None, help = 'Load a RAG database index') + mode_group.add_argument('--normalized_rag_feature', action='store_true', default=True, help='Enable normalization for RAG features and signals') mode_group.add_argument('--dev', action='store_true', default=None, help='Development mode') mode_group.add_argument('--log', action='store_true', default=None, help='Enable logging') diff --git a/ecg_bench/utils/data_loader_utils.py b/ecg_bench/utils/data_loader_utils.py index 9ae9990..1196d35 100644 --- a/ecg_bench/utils/data_loader_utils.py +++ b/ecg_bench/utils/data_loader_utils.py @@ -110,7 +110,11 @@ def setup_conversation_template(self, signal = None): if self.args.retrieval_base in ['feature', 'combined']: if self.args.dev: print("🔍 DEBUG: Extracting features") - feature=self.rag_db.feature_extractor.extract_features(signal) + original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) + feature=original_feature + if self.args.normalized_rag_feature: + feature = self.rag_db.query_feature_normalization(original_feature) + signal = self.rag_db.query_signal_lead_normalization(signal) if self.args.dev: print("🔍 DEBUG: Features extracted, shape: ", feature.shape) @@ -161,13 +165,17 @@ def append_messages_to_conv(self, conv, altered_text, signal=None): message_value = message_value.replace('', '') message_value = message_value.replace('image', 'signal').replace('Image', 'Signal') if self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']: - feature=self.rag_db.feature_extractor.extract_features(signal) + original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) + feature=original_feature + if self.args.normalized_rag_feature: + feature = self.rag_db.query_feature_normalization(original_feature) + signal = self.rag_db.query_signal_lead_normalization(signal) if is_human and count == 0: if self.args.rag and self.args.rag_prompt_mode == 'user_query': rag_results = self.rag_db.search_similar(query_features=feature, query_signal=signal, k=self.args.rag_k, mode=self.args.retrieval_base) filtered_rag_results = self.rag_db.format_search(rag_results,self.args.retrieved_information) if self.args.retrieved_information == 'combined': - message_value = f"\nFeature Information:\n{feature}\n\n{filtered_rag_results}\n{message_value}" + message_value = f"\nFeature Information:\n{self.rag_db.convert_features_to_structured(original_feature)}\n\n{filtered_rag_results}\n{message_value}" elif self.args.retrieved_information == 'report': message_value = f"\n{filtered_rag_results}\n{message_value}" else: diff --git a/ecg_bench/utils/preprocess_utils.py b/ecg_bench/utils/preprocess_utils.py index f6efc56..65666db 100644 --- a/ecg_bench/utils/preprocess_utils.py +++ b/ecg_bench/utils/preprocess_utils.py @@ -1047,6 +1047,61 @@ def extract_features(self, ecg): return np.array(features) + def extract_rag_features(self, ecg): + """ + Extract a subset of features for RAG applications. + Keeps only: max, min, dominant_frequency, total_power, spectral_centroid, + peak_frequency_power, Heart Rate Features, Wavelet Features, average_absolute_difference, root_mean_square_difference + """ + features = [] + + for lead in range(ecg.shape[0]): + lead_signal = ecg[lead, :] + + # Basic statistical features (only max and min) + features.extend([ + np.max(lead_signal), + np.min(lead_signal) + ]) + + # Frequency domain features + freqs, psd = signal.welch(lead_signal, fs=self.target_sf, nperseg=min(1024, len(lead_signal))) + total_power = np.sum(psd) + features.extend([ + total_power, # Total power + np.max(psd), # Peak frequency power + freqs[np.argmax(psd)], # Dominant frequency + ]) + + # Spectral centroid with NaN handling + if total_power > 0: + spectral_centroid = np.sum(freqs * psd) / total_power + else: + spectral_centroid = 0 + features.append(spectral_centroid) + + # Find peaks with robust thresholding + if np.max(lead_signal) != np.min(lead_signal): # Avoid division by zero + peak_height = 0.3 * (np.max(lead_signal) - np.min(lead_signal)) + np.min(lead_signal) + min_distance = max(int(0.2 * self.target_sf), 1) # Ensure positive distance + peaks, _ = signal.find_peaks(lead_signal, height=peak_height, distance=min_distance) + else: + peaks = [] + + # Heart rate features + heart_rate_features = self._calculate_heart_rate_features(lead_signal, peaks) + features.extend(heart_rate_features) + + # Wavelet features + wavelet_features = self._calculate_wavelet_features(lead_signal) + features.extend(wavelet_features) + + # Non-linear features + features.append(np.mean(np.abs(np.diff(lead_signal)))) # Average absolute difference + features.append(np.sqrt(np.mean(np.square(np.diff(lead_signal))))) # Root mean square of successive differences + + return np.array(features) + def _calculate_heart_rate_features(self, ecg, peaks): if len(peaks) > 1: # Heart rate @@ -1118,4 +1173,58 @@ def find_st_deviation(self, ecg, peaks): return ecg[st_point] - ecg[peaks[-1]] return 0 + def signal_lead_normalization(ecg): + """ + Normalize each lead individually using z-score normalization. + """ + if ecg.shape[0] == 12: + ecg = ecg.T + transpose_back = True + else: + transpose_back = False + + normalized_ecg = np.zeros_like(ecg, dtype=np.float32) + + for lead_idx in range(12): + lead_signal = ecg[:, lead_idx] + lead_mean = np.mean(lead_signal) + lead_std = np.std(lead_signal) + 1e-10 + normalized_ecg[:, lead_idx] = (lead_signal - lead_mean) / lead_std + + if transpose_back: + normalized_ecg = normalized_ecg.T + + return normalized_ecg + + def feature_normalization(self, rag_features): + """ + Normalize RAG features using z-score normalization. + """ + features_per_lead = len(self.ecg_feature_list) + expected_total_features = 12 * features_per_lead + + if rag_features.ndim != 1: + raise ValueError(f"Expected 1D array, got shape {rag_features.shape}") + + if len(rag_features) != expected_total_features: + raise ValueError(f"Expected {expected_total_features} features for 12-lead ECG, got {len(rag_features)}") + + normalized_features = np.zeros_like(rag_features, dtype=np.float32) + + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + feature_values.append(rag_features[feature_pos]) + + feature_values = np.array(feature_values) + + feature_mean = np.mean(feature_values) + feature_std = np.std(feature_values) + 1e-10 + + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + normalized_features[feature_pos] = (rag_features[feature_pos] - feature_mean) / feature_std + + return normalized_features diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 06ca8f8..7ee89ad 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -10,13 +10,13 @@ def __init__(self, args, fm): self.args = args self.fm = fm self.ecg_feature_list = [ - "mean", - "std", + # "mean", + # "std", "max", "min", - "median", - "25th percentile", - "75th percentile", + # "median", + # "25th percentile", + # "75th percentile", "total power", "peak frequency power", "dominant frequency", @@ -35,6 +35,7 @@ def __init__(self, args, fm): "average absolute difference", "root mean square difference" ] + print('Loading RAG database...') if self.args.create_rag_db: @@ -47,11 +48,13 @@ def __init__(self, args, fm): print('Loading RAG database from file...') self.metadata = self.fm.open_json(self.args.load_rag_db) self.feature_extractor = ECGFeatureExtractor(self.args.target_sf) - self.index = faiss.read_index(f"./data/mimic/combined.index") + self.index = faiss.read_index(f"./data/mimic/combined_normalized.index") if self.args.dev: print("🔍 DEBUG: Combined index loaded directly") - self.feature_dim = 288 + self.feature_dim = 12*len(self.ecg_feature_list) self.signal_dim = self.index.d - self.feature_dim + + if self.args.retrieval_base == 'signal': self.signal_index = faiss.read_index(self.args.load_rag_db_idx) if self.args.dev: @@ -69,7 +72,8 @@ def __init__(self, args, fm): print('Metadata loaded.') self.reports = [item['report'] for item in self.metadata] self.file_paths = [item['file_path'] for item in self.metadata] - + self.original_ecgs = [item['signal'] for item in self.metadata] + self.original_features = [item['features'] for item in self.metadata] print(f'RAG {self.args.retrieval_base} database loaded.') # print('Building sub-indices...') @@ -78,8 +82,64 @@ def __init__(self, args, fm): print('features dim:', self.feature_dim) print('signals dim:', self.signal_dim) print('total samples:', len(self.reports)) + print(f'Normalization enabled: {self.args.normalized_rag_feature}') print('Index loaded.') + def query_signal_lead_normalization(self, signal): + """ + Normalize each lead individually using z-score normalization. + """ + if signal.shape[0] == 12: + signal = signal.T + transpose_back = True + else: + transpose_back = False + + normalized_signal = np.zeros_like(signal, dtype=np.float32) + + for lead_idx in range(12): + lead_signal = signal[:, lead_idx] + lead_mean = np.mean(lead_signal) + lead_std = np.std(lead_signal) + 1e-10 + normalized_signal[:, lead_idx] = (lead_signal - lead_mean) / lead_std + + if transpose_back: + normalized_signal = normalized_signal.T + + return normalized_signal + + def query_feature_normalization(self, rag_features): + """ + Normalize RAG features using z-score normalization. + """ + features_per_lead = len(self.ecg_feature_list) + expected_total_features = 12 * features_per_lead + + if rag_features.ndim != 1: + raise ValueError(f"Expected 1D array, got shape {rag_features.shape}") + + if len(rag_features) != expected_total_features: + raise ValueError(f"Expected {expected_total_features} features for 12-lead ECG, got {len(rag_features)}") + + normalized_features = np.zeros_like(rag_features, dtype=np.float32) + + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + feature_values.append(rag_features[feature_pos]) + + feature_values = np.array(feature_values) + + feature_mean = np.mean(feature_values) + feature_std = np.std(feature_values) + 1e-10 + + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + normalized_features[feature_pos] = (rag_features[feature_pos] - feature_mean) / feature_std + + return normalized_features + def _build_sub_indices(self): ntotal = self.index.ntotal nlist=min(100, max(1, ntotal // 30)) @@ -109,11 +169,10 @@ def _build_sub_indices(self): def create_and_save_db(self): print('Initializing RAG database creation...') metadata = [] - vectors_for_index = [] + combined_vectors = [] feature_vectors = [] signal_vectors = [] - - + npy_files = list(Path(self.preprocessed_dir).glob('*.npy')) if self.args.dev: npy_files = npy_files[:1000] @@ -123,6 +182,7 @@ def create_and_save_db(self): print(f'Toy mode: Processing {len(npy_files)} files') print(f'Found {len(npy_files)} files to process') + print(f'Normalization enabled: {self.args.normalized_rag_feature}') print('Starting feature extraction from ECG signals...') for file_path in tqdm(npy_files, desc="Extracting features"): @@ -130,33 +190,39 @@ def create_and_save_db(self): data = self.fm.open_npy(file_path) ecg = data['ecg'] report = data['report'] - features = self.feature_extractor.extract_features(ecg) - - # Store vectors for different indices - feature_vector = features.flatten() - signal_vector = ecg.flatten() - combined_vector = np.hstack([feature_vector, signal_vector]) - - feature_vectors.append(feature_vector) - signal_vectors.append(signal_vector) - vectors_for_index.append(combined_vector) - + features = self.feature_extractor.extract_rag_features(ecg).flatten() # Store only metadata in JSON metadata.append({ 'report': report, - 'file_path': str(file_path) + 'file_path': str(file_path), + 'features': features, + 'signal': ecg }) + + if not self.args.normalized_rag_feature: + signal_vector = ecg.flatten() + feature_vector = features.flatten() + + else: + signal_vector = self.query_signal_lead_normalization(ecg).flatten() + feature_vector = self.query_feature_normalization(features).flatten() + + combined_vector = np.hstack([feature_vector, signal_vector]) + signal_vectors.append(signal_vector) + feature_vectors.append(feature_vector) + combined_vectors.append(combined_vector) + + except Exception as e: print(f"Error processing {file_path}: {str(e)}") continue - + print(f'Successfully processed {len(metadata)} files') - print('Converting vectors to arrays...') - + # Convert to arrays feature_array = np.stack(feature_vectors) signal_array = np.stack(signal_vectors) - combined_array = np.stack(vectors_for_index) + combined_array = np.stack(combined_vectors) # Calculate optimal number of clusters based on dataset size ntotal = len(combined_array) @@ -173,7 +239,7 @@ def create_and_save_db(self): print('Adding vectors to feature index...') self.feature_index.add(feature_array) self.feature_index.make_direct_map() - feature_path = f"./data/{self.args.base_data}/feature.index" + feature_path = f"./data/{self.args.base_data}/feature{'_normalized' if self.args.normalized_rag_feature else ''}.index" print(f'Saving feature index to {feature_path}...') faiss.write_index(self.feature_index, feature_path) print('Feature index saved successfully!') @@ -187,7 +253,7 @@ def create_and_save_db(self): print('Adding vectors to signal index...') self.signal_index.add(signal_array) self.signal_index.make_direct_map() - signal_path = f"./data/{self.args.base_data}/signal.index" + signal_path = f"./data/{self.args.base_data}/signal{'_normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" print(f'Saving signal index to {signal_path}...') faiss.write_index(self.signal_index, signal_path) print('Signal index saved successfully!') @@ -201,7 +267,7 @@ def create_and_save_db(self): print('Adding vectors to combined index...') self.index.add(combined_array) self.index.make_direct_map() - combined_path = f"./data/{self.args.base_data}/combined.index" + combined_path = f"./data/{self.args.base_data}/combined{'_normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" print(f'Saving combined index to {combined_path}...') faiss.write_index(self.index, combined_path) print('Combined index saved successfully!') @@ -212,6 +278,7 @@ def create_and_save_db(self): self.fm.save_json(metadata, metadata_path) print('Metadata saved successfully!') + print('RAG database creation completed successfully!') print(f'Total samples: {len(metadata)}') print(f'Feature dimension: {feature_array.shape[1]}') @@ -268,13 +335,9 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode='sign # Prepare results using reconstructed vectors from index results = {} for i, (dist, idx) in enumerate(zip(distances[0], original_indices)): - full_vector = self.index.reconstruct(int(idx)) - features = full_vector[:self.feature_dim] - signal = full_vector[self.feature_dim:] - result_dict = { - 'signal': signal, - 'feature': features, + 'signal': self.original_ecgs[idx], + 'feature': self.original_features[idx], 'report': self.reports[idx], 'distance': float(dist), 'file_path': self.file_paths[idx] @@ -295,9 +358,11 @@ def format_search(self, results, retrieved_information='combined'): if retrieved_information == 'feature': output += "features. Utilize this information to further enhance your response.\n\n" elif retrieved_information == 'report': - output += "diagnosis. Utilize this information to further enhance your response.\n\n" + output += "diagnosis. Utilize this information to further enhance your response. The lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" else: # combined - output += "features and diagnosis. Utilize this information to further enhance your response.\n\n" + output += "features and diagnosis. Utilize this information to further enhance your response. The lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" + + features_per_lead = len(self.ecg_feature_list) for idx, res in results.items(): # Filter out entries where all feature values are zero @@ -309,23 +374,52 @@ def format_search(self, results, retrieved_information='combined'): # Include feature information based on retrieved_information if retrieved_information in ['feature', 'combined']: output += "Feature Information:\n" - # Zip through feature names and feature values to format each line. - for feature_name, feature_value in zip(self.ecg_feature_list, res['feature']): - output += f"{feature_name}: {str(round(float(feature_value), 6))}\n" + + # Organize features by feature type across all leads + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + feature_values.append(round(float(res['feature'][feature_pos]), 6)) + output += f"{feature_name}: {feature_values}\n" output += "\n" # Include diagnosis information based on retrieved_information if retrieved_information in ['report', 'combined']: output += "Diagnosis Information:\n" output += f"{res['report']}\n\n" - - - + if self.args.dev: print("🔍 DEBUG: First 300 characters of formatted output:") print(output[:300] + "..." if len(output) > 300 else output) return output + def convert_features_to_structured(self, feature_array): + """ + Convert a flat feature array into a formatted string organized by feature type. + + Args: + feature_array: numpy array of shape (228,) containing RAG features for 12 leads + + Returns: + formatted_string: formatted string with feature names and arrays of 12 values + """ + features_per_lead = len(self.ecg_feature_list) + + if len(feature_array) != 12 * features_per_lead: + raise ValueError(f"Expected {12 * features_per_lead} features, got {len(feature_array)}") + + formatted_output = "" + + for feature_idx, feature_name in enumerate(self.ecg_feature_list): + feature_values = [] + for lead_idx in range(12): + feature_pos = lead_idx * features_per_lead + feature_idx + feature_values.append(round(float(feature_array[feature_pos]), 6)) + formatted_output += f"{feature_name}: {feature_values}\n" + + return formatted_output + def filter_results(self, results): filtered_results = {} count = 0 @@ -349,26 +443,31 @@ def test_search(self): random_idx = np.random.randint(0, len(npy_files)) query_signal = self.fm.open_npy(npy_files[random_idx])['ecg'] print('query_signal', query_signal.shape) + # Flatten the signal to match the expected dimensions - query_signal = query_signal.flatten() + query_signal_flat = query_signal.flatten() + start_time = time.time() # Use retrieval_base parameter to determine search mode retrieval_base = getattr(self.args, 'retrieval_base', 'signal') if retrieval_base == 'feature': # Extract features for feature-based search - # Need to reshape back to 2D for feature extraction - query_signal_2d = query_signal.reshape(12, -1) - features = self.feature_extractor.extract_features(query_signal_2d) + features = self.feature_extractor.extract_rag_features(query_signal) + if self.args.normalized_rag_feature: + features = self.query_feature_normalization(features) results = self.search_similar(query_features=features, k=10, mode='feature') elif retrieval_base == 'combined': # Extract features for combined search - # Need to reshape back to 2D for feature extraction - query_signal_2d = query_signal.reshape(12, -1) - features = self.feature_extractor.extract_features(query_signal_2d) - results = self.search_similar(query_features=features, query_signal=query_signal, k=10, mode='combined') + features = self.feature_extractor.extract_rag_features(query_signal) + if self.args.normalized_rag_feature: + features = self.query_feature_normalization(features) + query_signal_flat = self.query_signal_lead_normalization(query_signal).flatten() + results = self.search_similar(query_features=features, query_signal=query_signal_flat, k=10, mode='combined') else: # signal mode (default) - results = self.search_similar(query_signal=query_signal, k=10, mode='signal') + if self.args.normalized_rag_feature: + query_signal_flat = self.query_signal_lead_normalization(query_signal).flatten() + results = self.search_similar(query_signal=query_signal_flat, k=10, mode='signal') formatted_results = self.format_search(results, retrieved_information=getattr(self.args, 'retrieved_information', 'combined')) print(formatted_results) From bcaa4a70d500c7098f6ad5e6e5ef77d4d22a146b Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sat, 19 Jul 2025 20:19:15 +0000 Subject: [PATCH 03/23] make sure json can handle numpy arrays --- ecg_bench/utils/dir_file_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ecg_bench/utils/dir_file_utils.py b/ecg_bench/utils/dir_file_utils.py index 35be8fc..baea46f 100644 --- a/ecg_bench/utils/dir_file_utils.py +++ b/ecg_bench/utils/dir_file_utils.py @@ -21,8 +21,18 @@ def open_json(path: Union[str, Path]) -> dict: @staticmethod def save_json(data: dict, path: Union[str, Path]): """Save a dictionary to a JSON file.""" + def convert_numpy(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, dict): + return {key: convert_numpy(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [convert_numpy(item) for item in obj] + else: + return obj + with open(path, 'w') as f: - json.dump(data, f) + json.dump(convert_numpy(data), f) @staticmethod def get_system_prompt(system_prompt_path: Union[str, Path]): From 491b217b7372b4df79190c369bb42315dc6b56bc Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sun, 20 Jul 2025 07:32:02 +0000 Subject: [PATCH 04/23] apply normalization to 3 different retrieval base --- ecg_bench/config.py | 2 +- ecg_bench/utils/rag_utils.py | 74 +++++++++++++++++------------------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/ecg_bench/config.py b/ecg_bench/config.py index ae11f9a..8c55868 100644 --- a/ecg_bench/config.py +++ b/ecg_bench/config.py @@ -60,7 +60,7 @@ def get_args(): mode_group.add_argument('--retrieved_information', type=str, default='combined', choices=['feature', 'report', 'combined'], help='Type of information to retrieve in output') mode_group.add_argument('--load_rag_db', type = str, default = None, help = 'Load a RAG database') mode_group.add_argument('--load_rag_db_idx', type = str, default = None, help = 'Load a RAG database index') - mode_group.add_argument('--normalized_rag_feature', action='store_true', default=True, help='Enable normalization for RAG features and signals') + mode_group.add_argument('--normalized_rag_feature', action='store_true', default=None, help='Enable normalization for RAG features and signals') mode_group.add_argument('--dev', action='store_true', default=None, help='Development mode') mode_group.add_argument('--log', action='store_true', default=None, help='Enable logging') diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 7ee89ad..a9734b6 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -41,6 +41,9 @@ def __init__(self, args, fm): if self.args.create_rag_db: self.preprocessed_dir = f"./data/{self.args.base_data}/preprocessed_{self.args.seg_len}_{self.args.target_sf}" self.feature_extractor = ECGFeatureExtractor(self.args.target_sf) + self.feature_dim = 12* len(self.ecg_feature_list) + self.signal_dim = 12*self.args.seg_len + self.feature_weight=np.sqrt(self.signal_dim/self.feature_dim) print('Creating RAG database...') print('No RAG database found. Creating new one...') self.metadata = self.create_and_save_db() @@ -48,11 +51,11 @@ def __init__(self, args, fm): print('Loading RAG database from file...') self.metadata = self.fm.open_json(self.args.load_rag_db) self.feature_extractor = ECGFeatureExtractor(self.args.target_sf) - self.index = faiss.read_index(f"./data/mimic/combined_normalized.index") if self.args.dev: print("🔍 DEBUG: Combined index loaded directly") self.feature_dim = 12*len(self.ecg_feature_list) - self.signal_dim = self.index.d - self.feature_dim + self.signal_dim = 12*self.args.seg_len + self.feature_weight=np.sqrt(self.signal_dim/self.feature_dim) if self.args.retrieval_base == 'signal': @@ -72,8 +75,6 @@ def __init__(self, args, fm): print('Metadata loaded.') self.reports = [item['report'] for item in self.metadata] self.file_paths = [item['file_path'] for item in self.metadata] - self.original_ecgs = [item['signal'] for item in self.metadata] - self.original_features = [item['features'] for item in self.metadata] print(f'RAG {self.args.retrieval_base} database loaded.') # print('Building sub-indices...') @@ -112,8 +113,7 @@ def query_feature_normalization(self, rag_features): """ Normalize RAG features using z-score normalization. """ - features_per_lead = len(self.ecg_feature_list) - expected_total_features = 12 * features_per_lead + expected_total_features = self.feature_dim if rag_features.ndim != 1: raise ValueError(f"Expected 1D array, got shape {rag_features.shape}") @@ -126,7 +126,7 @@ def query_feature_normalization(self, rag_features): for feature_idx, feature_name in enumerate(self.ecg_feature_list): feature_values = [] for lead_idx in range(12): - feature_pos = lead_idx * features_per_lead + feature_idx + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx feature_values.append(rag_features[feature_pos]) feature_values = np.array(feature_values) @@ -135,7 +135,7 @@ def query_feature_normalization(self, rag_features): feature_std = np.std(feature_values) + 1e-10 for lead_idx in range(12): - feature_pos = lead_idx * features_per_lead + feature_idx + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx normalized_features[feature_pos] = (rag_features[feature_pos] - feature_mean) / feature_std return normalized_features @@ -175,7 +175,7 @@ def create_and_save_db(self): npy_files = list(Path(self.preprocessed_dir).glob('*.npy')) if self.args.dev: - npy_files = npy_files[:1000] + npy_files = npy_files[:300] print(f'Development mode: Processing {len(npy_files)} files') if self.args.toy: npy_files = npy_files[:400000] @@ -191,12 +191,9 @@ def create_and_save_db(self): ecg = data['ecg'] report = data['report'] features = self.feature_extractor.extract_rag_features(ecg).flatten() - # Store only metadata in JSON metadata.append({ 'report': report, 'file_path': str(file_path), - 'features': features, - 'signal': ecg }) if not self.args.normalized_rag_feature: @@ -207,7 +204,7 @@ def create_and_save_db(self): signal_vector = self.query_signal_lead_normalization(ecg).flatten() feature_vector = self.query_feature_normalization(features).flatten() - combined_vector = np.hstack([feature_vector, signal_vector]) + combined_vector = np.hstack([feature_vector*self.feature_weight, signal_vector]) signal_vectors.append(signal_vector) feature_vectors.append(feature_vector) combined_vectors.append(combined_vector) @@ -239,7 +236,7 @@ def create_and_save_db(self): print('Adding vectors to feature index...') self.feature_index.add(feature_array) self.feature_index.make_direct_map() - feature_path = f"./data/{self.args.base_data}/feature{'_normalized' if self.args.normalized_rag_feature else ''}.index" + feature_path = f"./data/{self.args.base_data}/feature_{'normalized' if self.args.normalized_rag_feature else ''}.index" print(f'Saving feature index to {feature_path}...') faiss.write_index(self.feature_index, feature_path) print('Feature index saved successfully!') @@ -253,7 +250,7 @@ def create_and_save_db(self): print('Adding vectors to signal index...') self.signal_index.add(signal_array) self.signal_index.make_direct_map() - signal_path = f"./data/{self.args.base_data}/signal{'_normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" + signal_path = f"./data/{self.args.base_data}/signal_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" print(f'Saving signal index to {signal_path}...') faiss.write_index(self.signal_index, signal_path) print('Signal index saved successfully!') @@ -267,7 +264,7 @@ def create_and_save_db(self): print('Adding vectors to combined index...') self.index.add(combined_array) self.index.make_direct_map() - combined_path = f"./data/{self.args.base_data}/combined{'_normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" + combined_path = f"./data/{self.args.base_data}/combined_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" print(f'Saving combined index to {combined_path}...') faiss.write_index(self.index, combined_path) print('Combined index saved successfully!') @@ -325,22 +322,24 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode='sign else: # combined mode self.index.nprobe = nprobe - query_combined = np.hstack([query_features, query_signal]) + query_combined = np.hstack([query_features*self.feature_weight, query_signal]) query_combined = query_combined.reshape(1, -1) distances, indices = self.index.search(query_combined, k) original_indices = indices[0] - - # Prepare results using reconstructed vectors from index results = {} for i, (dist, idx) in enumerate(zip(distances[0], original_indices)): + file_path = self.file_paths[idx] + signal=self.fm.open_npy(file_path)['ecg'] + features=self.feature_extractor.extract_rag_features(signal) + result_dict = { - 'signal': self.original_ecgs[idx], - 'feature': self.original_features[idx], + 'signal': signal, + 'feature': features, 'report': self.reports[idx], 'distance': float(dist), - 'file_path': self.file_paths[idx] + 'file_path': file_path } results[i] = result_dict @@ -362,8 +361,6 @@ def format_search(self, results, retrieved_information='combined'): else: # combined output += "features and diagnosis. Utilize this information to further enhance your response. The lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" - features_per_lead = len(self.ecg_feature_list) - for idx, res in results.items(): # Filter out entries where all feature values are zero if np.all(np.array(res['feature']) == 0): @@ -371,6 +368,8 @@ def format_search(self, results, retrieved_information='combined'): output += f"Retrieved ECG {idx+1}\n" + if self.args.dev: + output+=f"Distance: {res['distance']}\n" # Include feature information based on retrieved_information if retrieved_information in ['feature', 'combined']: output += "Feature Information:\n" @@ -379,7 +378,7 @@ def format_search(self, results, retrieved_information='combined'): for feature_idx, feature_name in enumerate(self.ecg_feature_list): feature_values = [] for lead_idx in range(12): - feature_pos = lead_idx * features_per_lead + feature_idx + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx feature_values.append(round(float(res['feature'][feature_pos]), 6)) output += f"{feature_name}: {feature_values}\n" output += "\n" @@ -388,10 +387,6 @@ def format_search(self, results, retrieved_information='combined'): if retrieved_information in ['report', 'combined']: output += "Diagnosis Information:\n" output += f"{res['report']}\n\n" - - if self.args.dev: - print("🔍 DEBUG: First 300 characters of formatted output:") - print(output[:300] + "..." if len(output) > 300 else output) return output def convert_features_to_structured(self, feature_array): @@ -404,17 +399,15 @@ def convert_features_to_structured(self, feature_array): Returns: formatted_string: formatted string with feature names and arrays of 12 values """ - features_per_lead = len(self.ecg_feature_list) - - if len(feature_array) != 12 * features_per_lead: - raise ValueError(f"Expected {12 * features_per_lead} features, got {len(feature_array)}") + if len(feature_array) != self.feature_dim: + raise ValueError(f"Expected {self.feature_dim} features, got {len(feature_array)}") formatted_output = "" for feature_idx, feature_name in enumerate(self.ecg_feature_list): feature_values = [] for lead_idx in range(12): - feature_pos = lead_idx * features_per_lead + feature_idx + feature_pos = lead_idx * len(self.ecg_feature_list) + feature_idx feature_values.append(round(float(feature_array[feature_pos]), 6)) formatted_output += f"{feature_name}: {feature_values}\n" @@ -439,9 +432,12 @@ def filter_results(self, results): def test_search(self): self.preprocessed_dir = f"./data/{self.args.base_data}/preprocessed_{self.args.seg_len}_{self.args.target_sf}" + rng = np.random.RandomState(42) npy_files = list(Path(self.preprocessed_dir).glob('*.npy')) - random_idx = np.random.randint(0, len(npy_files)) + random_idx = rng.randint(0, len(npy_files)) query_signal = self.fm.open_npy(npy_files[random_idx])['ecg'] + query_report = self.fm.open_npy(npy_files[random_idx])['report'] + print('query_report: /n', query_report) print('query_signal', query_signal.shape) # Flatten the signal to match the expected dimensions @@ -450,24 +446,24 @@ def test_search(self): start_time = time.time() # Use retrieval_base parameter to determine search mode - retrieval_base = getattr(self.args, 'retrieval_base', 'signal') + retrieval_base = getattr(self.args, 'retrieval_base', 'combined') if retrieval_base == 'feature': # Extract features for feature-based search features = self.feature_extractor.extract_rag_features(query_signal) if self.args.normalized_rag_feature: features = self.query_feature_normalization(features) - results = self.search_similar(query_features=features, k=10, mode='feature') + results = self.search_similar(query_features=features, k=3, mode='feature') elif retrieval_base == 'combined': # Extract features for combined search features = self.feature_extractor.extract_rag_features(query_signal) if self.args.normalized_rag_feature: features = self.query_feature_normalization(features) query_signal_flat = self.query_signal_lead_normalization(query_signal).flatten() - results = self.search_similar(query_features=features, query_signal=query_signal_flat, k=10, mode='combined') + results = self.search_similar(query_features=features, query_signal=query_signal_flat, k=3, mode='combined') else: # signal mode (default) if self.args.normalized_rag_feature: query_signal_flat = self.query_signal_lead_normalization(query_signal).flatten() - results = self.search_similar(query_signal=query_signal_flat, k=10, mode='signal') + results = self.search_similar(query_signal=query_signal_flat, k=3, mode='signal') formatted_results = self.format_search(results, retrieved_information=getattr(self.args, 'retrieved_information', 'combined')) print(formatted_results) From 7ea585bac432c6b84d79e51bb8d46e5698b0843a Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sun, 20 Jul 2025 07:34:59 +0000 Subject: [PATCH 05/23] keep the code clean --- ecg_bench/utils/data_loader_utils.py | 30 +--------------------------- ecg_bench/utils/rag_utils.py | 24 +--------------------- 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/ecg_bench/utils/data_loader_utils.py b/ecg_bench/utils/data_loader_utils.py index 1196d35..e1a643a 100644 --- a/ecg_bench/utils/data_loader_utils.py +++ b/ecg_bench/utils/data_loader_utils.py @@ -108,38 +108,20 @@ def setup_conversation_template(self, signal = None): conv = get_conv_template('gemma') feature=None if self.args.retrieval_base in ['feature', 'combined']: - if self.args.dev: - print("🔍 DEBUG: Extracting features") original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) feature=original_feature if self.args.normalized_rag_feature: feature = self.rag_db.query_feature_normalization(original_feature) signal = self.rag_db.query_signal_lead_normalization(signal) - if self.args.dev: - print("🔍 DEBUG: Features extracted, shape: ", feature.shape) if 'gemma' not in self.args.model and ('qwen' in self.args.model or 'llama' in self.args.model): if self.args.rag and self.args.rag_prompt_mode == 'system_prompt': - if self.args.dev: - print("🔍 DEBUG: Setting ptompt for system_prompt modes") rag_results = self.rag_db.search_similar(query_features=feature, query_signal=signal, k=self.args.rag_k, mode=self.args.retrieval_base) - if self.args.dev: - print("🔍 DEBUG: RAG results retrieved") filtered_rag_results = self.rag_db.format_search(rag_results,self.args.retrieved_information) - if self.args.dev: - print("🔍 DEBUG: RAG results formatted") modified_system_prompt = f"{self.system_prompt}\n{filtered_rag_results}" - if self.args.dev: - print("🔍 DEBUG: Modified system prompt set") - print('filtered_rag_results', filtered_rag_results) - print('modified_system_prompt', modified_system_prompt) conv.set_system_message(modified_system_prompt) - if self.args.dev: - print("🔍 DEBUG: System prompt set!") else: conv.set_system_message(self.system_prompt) - if self.args.dev: - print("🔍 DEBUG: System prompt set!") return conv def process_altered_text(self, altered_text): @@ -182,8 +164,6 @@ def append_messages_to_conv(self, conv, altered_text, signal=None): message_value = f"\n{message_value}" count += 1 conv.append_message(role, message_value) - if self.args.dev: - print("🔍 DEBUG: Message appended to conv!") return conv def get_input_tokens(self, conv): @@ -390,16 +370,10 @@ def prepare_end2end_input(self, ecg_signal, altered_text): return self.prepare_inference_end2end(ecg_signal, altered_text) def prepare_training_end2end(self, ecg_signal, altered_text): - if self.args.dev: - print("🔍 DEBUG: Preparing training end2end input") conv = self.setup_conversation_template(signal=ecg_signal) - if self.args.dev: - print("🔍 DEBUG: Conversation template set!") altered_text = self.process_altered_text(altered_text) conv = self.append_messages_to_conv(conv, altered_text, ecg_signal) - if self.args.dev: - print("🔍 DEBUG: Messages appended to conv!") - + tokens_before, tokens_after = self.get_input_tokens(conv) symbol_signal = self.train_utils.ecg_tokenizer_utils._to_symbol_string(ecg_signal) @@ -424,8 +398,6 @@ def prepare_training_end2end(self, ecg_signal, altered_text): if len(input_ids) < self.args.pad_to_max: padding_length = self.args.pad_to_max - len(input_ids) input_ids = [self.llm_tokenizer.pad_token_id] * padding_length + input_ids - if self.args.dev: - print("🔍 DEBUG: About to call create_labels_from_responses") labels = self.create_labels_from_responses(input_ids, altered_text) if self.args.dev: diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index a9734b6..42703a9 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -51,22 +51,14 @@ def __init__(self, args, fm): print('Loading RAG database from file...') self.metadata = self.fm.open_json(self.args.load_rag_db) self.feature_extractor = ECGFeatureExtractor(self.args.target_sf) - if self.args.dev: - print("🔍 DEBUG: Combined index loaded directly") self.feature_dim = 12*len(self.ecg_feature_list) self.signal_dim = 12*self.args.seg_len self.feature_weight=np.sqrt(self.signal_dim/self.feature_dim) - if self.args.retrieval_base == 'signal': - self.signal_index = faiss.read_index(self.args.load_rag_db_idx) - if self.args.dev: - print("🔍 DEBUG: Signal index loaded directly") - + self.signal_index = faiss.read_index(self.args.load_rag_db_idx) elif self.args.retrieval_base == 'feature': self.feature_index = faiss.read_index(self.args.load_rag_db_idx) - if self.args.dev: - print("🔍 DEBUG: Feature index loaded directly") else: raise ValueError("Please provide a valid retrieval base.") else: @@ -302,23 +294,11 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode='sign query_features = query_features.reshape(1, self.feature_dim) distances, indices = self.feature_index.search(query_features, k) original_indices = indices[0] - if self.args.dev: - print("🔍 DEBUG: Feature index search completed") elif mode == 'signal': - if self.args.dev: - print("🔍 DEBUG: Signal index search started") self.signal_index.nprobe = nprobe - if self.args.dev: - print("🔍 DEBUG: Reshaping query signal") query_signal = query_signal.reshape(1, -1) - if self.args.dev: - print("🔍 DEBUG: Computing distances and indices") distances, indices = self.signal_index.search(query_signal, k) - if self.args.dev: - print("🔍 DEBUG: Signal index search completed") original_indices = indices[0] - if self.args.dev: - print("🔍 DEBUG: Original index search completed") else: # combined mode self.index.nprobe = nprobe @@ -349,8 +329,6 @@ def format_search(self, results, retrieved_information='combined'): if retrieved_information not in ['feature', 'report', 'combined']: raise ValueError("retrieved_information must be 'feature', 'report', or 'combined'") results = self.filter_results(results) - if self.args.dev: - print(f"🔍 DEBUG - Number of filtered results: {len(results)}") output = f"The following is the top {len(results)} retrieved ECGs and their corresponding " # Adjust the description based on retrieved_information From d92d91ed3abc03cb149ff7d0e131fa80275d1fce Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sun, 20 Jul 2025 08:03:14 +0000 Subject: [PATCH 06/23] update names of folders --- ecg_bench/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ecg_bench/main.py b/ecg_bench/main.py index 6b4fd80..5598d0f 100644 --- a/ecg_bench/main.py +++ b/ecg_bench/main.py @@ -116,7 +116,8 @@ def create_save_path(args, fm): args.retrieval_base, args.retrieved_information, args.rag_k, - args.rag_prompt_mode + args.rag_prompt_mode, + args.normalized_rag_feature ]) model_params.append(encoder_in) @@ -257,7 +258,7 @@ def run_inference(model, test_loader, tokenizer, args, train_utils): # Construct filename based on args.rag if args.rag: - filename = f"seed_{seed}_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}.json" + filename = f"seed_{seed}_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}_{args.normalized_rag_feature}.json" else: filename = f"seed_{seed}_{args.perturb}_{args.rag}.json" @@ -273,7 +274,7 @@ def run_inference(model, test_loader, tokenizer, args, train_utils): # Update statistical results filename similarly if args.rag: - stat_filename = f"statistical_results_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}.json" + stat_filename = f"statistical_results_{args.perturb}_{args.rag}_{args.retrieval_base}_{args.retrieved_information}_{args.rag_k}_{args.rag_prompt_mode}_{args.normalized_rag_feature}.json" else: stat_filename = f"statistical_results_{args.perturb}_{args.rag}.json" From 631952d69ece90926c70e824da42c868f8e307a5 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sun, 20 Jul 2025 10:31:22 +0000 Subject: [PATCH 07/23] update naming the files --- ecg_bench/main.py | 3 +-- ecg_bench/utils/data_loader_utils.py | 6 +++--- ecg_bench/utils/rag_utils.py | 16 +++++++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/ecg_bench/main.py b/ecg_bench/main.py index 5598d0f..3ad8ae0 100644 --- a/ecg_bench/main.py +++ b/ecg_bench/main.py @@ -107,8 +107,7 @@ def create_save_path(args, fm): args.image, args.augment_image, args.train_encoder, - args.rag, - 'user_input' + args.rag ] if args.rag: diff --git a/ecg_bench/utils/data_loader_utils.py b/ecg_bench/utils/data_loader_utils.py index e1a643a..ab788f6 100644 --- a/ecg_bench/utils/data_loader_utils.py +++ b/ecg_bench/utils/data_loader_utils.py @@ -157,11 +157,11 @@ def append_messages_to_conv(self, conv, altered_text, signal=None): rag_results = self.rag_db.search_similar(query_features=feature, query_signal=signal, k=self.args.rag_k, mode=self.args.retrieval_base) filtered_rag_results = self.rag_db.format_search(rag_results,self.args.retrieved_information) if self.args.retrieved_information == 'combined': - message_value = f"\nFeature Information:\n{self.rag_db.convert_features_to_structured(original_feature)}\n\n{filtered_rag_results}\n{message_value}" + message_value = f"\nFeature Information:\n{self.rag_db.convert_features_to_structured(original_feature)}\n\n{filtered_rag_results}Question:\n{message_value}" elif self.args.retrieved_information == 'report': - message_value = f"\n{filtered_rag_results}\n{message_value}" + message_value = f"\n{filtered_rag_results}Question:\n{message_value}" else: - message_value = f"\n{message_value}" + message_value = f"Question:\n{message_value}" count += 1 conv.append_message(role, message_value) return conv diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 42703a9..65c49f5 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -228,7 +228,7 @@ def create_and_save_db(self): print('Adding vectors to feature index...') self.feature_index.add(feature_array) self.feature_index.make_direct_map() - feature_path = f"./data/{self.args.base_data}/feature_{'normalized' if self.args.normalized_rag_feature else ''}.index" + feature_path = f"./data/{self.args.base_data}/feature_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" print(f'Saving feature index to {feature_path}...') faiss.write_index(self.feature_index, feature_path) print('Feature index saved successfully!') @@ -328,14 +328,14 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode='sign def format_search(self, results, retrieved_information='combined'): if retrieved_information not in ['feature', 'report', 'combined']: raise ValueError("retrieved_information must be 'feature', 'report', or 'combined'") - results = self.filter_results(results) + # results = self.filter_results(results) output = f"The following is the top {len(results)} retrieved ECGs and their corresponding " # Adjust the description based on retrieved_information if retrieved_information == 'feature': - output += "features. Utilize this information to further enhance your response.\n\n" + output += "features. Utilize this information to further enhance your response.\n\nThe lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" elif retrieved_information == 'report': - output += "diagnosis. Utilize this information to further enhance your response. The lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" + output += "diagnosis. Utilize this information to further enhance your response. \n\n" else: # combined output += "features and diagnosis. Utilize this information to further enhance your response. The lead order is I, II, III, aVL, aVR, aVF, V1, V2, V3, V4, V5, V6.\n\n" @@ -395,13 +395,11 @@ def filter_results(self, results): filtered_results = {} count = 0 for idx, res in results.items(): - # Check if more than x% of values are exactly zero or if the sum is too small feature_array = np.array(res['feature']) zero_percentage = np.sum(np.abs(feature_array) < 1e-3) / len(feature_array) - total_magnitude = np.sum(np.abs(feature_array)) - - # Filter out entries that are mostly zeros or have very low total magnitude - if zero_percentage > 0.5 or total_magnitude < 0.5: + # total_magnitude = np.sum(np.abs(feature_array)) + + if zero_percentage > 0.6: continue filtered_results[count] = res From 99f41d173221e874abd5ce51c87e32ba7a81f6fd Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sun, 20 Jul 2025 16:55:48 +0000 Subject: [PATCH 08/23] update --- ecg_bench/utils/dir_file_utils.py | 12 +----------- ecg_bench/utils/rag_utils.py | 32 +++++++++++++++++++------------ 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/ecg_bench/utils/dir_file_utils.py b/ecg_bench/utils/dir_file_utils.py index baea46f..35be8fc 100644 --- a/ecg_bench/utils/dir_file_utils.py +++ b/ecg_bench/utils/dir_file_utils.py @@ -21,18 +21,8 @@ def open_json(path: Union[str, Path]) -> dict: @staticmethod def save_json(data: dict, path: Union[str, Path]): """Save a dictionary to a JSON file.""" - def convert_numpy(obj): - if isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, dict): - return {key: convert_numpy(value) for key, value in obj.items()} - elif isinstance(obj, list): - return [convert_numpy(item) for item in obj] - else: - return obj - with open(path, 'w') as f: - json.dump(convert_numpy(data), f) + json.dump(data, f) @staticmethod def get_system_prompt(system_prompt_path: Union[str, Path]): diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 65c49f5..3049874 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -59,6 +59,8 @@ def __init__(self, args, fm): self.signal_index = faiss.read_index(self.args.load_rag_db_idx) elif self.args.retrieval_base == 'feature': self.feature_index = faiss.read_index(self.args.load_rag_db_idx) + elif self.args.retrieval_base == 'combined': + self.combined_index = faiss.read_index(self.args.load_rag_db_idx) else: raise ValueError("Please provide a valid retrieval base.") else: @@ -76,7 +78,7 @@ def __init__(self, args, fm): print('signals dim:', self.signal_dim) print('total samples:', len(self.reports)) print(f'Normalization enabled: {self.args.normalized_rag_feature}') - print('Index loaded.') + print(f'{self.args.retrieval_base} Index loaded.') def query_signal_lead_normalization(self, signal): """ @@ -133,14 +135,14 @@ def query_feature_normalization(self, rag_features): return normalized_features def _build_sub_indices(self): - ntotal = self.index.ntotal + ntotal = self.combined_index.ntotal nlist=min(100, max(1, ntotal // 30)) feature_vectors = np.zeros((ntotal, self.feature_dim), dtype=np.float32) signal_vectors = np.zeros((ntotal, self.signal_dim), dtype=np.float32) for i in range(ntotal): - full_vector = self.index.reconstruct(i) + full_vector = self.combined_index.reconstruct(i) feature_vectors[i] = full_vector[:self.feature_dim] signal_vectors[i] = full_vector[self.feature_dim:] @@ -250,15 +252,15 @@ def create_and_save_db(self): # Create and save combined index print('Creating combined index...') quantizer_combined = faiss.IndexFlatL2(combined_array.shape[1]) - self.index = faiss.IndexIVFFlat(quantizer_combined, combined_array.shape[1], nlist) + self.combined_index = faiss.IndexIVFFlat(quantizer_combined, combined_array.shape[1], nlist) print('Training combined index...') - self.index.train(combined_array) + self.combined_index.train(combined_array) print('Adding vectors to combined index...') - self.index.add(combined_array) - self.index.make_direct_map() + self.combined_index.add(combined_array) + self.combined_index.make_direct_map() combined_path = f"./data/{self.args.base_data}/combined_{'normalized' if self.args.normalized_rag_feature else 'unnormalized'}.index" print(f'Saving combined index to {combined_path}...') - faiss.write_index(self.index, combined_path) + faiss.write_index(self.combined_index, combined_path) print('Combined index saved successfully!') # Save metadata JSON @@ -301,10 +303,16 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode='sign original_indices = indices[0] else: # combined mode - self.index.nprobe = nprobe - query_combined = np.hstack([query_features*self.feature_weight, query_signal]) - query_combined = query_combined.reshape(1, -1) - distances, indices = self.index.search(query_combined, k) + self.combined_index.nprobe = nprobe + query_features = query_features.reshape(1, self.feature_dim) + query_signal = query_signal.reshape(1, -1) + query_combined = np.hstack([query_features*self.feature_weight, query_signal]).reshape(1, -1) + + print(f"Query combined shape: {query_combined.shape}") + print(f"Combined index dimension: {self.combined_index.d}") + print(f"Combined index total: {self.combined_index.ntotal}") + print(f"Query combined sample values: {query_combined[0, :5]}") + distances, indices = self.combined_index.search(query_combined, k) original_indices = indices[0] # Prepare results using reconstructed vectors from index From aa7ff798e2a97da641b1d462c021babf049cdfca Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sun, 20 Jul 2025 16:58:50 +0000 Subject: [PATCH 09/23] update --- ecg_bench/utils/rag_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 3049874..5d7dfce 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -307,11 +307,6 @@ def search_similar(self, query_features=None, query_signal=None, k=5, mode='sign query_features = query_features.reshape(1, self.feature_dim) query_signal = query_signal.reshape(1, -1) query_combined = np.hstack([query_features*self.feature_weight, query_signal]).reshape(1, -1) - - print(f"Query combined shape: {query_combined.shape}") - print(f"Combined index dimension: {self.combined_index.d}") - print(f"Combined index total: {self.combined_index.ntotal}") - print(f"Query combined sample values: {query_combined[0, :5]}") distances, indices = self.combined_index.search(query_combined, k) original_indices = indices[0] From 521228ed0bee58acb174847142ad28c3380af6fd Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Mon, 21 Jul 2025 14:31:54 +0000 Subject: [PATCH 10/23] update org --- ecg_bench/organize_results.py | 96 +++++++++++++++++++++++++------- ecg_bench/scripts/org_results.sh | 8 ++- 2 files changed, 82 insertions(+), 22 deletions(-) diff --git a/ecg_bench/organize_results.py b/ecg_bench/organize_results.py index dc95d92..29991d6 100644 --- a/ecg_bench/organize_results.py +++ b/ecg_bench/organize_results.py @@ -4,12 +4,55 @@ from ecg_bench.config import get_args def extract_file_info(file): - parts = file.split('_') - rag_used = parts[-2] == 'True' - rag_k = int(parts[-1].split('.')[0]) if rag_used else None - is_seed = 'seed' in file - seed_num = int(file.split('/')[-1].split('_')[1]) if is_seed else None - return rag_used, rag_k, is_seed, seed_num + filename = file.split('/')[-1] + parts = filename.split('_') + + if filename.startswith('seed_'): + # seed_{seed}_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json + seed_num = int(parts[1]) + perturb = parts[2] + rag_used = parts[3] == 'True' + + if rag_used: + retrieval_base = parts[4] + retrieved_information = parts[5] + rag_k = int(parts[6]) + rag_prompt_mode = parts[7]+parts[8] + normalized_rag_feature = parts[9].split('.')[0] + else: + retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None + rag_k = None + + is_seed = True + else: + # statistical_results_{perturb}_{rag}_{retrieval_base}_{retrieved_information}_{rag_k}_{rag_prompt_mode}_{normalized_rag_feature}.json + perturb = parts[2] + rag_used = parts[3] == 'True' + + if rag_used: + retrieval_base = parts[4] + retrieved_information = parts[5] + rag_k = int(parts[6]) + rag_prompt_mode = parts[7]+parts[8] + normalized_rag_feature = parts[9].split('.')[0] + else: + retrieval_base = retrieved_information = rag_prompt_mode = normalized_rag_feature = None + rag_k = None + + is_seed = False + seed_num = None + + return { + 'rag_used': rag_used, + 'rag_k': rag_k, + 'is_seed': is_seed, + 'seed_num': seed_num, + 'perturb': perturb, + 'retrieval_base': retrieval_base, + 'retrieved_information': retrieved_information, + 'rag_prompt_mode': rag_prompt_mode, + 'normalized_rag_feature': normalized_rag_feature + } def process_seed_data(data): averages = data['averages'] @@ -28,31 +71,39 @@ def collect_results(json_files): statistical_no_rag = {} individual_seeds_rag = defaultdict(dict) statistical_rag = {} + config_info_no_rag = None + config_info_rag = {} for file in json_files: - rag_used, rag_k, is_seed, seed_num = extract_file_info(file) + info = extract_file_info(file) with open(file, 'r') as f: data = json.load(f) - if is_seed: + if info['is_seed']: metrics = process_seed_data(data) - if rag_used: - individual_seeds_rag[rag_k][seed_num] = metrics + if info['rag_used']: + individual_seeds_rag[info['rag_k']][info['seed_num']] = metrics + config_info_rag[info['rag_k']] = info else: - individual_seeds_no_rag[seed_num] = metrics + individual_seeds_no_rag[info['seed_num']] = metrics + config_info_no_rag = info else: - if rag_used: - statistical_rag[rag_k] = data + if info['rag_used']: + statistical_rag[info['rag_k']] = data + config_info_rag[info['rag_k']] = info else: statistical_no_rag = data + config_info_no_rag = info return (individual_seeds_no_rag, statistical_no_rag, - individual_seeds_rag, statistical_rag) + individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag) -def print_seed_results(title, seed_dict): +def print_seed_results(title, seed_dict, config_info=None): if not seed_dict: return print(title) + if config_info: + print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}") for seed in sorted(seed_dict.keys()): print(f" Seed {seed}:") for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']: @@ -60,10 +111,12 @@ def print_seed_results(title, seed_dict): print(f" {metric}: {value:.2f}") print('--------------------------------') -def print_statistical_results(title, stats_dict): +def print_statistical_results(title, stats_dict, config_info=None): if not stats_dict: return print(title) + if config_info: + print(f" Config: perturb={config_info['perturb']}, retrieval_base={config_info['retrieval_base']}, retrieved_info={config_info['retrieved_information']}, prompt_mode={config_info['rag_prompt_mode']}, normalized={config_info['normalized_rag_feature']}") for metric in ['BLEU', 'METEOR', 'ROUGE', 'BERTSCORE', 'ACC']: value = (stats_dict['ROUGE']['rouge-l'] if metric == 'ROUGE' else stats_dict['BERTSCORE']['hf-f1'] if metric == 'BERTSCORE' else @@ -89,14 +142,15 @@ def main(): return (individual_seeds_no_rag, statistical_no_rag, - individual_seeds_rag, statistical_rag) = collect_results(json_files) + individual_seeds_rag, statistical_rag, config_info_no_rag, config_info_rag) = collect_results(json_files) - print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag) - print_statistical_results("Statistical Results without RAG:", statistical_no_rag) + print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag, config_info_no_rag) + print_statistical_results("Statistical Results without RAG:", statistical_no_rag, config_info_no_rag) for k in sorted(individual_seeds_rag.keys()): - print_seed_results(f"Individual Seed Results with RAG k={k}:", individual_seeds_rag[k]) - print_statistical_results(f"Statistical Results with RAG k={k}:", statistical_rag.get(k, {})) + config_info = config_info_rag.get(k) + print_seed_results(f"Individual Seed Results with RAG k={k}:", individual_seeds_rag[k], config_info) + print_statistical_results(f"Statistical Results with RAG k={k}:", statistical_rag.get(k, {}), config_info) print('================================================') diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index d68d180..4857388 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -2,8 +2,14 @@ # data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250") data=("ecg_instruct_45k_mapped_1250") +# retrieval_base="feature" +# retrieved_information="combined" +# rag_k=1 +# rag_prompt_mode="system_prompt" +# normalized_rag_features=True + checkpoints=( - "llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_report_5_False" + 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_combined_1_system_prompt_True_False' ) for d in "${data[@]}"; do From 6106799bd75e7f87b020342ab1716cec4b076505 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Mon, 21 Jul 2025 15:16:24 +0000 Subject: [PATCH 11/23] update data_loader by adding self.args.rag --- ecg_bench/scripts/org_results.sh | 5 ++++- ecg_bench/utils/data_loader_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index 4857388..146bd08 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -9,7 +9,10 @@ data=("ecg_instruct_45k_mapped_1250") # normalized_rag_features=True checkpoints=( - 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_combined_1_system_prompt_True_False' + 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_combined_1_system_prompt_True_False' + 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_combined_1_system_prompt_None_False' + 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_signal_combined_1_system_prompt_True_False' + 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_signal_combined_1_system_prompt_None_False' ) for d in "${data[@]}"; do diff --git a/ecg_bench/utils/data_loader_utils.py b/ecg_bench/utils/data_loader_utils.py index ab788f6..3428ca3 100644 --- a/ecg_bench/utils/data_loader_utils.py +++ b/ecg_bench/utils/data_loader_utils.py @@ -107,7 +107,7 @@ def setup_conversation_template(self, signal = None): elif 'gemma' in self.args.model: conv = get_conv_template('gemma') feature=None - if self.args.retrieval_base in ['feature', 'combined']: + if self.args.rag and self.args.retrieval_base in ['feature', 'combined']: original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) feature=original_feature if self.args.normalized_rag_feature: @@ -146,7 +146,7 @@ def append_messages_to_conv(self, conv, altered_text, signal=None): message_value = message_value.replace('', '') message_value = message_value.replace('', '') message_value = message_value.replace('image', 'signal').replace('Image', 'Signal') - if self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']: + if self.args.rag and (self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']): original_feature=self.rag_db.feature_extractor.extract_rag_features(signal) feature=original_feature if self.args.normalized_rag_feature: From 741114181460b8a1c4b46c15a5840447f0fea858 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Tue, 22 Jul 2025 14:57:21 +0000 Subject: [PATCH 12/23] update --- ecg_bench/scripts/org_results.sh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index 146bd08..a87e390 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -9,10 +9,7 @@ data=("ecg_instruct_45k_mapped_1250") # normalized_rag_features=True checkpoints=( - 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_combined_1_system_prompt_True_False' - 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_combined_1_system_prompt_None_False' - 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_signal_combined_1_system_prompt_True_False' - 'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_signal_combined_1_system_prompt_None_False' +'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' ) for d in "${data[@]}"; do From 4597371afe017f3432a0815af871d75355d81cf9 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Tue, 22 Jul 2025 15:12:13 +0000 Subject: [PATCH 13/23] update org.py --- ecg_bench/organize_results.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/ecg_bench/organize_results.py b/ecg_bench/organize_results.py index 29991d6..922cfd0 100644 --- a/ecg_bench/organize_results.py +++ b/ecg_bench/organize_results.py @@ -82,15 +82,30 @@ def collect_results(json_files): if info['is_seed']: metrics = process_seed_data(data) if info['rag_used']: - individual_seeds_rag[info['rag_k']][info['seed_num']] = metrics - config_info_rag[info['rag_k']] = info + rag_key = ( + info['rag_k'], + info['retrieval_base'], + info['retrieved_information'], + info['rag_prompt_mode'], + info['normalized_rag_feature'] + ) + individual_seeds_rag[rag_key][info['seed_num']] = metrics + config_info_rag[rag_key] = info else: individual_seeds_no_rag[info['seed_num']] = metrics config_info_no_rag = info else: if info['rag_used']: - statistical_rag[info['rag_k']] = data - config_info_rag[info['rag_k']] = info + rag_key = ( + info['rag_k'], + info['retrieval_base'], + info['retrieved_information'], + info['rag_prompt_mode'], + info['normalized_rag_feature'] + ) + statistical_rag[rag_key] = data + config_info_rag[rag_key] = info + else: statistical_no_rag = data config_info_no_rag = info @@ -147,10 +162,11 @@ def main(): print_seed_results("Individual Seed Results without RAG:", individual_seeds_no_rag, config_info_no_rag) print_statistical_results("Statistical Results without RAG:", statistical_no_rag, config_info_no_rag) - for k in sorted(individual_seeds_rag.keys()): - config_info = config_info_rag.get(k) - print_seed_results(f"Individual Seed Results with RAG k={k}:", individual_seeds_rag[k], config_info) - print_statistical_results(f"Statistical Results with RAG k={k}:", statistical_rag.get(k, {}), config_info) + for rag_key in sorted(individual_seeds_rag.keys()): + config_info = config_info_rag.get(rag_key) + print_seed_results(f"Individual Seed Results with RAG config={rag_key}:", individual_seeds_rag[rag_key], config_info) + print_statistical_results(f"Statistical Results with RAG config={rag_key}:", statistical_rag.get(rag_key, {}), config_info) + print('================================================') From 8fe8d83b29e794e07cf58d79652a2ad3440f4462 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Thu, 24 Jul 2025 00:55:57 +0000 Subject: [PATCH 14/23] update --- ecg_bench/scripts/org_results.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index a87e390..7e56b35 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -8,9 +8,8 @@ data=("ecg_instruct_45k_mapped_1250") # rag_prompt_mode="system_prompt" # normalized_rag_features=True -checkpoints=( -'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' -) +checkpoints='llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_report_1_system_prompt_None_False' + for d in "${data[@]}"; do if [ "$d" = "ecg_instruct_pulse_mapped_1250" ]; then From 18219f18804927759b1b98bd067c3e04477048cc Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Wed, 30 Jul 2025 12:50:47 +0000 Subject: [PATCH 15/23] rag_k & rag prompt table filled --- ecg_bench/scripts/org_results.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index 7e56b35..36b9dbd 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -8,7 +8,7 @@ data=("ecg_instruct_45k_mapped_1250") # rag_prompt_mode="system_prompt" # normalized_rag_features=True -checkpoints='llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_report_1_system_prompt_None_False' +checkpoints='llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_report_10_system_prompt_None_False' for d in "${data[@]}"; do From e0dcb87fc78374a891f409e638f17899177f9c2d Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Mon, 25 Aug 2025 20:26:59 +0000 Subject: [PATCH 16/23] update --- ecg_bench/scripts/org_results.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index 36b9dbd..d5da1ee 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -2,14 +2,17 @@ # data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250") data=("ecg_instruct_45k_mapped_1250") +# data=("ecg-qa_mimic-iv-ecg_mapped_1250") +# data=("ecg-qa_ptbxl_mapped_1250") +# data=("pretrain_mimic_mapped_1250") # retrieval_base="feature" # retrieved_information="combined" # rag_k=1 # rag_prompt_mode="system_prompt" # normalized_rag_features=True -checkpoints='llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_report_10_system_prompt_None_False' - +# checkpoints='llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' +checkpoints='qwen2.5-3b_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' for d in "${data[@]}"; do if [ "$d" = "ecg_instruct_pulse_mapped_1250" ]; then From 3d1de8920274cc7a764ef66e5caa3c5e018a4c20 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Tue, 2 Sep 2025 02:31:12 +0000 Subject: [PATCH 17/23] update encoder methods and use original data loading logic --- ecg_bench/main.py | 55 ++++++++++++++++++++----------- ecg_bench/scripts/train_1st.sh | 4 +-- ecg_bench/utils/training_utils.py | 11 +++++++ transformers | 2 +- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/ecg_bench/main.py b/ecg_bench/main.py index 27fe841..8f6e7cb 100644 --- a/ecg_bench/main.py +++ b/ecg_bench/main.py @@ -12,6 +12,7 @@ import torch.multiprocessing as mp from datasets import load_dataset from huggingface_hub import HfFolder, login +from torch.optim import Adam from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -295,25 +296,41 @@ def main(rank, world_size): print(f"Total number of parameters: {train_utils.count_parameters(model)}") - if args.train: - optimizer_class = train_utils.get_optimizer_class(args.optimizer) - optimizer = ScheduledOptim( - optimizer_class(filter(lambda x: x.requires_grad, model.parameters()), - betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay), - model_object["model_hidden_size"], args) - train_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_train").with_transform(fm.decode_batch) - print(f"Length of Train Data: {len(train_data)}") - elif args.inference: - test_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_test").with_transform(fm.decode_batch) - print(f"Length of Test Data: {len(test_data)}") - - if args.train == "first": - data = train_data.select(range(800000)) - elif args.train in ["second", "end2end"]: - data = train_data.select(range(400000)) - elif args.inference in ["second", "end2end"]: - data = test_data.select(range(20000)) - print("Length of Dataset Considered:", len(data)) + # if args.train: + # optimizer_class = train_utils.get_optimizer_class(args.optimizer) + # optimizer = ScheduledOptim( + # optimizer_class(filter(lambda x: x.requires_grad, model.parameters()), + # betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay), + # model_object["model_hidden_size"], args) + # train_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_train").with_transform(fm.decode_batch) + # print(f"Length of Train Data: {len(train_data)}") + # elif args.inference: + # test_data = load_dataset(f"willxxy/{args.data}", split=f"fold{args.fold}_test").with_transform(fm.decode_batch) + # print(f"Length of Test Data: {len(test_data)}") + + # if args.train == "first": + # data = train_data.select(range(800000)) + # elif args.train in ["second", "end2end"]: + # data = train_data.select(range(400000)) + # elif args.inference in ["second", "end2end"]: + # data = test_data.select(range(20000)) + # print("Length of Dataset Considered:", len(data)) + + optimizer = ScheduledOptim( + Adam(filter(lambda x: x.requires_grad, model.parameters()), + betas=(args.beta1, args.beta2), eps=args.eps, lr=args.lr, weight_decay=args.weight_decay), + model_object['model_hidden_size'], args) + + json_data_file = fm.open_json(f'./data/{args.data}.json') + train_data, test_data = train_utils.split_dataset(json_data_file) + + if args.train == 'first': + data = train_data[:800000] + elif args.train in ['second', 'end2end']: + data = train_data[:400000] + elif args.inference in ['second', 'end2end']: + data = test_data[:20000] + print('Length of Dataset:', len(data)) if args.train == "first": dataset = FirstStageECGDataset( diff --git a/ecg_bench/scripts/train_1st.sh b/ecg_bench/scripts/train_1st.sh index 665399e..26bfce2 100644 --- a/ecg_bench/scripts/train_1st.sh +++ b/ecg_bench/scripts/train_1st.sh @@ -2,13 +2,13 @@ # models=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") models=("merl") -data=("ecg-qa-mimic-iv-ecg-250-1250") +# data=("ecg-qa-mimic-iv-ecg-250-1250") # data=("ecg_instruct_45k_mapped_1250") ### MULTI GPU for model in "${models[@]}"; do python main.py \ - --data=$data \ + --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ --model=$model \ --device=cuda:0 \ --train=first \ diff --git a/ecg_bench/utils/training_utils.py b/ecg_bench/utils/training_utils.py index 65121ef..42ec931 100644 --- a/ecg_bench/utils/training_utils.py +++ b/ecg_bench/utils/training_utils.py @@ -34,6 +34,17 @@ def __init__(self, args, fm, viz, device, ecg_tokenizer_utils=None): self.args, self.fm, self.viz, self.device = args, fm, viz, device self.ecg_tokenizer_utils = ecg_tokenizer_utils self.cache_dir = "../.huggingface" + + def split_dataset(self, data, train_ratio=0.7): + data = np.array(data) + n_samples = len(data) + indices = np.random.permutation(n_samples) + n_train = int(n_samples * train_ratio) + train_indices = indices[:n_train] + test_indices = indices[n_train:] + train_data = [data[i] for i in train_indices] + test_data = [data[i] for i in test_indices] + return train_data, test_data def save_config(self): args_dict = {k: v for k, v in vars(self.args).items() if not k.startswith("_")} diff --git a/transformers b/transformers index 51f94ea..241c04d 160000 --- a/transformers +++ b/transformers @@ -1 +1 @@ -Subproject commit 51f94ea06d19a6308c61bbb4dc97c40aabd12bad +Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc From b7ce9d975e5029a1ff05625b9448f35168b8b9ad Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Tue, 2 Sep 2025 23:26:02 +0000 Subject: [PATCH 18/23] update data loader with old mapping names --- ecg_bench/scripts/train_1st.sh | 4 +-- ecg_bench/scripts/train_2nd.sh | 45 +++++++++++++++------------- ecg_bench/utils/data_loader_utils.py | 9 ++++-- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/ecg_bench/scripts/train_1st.sh b/ecg_bench/scripts/train_1st.sh index 26bfce2..a87f15b 100644 --- a/ecg_bench/scripts/train_1st.sh +++ b/ecg_bench/scripts/train_1st.sh @@ -1,7 +1,7 @@ #!/bin/bash # models=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") -models=("merl") +models=("siglip") # data=("ecg-qa-mimic-iv-ecg-250-1250") # data=("ecg_instruct_45k_mapped_1250") @@ -10,7 +10,7 @@ for model in "${models[@]}"; do python main.py \ --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ --model=$model \ - --device=cuda:0 \ + --device=cuda:4 \ --train=first \ --batch_size=64 \ --seg_len=1250 \ diff --git a/ecg_bench/scripts/train_2nd.sh b/ecg_bench/scripts/train_2nd.sh index cfa31dd..eefd6df 100644 --- a/ecg_bench/scripts/train_2nd.sh +++ b/ecg_bench/scripts/train_2nd.sh @@ -1,9 +1,12 @@ #!/usr/bin/env bash # ------------------- CONFIGURABLE LISTS ------------------- -encoders=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") -encoders_checkpoints=("stmem_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "merl_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "mlae_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "mtae_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None") -llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct") -datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here +# encoders=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") +encoders=("merl") +encoders_checkpoints=("merl_adam_64_50_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_1_None_None_False") +# llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct") +llms=("llama-3.2-1b-instruct") +# datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here +datasets=("ecg_instruct_45k_mapped_1250") # ---------------------------------------------------------- for data in "${datasets[@]}"; do @@ -26,7 +29,7 @@ for data in "${datasets[@]}"; do python main.py \ --data="$data" \ --model="${encoder}_${llm}" \ - --device=cuda:7 \ + --device=cuda:3 \ --train=second \ --batch_size=2 \ --seg_len=1250 \ @@ -37,25 +40,25 @@ for data in "${datasets[@]}"; do --attn_implementation=flash_attention_2 \ --system_prompt=./data/system_prompt_e2e.txt \ $([ -n "$checkpoint_path" ] && echo "--encoder_checkpoint=$checkpoint_path") \ - --dev + --log done done done -models=("vit" "clip" "siglip" ) +# models=("merl") -for model in "${models[@]}"; do - python main.py \ - --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ - --model=$model \ - --device=cuda:6 \ - --train=first \ - --batch_size=8 \ - --seg_len=1250 \ - --epochs=2 \ - --instance_normalize \ - --attn_implementation=flash_attention_2 \ - --image \ - --log -done \ No newline at end of file +# for model in "${models[@]}"; do +# python main.py \ +# --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ +# --model=$model \ +# --device=cuda:6 \ +# --train=first \ +# --batch_size=8 \ +# --seg_len=1250 \ +# --epochs=2 \ +# --instance_normalize \ +# --attn_implementation=flash_attention_2 \ +# --image \ +# --log +# done \ No newline at end of file diff --git a/ecg_bench/utils/data_loader_utils.py b/ecg_bench/utils/data_loader_utils.py index 6da2978..e3a526a 100644 --- a/ecg_bench/utils/data_loader_utils.py +++ b/ecg_bench/utils/data_loader_utils.py @@ -81,9 +81,11 @@ def create_position_ids(self, padded_sequence): return position_ids def get_qa(self, altered_text): - if self.args.data == f"pretrain-mimic-{self.args.target_sf}-{self.args.seg_len}": + # if self.args.data == f"pretrain-mimic-{self.args.target_sf}-{self.args.seg_len}": + if self.args.data == f"pretrain_mimic_mapped_{self.args.seg_len}": question, answer = altered_text[0]["value"].replace("\n", "").replace("", ""), altered_text[1]["value"] - elif self.args.data in [f"ecg-qa-mimic-iv-ecg-{self.args.target_sf}-{self.args.seg_len}", f"ecg-qa-ptbxl-{self.args.target_sf}-{self.args.seg_len}"]: + # elif self.args.data in [f"ecg-qa-mimic-iv-ecg-{self.args.target_sf}-{self.args.seg_len}", f"ecg-qa-ptbxl-{self.args.target_sf}-{self.args.seg_len}"]: + elif self.args.data in [f"ecg-qa_mimic-iv-ecg_mapped_{self.args.seg_len}", f"ecg-qa_ptbxl_mapped_{self.args.seg_len}"]: question_type, question, answer = altered_text[0], altered_text[1], altered_text[2] answer = " ".join(answer) if isinstance(answer, list) else answer return question, answer @@ -128,7 +130,8 @@ def setup_conversation_template(self, signal = None): return conv def process_altered_text(self, altered_text): - if self.args.data not in [f"ecg-instruct-45k-{self.args.target_sf}-{self.args.seg_len}", + if self.args.data not in [#f"ecg-instruct-45k-{self.args.target_sf}-{self.args.seg_len}", + f"ecg_instruct_45k_mapped_{self.args.seg_len}", f"ecg-instruct-pulse-{self.args.target_sf}-{self.args.seg_len}", f"ecg-bench-pulse-{self.args.target_sf}-{self.args.seg_len}"]: question, answer = self.get_qa(altered_text) From 98d8e4ff363932f74a48beaeeb0d77f3a8f4e946 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Fri, 5 Sep 2025 04:36:08 +0000 Subject: [PATCH 19/23] modify scripts for merl & siglip --- ecg_bench/scripts/train_1st.sh | 2 +- ecg_bench/scripts/train_2nd.sh | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ecg_bench/scripts/train_1st.sh b/ecg_bench/scripts/train_1st.sh index a87f15b..2e91b7c 100644 --- a/ecg_bench/scripts/train_1st.sh +++ b/ecg_bench/scripts/train_1st.sh @@ -1,7 +1,7 @@ #!/bin/bash # models=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit") -models=("siglip") +models=("merl") # data=("ecg-qa-mimic-iv-ecg-250-1250") # data=("ecg_instruct_45k_mapped_1250") diff --git a/ecg_bench/scripts/train_2nd.sh b/ecg_bench/scripts/train_2nd.sh index eefd6df..18d7687 100644 --- a/ecg_bench/scripts/train_2nd.sh +++ b/ecg_bench/scripts/train_2nd.sh @@ -6,7 +6,9 @@ encoders_checkpoints=("merl_adam_64_50_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_ # llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct") llms=("llama-3.2-1b-instruct") # datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here -datasets=("ecg_instruct_45k_mapped_1250") +# datasets=("ecg_instruct_45k_mapped_1250") +datasets=("ecg-qa_mimic-iv-ecg_mapped_1250") +# datasets=("ecg-qa_mimic-iv-ecg_mapped_1250" "ecg-qa_ptbxl_mapped_1250") # ---------------------------------------------------------- for data in "${datasets[@]}"; do From a14c2b66b7fa4b6630235ff338bc7feb9b7b00a3 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Sat, 13 Sep 2025 03:22:12 +0000 Subject: [PATCH 20/23] update --- ecg_bench/runners/inference.py | 2 +- ecg_bench/scripts/org_results.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ecg_bench/runners/inference.py b/ecg_bench/runners/inference.py index 60db82e..01987db 100644 --- a/ecg_bench/runners/inference.py +++ b/ecg_bench/runners/inference.py @@ -27,7 +27,7 @@ def tester_chat(model, dataloader, tokenizer, args, train_utils): signal_id_index = batch["signal_id_index"].item() offset = 0 for conv_turn in assistant_ranges: - print("conv_turn", conv_turn) + # print("conv_turn", conv_turn) start = conv_turn["start"] + 4 + offset end = conv_turn["end"] + 1 + offset curr_input_ids = chat_input_ids[:, :start] diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index 3d866f6..e14e0ec 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -11,7 +11,7 @@ data=("ecg-qa_ptbxl_mapped_1250") # rag_prompt_mode="system_prompt" # normalized_rag_features=True -checkpoints='llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_report_1_system_prompt_None_False' +checkpoints='siglip_llama-3.2-1b-instruct_adam_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_True_None_None_True_1_None_None_feature_report_1_system_prompt_None_False' # checkpoints='qwen2.5-3b_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' for d in "${data[@]}"; do @@ -23,6 +23,6 @@ for d in "${data[@]}"; do for ckpt in "${checkpoints[@]}"; do python organize_results.py \ - --checkpoint=./runs/$d/0/$ckpt/ecgqa-mimic + --checkpoint=./runs/$d/0/$ckpt/ done done \ No newline at end of file From 965da0cbd4a604294b989ee6be7b87edfd77a97a Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Mon, 15 Sep 2025 13:00:00 +0000 Subject: [PATCH 21/23] update before noise experiment --- ecg_bench/scripts/org_results.sh | 4 ++-- ecg_bench/scripts/train_2nd.sh | 17 ----------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index e14e0ec..a817cf8 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -2,8 +2,8 @@ # data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250") # data=("ecg_instruct_45k_mapped_1250") -# data=("ecg-qa_mimic-iv-ecg_mapped_1250") -data=("ecg-qa_ptbxl_mapped_1250") +data=("ecg-qa_mimic-iv-ecg_mapped_1250") +# data=("ecg-qa_ptbxl_mapped_1250") # data=("pretrain_mimic_mapped_1250") # retrieval_base="feature" # retrieved_information="combined" diff --git a/ecg_bench/scripts/train_2nd.sh b/ecg_bench/scripts/train_2nd.sh index 18d7687..a757bba 100644 --- a/ecg_bench/scripts/train_2nd.sh +++ b/ecg_bench/scripts/train_2nd.sh @@ -47,20 +47,3 @@ for data in "${datasets[@]}"; do done done - -# models=("merl") - -# for model in "${models[@]}"; do -# python main.py \ -# --data=ecg-qa_mimic-iv-ecg_mapped_1250 \ -# --model=$model \ -# --device=cuda:6 \ -# --train=first \ -# --batch_size=8 \ -# --seg_len=1250 \ -# --epochs=2 \ -# --instance_normalize \ -# --attn_implementation=flash_attention_2 \ -# --image \ -# --log -# done \ No newline at end of file From 5b7a2ed72369727408759bb2b5bb6789aca34195 Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Tue, 16 Sep 2025 04:15:50 +0000 Subject: [PATCH 22/23] update latest --- ecg_bench/scripts/org_results.sh | 6 +++--- ecg_bench/utils/rag_utils.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ecg_bench/scripts/org_results.sh b/ecg_bench/scripts/org_results.sh index a817cf8..b2811f1 100644 --- a/ecg_bench/scripts/org_results.sh +++ b/ecg_bench/scripts/org_results.sh @@ -1,8 +1,8 @@ #!/bin/bash # data=("ecg-qa_ptbxl_mapped_1250" "pretrain_mimic_mapped_1250" "ecg_instruct_45k_mapped_1250" "ecg_instruct_pulse_mapped_1250" "ecg-qa_mimic-iv-ecg_mapped_1250") -# data=("ecg_instruct_45k_mapped_1250") -data=("ecg-qa_mimic-iv-ecg_mapped_1250") +data=("ecg_instruct_45k_mapped_1250") +# data=("ecg-qa_mimic-iv-ecg_mapped_1250") # data=("ecg-qa_ptbxl_mapped_1250") # data=("pretrain_mimic_mapped_1250") # retrieval_base="feature" @@ -11,7 +11,7 @@ data=("ecg-qa_mimic-iv-ecg_mapped_1250") # rag_prompt_mode="system_prompt" # normalized_rag_features=True -checkpoints='siglip_llama-3.2-1b-instruct_adam_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_True_None_None_True_1_None_None_feature_report_1_system_prompt_None_False' +checkpoints='llama-3.2-1b-instruct_adam_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_1_None_None_feature_report_1_system_prompt_None_False' # checkpoints='qwen2.5-3b_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_False' for d in "${data[@]}"; do diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 1673b7a..9c20856 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -369,7 +369,8 @@ def format_search(self, results, retrieved_information='combined'): # Include diagnosis information based on retrieved_information if retrieved_information in ["report", "combined"]: output += "Diagnosis Information:\n" - output += f"{res['report']}\n\n" + output +="--------------------------" + # output += f"{res['report']}\n\n" return output def convert_features_to_structured(self, feature_array): From 10c9b5fd026545a294816d2d69d2c79bb0be8c5d Mon Sep 17 00:00:00 2001 From: nbbb24 Date: Wed, 24 Sep 2025 20:44:24 +0000 Subject: [PATCH 23/23] clean up rag --- README.md | 5 +++-- ecg_bench/utils/rag_utils.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0840ff4..35e38f2 100644 --- a/README.md +++ b/README.md @@ -841,8 +841,9 @@ This is a list of TODOs for the repository. If you are interested in contributin - [x] Add option for data mixing. - [x] Adjust feature selection for RAG. - [x] Apply normalization to RAG database. -- [x] Cross Dataset Ablation. -- [x] Add "Only feature" retrieval option for RAG +- [ ] Cross Dataset Ablation. +- [ ] Apply RAG with encoder methods +- [ ] Add "Only feature" retrieval option for RAG - [x] For preprocessing, stratify based on patient, such that no overlapping patients between train and test. - [x] Add official splits for benchmarking. - [x] Upload to huggingface datasets and use huggingface datasets data loading in main. diff --git a/ecg_bench/utils/rag_utils.py b/ecg_bench/utils/rag_utils.py index 9c20856..59c1036 100644 --- a/ecg_bench/utils/rag_utils.py +++ b/ecg_bench/utils/rag_utils.py @@ -369,8 +369,8 @@ def format_search(self, results, retrieved_information='combined'): # Include diagnosis information based on retrieved_information if retrieved_information in ["report", "combined"]: output += "Diagnosis Information:\n" - output +="--------------------------" - # output += f"{res['report']}\n\n" + # output +="--------------------------" + output += f"{res['report']}\n\n" return output def convert_features_to_structured(self, feature_array):