Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ set(ENGINE_SRC
set(UTILITIES_SRC
code/MurTree/Utilities/file_reader.cpp
code/MurTree/Utilities/parameter_handler.cpp
code/MurTree/Utilities/parameters.cpp
)

# MurTree can be built as a:
Expand Down
9 changes: 6 additions & 3 deletions LinuxRelease/code/MurTree/Utilities/subdir.mk
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
# Add inputs and outputs from these tool invocations to the build variables
CPP_SRCS += \
../code/MurTree/Utilities/file_reader.cpp \
../code/MurTree/Utilities/parameter_handler.cpp
../code/MurTree/Utilities/parameter_handler.cpp \
../code/MurTree/Utilities/parameters.cpp

OBJS += \
./code/MurTree/Utilities/file_reader.o \
./code/MurTree/Utilities/parameter_handler.o
./code/MurTree/Utilities/parameter_handler.o \
./code/MurTree/Utilities/parameters.o

CPP_DEPS += \
./code/MurTree/Utilities/file_reader.d \
./code/MurTree/Utilities/parameter_handler.d
./code/MurTree/Utilities/parameter_handler.d \
./code/MurTree/Utilities/parameter.d


# Each subdirectory must supply rules for building sources it contributes
Expand Down
2 changes: 2 additions & 0 deletions MurTree.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<ClCompile Include="code\MurTree\Engine\specialised_general_classification_decision_tree_solver.cpp" />
<ClCompile Include="code\MurTree\Utilities\file_reader.cpp" />
<ClCompile Include="code\MurTree\Utilities\parameter_handler.cpp" />
<ClCompile Include="code\MurTree\Utilities\parameters.cpp" />
<ClCompile Include="main.cpp" />
</ItemGroup>
<ItemGroup>
Expand Down Expand Up @@ -76,6 +77,7 @@
<ClInclude Include="code\MurTree\Engine\specialised_general_classification_decision_tree_solver.h" />
<ClInclude Include="code\MurTree\Utilities\file_reader.h" />
<ClInclude Include="code\MurTree\Utilities\parameter_handler.h" />
<ClInclude Include="code\MurTree\Utilities\parameters.h" />
<ClInclude Include="code\MurTree\Utilities\runtime_assert.h" />
<ClInclude Include="code\MurTree\Utilities\solver_result.h" />
<ClInclude Include="code\MurTree\Utilities\split_data.h" />
Expand Down
6 changes: 6 additions & 0 deletions MurTree.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
<ClCompile Include="code\MurTree\Utilities\parameter_handler.cpp">
<Filter>Utilities</Filter>
</ClCompile>
<ClCompile Include="code\MurTree\Utilities\parameters.cpp">
<Filter>Utilities</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<Filter Include="Data Structures">
Expand Down Expand Up @@ -177,6 +180,9 @@
<ClInclude Include="code\MurTree\Utilities\parameter_handler.h">
<Filter>Utilities</Filter>
</ClInclude>
<ClInclude Include="code\MurTree\Utilities\parameters.h">
<Filter>Utilities</Filter>
</ClInclude>
<ClInclude Include="code\MurTree\Utilities\runtime_assert.h">
<Filter>Utilities</Filter>
</ClInclude>
Expand Down
121 changes: 120 additions & 1 deletion code/MurTree/Engine/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,128 @@
#include "../Utilities/runtime_assert.h"
#include "../Utilities/file_reader.h"
#include "../Data Structures/child_subtree_info.h"
// include vector
#include <vector>
// include feature vector binary
#include "../Data Structures/feature_vector_binary.h"

namespace MurTree
{
// We add an alternative solver constructor with feature_vectors as an argument
// This is required if we want to provide an interface to the solver that does not depend on a file read
// It is not the cleanest solution, but changing the constructor depending on files will cascade also
// to other dependencies of the solver, therefore we add this alternative constructor
Solver::Solver(ParameterHandler& parameters, const std::vector<std::vector<FeatureVectorBinary>>& feature_vectors):
verbose_(parameters.GetBooleanParameter("verbose")),
cache_(0),
binary_data_(0),
feature_selectors_(100, 0),
splits_data(100, 0),
similarity_lower_bound_computer_(0),
specialised_solver1_(0),
specialised_solver2_(0)
{
if (parameters.GetStringParameter("node-selection") == "dynamic") { dynamic_child_selection_ = true; }
else if (parameters.GetStringParameter("node-selection") == "post-order") { dynamic_child_selection_ = false; }
else { std::cout << "Unknown node selection strategy: '" << parameters.GetStringParameter("node-selection") << "'\n"; exit(1); }

// start: read in the data
if (feature_vectors.empty())
{
feature_vectors_ = FileReader::ReadDataDL(parameters.GetStringParameter("file"), parameters.GetIntegerParameter("duplicate-factor"));
}

else { feature_vectors_ = feature_vectors; }

num_labels_ = feature_vectors_.size();
num_features_ = -1;
for (auto& v : feature_vectors_) // could do better checking of the data
{
if (!v.empty())
{
num_features_ = v[0].NumTotalFeatures();
}
}
runtime_assert(num_features_ > 0 && num_labels_ > 1);

binary_data_ = new BinaryDataInternal(num_labels_, num_features_);
for (int label = 0; label < num_labels_; label++)
{
for (int i = 0; i < feature_vectors_[label].size(); i++)
{
binary_data_->AddFeatureVector(&feature_vectors_[label][i], label);
}
}
// end: read in the data

binary_data_ = new BinaryDataInternal(num_labels_, num_features_);
for (int label = 0; label < num_labels_; label++)
{
for (int i = 0; i < feature_vectors_[label].size(); i++)
{
binary_data_->AddFeatureVector(&feature_vectors_[label][i], label);
}
}
//end: read in the data

for(int i = 0; i < 100; i++) { splits_data[i] = new SplitBinaryData(num_labels_, num_features_); }

for (int i = 0; i < 100; i++)
{
if (parameters.GetStringParameter("feature-ordering") == "in-order") { feature_selectors_[i] = new FeatureSelectorInOrder(num_features_); }
else if (parameters.GetStringParameter("feature-ordering") == "random") { feature_selectors_[i] = new FeatureSelectorRandom(num_features_); }
else if (parameters.GetStringParameter("feature-ordering") == "gini") { feature_selectors_[i] = new FeatureSelectorGini(num_labels_, num_features_); }
else { std::cout << "Unknown feature ordering strategy!\n"; exit(1); }
}

similarity_lower_bound_computer_ = new SimilarityLowerBoundComputer(100, 100, binary_data_->Size());
if (parameters.GetBooleanParameter("similarity-lower-bound") == false) { similarity_lower_bound_computer_->Disable(); }

if (parameters.GetStringParameter("cache-type") == "branch") { cache_ = new BranchCache(100); }
else if (parameters.GetStringParameter("cache-type") == "dataset") { cache_ = new DatasetCache(binary_data_->Size()); }
else if (parameters.GetStringParameter("cache-type") == "closure") { cache_ = new ClosureCache(num_features_, binary_data_->Size()); }
else
{
std::cout << "Parameter error: unknown cache type: " << parameters.GetStringParameter("cache-type") << "\n";
runtime_assert(1 == 2);
}

if (binary_data_->NumLabels() == 2)
{
specialised_solver1_ = new SpecialisedBinaryClassificationDecisionTreeSolver
(
num_features_,
binary_data_->Size(),
parameters.GetBooleanParameter("incremental-frequency")
);

specialised_solver2_ = new SpecialisedBinaryClassificationDecisionTreeSolver
(
num_features_,
binary_data_->Size(),
parameters.GetBooleanParameter("incremental-frequency")
);
}
else
{
specialised_solver1_ = new SpecialisedGeneralClassificationDecisionTreeSolver
(
num_labels_,
num_features_,
binary_data_->Size(),
parameters.GetBooleanParameter("incremental-frequency")
);

specialised_solver2_ = new SpecialisedGeneralClassificationDecisionTreeSolver
(
num_labels_,
num_features_,
binary_data_->Size(),
parameters.GetBooleanParameter("incremental-frequency")
);
}
}

Solver::Solver(ParameterHandler& parameters):
verbose_(parameters.GetBooleanParameter("verbose")),
cache_(0),
Expand Down Expand Up @@ -115,7 +234,7 @@ Solver::~Solver()
for (int i = 0; i < feature_selectors_.size(); i++) { delete feature_selectors_[i]; }
}

void Solver::ReplaceData(std::vector<std::vector<FeatureVectorBinary> >& new_instances)
void Solver::ReplaceData(std::vector<std::vector<FeatureVectorBinary>>& new_instances)
{
runtime_assert(new_instances.size() == binary_data_->NumLabels());

Expand Down
2 changes: 2 additions & 0 deletions code/MurTree/Engine/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#include "../Utilities/solver_result.h"
#include "../Utilities/statistics.h"
#include "../Utilities/parameter_handler.h"
#include "../Data Structures/feature_vector_binary.h"

namespace MurTree
{
class Solver
{
public:
Solver(ParameterHandler &solver_parameters);
Solver(ParameterHandler &solver_parameters, const std::vector<std::vector<FeatureVectorBinary>>& feature_vectors);
~Solver();

SolverResult Solve(ParameterHandler &runtime_parameters);
Expand Down
1 change: 1 addition & 0 deletions code/MurTree/Utilities/file_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <string>
#include <cstdint>

#include "../Data Structures/feature_vector_binary.h"
#include "split_data.h"
Expand Down
1 change: 1 addition & 0 deletions code/MurTree/Utilities/parameter_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <cstdint>
#include <string>
#include <vector>
#include <map>
Expand Down
Loading