From 5193ca6d3b67aa129673f07256178d82b8db82f5 Mon Sep 17 00:00:00 2001 From: bing199 <71914138+bing199@users.noreply.github.com> Date: Sun, 4 Jan 2026 17:02:19 +0800 Subject: [PATCH] Update SCLUDA_Houston.py 1 --- SCLUDA_Houston.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/SCLUDA_Houston.py b/SCLUDA_Houston.py index 008eac1..ec91a3b 100644 --- a/SCLUDA_Houston.py +++ b/SCLUDA_Houston.py @@ -67,20 +67,13 @@ def clean_sampling_epoch(labels, probabilities, output): sorted_index_method='normalized_margin', ) - true_labels_idx = [] # 置信样本的索引 - all_labels_idx = [] # 所有标签的索引 + # 亲测高效,老的版本不知道崩了多少次^_^ + all_labels_idx = np.arange(len(labels)) + error_set = set(ordered_label_errors) + true_labels_idx = [i for i in all_indices if i not in error_set] + print('len of all_lables', len(labels)) print('len of errors_lables', len(ordered_label_errors)) - for i in range(len(labels)): - all_labels_idx.append(i) - - if len(ordered_label_errors) == 0: - true_labels_idx = all_labels_idx - else: - - for j in range(len(ordered_label_errors)): - all_labels_idx.remove(ordered_label_errors[j]) - true_labels_idx = all_labels_idx print('len of true_lables', len(true_labels_idx)) # 得到正确样本的索引 orig_class_count = np.bincount(labels, minlength=CLASS_NUM) train_bool_mask = ~label_errors_bool