From de4e6cb1c293f89b51f730cc78678e5033b50f78 Mon Sep 17 00:00:00 2001 From: Yasel Date: Wed, 26 Apr 2023 11:29:48 +0200 Subject: [PATCH 1/5] Move DefineNewParameters and CheckParameters from main.cpp to a new file --- CMakeLists.txt | 1 + LinuxRelease/code/MurTree/Utilities/subdir.mk | 9 +- MurTree.vcxproj | 2 + MurTree.vcxproj.filters | 6 + code/MurTree/Utilities/parameters.cpp | 215 +++++++++++++++++ code/MurTree/Utilities/parameters.h | 9 + main.cpp | 221 +----------------- 7 files changed, 242 insertions(+), 221 deletions(-) create mode 100644 code/MurTree/Utilities/parameters.cpp create mode 100644 code/MurTree/Utilities/parameters.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4dab451..6bb520c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ set(ENGINE_SRC set(UTILITIES_SRC code/MurTree/Utilities/file_reader.cpp code/MurTree/Utilities/parameter_handler.cpp + code/MurTree/Utilities/parameters.cpp ) # MurTree can be built as a: diff --git a/LinuxRelease/code/MurTree/Utilities/subdir.mk b/LinuxRelease/code/MurTree/Utilities/subdir.mk index b24e871..2e8abfc 100644 --- a/LinuxRelease/code/MurTree/Utilities/subdir.mk +++ b/LinuxRelease/code/MurTree/Utilities/subdir.mk @@ -5,15 +5,18 @@ # Add inputs and outputs from these tool invocations to the build variables CPP_SRCS += \ ../code/MurTree/Utilities/file_reader.cpp \ -../code/MurTree/Utilities/parameter_handler.cpp +../code/MurTree/Utilities/parameter_handler.cpp \ +../code/MurTree/Utilities/parameters.cpp OBJS += \ ./code/MurTree/Utilities/file_reader.o \ -./code/MurTree/Utilities/parameter_handler.o +./code/MurTree/Utilities/parameter_handler.o \ +./code/MurTree/Utilities/parameters.o CPP_DEPS += \ ./code/MurTree/Utilities/file_reader.d \ -./code/MurTree/Utilities/parameter_handler.d +./code/MurTree/Utilities/parameter_handler.d \ +./code/MurTree/Utilities/parameter.d # Each subdirectory must supply rules for building sources it contributes diff --git a/MurTree.vcxproj b/MurTree.vcxproj index 98c7d63..5b2d193 100644 --- a/MurTree.vcxproj +++ b/MurTree.vcxproj @@ -41,6 +41,7 @@ + @@ -76,6 +77,7 @@ + diff --git a/MurTree.vcxproj.filters b/MurTree.vcxproj.filters index 9e67b73..d49cbe9 100644 --- a/MurTree.vcxproj.filters +++ b/MurTree.vcxproj.filters @@ -68,6 +68,9 @@ Utilities + + Utilities + @@ -177,6 +180,9 @@ Utilities + + Utilities + Utilities diff --git a/code/MurTree/Utilities/parameters.cpp b/code/MurTree/Utilities/parameters.cpp new file mode 100644 index 0000000..c0727cc --- /dev/null +++ b/code/MurTree/Utilities/parameters.cpp @@ -0,0 +1,215 @@ +#include "parameters.h" + +namespace MurTree { + + ParameterHandler DefineParameters() { + + ParameterHandler parameters; + parameters.DefineNewCategory("Main Parameters"); + parameters.DefineNewCategory("Algorithmic Parameters"); + parameters.DefineNewCategory("Tuning Parameters"); + + parameters.DefineStringParameter + ( + "file", + "Location to the dataset.", + "", //default value + "Main Parameters" + ); + + parameters.DefineFloatParameter + ( + "time", + "Maximum runtime given in seconds.", + 600, //default value + "Main Parameters", + 0, //min value + INT32_MAX //max value + ); + + parameters.DefineIntegerParameter + ( + "max-depth", + "Maximum allowed depth of the tree, where the depth is defined as the largest number of *decision/feature nodes* from the root to any leaf. Depth greater than four is usually time consuming.", + 3, //default value + "Main Parameters", + 0, //min value + 20 //max value + ); + + parameters.DefineIntegerParameter + ( + "max-num-nodes", + "Maximum number of *decision/feature nodes* allowed. Note that a tree with k feature nodes has k+1 leaf nodes.", + 7, //default value + "Main Parameters", + 0, + INT32_MAX + ); + + parameters.DefineFloatParameter + ( + "sparse-coefficient", + "Assigns the penalty for using decision/feature nodes. Large sparse coefficients will result in smaller trees.", + 0.0, + "Main Parameters", + 0.0, + 1.0 + ); + + parameters.DefineBooleanParameter + ( + "verbose", + "Determines if the solver should print logging information to the standard output.", + true, + "Main Parameters" + ); + + parameters.DefineBooleanParameter + ( + "all-trees", + "Instructs the algorithm to compute trees using all allowed combinations of max-depth and max-num-nodes. Used to stress-test the algorithm.", + false, + "Main Parameters" + ); + + parameters.DefineStringParameter + ( + "result-file", + "The results of the algorithm are printed in the provided file, using for simple benchmarking. The output file contains the runtime, misclassification score, and number of cache entries. Leave blank to avoid printing.", + "", //default value + "Main Parameters" + ); + + //Internal algorithmic parameters----------- + + parameters.DefineBooleanParameter + ( + "incremental-frequency", + "Activate incremental frequency computation, which takes into account previously computed trees when recomputing the frequency. In our experiments proved to be effective on all datasets.", + true, + "Algorithmic Parameters" + ); + + parameters.DefineBooleanParameter + ( + "similarity-lower-bound", + "Activate similarity-based lower bounding. Disabling this option may be better for some benchmarks, but on most of our tested datasets keeping this on was beneficial.", + true, + "Algorithmic Parameters" + ); + + parameters.DefineStringParameter + ( + "node-selection", + "Node selection strategy used to decide on whether the algorithm should examine the left or right child node first.", + "dynamic", //default value + "Algorithmic Parameters", + { "dynamic", "post-order" } + ); + + parameters.DefineStringParameter + ( + "feature-ordering", + "Feature ordering strategy used to determine the order in which features will be inspected in each node.", + "in-order", //default value + "Algorithmic Parameters", + { "in-order", "random", "gini" } + ); + + parameters.DefineIntegerParameter + ( + "random-seed", + "Random seed used only if the feature-ordering is set to random. A seed of -1 assings the seed based on the current time.", + 3, + "Algorithmic Parameters", + -1, + INT32_MAX + ); + + parameters.DefineStringParameter + ( + "cache-type", + "Cache type used to store computed subtrees. \"Dataset\" is more powerful than \"branch\" but may required more computational time. Need to be determined experimentally. \"Closure\" is experimental and typically slower than other options.", + "dataset", //default value + "Algorithmic Parameters", + { "branch", "dataset", "closure" } + ); + + parameters.DefineIntegerParameter + ( + "duplicate-factor", + "Duplicates the instances the given amount of times. Used for stress-testing the algorithm, not a practical parameter.", + 1, + "Algorithmic Parameters", + 1, + INT32_MAX + ); + + parameters.DefineIntegerParameter + ( + "upper-bound", + "Initial upper bound.", + INT32_MAX, //default value + "Algorithmic Parameters", + 0, + INT32_MAX + ); + + //Tuning parameters + + parameters.DefineBooleanParameter + ( + "hyper-parameter-tuning", + "Activate hyper-parameter tuning using max-depth and max-num-nodes as the maximum values allowed. The splits need to be provided in the appropriate folder...see the code. todo", + false, + "Tuning Parameters" + ); + + parameters.DefineStringParameter + ( + "splits-location-prefix", + "Prefix to where the splits may be found. Used in combination with hyper-parameter-tuning. The splits need to be provided in the appropriate folder...see the code. todo", + "", //default value + "Tuning Parameters" + ); + + parameters.DefineStringParameter + ( + "hyper-parameter-stats-file", + "Location of the output file that contains information about the hyper-parameter procedure.", + "", //default value + "Tuning Parameters" + ); + + return parameters; + } + + void CheckParameters(ParameterHandler& parameters) { + if (parameters.GetStringParameter("file") == "") { + std::cout << "Error: No file given!\n"; + exit(1); + } + + if (parameters.GetIntegerParameter("max-depth") > parameters.GetIntegerParameter("max-num-nodes")) { + std::cout << "Error: The depth parameter is greater than the number of nodes!\n"; + exit(1); + } + + if (parameters.GetIntegerParameter("max-num-nodes") > (uint32_t(1) << parameters.GetIntegerParameter("max-depth")) - 1) { + std::cout << "Error: The number of nodes exceeds the limit imposed by the depth!\n"; + exit(1); + } + + if (parameters.GetBooleanParameter("hyper-parameter-tuning") && parameters.GetStringParameter("splits-location-prefix") == "") { + std::cout << "Error: hyper tuning specified but no splits given\n"; + exit(1); + } + + if (parameters.GetBooleanParameter("hyper-parameter-tuning") && parameters.GetStringParameter("hyper-parameter-stats-file") == "") { + std::cout << "Error: hyper tuning specified but no output file location given\n"; + exit(1); + } + } + +} \ No newline at end of file diff --git a/code/MurTree/Utilities/parameters.h b/code/MurTree/Utilities/parameters.h new file mode 100644 index 0000000..b1b2e77 --- /dev/null +++ b/code/MurTree/Utilities/parameters.h @@ -0,0 +1,9 @@ +#include "parameter_handler.h" + +namespace MurTree { + + ParameterHandler DefineParameters(); + + void CheckParameters(ParameterHandler& parameters); + +} \ No newline at end of file diff --git a/main.cpp b/main.cpp index 9f87db7..3e3f0d1 100644 --- a/main.cpp +++ b/main.cpp @@ -8,6 +8,7 @@ //Authors: Emir Demirović, Anna Lukina, Emmanuel Hebrard, Jeffrey Chan, James Bailey, Christopher Leckie, Kotagiri Ramamohanarao, Peter J. Stuckey //For any issues related to the code, please feel free to contact Dr Emir Demirović, e.demirovic@tudelft.nl +#include "code/MurTree/Utilities/parameters.h" #include "code/MurTree/Engine/solver.h" #include "code/MurTree/Engine/dataset_cache.h" #include "code/MurTree/Engine/hyper_parameter_tuner.h" @@ -21,225 +22,9 @@ #include #include -MurTree::ParameterHandler DefineParameters() -{ - MurTree::ParameterHandler parameters; - - parameters.DefineNewCategory("Main Parameters"); - parameters.DefineNewCategory("Algorithmic Parameters"); - parameters.DefineNewCategory("Tuning Parameters"); - - parameters.DefineStringParameter - ( - "file", - "Location to the dataset.", - "", //default value - "Main Parameters" - ); - - parameters.DefineFloatParameter - ( - "time", - "Maximum runtime given in seconds.", - 600, //default value - "Main Parameters", - 0, //min value - INT32_MAX //max value - ); - - parameters.DefineIntegerParameter - ( - "max-depth", - "Maximum allowed depth of the tree, where the depth is defined as the largest number of *decision/feature nodes* from the root to any leaf. Depth greater than four is usually time consuming.", - 3, //default value - "Main Parameters", - 0, //min value - 20 //max value - ); - - parameters.DefineIntegerParameter - ( - "max-num-nodes", - "Maximum number of *decision/feature nodes* allowed. Note that a tree with k feature nodes has k+1 leaf nodes.", - 7, //default value - "Main Parameters", - 0, - INT32_MAX - ); - - parameters.DefineFloatParameter - ( - "sparse-coefficient", - "Assigns the penalty for using decision/feature nodes. Large sparse coefficients will result in smaller trees.", - 0.0, - "Main Parameters", - 0.0, - 1.0 - ); - - parameters.DefineBooleanParameter - ( - "verbose", - "Determines if the solver should print logging information to the standard output.", - true, - "Main Parameters" - ); - - parameters.DefineBooleanParameter - ( - "all-trees", - "Instructs the algorithm to compute trees using all allowed combinations of max-depth and max-num-nodes. Used to stress-test the algorithm.", - false, - "Main Parameters" - ); - - parameters.DefineStringParameter - ( - "result-file", - "The results of the algorithm are printed in the provided file, using for simple benchmarking. The output file contains the runtime, misclassification score, and number of cache entries. Leave blank to avoid printing.", - "", //default value - "Main Parameters" - ); - - //Internal algorithmic parameters----------- - - parameters.DefineBooleanParameter - ( - "incremental-frequency", - "Activate incremental frequency computation, which takes into account previously computed trees when recomputing the frequency. In our experiments proved to be effective on all datasets.", - true, - "Algorithmic Parameters" - ); - - parameters.DefineBooleanParameter - ( - "similarity-lower-bound", - "Activate similarity-based lower bounding. Disabling this option may be better for some benchmarks, but on most of our tested datasets keeping this on was beneficial.", - true, - "Algorithmic Parameters" - ); - - parameters.DefineStringParameter - ( - "node-selection", - "Node selection strategy used to decide on whether the algorithm should examine the left or right child node first.", - "dynamic", //default value - "Algorithmic Parameters", - { "dynamic", "post-order" } - ); - - parameters.DefineStringParameter - ( - "feature-ordering", - "Feature ordering strategy used to determine the order in which features will be inspected in each node.", - "in-order", //default value - "Algorithmic Parameters", - { "in-order", "random", "gini" } - ); - - parameters.DefineIntegerParameter - ( - "random-seed", - "Random seed used only if the feature-ordering is set to random. A seed of -1 assings the seed based on the current time.", - 3, - "Algorithmic Parameters", - -1, - INT32_MAX - ); - - parameters.DefineStringParameter - ( - "cache-type", - "Cache type used to store computed subtrees. \"Dataset\" is more powerful than \"branch\" but may required more computational time. Need to be determined experimentally. \"Closure\" is experimental and typically slower than other options.", - "dataset", //default value - "Algorithmic Parameters", - { "branch", "dataset", "closure" } - ); - - parameters.DefineIntegerParameter - ( - "duplicate-factor", - "Duplicates the instances the given amount of times. Used for stress-testing the algorithm, not a practical parameter.", - 1, - "Algorithmic Parameters", - 1, - INT32_MAX - ); - - parameters.DefineIntegerParameter - ( - "upper-bound", - "Initial upper bound.", - INT32_MAX, //default value - "Algorithmic Parameters", - 0, - INT32_MAX - ); - - //Tuning parameters - - parameters.DefineBooleanParameter - ( - "hyper-parameter-tuning", - "Activate hyper-parameter tuning using max-depth and max-num-nodes as the maximum values allowed. The splits need to be provided in the appropriate folder...see the code. todo", - false, - "Tuning Parameters" - ); - - parameters.DefineStringParameter - ( - "splits-location-prefix", - "Prefix to where the splits may be found. Used in combination with hyper-parameter-tuning. The splits need to be provided in the appropriate folder...see the code. todo", - "", //default value - "Tuning Parameters" - ); - - parameters.DefineStringParameter - ( - "hyper-parameter-stats-file", - "Location of the output file that contains information about the hyper-parameter procedure.", - "", //default value - "Tuning Parameters" - ); - - return parameters; -} - -void CheckParameters(MurTree::ParameterHandler& parameters) -{ - if (parameters.GetStringParameter("file") == "") - { - std::cout << "Error: No file given!\n"; exit(1); - } - - if (parameters.GetIntegerParameter("max-depth") > parameters.GetIntegerParameter("max-num-nodes")) - { - std::cout << "Error: The depth parameter is greater than the number of nodes!\n"; - exit(1); - } - - if (parameters.GetIntegerParameter("max-num-nodes") > (uint32_t(1) << parameters.GetIntegerParameter("max-depth")) - 1) - { - std::cout << "Error: The number of nodes exceeds the limit imposed by the depth!\n"; - exit(1); - } - - if (parameters.GetBooleanParameter("hyper-parameter-tuning") && parameters.GetStringParameter("splits-location-prefix") == "") - { - std::cout << "Error: hyper tuning specified but no splits given\n"; - exit(1); - } - - if (parameters.GetBooleanParameter("hyper-parameter-tuning") && parameters.GetStringParameter("hyper-parameter-stats-file") == "") - { - std::cout << "Error: hyper tuning specified but no output file location given\n"; - exit(1); - } -} - int internalMain(int argc, char* argv[]) { - MurTree::ParameterHandler parameters = DefineParameters(); + MurTree::ParameterHandler parameters = MurTree::DefineParameters(); bool manual_activation = false; bool single_parameter_set_tuning = false; @@ -318,7 +103,7 @@ int internalMain(int argc, char* argv[]) parameters.SetStringParameter("feature-ordering", "in-order"); } - CheckParameters(parameters); + MurTree::CheckParameters(parameters); if (parameters.GetBooleanParameter("verbose")) { parameters.PrintParameterValues(); } From cf7c56949ade87105579eb8450d15faf3d0f069c Mon Sep 17 00:00:00 2001 From: jurra Date: Fri, 5 May 2023 00:04:17 +0200 Subject: [PATCH 2/5] Here we implemented an approach to pass data as vector of vectors conditionally, so we can either now pass a file name, or actual data to the solver via the parameter handler. For this we created another parameter in the parameter handler that stores the vectors. --- code/MurTree/Engine/solver.cpp | 28 +++++++++++++--------- code/MurTree/Utilities/parameter_handler.h | 14 +++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/code/MurTree/Engine/solver.cpp b/code/MurTree/Engine/solver.cpp index ad63182..795eb7c 100644 --- a/code/MurTree/Engine/solver.cpp +++ b/code/MurTree/Engine/solver.cpp @@ -13,6 +13,7 @@ #include "../Utilities/runtime_assert.h" #include "../Utilities/file_reader.h" #include "../Data Structures/child_subtree_info.h" +// #include namespace MurTree { @@ -31,20 +32,25 @@ Solver::Solver(ParameterHandler& parameters): else { std::cout << "Unknown node selection strategy: '" << parameters.GetStringParameter("node-selection") << "'\n"; exit(1); } //start: read in the data - feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor")); - num_labels_ = feature_vectors_.size(); - num_features_ = -1; - for (auto& v : feature_vectors_) { if (!v.empty()) { num_features_ = v[0].NumTotalFeatures(); } } //could do better checking of the data - runtime_assert(num_features_ > 0 && num_labels_ > 1); - - binary_data_ = new BinaryDataInternal(num_labels_, num_features_); - for (int label = 0; label < num_labels_; label++) + if (parameters.GetData().empty()) { - for (int i = 0; i < feature_vectors_[label].size(); i++) + feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor")); + num_labels_ = feature_vectors_.size(); + num_features_ = -1; + for (auto& v : feature_vectors_) { if (!v.empty()) { num_features_ = v[0].NumTotalFeatures(); } } //could do better checking of the data + runtime_assert(num_features_ > 0 && num_labels_ > 1); + + binary_data_ = new BinaryDataInternal(num_labels_, num_features_); + for (int label = 0; label < num_labels_; label++) { - binary_data_->AddFeatureVector(&feature_vectors_[label][i], label); + for (int i = 0; i < feature_vectors_[label].size(); i++) + { + binary_data_->AddFeatureVector(&feature_vectors_[label][i], label); + } } } + + else { feature_vectors_ = parameters.GetData();} //end: read in the data for(int i = 0; i < 100; i++) { splits_data[i] = new SplitBinaryData(num_labels_, num_features_); } @@ -115,7 +121,7 @@ Solver::~Solver() for (int i = 0; i < feature_selectors_.size(); i++) { delete feature_selectors_[i]; } } -void Solver::ReplaceData(std::vector >& new_instances) +void Solver::ReplaceData(std::vector>& new_instances) { runtime_assert(new_instances.size() == binary_data_->NumLabels()); diff --git a/code/MurTree/Utilities/parameter_handler.h b/code/MurTree/Utilities/parameter_handler.h index 90af3de..89b37ad 100644 --- a/code/MurTree/Utilities/parameter_handler.h +++ b/code/MurTree/Utilities/parameter_handler.h @@ -9,6 +9,9 @@ #include #include +//import feature vector binary +#include "../Data Structures/feature_vector_binary.h" + //reserve special keyword for help //todo better error handling; also include these checks: @@ -20,6 +23,16 @@ namespace MurTree class ParameterHandler { public: + // This parameter is needed for the solver to not depend exclusively on a file read + // To be able to call the Solver, the default value is nullptr + void SetData(const std::vector>& feature_vectors_ = {} ) { + data_ = feature_vectors_; + } + + const std::vector>& GetData() const { + return data_; + } + void DefineNewCategory(const std::string& category_name, const std::string short_description = ""); void DefineStringParameter(const std::string& parameter_name, const std::string& short_description, const std::string& default_value, const std::string& category_name, const std::vector& allowed_values = std::vector()); //empty vector of allowed values means all values are allowed @@ -45,6 +58,7 @@ class ParameterHandler void PrintHelpSummary(std::ostream& out = std::cout); private: + std::vector> data_; void CheckStringParameter(const std::string& parameter_name, const std::string& value); void CheckIntegerParameter(const std::string& parameter_name, int64_t value); void CheckBooleanParameter(const std::string& parameter_name, bool value); From 1545434a53bb73d9cd2d6bc6d2ca3eee8bf31b3a Mon Sep 17 00:00:00 2001 From: jurra Date: Fri, 5 May 2023 14:12:01 +0200 Subject: [PATCH 3/5] fix bug when reading and storing data. The previous refactoring was wrong --- code/MurTree/Engine/solver.cpp | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/code/MurTree/Engine/solver.cpp b/code/MurTree/Engine/solver.cpp index 795eb7c..493eebc 100644 --- a/code/MurTree/Engine/solver.cpp +++ b/code/MurTree/Engine/solver.cpp @@ -33,24 +33,23 @@ Solver::Solver(ParameterHandler& parameters): //start: read in the data if (parameters.GetData().empty()) + {feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor"));} + + else { feature_vectors_ = parameters.GetData();} + + num_labels_ = feature_vectors_.size(); + num_features_ = -1; + for (auto& v : feature_vectors_) { if (!v.empty()) { num_features_ = v[0].NumTotalFeatures(); } } //could do better checking of the data + runtime_assert(num_features_ > 0 && num_labels_ > 1); + + binary_data_ = new BinaryDataInternal(num_labels_, num_features_); + for (int label = 0; label < num_labels_; label++) { - feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor")); - num_labels_ = feature_vectors_.size(); - num_features_ = -1; - for (auto& v : feature_vectors_) { if (!v.empty()) { num_features_ = v[0].NumTotalFeatures(); } } //could do better checking of the data - runtime_assert(num_features_ > 0 && num_labels_ > 1); - - binary_data_ = new BinaryDataInternal(num_labels_, num_features_); - for (int label = 0; label < num_labels_; label++) + for (int i = 0; i < feature_vectors_[label].size(); i++) { - for (int i = 0; i < feature_vectors_[label].size(); i++) - { - binary_data_->AddFeatureVector(&feature_vectors_[label][i], label); - } + binary_data_->AddFeatureVector(&feature_vectors_[label][i], label); } } - - else { feature_vectors_ = parameters.GetData();} //end: read in the data for(int i = 0; i < 100; i++) { splits_data[i] = new SplitBinaryData(num_labels_, num_features_); } From 025301e8dae3f58f98c5858b1bd9029e0cd7e554 Mon Sep 17 00:00:00 2001 From: jurra Date: Tue, 23 May 2023 16:19:17 +0200 Subject: [PATCH 4/5] We have no an alternative constructor that takes DLData as an argument. Now instead of passing data via the parameter handler, we do it via a Solver constructor. We also remove the check file from CheckParameters and added the new constructor to solver.h --- code/MurTree/Engine/solver.cpp | 126 ++++++++++++++++++++- code/MurTree/Engine/solver.h | 2 + code/MurTree/Utilities/parameter_handler.h | 14 --- code/MurTree/Utilities/parameters.cpp | 5 - 4 files changed, 122 insertions(+), 25 deletions(-) diff --git a/code/MurTree/Engine/solver.cpp b/code/MurTree/Engine/solver.cpp index 493eebc..024a09c 100644 --- a/code/MurTree/Engine/solver.cpp +++ b/code/MurTree/Engine/solver.cpp @@ -13,10 +13,128 @@ #include "../Utilities/runtime_assert.h" #include "../Utilities/file_reader.h" #include "../Data Structures/child_subtree_info.h" -// #include +// include vector +#include +// include feature vector binary +#include "../Data Structures/feature_vector_binary.h" namespace MurTree { +// We add an alternative solver constructor with feature_vectors as an argument +// This is required if we want to provide an interface to the solver that does not depend on a file read +// It is not the cleanest solution, but changing the constructor depending on files will cascade also +// to other dependencies of the solver, therefore we add this alternative constructor +Solver::Solver(ParameterHandler& parameters, const std::vector>& feature_vectors): + verbose_(parameters.GetBooleanParameter("verbose")), + cache_(0), + binary_data_(0), + feature_selectors_(100, 0), + splits_data(100, 0), + similarity_lower_bound_computer_(0), + specialised_solver1_(0), + specialised_solver2_(0) +{ + if (parameters.GetStringParameter("node-selection") == "dynamic") { dynamic_child_selection_ = true; } + else if (parameters.GetStringParameter("node-selection") == "post-order") { dynamic_child_selection_ = false; } + else { std::cout << "Unknown node selection strategy: '" << parameters.GetStringParameter("node-selection") << "'\n"; exit(1); } + + // start: read in the data + if (feature_vectors.empty()) + { + feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor")); + } + + else { feature_vectors_ = feature_vectors; } + + num_labels_ = feature_vectors_.size(); + num_features_ = -1; + for (auto& v : feature_vectors_) // could do better checking of the data + { + if (!v.empty()) + { + num_features_ = v[0].NumTotalFeatures(); + } + } + runtime_assert(num_features_ > 0 && num_labels_ > 1); + + binary_data_ = new BinaryDataInternal(num_labels_, num_features_); + for (int label = 0; label < num_labels_; label++) + { + for (int i = 0; i < feature_vectors_[label].size(); i++) + { + binary_data_->AddFeatureVector(&feature_vectors_[label][i], label); + } + } + // end: read in the data + + binary_data_ = new BinaryDataInternal(num_labels_, num_features_); + for (int label = 0; label < num_labels_; label++) + { + for (int i = 0; i < feature_vectors_[label].size(); i++) + { + binary_data_->AddFeatureVector(&feature_vectors_[label][i], label); + } + } + //end: read in the data + + for(int i = 0; i < 100; i++) { splits_data[i] = new SplitBinaryData(num_labels_, num_features_); } + + for (int i = 0; i < 100; i++) + { + if (parameters.GetStringParameter("feature-ordering") == "in-order") { feature_selectors_[i] = new FeatureSelectorInOrder(num_features_); } + else if (parameters.GetStringParameter("feature-ordering") == "random") { feature_selectors_[i] = new FeatureSelectorRandom(num_features_); } + else if (parameters.GetStringParameter("feature-ordering") == "gini") { feature_selectors_[i] = new FeatureSelectorGini(num_labels_, num_features_); } + else { std::cout << "Unknown feature ordering strategy!\n"; exit(1); } + } + + similarity_lower_bound_computer_ = new SimilarityLowerBoundComputer(100, 100, binary_data_->Size()); + if (parameters.GetBooleanParameter("similarity-lower-bound") == false) { similarity_lower_bound_computer_->Disable(); } + + if (parameters.GetStringParameter("cache-type") == "branch") { cache_ = new BranchCache(100); } + else if (parameters.GetStringParameter("cache-type") == "dataset") { cache_ = new DatasetCache(binary_data_->Size()); } + else if (parameters.GetStringParameter("cache-type") == "closure") { cache_ = new ClosureCache(num_features_, binary_data_->Size()); } + else + { + std::cout << "Parameter error: unknown cache type: " << parameters.GetStringParameter("cache-type") << "\n"; + runtime_assert(1 == 2); + } + + if (binary_data_->NumLabels() == 2) + { + specialised_solver1_ = new SpecialisedBinaryClassificationDecisionTreeSolver + ( + num_features_, + binary_data_->Size(), + parameters.GetBooleanParameter("incremental-frequency") + ); + + specialised_solver2_ = new SpecialisedBinaryClassificationDecisionTreeSolver + ( + num_features_, + binary_data_->Size(), + parameters.GetBooleanParameter("incremental-frequency") + ); + } + else + { + specialised_solver1_ = new SpecialisedGeneralClassificationDecisionTreeSolver + ( + num_labels_, + num_features_, + binary_data_->Size(), + parameters.GetBooleanParameter("incremental-frequency") + ); + + specialised_solver2_ = new SpecialisedGeneralClassificationDecisionTreeSolver + ( + num_labels_, + num_features_, + binary_data_->Size(), + parameters.GetBooleanParameter("incremental-frequency") + ); + } +} + Solver::Solver(ParameterHandler& parameters): verbose_(parameters.GetBooleanParameter("verbose")), cache_(0), @@ -32,11 +150,7 @@ Solver::Solver(ParameterHandler& parameters): else { std::cout << "Unknown node selection strategy: '" << parameters.GetStringParameter("node-selection") << "'\n"; exit(1); } //start: read in the data - if (parameters.GetData().empty()) - {feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor"));} - - else { feature_vectors_ = parameters.GetData();} - + feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor")); num_labels_ = feature_vectors_.size(); num_features_ = -1; for (auto& v : feature_vectors_) { if (!v.empty()) { num_features_ = v[0].NumTotalFeatures(); } } //could do better checking of the data diff --git a/code/MurTree/Engine/solver.h b/code/MurTree/Engine/solver.h index 7db2989..23943b0 100644 --- a/code/MurTree/Engine/solver.h +++ b/code/MurTree/Engine/solver.h @@ -15,6 +15,7 @@ #include "../Utilities/solver_result.h" #include "../Utilities/statistics.h" #include "../Utilities/parameter_handler.h" +#include "../Data Structures/feature_vector_binary.h" namespace MurTree { @@ -22,6 +23,7 @@ class Solver { public: Solver(ParameterHandler &solver_parameters); + Solver(ParameterHandler &solver_parameters, const std::vector>& feature_vectors); ~Solver(); SolverResult Solve(ParameterHandler &runtime_parameters); diff --git a/code/MurTree/Utilities/parameter_handler.h b/code/MurTree/Utilities/parameter_handler.h index 89b37ad..90af3de 100644 --- a/code/MurTree/Utilities/parameter_handler.h +++ b/code/MurTree/Utilities/parameter_handler.h @@ -9,9 +9,6 @@ #include #include -//import feature vector binary -#include "../Data Structures/feature_vector_binary.h" - //reserve special keyword for help //todo better error handling; also include these checks: @@ -23,16 +20,6 @@ namespace MurTree class ParameterHandler { public: - // This parameter is needed for the solver to not depend exclusively on a file read - // To be able to call the Solver, the default value is nullptr - void SetData(const std::vector>& feature_vectors_ = {} ) { - data_ = feature_vectors_; - } - - const std::vector>& GetData() const { - return data_; - } - void DefineNewCategory(const std::string& category_name, const std::string short_description = ""); void DefineStringParameter(const std::string& parameter_name, const std::string& short_description, const std::string& default_value, const std::string& category_name, const std::vector& allowed_values = std::vector()); //empty vector of allowed values means all values are allowed @@ -58,7 +45,6 @@ class ParameterHandler void PrintHelpSummary(std::ostream& out = std::cout); private: - std::vector> data_; void CheckStringParameter(const std::string& parameter_name, const std::string& value); void CheckIntegerParameter(const std::string& parameter_name, int64_t value); void CheckBooleanParameter(const std::string& parameter_name, bool value); diff --git a/code/MurTree/Utilities/parameters.cpp b/code/MurTree/Utilities/parameters.cpp index c0727cc..b5bfb52 100644 --- a/code/MurTree/Utilities/parameters.cpp +++ b/code/MurTree/Utilities/parameters.cpp @@ -186,11 +186,6 @@ namespace MurTree { } void CheckParameters(ParameterHandler& parameters) { - if (parameters.GetStringParameter("file") == "") { - std::cout << "Error: No file given!\n"; - exit(1); - } - if (parameters.GetIntegerParameter("max-depth") > parameters.GetIntegerParameter("max-num-nodes")) { std::cout << "Error: The depth parameter is greater than the number of nodes!\n"; exit(1); From c408ee963958a79c6f6eb2552e99aac596baa9dc Mon Sep 17 00:00:00 2001 From: kjgm Date: Wed, 18 Oct 2023 14:52:28 +0200 Subject: [PATCH 5/5] missing imports for mingw compiler --- code/MurTree/Utilities/file_reader.h | 1 + code/MurTree/Utilities/parameter_handler.h | 1 + 2 files changed, 2 insertions(+) diff --git a/code/MurTree/Utilities/file_reader.h b/code/MurTree/Utilities/file_reader.h index 2296195..4933cdd 100644 --- a/code/MurTree/Utilities/file_reader.h +++ b/code/MurTree/Utilities/file_reader.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "../Data Structures/feature_vector_binary.h" #include "split_data.h" diff --git a/code/MurTree/Utilities/parameter_handler.h b/code/MurTree/Utilities/parameter_handler.h index 90af3de..431f5df 100644 --- a/code/MurTree/Utilities/parameter_handler.h +++ b/code/MurTree/Utilities/parameter_handler.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include