diff --git a/SCLUDA_Houston.py b/SCLUDA_Houston.py index 008eac1..ec91a3b 100644 --- a/SCLUDA_Houston.py +++ b/SCLUDA_Houston.py @@ -67,20 +67,13 @@ def clean_sampling_epoch(labels, probabilities, output): sorted_index_method='normalized_margin', ) - true_labels_idx = [] # 置信样本的索引 - all_labels_idx = [] # 所有标签的索引 + # 亲测高效,老的版本不知道崩了多少次^_^ + all_labels_idx = np.arange(len(labels)) + error_set = set(ordered_label_errors) + true_labels_idx = [i for i in all_indices if i not in error_set] + print('len of all_lables', len(labels)) print('len of errors_lables', len(ordered_label_errors)) - for i in range(len(labels)): - all_labels_idx.append(i) - - if len(ordered_label_errors) == 0: - true_labels_idx = all_labels_idx - else: - - for j in range(len(ordered_label_errors)): - all_labels_idx.remove(ordered_label_errors[j]) - true_labels_idx = all_labels_idx print('len of true_lables', len(true_labels_idx)) # 得到正确样本的索引 orig_class_count = np.bincount(labels, minlength=CLASS_NUM) train_bool_mask = ~label_errors_bool