diff --git a/SCLUDA_UP2PC.py b/SCLUDA_UP2PC.py index 491b001..5ebbead 100644 --- a/SCLUDA_UP2PC.py +++ b/SCLUDA_UP2PC.py @@ -74,18 +74,13 @@ def clean_sampling_epoch(labels, probabilities, output): sorted_index_method='normalized_margin', ) - true_labels_idx = [] # 置信样本的索引 - all_labels_idx = [] # 所有标签的索引 + # 亲测高效,老的版本不知道崩了多少次^_^ + all_labels_idx = np.arange(len(labels)) + error_set = set(ordered_label_errors) + true_labels_idx = [i for i in all_indices if i not in error_set] + print('len of all_lables', len(labels)) print('len of errors_lables', len(ordered_label_errors)) - for i in range(len(labels)): - all_labels_idx.append(i) - if len(ordered_label_errors) == 0: - true_labels_idx = all_labels_idx - else: - for j in range(len(ordered_label_errors)): - all_labels_idx.remove(ordered_label_errors[j]) - true_labels_idx = all_labels_idx print('len of true_lables', len(true_labels_idx)) # weights orig_class_count = np.bincount(labels, minlength=CLASS_NUM)