diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4dab451..6bb520c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -70,6 +70,7 @@ set(ENGINE_SRC
set(UTILITIES_SRC
code/MurTree/Utilities/file_reader.cpp
code/MurTree/Utilities/parameter_handler.cpp
+ code/MurTree/Utilities/parameters.cpp
)
# MurTree can be built as a:
diff --git a/LinuxRelease/code/MurTree/Utilities/subdir.mk b/LinuxRelease/code/MurTree/Utilities/subdir.mk
index b24e871..2e8abfc 100644
--- a/LinuxRelease/code/MurTree/Utilities/subdir.mk
+++ b/LinuxRelease/code/MurTree/Utilities/subdir.mk
@@ -5,15 +5,18 @@
# Add inputs and outputs from these tool invocations to the build variables
CPP_SRCS += \
../code/MurTree/Utilities/file_reader.cpp \
-../code/MurTree/Utilities/parameter_handler.cpp
+../code/MurTree/Utilities/parameter_handler.cpp \
+../code/MurTree/Utilities/parameters.cpp
OBJS += \
./code/MurTree/Utilities/file_reader.o \
-./code/MurTree/Utilities/parameter_handler.o
+./code/MurTree/Utilities/parameter_handler.o \
+./code/MurTree/Utilities/parameters.o
CPP_DEPS += \
./code/MurTree/Utilities/file_reader.d \
-./code/MurTree/Utilities/parameter_handler.d
+./code/MurTree/Utilities/parameter_handler.d \
+./code/MurTree/Utilities/parameter.d
# Each subdirectory must supply rules for building sources it contributes
diff --git a/MurTree.vcxproj b/MurTree.vcxproj
index 98c7d63..5b2d193 100644
--- a/MurTree.vcxproj
+++ b/MurTree.vcxproj
@@ -41,6 +41,7 @@
+
@@ -76,6 +77,7 @@
+
diff --git a/MurTree.vcxproj.filters b/MurTree.vcxproj.filters
index 9e67b73..d49cbe9 100644
--- a/MurTree.vcxproj.filters
+++ b/MurTree.vcxproj.filters
@@ -68,6 +68,9 @@
Utilities
+
+ Utilities
+
@@ -177,6 +180,9 @@
Utilities
+
+ Utilities
+
Utilities
diff --git a/code/MurTree/Engine/solver.cpp b/code/MurTree/Engine/solver.cpp
index ad63182..024a09c 100644
--- a/code/MurTree/Engine/solver.cpp
+++ b/code/MurTree/Engine/solver.cpp
@@ -13,9 +13,128 @@
#include "../Utilities/runtime_assert.h"
#include "../Utilities/file_reader.h"
#include "../Data Structures/child_subtree_info.h"
+// include vector
+#include
+// include feature vector binary
+#include "../Data Structures/feature_vector_binary.h"
namespace MurTree
{
+// We add an alternative solver constructor with feature_vectors as an argument
+// This is required if we want to provide an interface to the solver that does not depend on a file read
+// It is not the cleanest solution, but changing the constructor depending on files will cascade also
+// to other dependencies of the solver, therefore we add this alternative constructor
+Solver::Solver(ParameterHandler& parameters, const std::vector>& feature_vectors):
+ verbose_(parameters.GetBooleanParameter("verbose")),
+ cache_(0),
+ binary_data_(0),
+ feature_selectors_(100, 0),
+ splits_data(100, 0),
+ similarity_lower_bound_computer_(0),
+ specialised_solver1_(0),
+ specialised_solver2_(0)
+{
+ if (parameters.GetStringParameter("node-selection") == "dynamic") { dynamic_child_selection_ = true; }
+ else if (parameters.GetStringParameter("node-selection") == "post-order") { dynamic_child_selection_ = false; }
+ else { std::cout << "Unknown node selection strategy: '" << parameters.GetStringParameter("node-selection") << "'\n"; exit(1); }
+
+ // start: read in the data
+ if (feature_vectors.empty())
+ {
+ feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor"));
+ }
+
+ else { feature_vectors_ = feature_vectors; }
+
+ num_labels_ = feature_vectors_.size();
+ num_features_ = -1;
+ for (auto& v : feature_vectors_) // could do better checking of the data
+ {
+ if (!v.empty())
+ {
+ num_features_ = v[0].NumTotalFeatures();
+ }
+ }
+ runtime_assert(num_features_ > 0 && num_labels_ > 1);
+
+ binary_data_ = new BinaryDataInternal(num_labels_, num_features_);
+ for (int label = 0; label < num_labels_; label++)
+ {
+ for (int i = 0; i < feature_vectors_[label].size(); i++)
+ {
+ binary_data_->AddFeatureVector(&feature_vectors_[label][i], label);
+ }
+ }
+ // end: read in the data
+
+ binary_data_ = new BinaryDataInternal(num_labels_, num_features_);
+ for (int label = 0; label < num_labels_; label++)
+ {
+ for (int i = 0; i < feature_vectors_[label].size(); i++)
+ {
+ binary_data_->AddFeatureVector(&feature_vectors_[label][i], label);
+ }
+ }
+ //end: read in the data
+
+ for(int i = 0; i < 100; i++) { splits_data[i] = new SplitBinaryData(num_labels_, num_features_); }
+
+ for (int i = 0; i < 100; i++)
+ {
+ if (parameters.GetStringParameter("feature-ordering") == "in-order") { feature_selectors_[i] = new FeatureSelectorInOrder(num_features_); }
+ else if (parameters.GetStringParameter("feature-ordering") == "random") { feature_selectors_[i] = new FeatureSelectorRandom(num_features_); }
+ else if (parameters.GetStringParameter("feature-ordering") == "gini") { feature_selectors_[i] = new FeatureSelectorGini(num_labels_, num_features_); }
+ else { std::cout << "Unknown feature ordering strategy!\n"; exit(1); }
+ }
+
+ similarity_lower_bound_computer_ = new SimilarityLowerBoundComputer(100, 100, binary_data_->Size());
+ if (parameters.GetBooleanParameter("similarity-lower-bound") == false) { similarity_lower_bound_computer_->Disable(); }
+
+ if (parameters.GetStringParameter("cache-type") == "branch") { cache_ = new BranchCache(100); }
+ else if (parameters.GetStringParameter("cache-type") == "dataset") { cache_ = new DatasetCache(binary_data_->Size()); }
+ else if (parameters.GetStringParameter("cache-type") == "closure") { cache_ = new ClosureCache(num_features_, binary_data_->Size()); }
+ else
+ {
+ std::cout << "Parameter error: unknown cache type: " << parameters.GetStringParameter("cache-type") << "\n";
+ runtime_assert(1 == 2);
+ }
+
+ if (binary_data_->NumLabels() == 2)
+ {
+ specialised_solver1_ = new SpecialisedBinaryClassificationDecisionTreeSolver
+ (
+ num_features_,
+ binary_data_->Size(),
+ parameters.GetBooleanParameter("incremental-frequency")
+ );
+
+ specialised_solver2_ = new SpecialisedBinaryClassificationDecisionTreeSolver
+ (
+ num_features_,
+ binary_data_->Size(),
+ parameters.GetBooleanParameter("incremental-frequency")
+ );
+ }
+ else
+ {
+ specialised_solver1_ = new SpecialisedGeneralClassificationDecisionTreeSolver
+ (
+ num_labels_,
+ num_features_,
+ binary_data_->Size(),
+ parameters.GetBooleanParameter("incremental-frequency")
+ );
+
+ specialised_solver2_ = new SpecialisedGeneralClassificationDecisionTreeSolver
+ (
+ num_labels_,
+ num_features_,
+ binary_data_->Size(),
+ parameters.GetBooleanParameter("incremental-frequency")
+ );
+ }
+}
+
Solver::Solver(ParameterHandler& parameters):
verbose_(parameters.GetBooleanParameter("verbose")),
cache_(0),
@@ -115,7 +234,7 @@ Solver::~Solver()
for (int i = 0; i < feature_selectors_.size(); i++) { delete feature_selectors_[i]; }
}
-void Solver::ReplaceData(std::vector >& new_instances)
+void Solver::ReplaceData(std::vector>& new_instances)
{
runtime_assert(new_instances.size() == binary_data_->NumLabels());
diff --git a/code/MurTree/Engine/solver.h b/code/MurTree/Engine/solver.h
index 7db2989..23943b0 100644
--- a/code/MurTree/Engine/solver.h
+++ b/code/MurTree/Engine/solver.h
@@ -15,6 +15,7 @@
#include "../Utilities/solver_result.h"
#include "../Utilities/statistics.h"
#include "../Utilities/parameter_handler.h"
+#include "../Data Structures/feature_vector_binary.h"
namespace MurTree
{
@@ -22,6 +23,7 @@ class Solver
{
public:
Solver(ParameterHandler &solver_parameters);
+ Solver(ParameterHandler &solver_parameters, const std::vector>& feature_vectors);
~Solver();
SolverResult Solve(ParameterHandler &runtime_parameters);
diff --git a/code/MurTree/Utilities/file_reader.h b/code/MurTree/Utilities/file_reader.h
index 2296195..4933cdd 100644
--- a/code/MurTree/Utilities/file_reader.h
+++ b/code/MurTree/Utilities/file_reader.h
@@ -4,6 +4,7 @@
#pragma once
#include
+#include
#include "../Data Structures/feature_vector_binary.h"
#include "split_data.h"
diff --git a/code/MurTree/Utilities/parameter_handler.h b/code/MurTree/Utilities/parameter_handler.h
index 90af3de..431f5df 100644
--- a/code/MurTree/Utilities/parameter_handler.h
+++ b/code/MurTree/Utilities/parameter_handler.h
@@ -3,6 +3,7 @@
#pragma once
+#include
#include
#include
#include