From 4a991f960f53ff0d0f7eb07b016e161de195b550 Mon Sep 17 00:00:00 2001 From: bing199 <71914138+bing199@users.noreply.github.com> Date: Sun, 4 Jan 2026 16:47:49 +0800 Subject: [PATCH] Update SCLUDA_SH2HZ.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用set,以及pythonic --- SCLUDA_SH2HZ.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/SCLUDA_SH2HZ.py b/SCLUDA_SH2HZ.py index f8cd229..75174f4 100644 --- a/SCLUDA_SH2HZ.py +++ b/SCLUDA_SH2HZ.py @@ -67,20 +67,13 @@ def clean_sampling_epoch(labels, probabilities, output): sorted_index_method='normalized_margin', ) - true_labels_idx = [] # 置信样本的索引 - all_labels_idx = [] # 所有标签的索引 + # 亲测高效,老的版本不知道崩了多少次^_^ + all_labels_idx = np.arange(len(labels)) + error_set = set(ordered_label_errors) + true_labels_idx = [i for i in all_indices if i not in error_set] + print('len of all_lables', len(labels)) print('len of errors_lables', len(ordered_label_errors)) - for i in range(len(labels)): - all_labels_idx.append(i) - - if len(ordered_label_errors) == 0: - true_labels_idx = all_labels_idx - else: - - for j in range(len(ordered_label_errors)): - all_labels_idx.remove(ordered_label_errors[j]) - true_labels_idx = all_labels_idx print('len of true_lables', len(true_labels_idx)) # 得到正确样本的索引 # weights orig_class_count = np.bincount(labels, minlength=CLASS_NUM)