Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions SCLUDA_Houston.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,13 @@ def clean_sampling_epoch(labels, probabilities, output):
sorted_index_method='normalized_margin',
)

true_labels_idx = [] # 置信样本的索引
all_labels_idx = [] # 所有标签的索引
# 亲测高效,老的版本不知道崩了多少次^_^
all_labels_idx = np.arange(len(labels))
error_set = set(ordered_label_errors)
true_labels_idx = [i for i in all_indices if i not in error_set]

print('len of all_lables', len(labels))
print('len of errors_lables', len(ordered_label_errors))
for i in range(len(labels)):
all_labels_idx.append(i)

if len(ordered_label_errors) == 0:
true_labels_idx = all_labels_idx
else:

for j in range(len(ordered_label_errors)):
all_labels_idx.remove(ordered_label_errors[j])
true_labels_idx = all_labels_idx
print('len of true_lables', len(true_labels_idx)) # 得到正确样本的索引
orig_class_count = np.bincount(labels, minlength=CLASS_NUM)
train_bool_mask = ~label_errors_bool
Expand Down