Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions include/graph/preprocessing.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LabelIndex {
virtual void Insert(NodeId node, std::vector<Label> label) = 0;
virtual void Load(const std::string& path) = 0;
virtual double PrecisionScore(torch::Tensor y_pred, torch::Tensor y_true) const = 0;
virtual double F1Score(torch::Tensor y_pred, torch::Tensor y_true, const std::string& option) const = 0;
};

class MultiLabelBinarizer: public LabelIndex {
Expand Down Expand Up @@ -77,6 +78,10 @@ class MultiLabelBinarizer: public LabelIndex {

}

virtual double F1Score(torch::Tensor y_pred, torch::Tensor y_true, const std::string& option) {

}

protected:
std::vector<std::vector<Label>> labels_;
IndexLookupTable index_;
Expand Down Expand Up @@ -125,6 +130,59 @@ class MultiLabelIndex : public MultiLabelBinarizer {
return n_right * 1.0 / n_total;
}

virtual double F1Score(torch::Tensor y_pred, torch::Tensor y_true, const std::string& option) const override {
int n_class = y_pred.size(1);
auto y = y_pred.argmax(1).to(torch::TensorOptions().dtype(torch::kInt64));
int64_t true_negatives = 0;
int64_t true_positives = 0;
int64_t false_negatives = 0;
int64_t false_positives = 0;
double precision = 0;
double recall = 0;
double average_f1score = 0;
for (int64_t i = 0; i < n_class; i++)
{
auto this_class_mask = torch::full(y.sizes(), i, torch::TensorOptions().dtype(torch::kInt64));
auto positive_label_mask = y_true.eq(this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64));
auto mask = torch::ones(positive_label_mask.sizes(), torch::TensorOptions().dtype(torch::kInt64));
auto negative_label_mask = torch::sub(mask, positive_label_mask);

auto correct_null_predictions = torch::mul(torch::ne(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), negative_label_mask);
double tmp_true_negatives = torch::mul(correct_null_predictions, mask).sum().item().toInt();
true_negatives += tmp_true_negatives;
//std::cout << "true_neg=" << tmp_true_negatives << std::endl;

auto correct_non_null_predictions = torch::mul(torch::eq(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), positive_label_mask);
double tmp_true_positives = torch::mul(correct_non_null_predictions, mask).sum().item().toInt();
true_positives += tmp_true_positives;
//std::cout << "true_pos=" << tmp_true_positives << std::endl;

auto incorrect_null_predictions = torch::mul(torch::ne(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), positive_label_mask);
double tmp_false_negatives = torch::mul(incorrect_null_predictions, mask).sum().item().toInt();
false_negatives += tmp_false_negatives;
//std::cout << "false_neg=" << tmp_false_negatives << std::endl;

auto incorrect_non_null_predictions = torch::mul(torch::eq(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), negative_label_mask);
double tmp_false_positives = torch::mul(incorrect_non_null_predictions, mask).sum().item().toInt();
false_positives += tmp_false_positives;
//std::cout << "false_pos=" << tmp_false_positives << std::endl;

if (option == "macro") {
precision = double(tmp_true_positives) / (tmp_true_positives + tmp_false_positives + 1e-13);
recall = double(tmp_true_positives) / (tmp_true_positives + tmp_false_negatives + 1e-13);
average_f1score += 2 * recall * precision / (recall + precision + 1e-13);
}
}
if (option == "macro") {
return average_f1score;
}
else {
precision = double(true_positives) / (true_positives + false_positives + 1e-13);
recall = double(true_positives) / (true_positives + false_negatives + 1e-13);
return 2 * recall * precision / (recall + precision + 1e-13);
}
}

private:
Label current_class_;
};
Expand Down Expand Up @@ -174,6 +232,58 @@ class SingleLabelIndex : public LabelIndex {
return n_right * 1.0 / n_total;
}

virtual double F1Score(torch::Tensor y_pred, torch::Tensor y_true, const std::string& option) const override {
int n_class = y_pred.size(1);
auto y = y_pred.argmax(1).to(torch::TensorOptions().dtype(torch::kInt64));
int64_t true_negatives = 0;
int64_t true_positives = 0;
int64_t false_negatives = 0;
int64_t false_positives = 0;
double precision = 0;
double recall = 0;
double average_f1score = 0;
for (int64_t i = 0; i < n_class; i++)
{
auto this_class_mask = torch::full(y.sizes(), i, torch::TensorOptions().dtype(torch::kInt64));
auto positive_label_mask = y_true.eq(this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64));
auto mask = torch::ones(positive_label_mask.sizes(), torch::TensorOptions().dtype(torch::kInt64));
auto negative_label_mask = torch::sub(mask, positive_label_mask);

auto correct_null_predictions = torch::mul(torch::ne(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), negative_label_mask);
double tmp_true_negatives = torch::mul(correct_null_predictions, mask).sum().item().toInt();
true_negatives += tmp_true_negatives;
//std::cout << "true_neg=" << tmp_true_negatives << std::endl;

auto correct_non_null_predictions = torch::mul(torch::eq(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), positive_label_mask);
double tmp_true_positives = torch::mul(correct_non_null_predictions, mask).sum().item().toInt();
true_positives += tmp_true_positives;
//std::cout << "true_pos=" << tmp_true_positives << std::endl;

auto incorrect_null_predictions = torch::mul(torch::ne(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), positive_label_mask);
double tmp_false_negatives = torch::mul(incorrect_null_predictions, mask).sum().item().toInt();
false_negatives += tmp_false_negatives;
//std::cout << "false_neg=" << tmp_false_negatives << std::endl;

auto incorrect_non_null_predictions = torch::mul(torch::eq(y, this_class_mask).to(torch::TensorOptions().dtype(torch::kInt64)), negative_label_mask);
double tmp_false_positives = torch::mul(incorrect_non_null_predictions, mask).sum().item().toInt();
false_positives += tmp_false_positives;
//std::cout << "false_pos=" << tmp_false_positives << std::endl;

if (option == "macro") {
precision = double(tmp_true_positives) / (tmp_true_positives + tmp_false_positives + 1e-13);
recall = double(tmp_true_positives) / (tmp_true_positives + tmp_false_negatives + 1e-13);
average_f1score += 2 * recall * precision / (recall + precision + 1e-13);
}
}
if (option == "macro") {
return average_f1score;
}
else {
precision = double(true_positives) / (true_positives + false_positives + 1e-13);
recall = double(true_positives) / (true_positives + false_negatives + 1e-13);
return 2 * recall * precision / (recall + precision + 1e-13);
}
}

private:
std::unordered_map<NodeId, Label> labels_;
Expand Down