diff --git a/api/MWFunctions b/api/MWFunctions index b41efc1c3..24c562364 100644 --- a/api/MWFunctions +++ b/api/MWFunctions @@ -27,7 +27,7 @@ #include "trees/BoundingBox.h" #include "trees/FunctionTree.h" #include "trees/FunctionTreeVector.h" -#include "utils/ComplexFunction.h" +#include "utils/CompFunction.h" #include "core/InterpolatingBasis.h" #include "core/LegendreBasis.h" diff --git a/api/mrcpp_declarations.h b/api/mrcpp_declarations.h index f6501b726..8296da045 100644 --- a/api/mrcpp_declarations.h +++ b/api/mrcpp_declarations.h @@ -34,7 +34,7 @@ namespace mrcpp { class Timer; class Printer; -template class Plotter; +template class Plotter; template class Gaussian; template class GaussFunc; @@ -42,26 +42,28 @@ template class GaussPoly; template class GaussExp; template class BoundingBox; -template class NodeBox; +template class NodeBox; template class NodeIndex; template class NodeIndexComp; -class SharedMemory; +template class SharedMemory; class ScalingBasis; class LegendreBasis; class InterpolatingBasis; -template class RepresentableFunction; +template class RepresentableFunction; template class MultiResolutionAnalysis; -template class MWTree; -template class FunctionTree; +template class MWTree; +template class FunctionTree; class OperatorTree; -template class NodeAllocator; +template class NodeAllocator; -template class MWNode; -template class FunctionNode; +template class MWNode; +template class FunctionNode; +template class CompFunction; +class ComplexFunction; class OperatorNode; template class IdentityConvolution; @@ -79,31 +81,30 @@ template class DerivativeKernel; class PoissonKernel; class HelmholtzKernel; -template class TreeBuilder; -template class TreeCalculator; -template class DefaultCalculator; -template class ProjectionCalculator; -template class AdditionCalculator; -template class MultiplicationCalculator; -template class ConvolutionCalculator; -template class DerivativeCalculator; +template class TreeBuilder; +template class TreeCalculator; +template class DefaultCalculator; +template class ProjectionCalculator; +template class AdditionCalculator; +template class MultiplicationCalculator; +template class ConvolutionCalculator; +template class DerivativeCalculator; class CrossCorrelationCalculator; -template class TreeAdaptor; -template class AnalyticAdaptor; -template class WaveletAdaptor; -template class CopyAdaptor; +template class TreeAdaptor; +template class AnalyticAdaptor; +template class WaveletAdaptor; +template class CopyAdaptor; -template class TreeIterator; -template class IteratorNode; +template class TreeIterator; +template class IteratorNode; class BandWidth; -template class OperatorState; +template class OperatorState; template using Coord = std::array; -template using MWNodeVector = std::vector *>; +template using MWNodeVector = std::vector *>; -template using FMap_ = std::function; -typedef FMap_ FMap; +template using FMap = std::function; } // namespace mrcpp diff --git a/cmake/compiler_flags/CXXFlags.cmake b/cmake/compiler_flags/CXXFlags.cmake index a12df3d12..cbcf32898 100644 --- a/cmake/compiler_flags/CXXFlags.cmake +++ b/cmake/compiler_flags/CXXFlags.cmake @@ -13,7 +13,7 @@ option(ENABLE_ARCH_FLAGS "Enable architecture-specific compiler flags" ON) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) set(CMAKE_CXX_EXTENSIONS FALSE) set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE) diff --git a/docs/mrcpp_api/mwfunctions.rst b/docs/mrcpp_api/mwfunctions.rst index cef911f6f..527497f5c 100644 --- a/docs/mrcpp_api/mwfunctions.rst +++ b/docs/mrcpp_api/mwfunctions.rst @@ -165,7 +165,7 @@ Constructing an MRA An MRA is defined in two steps, first the computational domain is given by a ``BoundingBox`` (D is the dimension), e.g. for a total domain of -:math:`[-32,32]^3` in three dimensions (eight root boxes of size :math:`[16]^3` +:math:`[-32,32]^3` in three dimensions (eight root boxes of size :math:`[32]^3` each): .. code-block:: cpp diff --git a/examples/derivative.cpp b/examples/derivative.cpp index bc0475db9..bd33bcfe7 100644 --- a/examples/derivative.cpp +++ b/examples/derivative.cpp @@ -51,8 +51,8 @@ int main(int argc, char **argv) { mrcpp::FunctionTree err_tree(MRA); // Projecting functions - mrcpp::project(prec, f_tree, f); - mrcpp::project(prec, df_tree, df); + mrcpp::project(prec, f_tree, f); + mrcpp::project(prec, df_tree, df); // Applying derivative operator mrcpp::apply(dg_tree, D_00, f_tree, 0); diff --git a/examples/mpi_matrix.cpp b/examples/mpi_matrix.cpp index f3580d158..536774c7f 100644 --- a/examples/mpi_matrix.cpp +++ b/examples/mpi_matrix.cpp @@ -54,7 +54,7 @@ int main(int argc, char **argv) { }; mrcpp::FunctionTree<3> *tree = new mrcpp::FunctionTree<3>(MRA); if (i % wsize == wrank) { - mrcpp::project<3>(prec, *tree, f); + mrcpp::project<3, double>(prec, *tree, f); tree->normalize(); } f_vec.push_back(std::make_tuple(1.0, tree)); diff --git a/examples/mpi_send_tree.cpp b/examples/mpi_send_tree.cpp index 44ffd0dec..aff8be379 100644 --- a/examples/mpi_send_tree.cpp +++ b/examples/mpi_send_tree.cpp @@ -55,7 +55,7 @@ int main(int argc, char **argv) { mrcpp::FunctionTree f_tree(MRA); // Only rank 0 projects the function - if (wrank == 0) mrcpp::project(prec, f_tree, f); + if (wrank == 0) mrcpp::project(prec, f_tree, f); { // Print data before send auto integral = f_tree.integrate(); diff --git a/examples/mpi_shared_tree.cpp b/examples/mpi_shared_tree.cpp index aa59d8204..ba7f7db5e 100644 --- a/examples/mpi_shared_tree.cpp +++ b/examples/mpi_shared_tree.cpp @@ -63,12 +63,12 @@ int main(int argc, char **argv) { }; // Initialize a shared memory tree, max 100MB - auto shared_mem = new mrcpp::SharedMemory(scomm, 100); + auto shared_mem = new mrcpp::SharedMemory(scomm, 100); mrcpp::FunctionTree f_tree(MRA, shared_mem); // Only first rank projects auto frank = 0; - if (srank == frank) mrcpp::project(prec, f_tree, f); + if (srank == frank) mrcpp::project(prec, f_tree, f); mrcpp::share_tree(f_tree, frank, 0, scomm); { // Print data after share diff --git a/examples/projection.cpp b/examples/projection.cpp index 92f1a7b53..9243485fb 100644 --- a/examples/projection.cpp +++ b/examples/projection.cpp @@ -37,7 +37,7 @@ int main(int argc, char **argv) { // Projecting function mrcpp::FunctionTree f_tree(MRA); - mrcpp::project(prec, f_tree, f, -1); + mrcpp::project(prec, f_tree, f, -1); auto integral = f_tree.integrate(); mrcpp::print::header(0, "Projecting analytic function"); diff --git a/examples/scf.cpp b/examples/scf.cpp index f91058f14..880830c91 100644 --- a/examples/scf.cpp +++ b/examples/scf.cpp @@ -21,17 +21,14 @@ void setupNuclearPotential(double Z, FunctionTree &V) { // Smoothing parameter auto c = 0.00435 * prec / std::pow(Z, 5); - auto u = [](double r) -> double { - return std::erf(r) / r + - 1.0 / (3.0 * std::sqrt(mrcpp::pi)) * (std::exp(-r * r) + 16.0 * std::exp(-4.0 * r * r)); - }; + auto u = [](double r) -> double { return std::erf(r) / r + 1.0 / (3.0 * std::sqrt(mrcpp::pi)) * (std::exp(-r * r) + 16.0 * std::exp(-4.0 * r * r)); }; auto f = [u, c, Z](const Coord<3> &r) -> double { auto x = std::sqrt(r[0] * r[0] + r[1] * r[1] + r[2] * r[2]); return -1.0 * Z * u(x / c) / c; }; // Projecting function - project(prec, V, f); + project(prec, V, f); print::footer(0, timer, 2); Printer::setPrintLevel(oldlevel); @@ -48,7 +45,7 @@ void setupInitialGuess(FunctionTree &phi) { }; // Projecting and normalizing function - project(prec, phi, f); + project(prec, phi, f); phi.normalize(); print::footer(0, timer, 2); diff --git a/examples/schrodinger_semigroup1d.cpp b/examples/schrodinger_semigroup1d.cpp index 2c1de2fa8..657d296f1 100644 --- a/examples/schrodinger_semigroup1d.cpp +++ b/examples/schrodinger_semigroup1d.cpp @@ -1,13 +1,11 @@ #include "MRCPP/MWFunctions" -#include -#include #include "MRCPP/Plotter" -#include -#include "operators/TimeEvolutionOperator.h" #include "functions/special_functions.h" +#include "operators/TimeEvolutionOperator.h" #include "treebuilders/complex_apply.h" - - +#include +#include +#include const auto min_scale = 0; const auto max_depth = 25; @@ -15,15 +13,14 @@ const auto max_depth = 25; const auto order = 4; const auto prec = 1.0e-7; -int finest_scale = 10; //for time evolution operator construction (not recommended to use more than 10) -int max_Jpower = 20; //the amount of J integrals to be used in construction (20 should be enough) +int finest_scale = 10; // for time evolution operator construction (not recommended to use more than 10) +int max_Jpower = 20; // the amount of J integrals to be used in construction (20 should be enough) // Time moments: -double t1 = 0.001; //initial time moment (not recommended to use more than 0.001) -double delta_t = 0.001; //time step (not recommended to use less than 0.001) -double t2 = delta_t + t1; //final time moment +double t1 = 0.001; // initial time moment (not recommended to use more than 0.001) +double delta_t = 0.001; // time step (not recommended to use less than 0.001) +double t2 = delta_t + t1; // final time moment - /** * @brief Exploring free-particle time evolution. * @details We check the time propagator. @@ -41,17 +38,16 @@ double t2 = delta_t + t1; //final time moment * \psi(x, t) = \sqrt{\frac{\sigma}{4it + \sigma}} e^{-\frac{(x - x_0)^2}{4it + \sigma}} * . * \f] - * + * */ -int main(int argc, char **argv) -{ +int main(int argc, char **argv) { auto timer = mrcpp::Timer(); // Initialize printing auto printlevel = 0; mrcpp::Printer::init(printlevel); mrcpp::print::environment(0); - + // Initialize world in the unit cube [0,1] auto basis = mrcpp::LegendreBasis(order); auto world = mrcpp::BoundingBox<1>(min_scale); @@ -74,87 +70,74 @@ int main(int argc, char **argv) double x0 = 0.5; // Functions f(x) = psi(x, t1) and g(x) = psi(x, t2) - auto Re_f = [sigma, x0, t=t1](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); - }; - auto Im_f = [sigma, x0, t=t1](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); - }; - auto Re_g = [sigma, x0, t=t2](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); - }; - auto Im_g = [sigma, x0, t=t2](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); - }; + auto Re_f = [sigma, x0, t = t1](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); }; + auto Im_f = [sigma, x0, t = t1](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); }; + auto Re_g = [sigma, x0, t = t2](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); }; + auto Im_g = [sigma, x0, t = t2](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); }; // Projecting functions mrcpp::FunctionTree<1> Re_f_tree(MRA); - mrcpp::project<1>(prec, Re_f_tree, Re_f); + mrcpp::project<1, double>(prec, Re_f_tree, Re_f); mrcpp::FunctionTree<1> Im_f_tree(MRA); - mrcpp::project<1>(prec, Im_f_tree, Im_f); + mrcpp::project<1, double>(prec, Im_f_tree, Im_f); mrcpp::FunctionTree<1> Re_g_tree(MRA); - mrcpp::project<1>(prec, Re_g_tree, Re_g); + mrcpp::project<1, double>(prec, Re_g_tree, Re_g); mrcpp::FunctionTree<1> Im_g_tree(MRA); - mrcpp::project<1>(prec, Im_g_tree, Im_g); + mrcpp::project<1, double>(prec, Im_g_tree, Im_g); // Output function trees mrcpp::FunctionTree<1> Re_fout_tree(MRA); mrcpp::FunctionTree<1> Im_fout_tree(MRA); - + // Complex objects for use in apply() - mrcpp::ComplexObject< mrcpp::ConvolutionOperator<1> > E(ReExp, ImExp); - mrcpp::ComplexObject< mrcpp::FunctionTree<1> > input(Re_f_tree, Im_f_tree); - mrcpp::ComplexObject< mrcpp::FunctionTree<1> > output(Re_fout_tree, Im_fout_tree); + mrcpp::ComplexObject> E(ReExp, ImExp); + mrcpp::ComplexObject> input(Re_f_tree, Im_f_tree); + mrcpp::ComplexObject> output(Re_fout_tree, Im_fout_tree); mrcpp::print::header(0, "Applying operator"); mrcpp::print::footer(0, timer, 2); // Apply operator Exp(delta_t) f(x) mrcpp::apply(prec, output, E, input); - + mrcpp::print::header(0, "Checking the result on analytical solution"); mrcpp::print::footer(0, timer, 2); // Check g(x) = Exp(delta_t) f(x) - mrcpp::FunctionTree<1> Re_error(MRA); // = Re_fout_tree - Re_g_tree - mrcpp::FunctionTree<1> Im_error(MRA); // = Im_fout_tree - Im_g_tree - + mrcpp::FunctionTree<1> Re_error(MRA); // = Re_fout_tree - Re_g_tree + mrcpp::FunctionTree<1> Im_error(MRA); // = Im_fout_tree - Im_g_tree + // Re_error = Re_fout_tree - Re_g_tree add(prec, Re_error, 1.0, Re_fout_tree, -1.0, Re_g_tree); auto Re_integral = Re_error.integrate(); auto Re_sq_norm = Re_error.getSquareNorm(); mrcpp::print::value(0, "Integral of Re(Exp(delta_t) f(x) - g(x)) =", Re_integral); mrcpp::print::value(0, "Square norm of Re(Exp(delta_t) f(x) - g(x)) =", Re_sq_norm); - + // Im_error = Im_fout_tree - Im_g_tree add(prec, Im_error, 1.0, Im_fout_tree, -1.0, Im_g_tree); auto Im_integral = Im_error.integrate(); auto Im_sq_norm = Im_error.getSquareNorm(); mrcpp::print::value(0, "Integral of Im(Exp(delta_t) f(x) - g(x)) =", Im_integral); mrcpp::print::value(0, "Square norm of Im(Exp(delta_t) f(x) - g(x)) =", Im_sq_norm); - + mrcpp::print::header(0, "Saving plots to files"); mrcpp::print::footer(0, timer, 2); // Set plotting parameters - int nPts = 1000; - mrcpp::Coord<1> o{0.0}; - mrcpp::Coord<1> a{1.0}; - mrcpp::Plotter<1> plot(o); + int nPts = 1000; + mrcpp::Coord<1> o{0.0}; + mrcpp::Coord<1> a{1.0}; + mrcpp::Plotter<1> plot(o); plot.setRange(a); - plot.linePlot({nPts}, Re_error, "Re_error"); // Write to file Re_error.line - plot.linePlot({nPts}, Im_error, "Im_error"); // Write to file Im_error.line - plot.linePlot({nPts}, Re_f_tree, "Re_f_tree"); // Write to file Re_f_tree.line - plot.linePlot({nPts}, Im_f_tree, "Im_f_tree"); // Write to file Im_f_tree.line - plot.linePlot({nPts}, Re_g_tree, "Re_g_tree"); // Write to file Re_g_tree.line - plot.linePlot({nPts}, Im_g_tree, "Im_g_tree"); // Write to file Im_g_tree.line - + plot.linePlot({nPts}, Re_error, "Re_error"); // Write to file Re_error.line + plot.linePlot({nPts}, Im_error, "Im_error"); // Write to file Im_error.line + plot.linePlot({nPts}, Re_f_tree, "Re_f_tree"); // Write to file Re_f_tree.line + plot.linePlot({nPts}, Im_f_tree, "Im_f_tree"); // Write to file Im_f_tree.line + plot.linePlot({nPts}, Re_g_tree, "Re_g_tree"); // Write to file Re_g_tree.line + plot.linePlot({nPts}, Im_g_tree, "Im_g_tree"); // Write to file Im_g_tree.line + mrcpp::print::footer(0, timer, 2); return 0; } - diff --git a/examples/tree_cleaner.cpp b/examples/tree_cleaner.cpp index 6d970c5e3..350f98e40 100644 --- a/examples/tree_cleaner.cpp +++ b/examples/tree_cleaner.cpp @@ -9,6 +9,7 @@ const auto order = 7; const auto prec = 1.0e-5; const auto D = 3; + int main(int argc, char **argv) { auto timer = mrcpp::Timer(); @@ -42,14 +43,14 @@ int main(int argc, char **argv) { auto iter = 0; auto n_nodes = 1; while (n_nodes > 0) { - mrcpp::project(-1.0, f_tree, f); // Projecting on fixed grid + mrcpp::project(-1.0, f_tree, f); // Projecting on fixed grid n_nodes = mrcpp::refine_grid(f_tree, prec); // Refine grid mrcpp::clear_grid(f_tree); // Clear MW coefs printout(0, " iter " << std::setw(3) << iter++ << std::setw(45)); printout(0, " n_nodes " << std::setw(5) << n_nodes << std::endl); } // Projecting on final converged grid - mrcpp::project(-1.0, f_tree, f); + mrcpp::project(-1.0, f_tree, f); auto integral = f_tree.integrate(); auto sq_norm = f_tree.getSquareNorm(); diff --git a/src/functions/AnalyticFunction.h b/src/functions/AnalyticFunction.h index abf0fcbd6..aca20285b 100644 --- a/src/functions/AnalyticFunction.h +++ b/src/functions/AnalyticFunction.h @@ -32,29 +32,27 @@ namespace mrcpp { -template class AnalyticFunction : public RepresentableFunction { +template class AnalyticFunction : public RepresentableFunction { public: AnalyticFunction() = default; ~AnalyticFunction() override = default; - AnalyticFunction(std::function &r)> f, const double *a = nullptr, const double *b = nullptr) - : RepresentableFunction(a, b) + AnalyticFunction(std::function &r)> f, const double *a = nullptr, const double *b = nullptr) + : RepresentableFunction(a, b) , func(f) {} - AnalyticFunction(std::function &r)> f, - const std::vector &a, - const std::vector &b) + AnalyticFunction(std::function &r)> f, const std::vector &a, const std::vector &b) : AnalyticFunction(f, a.data(), b.data()) {} - void set(std::function &r)> f) { this->func = f; } + void set(std::function &r)> f) { this->func = f; } - double evalf(const Coord &r) const override { - double val = 0.0; + T evalf(const Coord &r) const override { + T val = 0.0; if (not this->outOfBounds(r)) val = this->func(r); return val; } protected: - std::function &r)> func; + std::function &r)> func; }; } // namespace mrcpp diff --git a/src/functions/BoysFunction.cpp b/src/functions/BoysFunction.cpp index 71b705139..7b9f1ddb5 100644 --- a/src/functions/BoysFunction.cpp +++ b/src/functions/BoysFunction.cpp @@ -32,7 +32,7 @@ namespace mrcpp { BoysFunction::BoysFunction(int n, double p) - : RepresentableFunction<1>() + : RepresentableFunction<1, double>() , order(n) , prec(p) , MRA(BoundingBox<1>(), InterpolatingBasis(13)) {} @@ -50,8 +50,8 @@ double BoysFunction::evalf(const Coord<1> &r) const { return std::exp(-xt_2) * t_2n; }; - FunctionTree<1> tree(this->MRA); - mrcpp::project<1>(this->prec, tree, f); + FunctionTree<1, double> tree(this->MRA); + mrcpp::project<1, double>(this->prec, tree, f); double result = tree.integrate(); Printer::setPrintLevel(oldlevel); diff --git a/src/functions/BoysFunction.h b/src/functions/BoysFunction.h index 4dc76bd72..f8b8824d1 100644 --- a/src/functions/BoysFunction.h +++ b/src/functions/BoysFunction.h @@ -30,7 +30,7 @@ namespace mrcpp { -class BoysFunction final : public RepresentableFunction<1> { +class BoysFunction final : public RepresentableFunction<1, double> { public: BoysFunction(int n, double prec = 1.0e-10); diff --git a/src/functions/GaussExp.h b/src/functions/GaussExp.h index aa6ad4da3..a4315e381 100644 --- a/src/functions/GaussExp.h +++ b/src/functions/GaussExp.h @@ -51,7 +51,7 @@ namespace mrcpp { * */ -template class GaussExp : public RepresentableFunction { +template class GaussExp : public RepresentableFunction { public: GaussExp(int nTerms = 0, double prec = GAUSS_EXP_PREC); GaussExp(const GaussExp &gExp); diff --git a/src/functions/Gaussian.h b/src/functions/Gaussian.h index d02cc43b1..ddb039202 100644 --- a/src/functions/Gaussian.h +++ b/src/functions/Gaussian.h @@ -40,7 +40,7 @@ namespace mrcpp { -template class Gaussian : public RepresentableFunction { +template class Gaussian : public RepresentableFunction { public: Gaussian(double a, double c, const Coord &r, const std::array &p); Gaussian(const std::array &a, double c, const Coord &r, const std::array &p); diff --git a/src/functions/JpowerIntegrals.cpp b/src/functions/JpowerIntegrals.cpp index 0d0d43181..179f6fcc6 100644 --- a/src/functions/JpowerIntegrals.cpp +++ b/src/functions/JpowerIntegrals.cpp @@ -24,47 +24,37 @@ */ #include "JpowerIntegrals.h" -#include // std::find_if_not - +#include // std::find_if_not namespace mrcpp { - -JpowerIntegrals::JpowerIntegrals(double a, int scaling, int M, double threshold) -{ +JpowerIntegrals::JpowerIntegrals(double a, int scaling, int M, double threshold) { this->scaling = scaling; int N = 1 << scaling; - for(int l = 0; l < N; l++ ) - integrals.push_back( calculate_J_power_integrals(l, a, M, threshold) ); - for(int l = 1 - N; l < 0; l++ ) - integrals.push_back( calculate_J_power_integrals(l, a, M, threshold) ); + for (int l = 0; l < N; l++) integrals.push_back(calculate_J_power_integrals(l, a, M, threshold)); + for (int l = 1 - N; l < 0; l++) integrals.push_back(calculate_J_power_integrals(l, a, M, threshold)); } - /// @brief in progress /// @param index - interger lying in the interval \f$ [ -2^n + 1, \ldots, 2^n - 1 ] \f$. /// @return in progress -std::vector> & JpowerIntegrals::operator[](int index) -{ - if( index < 0 ) index += integrals.size(); +std::vector> &JpowerIntegrals::operator[](int index) { + if (index < 0) index += integrals.size(); return integrals[index]; } -std::vector> JpowerIntegrals::calculate_J_power_integrals(int l, double a, int M, double threshold) -{ +std::vector> JpowerIntegrals::calculate_J_power_integrals(int l, double a, int M, double threshold) { using namespace std::complex_literals; std::complex J_0 = 0.25 * std::exp(-0.25i * M_PI) / std::sqrt(M_PI * a) * std::exp(0.25i * static_cast(l * l) / a); std::complex beta(0, 0.5 / a); auto alpha = static_cast(l) * beta; - + std::vector> J = {0.0, J_0}; - for (int m = 0; m < M; m++) - { + for (int m = 0; m < M; m++) { std::complex term1 = J[J.size() - 1] * alpha; - std::complex term2 - = J[J.size() - 2] * beta * static_cast(m) / static_cast(m + 2); + std::complex term2 = J[J.size() - 2] * beta * static_cast(m) / static_cast(m + 2); std::complex last = (term1 + term2) / static_cast(m + 3); J.push_back(last); } @@ -73,14 +63,10 @@ std::vector> JpowerIntegrals::calculate_J_power_integrals(i return J; } - /// @details Removes negligible elements in \b J until it reaches a considerable value. -void JpowerIntegrals::crop(std::vector> & J, double threshold) -{ +void JpowerIntegrals::crop(std::vector> &J, double threshold) { // Lambda function to check if an element is negligible - auto isNegligible = [threshold](const std::complex& c) { - return std::abs(c.real()) < threshold && std::abs(c.imag()) < threshold; - }; + auto isNegligible = [threshold](const std::complex &c) { return std::abs(c.real()) < threshold && std::abs(c.imag()) < threshold; }; // Remove negligible elements from the end of the vector J.erase(std::find_if_not(J.rbegin(), J.rend(), isNegligible).base(), J.end()); } diff --git a/src/functions/Polynomial.cpp b/src/functions/Polynomial.cpp index 397b4e268..c54acc148 100644 --- a/src/functions/Polynomial.cpp +++ b/src/functions/Polynomial.cpp @@ -45,7 +45,7 @@ namespace mrcpp { /** Construct polynomial of order zero with given size and bounds. * Includes default constructor. */ Polynomial::Polynomial(int k, const double *a, const double *b) - : RepresentableFunction<1>(a, b) { + : RepresentableFunction<1, double>(a, b) { assert(k >= 0); this->N = 1.0; this->L = 0.0; @@ -88,8 +88,8 @@ Polynomial &Polynomial::operator=(const Polynomial &poly) { /** Evaluate scaled and translated polynomial */ double Polynomial::evalf(double x) const { if (isBounded()) { - if (x < this->getScaledLowerBound() ) return 0.0; - if (x > this->getScaledUpperBound() ) return 0.0; + if (x < this->getScaledLowerBound()) return 0.0; + if (x > this->getScaledUpperBound()) return 0.0; } double xp = 1.0; double y = 0.0; @@ -146,12 +146,8 @@ Polynomial &Polynomial::operator*=(double c) { /** Calculate P = P*Q */ Polynomial &Polynomial::operator*=(const Polynomial &Q) { Polynomial &P = *this; - if (std::abs(P.getDilation() - Q.getDilation()) > MachineZero) { - MSG_ERROR("Polynomials not defined on same scale."); - } - if (std::abs(P.getTranslation() - Q.getTranslation()) > MachineZero) { - MSG_ERROR("Polynomials not defined on same translation."); - } + if (std::abs(P.getDilation() - Q.getDilation()) > MachineZero) { MSG_ERROR("Polynomials not defined on same scale."); } + if (std::abs(P.getTranslation() - Q.getTranslation()) > MachineZero) { MSG_ERROR("Polynomials not defined on same translation."); } int P_order = P.getOrder(); int Q_order = Q.getOrder(); @@ -197,12 +193,8 @@ Polynomial &Polynomial::operator-=(const Polynomial &Q) { /** Calculate P = P + c*Q. */ void Polynomial::addInPlace(double c, const Polynomial &Q) { Polynomial &P = *this; - if (std::abs(P.getDilation() - Q.getDilation()) > MachineZero) { - MSG_ERROR("Polynomials not defined on same scale."); - } - if (std::abs(P.getTranslation() - Q.getTranslation()) > MachineZero) { - MSG_ERROR("Polynomials not defined on same translation."); - } + if (std::abs(P.getDilation() - Q.getDilation()) > MachineZero) { MSG_ERROR("Polynomials not defined on same scale."); } + if (std::abs(P.getTranslation() - Q.getTranslation()) > MachineZero) { MSG_ERROR("Polynomials not defined on same translation."); } int P_order = P.getOrder(); int Q_order = Q.getOrder(); diff --git a/src/functions/Polynomial.h b/src/functions/Polynomial.h index e1c23e4a6..93e3ec77d 100644 --- a/src/functions/Polynomial.h +++ b/src/functions/Polynomial.h @@ -44,7 +44,7 @@ namespace mrcpp { -class Polynomial : public RepresentableFunction<1> { +class Polynomial : public RepresentableFunction<1, double> { public: Polynomial(int k = 0, const double *a = nullptr, const double *b = nullptr); Polynomial(int k, const std::vector &a, const std::vector &b) @@ -74,7 +74,7 @@ class Polynomial : public RepresentableFunction<1> { void setDilation(double n) { this->N = n; } void setTranslation(double l) { this->L = l; } void dilate(double n) { this->N *= n; } - void translate(double l) { this->L += this->N*l; } + void translate(double l) { this->L += this->N * l; } int size() const { return this->coefs.size(); } ///< Length of coefs vector int getOrder() const; diff --git a/src/functions/RepresentableFunction.cpp b/src/functions/RepresentableFunction.cpp index 8687297c7..3c55ac92b 100644 --- a/src/functions/RepresentableFunction.cpp +++ b/src/functions/RepresentableFunction.cpp @@ -38,7 +38,7 @@ namespace mrcpp { -template RepresentableFunction::RepresentableFunction(const double *a, const double *b) { +template RepresentableFunction::RepresentableFunction(const double *a, const double *b) { if (a == nullptr or b == nullptr) { this->bounded = false; this->A = nullptr; @@ -56,7 +56,7 @@ template RepresentableFunction::RepresentableFunction(const double *a } /** Constructs a new function with same bounds as the input function */ -template RepresentableFunction::RepresentableFunction(const RepresentableFunction &func) { +template RepresentableFunction::RepresentableFunction(const RepresentableFunction &func) { if (func.isBounded()) { this->bounded = true; this->A = new double[D]; @@ -74,11 +74,11 @@ template RepresentableFunction::RepresentableFunction(const Represent /** Copies function, not bounds. Use copy constructor if you want an * identical function. */ -template RepresentableFunction &RepresentableFunction::operator=(const RepresentableFunction &func) { +template RepresentableFunction &RepresentableFunction::operator=(const RepresentableFunction &func) { return *this; } -template RepresentableFunction::~RepresentableFunction() { +template RepresentableFunction::~RepresentableFunction() { if (this->isBounded()) { delete[] this->A; delete[] this->B; @@ -87,7 +87,7 @@ template RepresentableFunction::~RepresentableFunction() { this->B = nullptr; } -template void RepresentableFunction::setBounds(const double *a, const double *b) { +template void RepresentableFunction::setBounds(const double *a, const double *b) { if (a == nullptr or b == nullptr) { MSG_ERROR("Invalid arguments"); } if (not isBounded()) { this->bounded = true; @@ -101,7 +101,7 @@ template void RepresentableFunction::setBounds(const double *a, const } } -template bool RepresentableFunction::outOfBounds(const Coord &r) const { +template bool RepresentableFunction::outOfBounds(const Coord &r) const { if (not isBounded()) { return false; } for (int d = 0; d < D; d++) { if (r[d] < getLowerBound(d)) return true; @@ -110,8 +110,11 @@ template bool RepresentableFunction::outOfBounds(const Coord &r) c return false; } -template class RepresentableFunction<1>; -template class RepresentableFunction<2>; -template class RepresentableFunction<3>; +template class RepresentableFunction<1, double>; +template class RepresentableFunction<2, double>; +template class RepresentableFunction<3, double>; +template class RepresentableFunction<1, ComplexDouble>; +template class RepresentableFunction<2, ComplexDouble>; +template class RepresentableFunction<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/functions/RepresentableFunction.h b/src/functions/RepresentableFunction.h index 2d6998812..6123e3051 100644 --- a/src/functions/RepresentableFunction.h +++ b/src/functions/RepresentableFunction.h @@ -37,21 +37,22 @@ #include "MRCPP/constants.h" #include "MRCPP/mrcpp_declarations.h" +#include "MRCPP/utils/math_utils.h" #include "trees/NodeIndex.h" namespace mrcpp { -template class RepresentableFunction { +template class RepresentableFunction { public: RepresentableFunction(const double *a = nullptr, const double *b = nullptr); RepresentableFunction(const std::vector &a, const std::vector &b) : RepresentableFunction(a.data(), b.data()) {} - RepresentableFunction(const RepresentableFunction &func); - RepresentableFunction &operator=(const RepresentableFunction &func); + RepresentableFunction(const RepresentableFunction &func); + RepresentableFunction &operator=(const RepresentableFunction &func); virtual ~RepresentableFunction(); /** @returns Function value in a point @param[in] r: Cartesian coordinate */ - virtual double evalf(const Coord &r) const = 0; + virtual T evalf(const Coord &r) const = 0; void setBounds(const double *a, const double *b); void clearBounds(); @@ -65,7 +66,7 @@ template class RepresentableFunction { const double *getLowerBounds() const { return this->A; } const double *getUpperBounds() const { return this->B; } - friend class AnalyticAdaptor; + friend class AnalyticAdaptor; protected: bool bounded; diff --git a/src/functions/function_utils.cpp b/src/functions/function_utils.cpp index 60b287e1c..598c9b12a 100644 --- a/src/functions/function_utils.cpp +++ b/src/functions/function_utils.cpp @@ -33,9 +33,7 @@ double ObaraSaika_ab(int power_a, int power_b, double pos_a, double pos_b, doubl template double function_utils::calc_overlap(const GaussFunc &a, const GaussFunc &b) { double S = 1.0; - for (int d = 0; d < D; d++) { - S *= ObaraSaika_ab(a.getPower()[d], b.getPower()[d], a.getPos()[d], b.getPos()[d], a.getExp()[d], b.getExp()[d]); - } + for (int d = 0; d < D; d++) { S *= ObaraSaika_ab(a.getPower()[d], b.getPower()[d], a.getPos()[d], b.getPos()[d], a.getExp()[d], b.getExp()[d]); } S *= a.getCoef() * b.getCoef(); return S; } @@ -117,4 +115,5 @@ double function_utils::ObaraSaika_ab(int power_a, int power_b, double pos_a, dou template double function_utils::calc_overlap<1>(const GaussFunc<1> &a, const GaussFunc<1> &b); template double function_utils::calc_overlap<2>(const GaussFunc<2> &a, const GaussFunc<2> &b); template double function_utils::calc_overlap<3>(const GaussFunc<3> &a, const GaussFunc<3> &b); + } // namespace mrcpp diff --git a/src/operators/ABGVOperator.cpp b/src/operators/ABGVOperator.cpp index ca7d34580..05525405e 100644 --- a/src/operators/ABGVOperator.cpp +++ b/src/operators/ABGVOperator.cpp @@ -49,8 +49,7 @@ ABGVOperator::ABGVOperator(const MultiResolutionAnalysis &mra, double a, d initialize(a, b); } -template -void ABGVOperator::initialize(double a, double b) { +template void ABGVOperator::initialize(double a, double b) { int bw = 0; // Operator bandwidth if (std::abs(a) > MachineZero) bw = 1; if (std::abs(b) > MachineZero) bw = 1; diff --git a/src/operators/CartesianConvolution.cpp b/src/operators/CartesianConvolution.cpp index 432d3c02b..64ac5491d 100644 --- a/src/operators/CartesianConvolution.cpp +++ b/src/operators/CartesianConvolution.cpp @@ -28,8 +28,8 @@ #include "core/InterpolatingBasis.h" #include "core/LegendreBasis.h" -#include "functions/Gaussian.h" #include "functions/GaussExp.h" +#include "functions/Gaussian.h" #include "treebuilders/CrossCorrelationCalculator.h" #include "treebuilders/OperatorAdaptor.h" @@ -68,9 +68,9 @@ CartesianConvolution::CartesianConvolution(const MultiResolutionAnalysis<3> &mra } void CartesianConvolution::setCartesianComponents(int x, int y, int z) { - int x_shift = x*this->sep_rank; - int y_shift = y*this->sep_rank; - int z_shift = z*this->sep_rank; + int x_shift = x * this->sep_rank; + int y_shift = y * this->sep_rank; + int z_shift = z * this->sep_rank; for (int i = 0; i < this->sep_rank; i++) this->assign(i, 0, this->raw_exp[x_shift + i].get()); for (int i = 0; i < this->sep_rank; i++) this->assign(i, 1, this->raw_exp[y_shift + i].get()); diff --git a/src/operators/ConvolutionOperator.cpp b/src/operators/ConvolutionOperator.cpp index 26bf44e72..9d37929aa 100644 --- a/src/operators/ConvolutionOperator.cpp +++ b/src/operators/ConvolutionOperator.cpp @@ -28,8 +28,8 @@ #include "core/InterpolatingBasis.h" #include "core/LegendreBasis.h" -#include "functions/Gaussian.h" #include "functions/GaussExp.h" +#include "functions/Gaussian.h" #include "treebuilders/CrossCorrelationCalculator.h" #include "treebuilders/OperatorAdaptor.h" @@ -75,8 +75,7 @@ ConvolutionOperator::ConvolutionOperator(const MultiResolutionAnalysis &mr Printer::setPrintLevel(oldlevel); } -template -void ConvolutionOperator::initialize(GaussExp<1> &kernel, double k_prec, double o_prec) { +template void ConvolutionOperator::initialize(GaussExp<1> &kernel, double k_prec, double o_prec) { auto k_mra = this->getKernelMRA(); auto o_mra = this->getOperatorMRA(); @@ -86,10 +85,10 @@ void ConvolutionOperator::initialize(GaussExp<1> &kernel, double k_prec, doub for (int i = 0; i < kernel.size(); i++) { // Rescale Gaussian for D-dim application auto *k_func = kernel.getFunc(i).copy(); - k_func->setCoef( std::copysign( std::pow(std::abs(k_func->getCoef()), 1.0/D), k_func->getCoef() ) ); + k_func->setCoef(std::copysign(std::pow(std::abs(k_func->getCoef()), 1.0 / D), k_func->getCoef())); FunctionTree<1> k_tree(k_mra); - mrcpp::build_grid(k_tree, *k_func); // Generate empty grid to hold narrow Gaussian + mrcpp::build_grid(k_tree, *k_func); // Generate empty grid to hold narrow Gaussian mrcpp::project(k_prec, k_tree, *k_func); // Project Gaussian starting from the empty grid delete k_func; @@ -108,8 +107,7 @@ void ConvolutionOperator::initialize(GaussExp<1> &kernel, double k_prec, doub } } -template -MultiResolutionAnalysis<1> ConvolutionOperator::getKernelMRA() const { +template MultiResolutionAnalysis<1> ConvolutionOperator::getKernelMRA() const { const BoundingBox &box = this->MRA.getWorldBox(); const ScalingBasis &basis = this->MRA.getScalingBasis(); diff --git a/src/operators/ConvolutionOperator.h b/src/operators/ConvolutionOperator.h index c9879d2a2..33d254e9d 100644 --- a/src/operators/ConvolutionOperator.h +++ b/src/operators/ConvolutionOperator.h @@ -32,7 +32,7 @@ namespace mrcpp { /** @class ConvolutionOperator * * @brief Convolution defined by a Gaussian expansion - * + * * @details Represents the operator * \f[ * T = \sum_{m=1}^M @@ -51,13 +51,13 @@ namespace mrcpp { * \sum_{m=1}^M \alpha_m \exp \left( - \beta_m |x|^2 \right) * \f] * which is passed as a parameter to the first two constructors. - * + * * @note Every \f$ T_d \left( \beta_m, \sqrt[D]{| \alpha_m |} \right) \f$ is the same * operator associated with the one-dimensional variable \f$ x_d \f$ for \f$ d = 1, \ldots, D \f$. - * + * * \todo: One may want to change the logic so that \f$ D \f$-root is evaluated on the previous step, * namely, when \f$ \alpha_m, \beta_m \f$ are calculated. - * + * */ template class ConvolutionOperator : public MWOperator { public: @@ -71,9 +71,9 @@ template class ConvolutionOperator : public MWOperator { protected: ConvolutionOperator(const MultiResolutionAnalysis &mra) - : MWOperator(mra, mra.getRootScale(), -10) {} + : MWOperator(mra, mra.getRootScale(), -10) {} ConvolutionOperator(const MultiResolutionAnalysis &mra, int root, int reach) - : MWOperator(mra, root, reach) {} + : MWOperator(mra, root, reach) {} void initialize(GaussExp<1> &kernel, double k_prec, double o_prec); void setBuildPrec(double prec) { this->build_prec = prec; } diff --git a/src/operators/HeatKernel.h b/src/operators/HeatKernel.h index b0303eee8..bc5a8adba 100644 --- a/src/operators/HeatKernel.h +++ b/src/operators/HeatKernel.h @@ -49,7 +49,7 @@ namespace mrcpp { * t > 0 * . * \f] - * + * */ template class HeatKernel final : public GaussExp<1> { public: diff --git a/src/operators/HeatOperator.cpp b/src/operators/HeatOperator.cpp index 4cd980c54..cad3d9139 100644 --- a/src/operators/HeatOperator.cpp +++ b/src/operators/HeatOperator.cpp @@ -36,7 +36,7 @@ namespace mrcpp { * @param[in] prec: Build precision * @details This will project a kernel of a single gaussian with * exponent \f$ 1/(4t) \f$. - * + * */ template HeatOperator::HeatOperator(const MultiResolutionAnalysis &mra, double t, double prec) @@ -64,11 +64,11 @@ HeatOperator::HeatOperator(const MultiResolutionAnalysis &mra, double t, d * @details This will project a kernel of a single gaussian with * exponent \f$ 1/(4t) \f$. * This version of the constructor - * is used for calculations within periodic boundary conditions (PBC). + * is used for calculations within periodic boundary conditions (PBC). * The \a root parameter is the coarsest negative scale at wich the operator * is applied. The \a reach parameter is the bandwidth of the operator at * the root scale. For details see \ref MWOperator - * + * */ template HeatOperator::HeatOperator(const MultiResolutionAnalysis &mra, double t, double prec, int root, int reach) diff --git a/src/operators/HeatOperator.h b/src/operators/HeatOperator.h index aabc60658..f96560a81 100644 --- a/src/operators/HeatOperator.h +++ b/src/operators/HeatOperator.h @@ -54,7 +54,7 @@ namespace mrcpp { * t > 0 * . * \f] - * + * */ template class HeatOperator final : public ConvolutionOperator { public: diff --git a/src/operators/IdentityConvolution.cpp b/src/operators/IdentityConvolution.cpp index 5b8bde3af..038d076cc 100644 --- a/src/operators/IdentityConvolution.cpp +++ b/src/operators/IdentityConvolution.cpp @@ -60,7 +60,7 @@ IdentityConvolution::IdentityConvolution(const MultiResolutionAnalysis &mr * @param[in] reach: width at root scale (applies to periodic boundary conditions) * @details This will project a kernel of a single gaussian with * exponent sqrt(10/build_prec). This version of the constructor - * is used for calculations within periodic boundary conditions (PBC). + * is used for calculations within periodic boundary conditions (PBC). * The \a root parameter is the coarsest negative scale at wich the operator * is applied. The \a reach parameter is the bandwidth of the operator at * the root scale. For details see \ref MWOperator diff --git a/src/operators/MWOperator.cpp b/src/operators/MWOperator.cpp index 428e5fc1f..225108f48 100644 --- a/src/operators/MWOperator.cpp +++ b/src/operators/MWOperator.cpp @@ -32,8 +32,7 @@ using namespace Eigen; namespace mrcpp { -template -void MWOperator::initOperExp(int M) { +template void MWOperator::initOperExp(int M) { if (this->raw_exp.size() < M) MSG_ABORT("Incompatible raw expansion"); this->oper_exp.clear(); for (int m = 0; m < M; m++) { @@ -47,24 +46,21 @@ void MWOperator::initOperExp(int M) { for (int d = 0; d < D; d++) assign(i, d, this->raw_exp[i].get()); } -template -OperatorTree &MWOperator::getComponent(int i, int d) { +template OperatorTree &MWOperator::getComponent(int i, int d) { if (i < 0 or i >= this->oper_exp.size()) MSG_ERROR("Index out of bounds"); if (d < 0 or d >= D) MSG_ERROR("Dimension out of bounds"); if (this->oper_exp[i][d] == nullptr) MSG_ERROR("Invalid component"); return *this->oper_exp[i][d]; } -template -const OperatorTree &MWOperator::getComponent(int i, int d) const { +template const OperatorTree &MWOperator::getComponent(int i, int d) const { if (i < 0 or i >= this->oper_exp.size()) MSG_ERROR("Index out of bounds"); if (d < 0 or d >= D) MSG_ERROR("Dimension out of bounds"); if (this->oper_exp[i][d] == nullptr) MSG_ERROR("Invalid component"); return *this->oper_exp[i][d]; } -template -int MWOperator::getMaxBandWidth(int depth) const { +template int MWOperator::getMaxBandWidth(int depth) const { int maxWidth = -1; if (depth < 0) { maxWidth = *std::max_element(this->band_max.begin(), this->band_max.end()); @@ -74,14 +70,12 @@ int MWOperator::getMaxBandWidth(int depth) const { return maxWidth; } -template -void MWOperator::clearBandWidths() { +template void MWOperator::clearBandWidths() { for (auto &i : this->oper_exp) for (int d = 0; d < D; d++) i[d]->clearBandWidth(); } -template -void MWOperator::calcBandWidths(double prec) { +template void MWOperator::calcBandWidths(double prec) { int maxDepth = 0; // First compute BandWidths and find depth of the deepest component for (auto &i : this->oper_exp) { @@ -113,8 +107,7 @@ void MWOperator::calcBandWidths(double prec) { println(20, std::endl); } -template -MultiResolutionAnalysis<2> MWOperator::getOperatorMRA() const { +template MultiResolutionAnalysis<2> MWOperator::getOperatorMRA() const { const BoundingBox &box = this->MRA.getWorldBox(); const ScalingBasis &basis = this->MRA.getScalingBasis(); diff --git a/src/operators/MWOperator.h b/src/operators/MWOperator.h index 4e3962fdc..2dcad2b32 100644 --- a/src/operators/MWOperator.h +++ b/src/operators/MWOperator.h @@ -39,8 +39,7 @@ namespace mrcpp { * @details Fixme * */ -template -class MWOperator { +template class MWOperator { public: MWOperator(const MultiResolutionAnalysis &mra, int root, int reach) : oper_root(root) @@ -63,8 +62,8 @@ class MWOperator { OperatorTree &getComponent(int i, int d); const OperatorTree &getComponent(int i, int d) const; - std::array &operator[](int i) { return this->oper_exp[i]; } - const std::array &operator[](int i) const { return this->oper_exp[i]; } + std::array &operator[](int i) { return this->oper_exp[i]; } + const std::array &operator[](int i) const { return this->oper_exp[i]; } protected: int oper_root; @@ -78,7 +77,6 @@ class MWOperator { void initOperExp(int M); void assign(int i, int d, OperatorTree *oper) { this->oper_exp[i][d] = oper; } - }; } // namespace mrcpp diff --git a/src/operators/OperatorState.h b/src/operators/OperatorState.h index 245d9f70f..677375632 100644 --- a/src/operators/OperatorState.h +++ b/src/operators/OperatorState.h @@ -42,9 +42,9 @@ namespace mrcpp { #define GET_OP_IDX(FT, GT, ID) (2 * ((GT >> ID) & 1) + ((FT >> ID) & 1)) -template class OperatorState final { +template class OperatorState final { public: - OperatorState(MWNode &gn, double *scr1) + OperatorState(MWNode &gn, T *scr1) : gNode(&gn) { this->kp1 = this->gNode->getKp1(); this->kp1_d = this->gNode->getKp1_d(); @@ -53,7 +53,7 @@ template class OperatorState final { this->gData = this->gNode->getCoefs(); this->maxDeltaL = -1; - double *scr2 = scr1 + this->kp1_d; + T *scr2 = scr1 + this->kp1_d; for (int i = 1; i < D; i++) { if (IS_ODD(i)) { @@ -64,9 +64,9 @@ template class OperatorState final { } } - OperatorState(MWNode &gn, std::vector scr1) + OperatorState(MWNode &gn, std::vector scr1) : OperatorState(gn, scr1.data()) {} - void setFNode(MWNode &fn) { + void setFNode(MWNode &fn) { this->fNode = &fn; this->fData = this->fNode->getCoefs(); } @@ -86,15 +86,16 @@ template class OperatorState final { int getMaxDeltaL() const { return this->maxDeltaL; } int getOperIndex(int i) const { return GET_OP_IDX(this->ft, this->gt, i); } - double **getAuxData() { return this->aux; } + T **getAuxData() { return this->aux; } double **getOperData() { return this->oData; } - friend class ConvolutionCalculator; - friend class DerivativeCalculator; + friend class ConvolutionCalculator; + friend class DerivativeCalculator; private: int ft; int gt; + int maxDeltaL; double fThreshold; double gThreshold; @@ -104,13 +105,13 @@ template class OperatorState final { int kp1_d; int kp1_dm1; - MWNode *gNode; - MWNode *fNode; + MWNode *gNode; + MWNode *fNode; NodeIndex *fIdx; - double *aux[D + 1]; - double *gData; - double *fData; + T *aux[D + 1]; + T *gData; + T *fData; double *oData[D]; void calcMaxDeltaL() { diff --git a/src/operators/OperatorStatistics.cpp b/src/operators/OperatorStatistics.cpp index d542e88f5..f58ae2b0d 100644 --- a/src/operators/OperatorStatistics.cpp +++ b/src/operators/OperatorStatistics.cpp @@ -30,8 +30,7 @@ using namespace Eigen; namespace mrcpp { -template -OperatorStatistics::OperatorStatistics() +OperatorStatistics::OperatorStatistics() : nThreads(mrcpp_get_max_threads()) , totFCount(0) , totGCount(0) @@ -58,7 +57,7 @@ OperatorStatistics::OperatorStatistics() } } -template OperatorStatistics::~OperatorStatistics() { +OperatorStatistics::~OperatorStatistics() { for (int i = 0; i < this->nThreads; i++) { delete this->compCount[i]; } delete[] this->compCount; delete[] this->fCount; @@ -68,7 +67,7 @@ template OperatorStatistics::~OperatorStatistics() { } /** Sum all node counters from all threads. */ -template void OperatorStatistics::flushNodeCounters() { +void OperatorStatistics::flushNodeCounters() { for (int i = 0; i < this->nThreads; i++) { this->totFCount += this->fCount[i]; this->totGCount += this->gCount[i]; @@ -82,20 +81,20 @@ template void OperatorStatistics::flushNodeCounters() { } /** Increment g-node usage counter. Needed for load balancing. */ -template void OperatorStatistics::incrementGNodeCounters(const MWNode &gNode) { +template void OperatorStatistics::incrementGNodeCounters(const MWNode &gNode) { int thread = mrcpp_get_thread_num(); this->gCount[thread]++; } /** Increment operator application counter. */ -template void OperatorStatistics::incrementFNodeCounters(const MWNode &fNode, int ft, int gt) { +template void OperatorStatistics::incrementFNodeCounters(const MWNode &fNode, int ft, int gt) { int thread = mrcpp_get_thread_num(); this->fCount[thread]++; (*this->compCount[thread])(ft, gt) += 1; if (fNode.isGenNode()) { this->genCount[thread]++; } } -template std::ostream &OperatorStatistics::print(std::ostream &o) const { +std::ostream &OperatorStatistics::print(std::ostream &o) const { o << std::setw(8); o << "*OperatorFunc statistics: " << std::endl << std::endl; o << " Total calculated gNodes : " << this->totGCount << std::endl; @@ -105,8 +104,17 @@ template std::ostream &OperatorStatistics::print(std::ostream &o) con return o; } -template class OperatorStatistics<1>; -template class OperatorStatistics<2>; -template class OperatorStatistics<3>; +template void OperatorStatistics::incrementFNodeCounters<1, double>(const MWNode<1, double> &fNode, int ft, int gt); +template void OperatorStatistics::incrementFNodeCounters<2, double>(const MWNode<2, double> &fNode, int ft, int gt); +template void OperatorStatistics::incrementFNodeCounters<3, double>(const MWNode<3, double> &fNode, int ft, int gt); +template void OperatorStatistics::incrementFNodeCounters<1, ComplexDouble>(const MWNode<1, ComplexDouble> &fNode, int ft, int gt); +template void OperatorStatistics::incrementFNodeCounters<2, ComplexDouble>(const MWNode<2, ComplexDouble> &fNode, int ft, int gt); +template void OperatorStatistics::incrementFNodeCounters<3, ComplexDouble>(const MWNode<3, ComplexDouble> &fNode, int ft, int gt); +template void OperatorStatistics::incrementGNodeCounters<1, double>(const MWNode<1, double> &gNode); +template void OperatorStatistics::incrementGNodeCounters<2, double>(const MWNode<2, double> &gNode); +template void OperatorStatistics::incrementGNodeCounters<3, double>(const MWNode<3, double> &gNode); +template void OperatorStatistics::incrementGNodeCounters<1, ComplexDouble>(const MWNode<1, ComplexDouble> &gNode); +template void OperatorStatistics::incrementGNodeCounters<2, ComplexDouble>(const MWNode<2, ComplexDouble> &gNode); +template void OperatorStatistics::incrementGNodeCounters<3, ComplexDouble>(const MWNode<3, ComplexDouble> &gNode); } // namespace mrcpp diff --git a/src/operators/OperatorStatistics.h b/src/operators/OperatorStatistics.h index 395a5d62a..9a51728c0 100644 --- a/src/operators/OperatorStatistics.h +++ b/src/operators/OperatorStatistics.h @@ -32,14 +32,14 @@ namespace mrcpp { -template class OperatorStatistics final { +class OperatorStatistics final { public: OperatorStatistics(); ~OperatorStatistics(); void flushNodeCounters(); - void incrementFNodeCounters(const MWNode &fNode, int ft, int gt); - void incrementGNodeCounters(const MWNode &gNode); + template void incrementFNodeCounters(const MWNode &fNode, int ft, int gt); + template void incrementGNodeCounters(const MWNode &gNode); friend std::ostream &operator<<(std::ostream &o, const OperatorStatistics &os) { return os.print(o); } diff --git a/src/operators/TimeEvolutionOperator.cpp b/src/operators/TimeEvolutionOperator.cpp index 90f1f1ccd..09913a591 100644 --- a/src/operators/TimeEvolutionOperator.cpp +++ b/src/operators/TimeEvolutionOperator.cpp @@ -26,12 +26,11 @@ #include "TimeEvolutionOperator.h" //#include "MRCPP/MWOperators" - #include "core/InterpolatingBasis.h" #include "core/LegendreBasis.h" -#include "functions/Gaussian.h" #include "functions/GaussExp.h" +#include "functions/Gaussian.h" #include "treebuilders/CrossCorrelationCalculator.h" #include "treebuilders/DefaultCalculator.h" @@ -42,8 +41,8 @@ #include "treebuilders/project.h" #include "trees/BandWidth.h" -#include "trees/FunctionTreeVector.h" #include "trees/CornerOperatorTree.h" +#include "trees/FunctionTreeVector.h" #include "utils/Printer.h" #include "utils/Timer.h" @@ -55,10 +54,8 @@ #include "trees/OperatorNode.h" - namespace mrcpp { - /** @brief A uniform constructor for TimeEvolutionOperator class. * * @param[in] mra: MRA. @@ -72,23 +69,21 @@ namespace mrcpp { * */ template -TimeEvolutionOperator::TimeEvolutionOperator -(const MultiResolutionAnalysis &mra, double prec, double time, int finest_scale, bool imaginary, int max_Jpower) - : ConvolutionOperator(mra, mra.getRootScale(), -10) //One can use ConvolutionOperator instead as well +TimeEvolutionOperator::TimeEvolutionOperator(const MultiResolutionAnalysis &mra, double prec, double time, int finest_scale, bool imaginary, int max_Jpower) + : ConvolutionOperator(mra, mra.getRootScale(), -10) // One can use ConvolutionOperator instead as well { int oldlevel = Printer::setPrintLevel(0); this->setBuildPrec(prec); - SchrodingerEvolution_CrossCorrelation cross_correlation(30, mra.getOrder(), mra.getScalingBasis().getScalingType() ); + SchrodingerEvolution_CrossCorrelation cross_correlation(30, mra.getOrder(), mra.getScalingBasis().getScalingType()); this->cross_correlation = &cross_correlation; - initialize(time, finest_scale, imaginary, max_Jpower); //will go outside of the constructor in future + initialize(time, finest_scale, imaginary, max_Jpower); // will go outside of the constructor in future - this->initOperExp(1); //this turns out to be important + this->initOperExp(1); // this turns out to be important Printer::setPrintLevel(oldlevel); } - /** @brief An adaptive constructor for TimeEvolutionOperator class. * * @param[in] mra: MRA. @@ -105,24 +100,21 @@ TimeEvolutionOperator::TimeEvolutionOperator * */ template -TimeEvolutionOperator::TimeEvolutionOperator -(const MultiResolutionAnalysis &mra, double prec, double time, bool imaginary, int max_Jpower) - : ConvolutionOperator(mra, mra.getRootScale(), -10) //One can use ConvolutionOperator instead as well +TimeEvolutionOperator::TimeEvolutionOperator(const MultiResolutionAnalysis &mra, double prec, double time, bool imaginary, int max_Jpower) + : ConvolutionOperator(mra, mra.getRootScale(), -10) // One can use ConvolutionOperator instead as well { int oldlevel = Printer::setPrintLevel(0); this->setBuildPrec(prec); - SchrodingerEvolution_CrossCorrelation cross_correlation(30, mra.getOrder(), mra.getScalingBasis().getScalingType() ); + SchrodingerEvolution_CrossCorrelation cross_correlation(30, mra.getOrder(), mra.getScalingBasis().getScalingType()); this->cross_correlation = &cross_correlation; - initialize(time, imaginary, max_Jpower); //will go outside of the constructor in future + initialize(time, imaginary, max_Jpower); // will go outside of the constructor in future - this->initOperExp(1); //this turns out to be important + this->initOperExp(1); // this turns out to be important Printer::setPrintLevel(oldlevel); } - - /** @brief Creates Re or Im of operator * * @details Adaptive down to scale \f$ N = 18 \f$. @@ -132,9 +124,7 @@ TimeEvolutionOperator::TimeEvolutionOperator * only needed ones, while building the tree (in progress). * */ -template -void TimeEvolutionOperator::initialize(double time, bool imaginary, int max_Jpower) -{ +template void TimeEvolutionOperator::initialize(double time, bool imaginary, int max_Jpower) { int N = 18; double o_prec = this->build_prec; @@ -142,8 +132,7 @@ void TimeEvolutionOperator::initialize(double time, bool imaginary, int max_J auto o_tree = std::make_unique(o_mra, o_prec); std::map J; - for( int n = 0; n <= N+1; n ++ ) - J[n] = new JpowerIntegrals(time * std::pow(4, n), n, max_Jpower); + for (int n = 0; n <= N + 1; n++) J[n] = new JpowerIntegrals(time * std::pow(4, n), n, max_Jpower); TimeEvolution_CrossCorrelationCalculator calculator(J, this->cross_correlation, imaginary); OperatorAdaptor adaptor(o_prec, o_mra.getMaxScale(), true); @@ -155,7 +144,7 @@ void TimeEvolutionOperator::initialize(double time, bool imaginary, int max_J Timer trans_t; o_tree->mwTransform(BottomUp); o_tree->removeRoughScaleNoise(); - //o_tree->clearSquareNorm(); //does not affect printing + // o_tree->clearSquareNorm(); //does not affect printing o_tree->calcSquareNorm(); o_tree->setupOperNodeCache(); @@ -164,8 +153,7 @@ void TimeEvolutionOperator::initialize(double time, bool imaginary, int max_J this->raw_exp.push_back(std::move(o_tree)); - for( int n = 0; n <= N+1; n ++ ) - delete J[n]; + for (int n = 0; n <= N + 1; n++) delete J[n]; } /** @brief Creates Re or Im of operator @@ -173,9 +161,7 @@ void TimeEvolutionOperator::initialize(double time, bool imaginary, int max_J * @details Uniform down to finest scale. * */ -template -void TimeEvolutionOperator::initialize(double time, int finest_scale, bool imaginary, int max_Jpower) -{ +template void TimeEvolutionOperator::initialize(double time, int finest_scale, bool imaginary, int max_Jpower) { double o_prec = this->build_prec; auto o_mra = this->getOperatorMRA(); @@ -186,12 +172,11 @@ void TimeEvolutionOperator::initialize(double time, int finest_scale, bool im int N = finest_scale; double threshold = o_prec / 1000.0; std::map J; - for( int n = 0; n <= N+1; n ++ ) - J[n] = new JpowerIntegrals(time * std::pow(4, n), n, max_Jpower, threshold); + for (int n = 0; n <= N + 1; n++) J[n] = new JpowerIntegrals(time * std::pow(4, n), n, max_Jpower, threshold); TimeEvolution_CrossCorrelationCalculator calculator(J, this->cross_correlation, imaginary); auto o_tree = std::make_unique(o_mra, o_prec); - builder.build(*o_tree, calculator, uniform, N ); // Expand 1D kernel into 2D operator + builder.build(*o_tree, calculator, uniform, N); // Expand 1D kernel into 2D operator // Postprocess to make the operator functional Timer trans_t; @@ -203,11 +188,9 @@ void TimeEvolutionOperator::initialize(double time, int finest_scale, bool im this->raw_exp.push_back(std::move(o_tree)); - for( int n = 0; n <= N+1; n ++ ) - delete J[n]; + for (int n = 0; n <= N + 1; n++) delete J[n]; } - /** @brief Creates Re or Im of operator (in progress) * * @details Tree construction starts uniformly and then continues adaptively down to scale \f$ N = 18 \f$. @@ -216,8 +199,7 @@ void TimeEvolutionOperator::initialize(double time, int finest_scale, bool im * @note This method is not ready for use and should not be used (in progress). * */ -template void TimeEvolutionOperator::initializeSemiUniformly(double time, bool imaginary, int max_Jpower) -{ +template void TimeEvolutionOperator::initializeSemiUniformly(double time, bool imaginary, int max_Jpower) { MSG_ERROR("Not implemented yet method."); double o_prec = this->build_prec; @@ -234,8 +216,7 @@ template void TimeEvolutionOperator::initializeSemiUniformly(double t double threshold = o_prec / 1000.0; std::map J; - for( int n = 0; n <= N+1; n ++ ) - J[n] = new mrcpp::JpowerIntegrals(time * std::pow(4, n), n, max_Jpower, threshold); + for (int n = 0; n <= N + 1; n++) J[n] = new mrcpp::JpowerIntegrals(time * std::pow(4, n), n, max_Jpower, threshold); mrcpp::TimeEvolution_CrossCorrelationCalculator calculator(J, this->cross_correlation, imaginary); OperatorAdaptor adaptor(o_prec, o_mra.getMaxScale()); @@ -252,11 +233,9 @@ template void TimeEvolutionOperator::initializeSemiUniformly(double t this->raw_exp.push_back(std::move(o_tree)); - for( int n = 0; n <= N+1; n ++ ) - delete J[n]; + for (int n = 0; n <= N + 1; n++) delete J[n]; } - template class TimeEvolutionOperator<1>; template class TimeEvolutionOperator<2>; template class TimeEvolutionOperator<3>; diff --git a/src/operators/TimeEvolutionOperator.h b/src/operators/TimeEvolutionOperator.h index 9e8623fcd..839ba7b40 100644 --- a/src/operators/TimeEvolutionOperator.h +++ b/src/operators/TimeEvolutionOperator.h @@ -25,30 +25,30 @@ #pragma once -#include "MWOperator.h" #include "ConvolutionOperator.h" +#include "MWOperator.h" #include "core/SchrodingerEvolution_CrossCorrelation.h" namespace mrcpp { - /** @class TimeEvolutionOperator * * @brief Semigroup of the free-particle Schrodinger equation - * + * * @details Represents the semigroup * \f$ * \exp \left( i t \partial_x^2 \right) * . * \f$ * Matrix elements (actual operator tree) of the operator can be obtained by calling getComponent(0, 0). - * + * * @note So far implementation is done for Legendre scaling functions in 1d. - * + * * \todo: Extend to D dimensinal on a general interval [a, b] in the future. - * + * */ -template class TimeEvolutionOperator : public ConvolutionOperator //One can use ConvolutionOperator instead as well +template +class TimeEvolutionOperator : public ConvolutionOperator // One can use ConvolutionOperator instead as well { public: TimeEvolutionOperator(const MultiResolutionAnalysis &mra, double prec, double time, int finest_scale, bool imaginary, int max_Jpower = 30); @@ -63,12 +63,11 @@ template class TimeEvolutionOperator : public ConvolutionOperator / void initialize(double time, int finest_scale, bool imaginary, int max_Jpower); void initialize(double time, bool imaginary, int max_Jpower); void initializeSemiUniformly(double time, bool imaginary, int max_Jpower); - + void setBuildPrec(double prec) { this->build_prec = prec; } double build_prec{-1.0}; SchrodingerEvolution_CrossCorrelation *cross_correlation{nullptr}; }; - } // namespace mrcpp diff --git a/src/treebuilders/AdditionCalculator.h b/src/treebuilders/AdditionCalculator.h index 431600192..9223f1ae6 100644 --- a/src/treebuilders/AdditionCalculator.h +++ b/src/treebuilders/AdditionCalculator.h @@ -30,26 +30,36 @@ namespace mrcpp { -template class AdditionCalculator final : public TreeCalculator { +template class AdditionCalculator final : public TreeCalculator { public: - AdditionCalculator(const FunctionTreeVector &inp) - : sum_vec(inp) {} + AdditionCalculator(const FunctionTreeVector &inp, bool conjugate = false) + : sum_vec(inp) + , conj(conjugate) {} private: - FunctionTreeVector sum_vec; + FunctionTreeVector sum_vec; + bool conj; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) override { node_o.zeroCoefs(); const NodeIndex &idx = node_o.getNodeIndex(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); for (int i = 0; i < this->sum_vec.size(); i++) { - double c_i = get_coef(this->sum_vec, i); - FunctionTree &func_i = get_func(this->sum_vec, i); + T c_i = get_coef(this->sum_vec, i); + FunctionTree &func_i = get_func(this->sum_vec, i); // This generates missing nodes - const MWNode &node_i = func_i.getNode(idx); - const double *coefs_i = node_i.getCoefs(); + const MWNode &node_i = func_i.getNode(idx); + const T *coefs_i = node_i.getCoefs(); int n_coefs = node_i.getNCoefs(); - for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * coefs_i[j]; } + if constexpr (std::is_same::value) { + if (func_i.conjugate() xor conj) { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * std::conj(coefs_i[j]); } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * coefs_i[j]; } + } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * coefs_i[j]; } + } } node_o.setHasCoefs(); node_o.calcNorms(); diff --git a/src/treebuilders/AnalyticAdaptor.h b/src/treebuilders/AnalyticAdaptor.h index 45f73b4cd..3e9ca0613 100644 --- a/src/treebuilders/AnalyticAdaptor.h +++ b/src/treebuilders/AnalyticAdaptor.h @@ -30,16 +30,16 @@ namespace mrcpp { -template class AnalyticAdaptor final : public TreeAdaptor { +template class AnalyticAdaptor final : public TreeAdaptor { public: - AnalyticAdaptor(const RepresentableFunction &f, int ms) - : TreeAdaptor(ms) + AnalyticAdaptor(const RepresentableFunction &f, int ms) + : TreeAdaptor(ms) , func(&f) {} private: - const RepresentableFunction *func; + const RepresentableFunction *func; - bool splitNode(const MWNode &node) const override { + bool splitNode(const MWNode &node) const override { int scale = node.getScale(); int nQuadPts = node.getKp1(); if (this->func->isVisibleAtScale(scale, nQuadPts)) return false; diff --git a/src/treebuilders/ConvolutionCalculator.cpp b/src/treebuilders/ConvolutionCalculator.cpp index 668c86dbf..497fe0dd8 100644 --- a/src/treebuilders/ConvolutionCalculator.cpp +++ b/src/treebuilders/ConvolutionCalculator.cpp @@ -46,8 +46,8 @@ using Eigen::MatrixXi; namespace mrcpp { -template -ConvolutionCalculator::ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth) +template +ConvolutionCalculator::ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth) : maxDepth(depth) , prec(p) , oper(&o) @@ -57,14 +57,14 @@ ConvolutionCalculator::ConvolutionCalculator(double p, ConvolutionOperator initTimers(); } -template ConvolutionCalculator::~ConvolutionCalculator() { +template ConvolutionCalculator::~ConvolutionCalculator() { clearTimers(); this->operStat.flushNodeCounters(); println(10, this->operStat); for (int i = 0; i < this->bandSizes.size(); i++) { delete this->bandSizes[i]; } } -template void ConvolutionCalculator::initTimers() { +template void ConvolutionCalculator::initTimers() { int nThreads = mrcpp_get_max_threads(); for (int i = 0; i < nThreads; i++) { this->band_t.push_back(new Timer(false)); @@ -73,7 +73,7 @@ template void ConvolutionCalculator::initTimers() { } } -template void ConvolutionCalculator::clearTimers() { +template void ConvolutionCalculator::clearTimers() { int nThreads = mrcpp_get_max_threads(); for (int i = 0; i < nThreads; i++) { delete this->band_t[i]; @@ -85,7 +85,7 @@ template void ConvolutionCalculator::clearTimers() { this->norm_t.clear(); } -template void ConvolutionCalculator::printTimers() const { +template void ConvolutionCalculator::printTimers() const { int oldprec = Printer::setPrecision(1); int nThreads = mrcpp_get_max_threads(); printout(20, "\n\nthread "); @@ -102,7 +102,7 @@ template void ConvolutionCalculator::printTimers() const { /** Initialize the number of nodes formally within the bandwidth of an operator. The band size is used for thresholding. */ -template void ConvolutionCalculator::initBandSizes() { +template void ConvolutionCalculator::initBandSizes() { for (int i = 0; i < this->oper->size(); i++) { // IMPORTANT: only 0-th dimension! const OperatorTree &oTree = this->oper->getComponent(i, 0); @@ -118,7 +118,7 @@ template void ConvolutionCalculator::initBandSizes() { * of an operator. Currently this routine ignores the fact that * there are edges on the world box, and thus over estimates * the number of nodes. This is different from the previous version. */ -template void ConvolutionCalculator::calcBandSizeFactor(MatrixXi &bs, int depth, const BandWidth &bw) { +template void ConvolutionCalculator::calcBandSizeFactor(MatrixXi &bs, int depth, const BandWidth &bw) { for (int gt = 0; gt < this->nComp; gt++) { for (int ft = 0; ft < this->nComp; ft++) { int k = gt * this->nComp + ft; @@ -139,8 +139,8 @@ template void ConvolutionCalculator::calcBandSizeFactor(MatrixXi &bs, } /** Return a vector of nodes in F affected by O, given a node in G */ -template MWNodeVector *ConvolutionCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { - auto *band = new MWNodeVector; +template MWNodeVector *ConvolutionCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { + auto *band = new MWNodeVector; int o_depth = gNode.getScale() - this->oper->getOperatorRoot(); int g_depth = gNode.getDepth(); @@ -150,7 +150,7 @@ template MWNodeVector *ConvolutionCalculator::makeOperBand(const M int reach = this->oper->getOperatorReach(); if (width >= 0) { - const NodeBox &fWorld = this->fTree->getRootBox(); + const NodeBox &fWorld = this->fTree->getRootBox(); const NodeIndex &cIdx = fWorld.getCornerIndex(); const NodeIndex &gIdx = gNode.getNodeIndex(); @@ -180,7 +180,7 @@ template MWNodeVector *ConvolutionCalculator::makeOperBand(const M } /** Recursively retrieve all reachable f-nodes within the bandwidth. */ -template void ConvolutionCalculator::fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim) { +template void ConvolutionCalculator::fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim) { int l_start = idx[dim]; for (int j = 0; j < nbox[dim]; j++) { // Recurse until dim == 0 @@ -190,7 +190,7 @@ template void ConvolutionCalculator::fillOperBand(MWNodeVector *ba continue; } if (not manipulateOperator) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); @@ -198,18 +198,18 @@ template void ConvolutionCalculator::fillOperBand(MWNodeVector *ba const auto oper_scale = this->oper->getOperatorRoot(); if (oper_scale == 0) { if (periodic::in_unit_cell(idx) and onUnitcell) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); } if (not periodic::in_unit_cell(idx) and not onUnitcell) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); } } else if (oper_scale < 0) { if (periodic::in_unit_cell(idx) and onUnitcell) { - MWNode &fNode = this->fTree->getNode(idx); + MWNode &fNode = this->fTree->getNode(idx); idx_band.push_back(idx); band->push_back(&fNode); } @@ -222,23 +222,23 @@ template void ConvolutionCalculator::fillOperBand(MWNodeVector *ba idx[dim] = l_start; } -template void ConvolutionCalculator::calcNode(MWNode &node) { - auto &gNode = static_cast &>(node); +template void ConvolutionCalculator::calcNode(MWNode &node) { + auto &gNode = static_cast &>(node); gNode.zeroCoefs(); int o_depth = gNode.getScale() - this->oper->getOperatorRoot(); if (manipulateOperator and this->oper->getOperatorRoot() < 0) o_depth = gNode.getDepth(); - double tmpCoefs[gNode.getNCoefs()]; - OperatorState os(gNode, tmpCoefs); + T tmpCoefs[gNode.getNCoefs()]; + OperatorState os(gNode, tmpCoefs); this->operStat.incrementGNodeCounters(gNode); // Get all nodes in f within the bandwith of O in g this->band_t[mrcpp_get_thread_num()]->resume(); std::vector> idx_band; - MWNodeVector *fBand = makeOperBand(gNode, idx_band); + MWNodeVector *fBand = makeOperBand(gNode, idx_band); this->band_t[mrcpp_get_thread_num()]->stop(); - MWTree &gTree = gNode.getMWTree(); + MWTree &gTree = gNode.getMWTree(); double gThrs = gTree.getSquareNorm(); if (gThrs > 0.0) { auto nTerms = static_cast(this->oper->size()); @@ -250,7 +250,7 @@ template void ConvolutionCalculator::calcNode(MWNode &node) { this->calc_t[mrcpp_get_thread_num()]->resume(); for (int n = 0; n < fBand->size(); n++) { - MWNode &fNode = *(*fBand)[n]; + MWNode &fNode = *(*fBand)[n]; NodeIndex &fIdx = idx_band[n]; os.setFNode(fNode); os.setFIndex(fIdx); @@ -275,7 +275,7 @@ template void ConvolutionCalculator::calcNode(MWNode &node) { } /** Apply each component (term) of the operator expansion to a node in f */ -template void ConvolutionCalculator::applyOperComp(OperatorState &os) { +template void ConvolutionCalculator::applyOperComp(OperatorState &os) { double fNorm = os.fNode->getComponentNorm(os.ft); int o_depth = os.fNode->getScale() - this->oper->getOperatorRoot(); for (int i = 0; i < this->oper->size(); i++) { @@ -288,17 +288,17 @@ template void ConvolutionCalculator::applyOperComp(OperatorState & } } - /** @brief Apply a single operator component (term) to a single f-node. - * + * * @details Apply a single operator component (term) to a single f-node. * Whether the operator actualy is applied is determined by a screening threshold. * Here we make use of the sparcity of matrices \f$ A, B, C \f$. - * + * */ -template void ConvolutionCalculator::applyOperator(int i, OperatorState &os) { - MWNode &gNode = *os.gNode; - MWNode &fNode = *os.fNode; +template void ConvolutionCalculator::applyOperator(int i, OperatorState &os) { + MWNode &gNode = *os.gNode; + MWNode &fNode = *os.fNode; + const NodeIndex &fIdx = *os.fIdx; const NodeIndex &gIdx = gNode.getNodeIndex(); int o_depth = gNode.getScale() - this->oper->getOperatorRoot(); @@ -315,7 +315,7 @@ template void ConvolutionCalculator::applyOperator(int i, OperatorSta int a = (os.gt & (1 << d)) >> d; int b = (os.ft & (1 << d)) >> d; int idx = (a << 1) + b; - if ( oTree.isOutsideBand(oTransl, o_depth, idx) ) { return; } + if (oTree.isOutsideBand(oTransl, o_depth, idx)) { return; } const OperatorNode &oNode = oTree.getNode(o_depth, oTransl); int oIdx = os.getOperIndex(d); @@ -331,9 +331,10 @@ template void ConvolutionCalculator::applyOperator(int i, OperatorSta /** Perorm the required linear algebra operations in order to apply an operator component to a f-node in a n-dimensional tesor space. */ -template void ConvolutionCalculator::tensorApplyOperComp(OperatorState &os) { - double **aux = os.getAuxData(); +template void ConvolutionCalculator::tensorApplyOperComp(OperatorState &os) { + T **aux = os.getAuxData(); double **oData = os.getOperData(); + /* #ifdef HAVE_BLAS double mult = 0.0; for (int i = 0; i < D; i++) { @@ -358,9 +359,10 @@ template void ConvolutionCalculator::tensorApplyOperComp(OperatorStat } } #else + */ for (int i = 0; i < D; i++) { - Eigen::Map f(aux[i], os.kp1, os.kp1_dm1); - Eigen::Map g(aux[i + 1], os.kp1_dm1, os.kp1); + Eigen::Map> f(aux[i], os.kp1, os.kp1_dm1); + Eigen::Map> g(aux[i + 1], os.kp1_dm1, os.kp1); if (oData[i] != nullptr) { Eigen::Map op(oData[i], os.kp1, os.kp1); if (i == D - 1) { // Last dir: Add up into g @@ -377,10 +379,10 @@ template void ConvolutionCalculator::tensorApplyOperComp(OperatorStat } } } -#endif + //#endif } -template void ConvolutionCalculator::touchParentNodes(MWTree &tree) const { +template void ConvolutionCalculator::touchParentNodes(MWTree &tree) const { if (not manipulateOperator) { const auto oper_scale = this->oper->getOperatorRoot(); auto car_prod = math_utils::cartesian_product(std::vector{-1, 0}, D); @@ -396,15 +398,19 @@ template void ConvolutionCalculator::touchParentNodes(MWTree &tree } } -template MWNodeVector *ConvolutionCalculator::getInitialWorkVector(MWTree &tree) const { - auto *nodeVec = new MWNodeVector; +template MWNodeVector *ConvolutionCalculator::getInitialWorkVector(MWTree &tree) const { + auto *nodeVec = new MWNodeVector; if (tree.isPeriodic()) touchParentNodes(tree); tree_utils::make_node_table(tree, *nodeVec); return nodeVec; } -template class ConvolutionCalculator<1>; -template class ConvolutionCalculator<2>; -template class ConvolutionCalculator<3>; +template class ConvolutionCalculator<1, double>; +template class ConvolutionCalculator<2, double>; +template class ConvolutionCalculator<3, double>; + +template class ConvolutionCalculator<1, ComplexDouble>; +template class ConvolutionCalculator<2, ComplexDouble>; +template class ConvolutionCalculator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/ConvolutionCalculator.h b/src/treebuilders/ConvolutionCalculator.h index 3b88cb9b1..8ac4b5d34 100644 --- a/src/treebuilders/ConvolutionCalculator.h +++ b/src/treebuilders/ConvolutionCalculator.h @@ -33,12 +33,12 @@ namespace mrcpp { -template class ConvolutionCalculator final : public TreeCalculator { +template class ConvolutionCalculator final : public TreeCalculator { public: - ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth = MaxDepth); + ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth = MaxDepth); ~ConvolutionCalculator() override; - MWNodeVector *getInitialWorkVector(MWTree &tree) const override; + MWNodeVector *getInitialWorkVector(MWTree &tree) const override; void setPrecFunction(const std::function &idx)> &prec_func) { this->precFunc = prec_func; } void startManipulateOperator(bool excUnit) { @@ -52,45 +52,45 @@ template class ConvolutionCalculator final : public TreeCalculator { bool manipulateOperator{false}; bool onUnitcell{false}; ConvolutionOperator *oper; - FunctionTree *fTree; + FunctionTree *fTree; std::vector band_t; std::vector calc_t; std::vector norm_t; - OperatorStatistics operStat; + OperatorStatistics operStat; std::vector bandSizes; std::function &idx)> precFunc = [](const NodeIndex &idx) { return 1.0; }; static const int nComp = (1 << D); static const int nComp2 = (1 << D) * (1 << D); - MWNodeVector *makeOperBand(const MWNode &gNode, std::vector> &idx_band); - void fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim); + MWNodeVector *makeOperBand(const MWNode &gNode, std::vector> &idx_band); + void fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim); void initTimers(); void clearTimers(); void printTimers() const; void initBandSizes(); - int getBandSizeFactor(int i, int depth, const OperatorState &os) const { + int getBandSizeFactor(int i, int depth, const OperatorState &os) const { int k = os.gt * this->nComp + os.ft; return (*this->bandSizes[i])(depth, k); } void calcBandSizeFactor(Eigen::MatrixXi &bs, int depth, const BandWidth &bw); - void calcNode(MWNode &node) override; + void calcNode(MWNode &node) override; void postProcess() override { printTimers(); clearTimers(); initTimers(); } - void applyOperComp(OperatorState &os); - void applyOperator(int i, OperatorState &os); - void tensorApplyOperComp(OperatorState &os); + void applyOperComp(OperatorState &os); + void applyOperator(int i, OperatorState &os); + void tensorApplyOperComp(OperatorState &os); - void touchParentNodes(MWTree &tree) const; + void touchParentNodes(MWTree &tree) const; }; } // namespace mrcpp diff --git a/src/treebuilders/CopyAdaptor.cpp b/src/treebuilders/CopyAdaptor.cpp index 4017c6e5e..8312ebb0f 100644 --- a/src/treebuilders/CopyAdaptor.cpp +++ b/src/treebuilders/CopyAdaptor.cpp @@ -29,21 +29,21 @@ namespace mrcpp { -template -CopyAdaptor::CopyAdaptor(FunctionTree &t, int ms, int *bw) - : TreeAdaptor(ms) { +template +CopyAdaptor::CopyAdaptor(FunctionTree &t, int ms, int *bw) + : TreeAdaptor(ms) { setBandWidth(bw); tree_vec.push_back(std::make_tuple(1.0, &t)); } -template -CopyAdaptor::CopyAdaptor(FunctionTreeVector &t, int ms, int *bw) - : TreeAdaptor(ms) +template +CopyAdaptor::CopyAdaptor(FunctionTreeVector &t, int ms, int *bw) + : TreeAdaptor(ms) , tree_vec(t) { setBandWidth(bw); } -template void CopyAdaptor::setBandWidth(int *bw) { +template void CopyAdaptor::setBandWidth(int *bw) { for (int d = 0; d < D; d++) { if (bw != nullptr) { this->bandWidth[d] = bw[d]; @@ -53,7 +53,7 @@ template void CopyAdaptor::setBandWidth(int *bw) { } } -template bool CopyAdaptor::splitNode(const MWNode &node) const { +template bool CopyAdaptor::splitNode(const MWNode &node) const { const NodeIndex &idx = node.getNodeIndex(); for (int c = 0; c < node.getTDim(); c++) { for (int d = 0; d < D; d++) { @@ -61,8 +61,8 @@ template bool CopyAdaptor::splitNode(const MWNode &node) const { NodeIndex bwIdx = idx.child(c); bwIdx[d] += bw; for (int i = 0; i < this->tree_vec.size(); i++) { - const FunctionTree &func_i = get_func(tree_vec, i); - const MWNode *node_i = func_i.findNode(bwIdx); + const FunctionTree &func_i = get_func(tree_vec, i); + const MWNode *node_i = func_i.findNode(bwIdx); if (node_i != nullptr) return true; } } @@ -71,8 +71,12 @@ template bool CopyAdaptor::splitNode(const MWNode &node) const { return false; } -template class CopyAdaptor<1>; -template class CopyAdaptor<2>; -template class CopyAdaptor<3>; +template class CopyAdaptor<1, double>; +template class CopyAdaptor<2, double>; +template class CopyAdaptor<3, double>; + +template class CopyAdaptor<1, ComplexDouble>; +template class CopyAdaptor<2, ComplexDouble>; +template class CopyAdaptor<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/CopyAdaptor.h b/src/treebuilders/CopyAdaptor.h index c9451e599..a7825cca0 100644 --- a/src/treebuilders/CopyAdaptor.h +++ b/src/treebuilders/CopyAdaptor.h @@ -30,17 +30,17 @@ namespace mrcpp { -template class CopyAdaptor final : public TreeAdaptor { +template class CopyAdaptor final : public TreeAdaptor { public: - CopyAdaptor(FunctionTree &t, int ms, int *bw); - CopyAdaptor(FunctionTreeVector &t, int ms, int *bw); + CopyAdaptor(FunctionTree &t, int ms, int *bw); + CopyAdaptor(FunctionTreeVector &t, int ms, int *bw); private: int bandWidth[D]; - FunctionTreeVector tree_vec; + FunctionTreeVector tree_vec; void setBandWidth(int *bw); - bool splitNode(const MWNode &node) const override; + bool splitNode(const MWNode &node) const override; }; } // namespace mrcpp diff --git a/src/treebuilders/CrossCorrelationCalculator.cpp b/src/treebuilders/CrossCorrelationCalculator.cpp index efe9a3390..a5eef945d 100644 --- a/src/treebuilders/CrossCorrelationCalculator.cpp +++ b/src/treebuilders/CrossCorrelationCalculator.cpp @@ -77,8 +77,7 @@ template void CrossCorrelationCalculator::applyCcc(MWNode<2> &node, Cros const MWNode<1> &node_a = this->kernel->getNode(idx_a); const MWNode<1> &node_b = this->kernel->getNode(idx_b); - VectorXd vec_a; - VectorXd vec_b; + VectorXd vec_a, vec_b; node_a.getCoefs(vec_a); node_b.getCoefs(vec_b); @@ -91,7 +90,7 @@ template void CrossCorrelationCalculator::applyCcc(MWNode<2> &node, Cros for (int i = 0; i < t_dim * kp1_d; i++) { auto scaling_factor = node.getMWTree().getMRA().getWorldBox().getScalingFactor(0); // This is only implemented for unifrom scaling factors - // hence the zero TODO: make it work for non-unifrom scaling + // hence the zero TODO: make it work for non-uniform scaling coefs[i] = std::sqrt(scaling_factor) * two_n * vec_o(i); } } diff --git a/src/treebuilders/DefaultCalculator.h b/src/treebuilders/DefaultCalculator.h index 13f698162..4a1a4ce54 100644 --- a/src/treebuilders/DefaultCalculator.h +++ b/src/treebuilders/DefaultCalculator.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class DefaultCalculator final : public TreeCalculator { +template class DefaultCalculator final : public TreeCalculator { public: // Reimplementation without OpenMP, the default is faster this way - void calcNodeVector(MWNodeVector &nodeVec) override { + void calcNodeVector(MWNodeVector &nodeVec) override { int nNodes = nodeVec.size(); for (int n = 0; n < nNodes; n++) { calcNode(*nodeVec[n]); } } private: - void calcNode(MWNode &node) override { + void calcNode(MWNode &node) override { node.clearHasCoefs(); node.clearNorms(); } diff --git a/src/treebuilders/DerivativeCalculator.cpp b/src/treebuilders/DerivativeCalculator.cpp index a5acdc297..b298d1b6e 100644 --- a/src/treebuilders/DerivativeCalculator.cpp +++ b/src/treebuilders/DerivativeCalculator.cpp @@ -42,8 +42,8 @@ using Eigen::MatrixXd; namespace mrcpp { -template -DerivativeCalculator::DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f) +template +DerivativeCalculator::DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f) : applyDir(dir) , fTree(&f) , oper(&o) { @@ -51,12 +51,12 @@ DerivativeCalculator::DerivativeCalculator(int dir, DerivativeOperator &o, initTimers(); } -template DerivativeCalculator::~DerivativeCalculator() { +template DerivativeCalculator::~DerivativeCalculator() { this->operStat.flushNodeCounters(); println(10, this->operStat); } -template void DerivativeCalculator::initTimers() { +template void DerivativeCalculator::initTimers() { int nThreads = mrcpp_get_max_threads(); for (int i = 0; i < nThreads; i++) { this->band_t.push_back(Timer(false)); @@ -65,13 +65,13 @@ template void DerivativeCalculator::initTimers() { } } -template void DerivativeCalculator::clearTimers() { +template void DerivativeCalculator::clearTimers() { this->band_t.clear(); this->calc_t.clear(); this->norm_t.clear(); } -template void DerivativeCalculator::printTimers() const { +template void DerivativeCalculator::printTimers() const { int oldprec = Printer::setPrecision(1); int nThreads = mrcpp_get_max_threads(); printout(20, "\n\nthread "); @@ -86,12 +86,12 @@ template void DerivativeCalculator::printTimers() const { Printer::setPrecision(oldprec); } -template void DerivativeCalculator::calcNode(MWNode &inpNode, MWNode &outNode) { - //if (this->oper->getMaxBandWidth() > 1) MSG_ABORT("Only implemented for zero bw"); +template void DerivativeCalculator::calcNode(MWNode &inpNode, MWNode &outNode) { + // if (this->oper->getMaxBandWidth() > 1) MSG_ABORT("Only implemented for zero bw"); outNode.zeroCoefs(); int nComp = (1 << D); - double tmpCoefs[outNode.getNCoefs()]; - OperatorState os(outNode, tmpCoefs); + T tmpCoefs[outNode.getNCoefs()]; + OperatorState os(outNode, tmpCoefs); os.setFNode(inpNode); os.setFIndex(inpNode.nodeIndex); @@ -102,36 +102,34 @@ template void DerivativeCalculator::calcNode(MWNode &inpNode, MWNo for (int gt = 0; gt < nComp; gt++) { os.setGComponent(gt); applyOperator_bw0(os); - } + } } - // Multiply appropriate scaling factor. TODO: Could be included elsewhere - const double scaling_factor = - 1.0/std::pow(outNode.getMWTree().getMRA().getWorldBox().getScalingFactor(this->applyDir), oper->getOrder()); - if(abs(scaling_factor-1.0)>MachineZero){ + // Multiply appropriate scaling factor. TODO: Could be included elsewhere + const double scaling_factor = 1.0 / std::pow(outNode.getMWTree().getMRA().getWorldBox().getScalingFactor(this->applyDir), oper->getOrder()); + if (abs(scaling_factor - 1.0) > MachineZero) { for (int i = 0; i < outNode.getNCoefs(); i++) outNode.getCoefs()[i] *= scaling_factor; } - outNode.calcNorms(); //TODO:required? norms are not used for now + outNode.calcNorms(); // TODO:required? norms are not used for now } - -template void DerivativeCalculator::calcNode(MWNode &gNode) { +template void DerivativeCalculator::calcNode(MWNode &gNode) { gNode.zeroCoefs(); int nComp = (1 << D); - double tmpCoefs[gNode.getNCoefs()]; - OperatorState os(gNode, tmpCoefs); + T tmpCoefs[gNode.getNCoefs()]; + OperatorState os(gNode, tmpCoefs); this->operStat.incrementGNodeCounters(gNode); // Get all nodes in f within the bandwith of O in g this->band_t[mrcpp_get_thread_num()].resume(); std::vector> idx_band; - MWNodeVector fBand = makeOperBand(gNode, idx_band); + MWNodeVector fBand = makeOperBand(gNode, idx_band); this->band_t[mrcpp_get_thread_num()].stop(); this->calc_t[mrcpp_get_thread_num()].resume(); for (int n = 0; n < fBand.size(); n++) { - MWNode &fNode = *fBand[n]; + MWNode &fNode = *fBand[n]; NodeIndex &fIdx = idx_band[n]; os.setFNode(fNode); os.setFIndex(fIdx); @@ -146,8 +144,7 @@ template void DerivativeCalculator::calcNode(MWNode &gNode) { } } // Multiply appropriate scaling factor - const double scaling_factor = - std::pow(gNode.getMWTree().getMRA().getWorldBox().getScalingFactor(this->applyDir), oper->getOrder()); + const double scaling_factor = std::pow(gNode.getMWTree().getMRA().getWorldBox().getScalingFactor(this->applyDir), oper->getOrder()); for (int i = 0; i < gNode.getNCoefs(); i++) gNode.getCoefs()[i] /= scaling_factor; this->calc_t[mrcpp_get_thread_num()].stop(); @@ -157,12 +154,11 @@ template void DerivativeCalculator::calcNode(MWNode &gNode) { } /** Return a vector of nodes in F affected by O, given a node in G */ -template -MWNodeVector DerivativeCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { +template MWNodeVector DerivativeCalculator::makeOperBand(const MWNode &gNode, std::vector> &idx_band) { assert(this->applyDir >= 0); assert(this->applyDir < D); - MWNodeVector band; + MWNodeVector band; const NodeIndex &idx_0 = gNode.getNodeIndex(); // Assumes given width only in applyDir, otherwise width = 0 @@ -182,10 +178,10 @@ MWNodeVector DerivativeCalculator::makeOperBand(const MWNode &gNode, st } /** Apply a single operator component (term) to a single f-node assuming zero bandwidth */ -template void DerivativeCalculator::applyOperator_bw0(OperatorState &os) { - //cout<<" applyOperator "< &gNode = *os.gNode; - MWNode &fNode = *os.fNode; +template void DerivativeCalculator::applyOperator_bw0(OperatorState &os) { + // cout<<" applyOperator "< &gNode = *os.gNode; + MWNode &fNode = *os.fNode; const NodeIndex &fIdx = *os.fIdx; const NodeIndex &gIdx = gNode.getNodeIndex(); int depth = gNode.getDepth(); @@ -213,12 +209,11 @@ template void DerivativeCalculator::applyOperator_bw0(OperatorState void DerivativeCalculator::applyOperator(OperatorState &os) { - MWNode &gNode = *os.gNode; - MWNode &fNode = *os.fNode; +template void DerivativeCalculator::applyOperator(OperatorState &os) { + MWNode &gNode = *os.gNode; + MWNode &fNode = *os.fNode; const NodeIndex &fIdx = *os.fIdx; const NodeIndex &gIdx = gNode.getNodeIndex(); int depth = gNode.getDepth(); @@ -261,49 +256,12 @@ template void DerivativeCalculator::applyOperator(OperatorState &o /** Perform the required linear algebra operations in order to apply an operator component to a f-node in a n-dimensional tensor space. */ -template void DerivativeCalculator::tensorApplyOperComp(OperatorState &os) { - double **aux = os.getAuxData(); +template void DerivativeCalculator::tensorApplyOperComp(OperatorState &os) { + T **aux = os.getAuxData(); double **oData = os.getOperData(); -#ifdef HAVE_BLAS - double mult = 0.0; - for (int i = 0; i < D; i++) { - if (oData[i] != 0) { - if (i == D - 1) { // Last dir: Add up into g - mult = 1.0; - } - const double *f = aux[i]; - double *g = const_cast(aux[i + 1]); - cblas_dgemm(CblasColMajor, - CblasTrans, - CblasNoTrans, - os.kp1_dm1, - os.kp1, - os.kp1, - 1.0, - f, - os.kp1, - oData[i], - os.kp1, - mult, - g, - os.kp1_dm1); - } else { - // Identity operator in direction i - Eigen::Map f(aux[i], os.kp1, os.kp1_dm1); - Eigen::Map g(aux[i + 1], os.kp1_dm1, os.kp1); - if (oData[i] == 0) { - if (i == D - 1) { // Last dir: Add up into g - g += f.transpose(); - } else { - g = f.transpose(); - } - } - } - } -#else for (int i = 0; i < D; i++) { - Eigen::Map f(aux[i], os.kp1, os.kp1_dm1); - Eigen::Map g(aux[i + 1], os.kp1_dm1, os.kp1); + Eigen::Map> f(aux[i], os.kp1, os.kp1_dm1); + Eigen::Map> g(aux[i + 1], os.kp1_dm1, os.kp1); if (oData[i] != nullptr) { Eigen::Map op(oData[i], os.kp1, os.kp1); if (i == D - 1) { // Last dir: Add up into g @@ -320,15 +278,18 @@ template void DerivativeCalculator::tensorApplyOperComp(OperatorState } } } -#endif } -template MWNodeVector *DerivativeCalculator::getInitialWorkVector(MWTree &tree) const { +template MWNodeVector *DerivativeCalculator::getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } -template class DerivativeCalculator<1>; -template class DerivativeCalculator<2>; -template class DerivativeCalculator<3>; +template class DerivativeCalculator<1, double>; +template class DerivativeCalculator<2, double>; +template class DerivativeCalculator<3, double>; + +template class DerivativeCalculator<1, ComplexDouble>; +template class DerivativeCalculator<2, ComplexDouble>; +template class DerivativeCalculator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/DerivativeCalculator.h b/src/treebuilders/DerivativeCalculator.h index 5d4d28716..347554a46 100644 --- a/src/treebuilders/DerivativeCalculator.h +++ b/src/treebuilders/DerivativeCalculator.h @@ -30,40 +30,40 @@ namespace mrcpp { -template class DerivativeCalculator final : public TreeCalculator { +template class DerivativeCalculator final : public TreeCalculator { public: - DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f); + DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f); ~DerivativeCalculator() override; - MWNodeVector *getInitialWorkVector(MWTree &tree) const override; - void calcNode(MWNode &fNode, MWNode &gNode); + MWNodeVector *getInitialWorkVector(MWTree &tree) const override; + void calcNode(MWNode &fNode, MWNode &gNode); private: int applyDir; - FunctionTree *fTree; + FunctionTree *fTree; DerivativeOperator *oper; std::vector band_t; std::vector calc_t; std::vector norm_t; - OperatorStatistics operStat; + OperatorStatistics operStat; - MWNodeVector makeOperBand(const MWNode &gNode, std::vector> &idx_band); + MWNodeVector makeOperBand(const MWNode &gNode, std::vector> &idx_band); void initTimers(); void clearTimers(); void printTimers() const; - void calcNode(MWNode &node) override; + void calcNode(MWNode &node) override; void postProcess() override { printTimers(); clearTimers(); initTimers(); } - void applyOperator(OperatorState &os); - void applyOperator_bw0(OperatorState &os); - void tensorApplyOperComp(OperatorState &os); + void applyOperator(OperatorState &os); + void applyOperator_bw0(OperatorState &os); + void tensorApplyOperComp(OperatorState &os); }; } // namespace mrcpp diff --git a/src/treebuilders/MapCalculator.h b/src/treebuilders/MapCalculator.h index 492c1f440..33f799ee9 100644 --- a/src/treebuilders/MapCalculator.h +++ b/src/treebuilders/MapCalculator.h @@ -29,24 +29,24 @@ namespace mrcpp { -template class MapCalculator final : public TreeCalculator { +template class MapCalculator final : public TreeCalculator { public: - MapCalculator(FMap fm, FunctionTree &inp) + MapCalculator(FMap fm, FunctionTree &inp) : func(&inp) , fmap(std::move(fm)) {} private: - FunctionTree *func; - FMap fmap; - void calcNode(MWNode &node_o) override { + FunctionTree *func; + FMap fmap; + void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + MWNode node_i = func->getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] = fmap(coefs_i[j]); } node_o.cvTransform(Backward); node_o.mwTransform(Compression); diff --git a/src/treebuilders/MultiplicationAdaptor.h b/src/treebuilders/MultiplicationAdaptor.h index ff0fe992d..9637ac055 100644 --- a/src/treebuilders/MultiplicationAdaptor.h +++ b/src/treebuilders/MultiplicationAdaptor.h @@ -31,19 +31,19 @@ namespace mrcpp { -template class MultiplicationAdaptor : public TreeAdaptor { +template class MultiplicationAdaptor : public TreeAdaptor { public: - MultiplicationAdaptor(double pr, int ms, FunctionTreeVector &t) - : TreeAdaptor(ms) + MultiplicationAdaptor(double pr, int ms, FunctionTreeVector &t) + : TreeAdaptor(ms) , prec(pr) , trees(t) {} ~MultiplicationAdaptor() override = default; protected: double prec; - mutable FunctionTreeVector trees; + mutable FunctionTreeVector trees; - bool splitNode(const MWNode &node) const override { + bool splitNode(const MWNode &node) const override { if (this->trees.size() != 2) MSG_ERROR("Invalid tree vec size: " << this->trees.size()); auto &pNode0 = get_func(trees, 0).getNode(node.getNodeIndex()); auto &pNode1 = get_func(trees, 1).getNode(node.getNodeIndex()); diff --git a/src/treebuilders/MultiplicationCalculator.h b/src/treebuilders/MultiplicationCalculator.h index ba5669f4d..49fa67948 100644 --- a/src/treebuilders/MultiplicationCalculator.h +++ b/src/treebuilders/MultiplicationCalculator.h @@ -30,28 +30,38 @@ namespace mrcpp { -template class MultiplicationCalculator final : public TreeCalculator { +template class MultiplicationCalculator final : public TreeCalculator { public: - MultiplicationCalculator(const FunctionTreeVector &inp) - : prod_vec(inp) {} + MultiplicationCalculator(const FunctionTreeVector &inp, bool conjugate = false) + : prod_vec(inp) + , conj(conjugate) {} private: - FunctionTreeVector prod_vec; + FunctionTreeVector prod_vec; + bool conj; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) { const NodeIndex &idx = node_o.getNodeIndex(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); for (int j = 0; j < node_o.getNCoefs(); j++) { coefs_o[j] = 1.0; } for (int i = 0; i < this->prod_vec.size(); i++) { - double c_i = get_coef(this->prod_vec, i); - FunctionTree &func_i = get_func(this->prod_vec, i); + T c_i = get_coef(this->prod_vec, i); + FunctionTree &func_i = get_func(this->prod_vec, i); // This generates missing nodes - MWNode node_i = func_i.getNode(idx); // Copy node + MWNode node_i = func_i.getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); int n_coefs = node_i.getNCoefs(); - for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * coefs_i[j]; } + if constexpr (std::is_same::value) { + if (func_i.conjugate() xor (conj and i == 0)) { // NB: take complex conjugate of "bra" + for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * std::conj(coefs_i[j]); } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * coefs_i[j]; } + } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * coefs_i[j]; } + } } node_o.cvTransform(Backward); node_o.mwTransform(Compression); diff --git a/src/treebuilders/PowerCalculator.h b/src/treebuilders/PowerCalculator.h index bb2124b73..79147fc4b 100644 --- a/src/treebuilders/PowerCalculator.h +++ b/src/treebuilders/PowerCalculator.h @@ -29,25 +29,25 @@ namespace mrcpp { -template class PowerCalculator final : public TreeCalculator { +template class PowerCalculator final : public TreeCalculator { public: - PowerCalculator(FunctionTree &inp, double pow) + PowerCalculator(FunctionTree &inp, double pow) : power(pow) , func(&inp) {} private: double power; - FunctionTree *func; + FunctionTree *func; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + MWNode node_i = func->getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); + const T *coefs_i = node_i.getCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::pow(coefs_i[j], this->power); } node_o.cvTransform(Backward); node_o.mwTransform(Compression); diff --git a/src/treebuilders/ProjectionCalculator.cpp b/src/treebuilders/ProjectionCalculator.cpp index 46335d092..931733232 100644 --- a/src/treebuilders/ProjectionCalculator.cpp +++ b/src/treebuilders/ProjectionCalculator.cpp @@ -30,18 +30,19 @@ using Eigen::MatrixXd; namespace mrcpp { -template void ProjectionCalculator::calcNode(MWNode &node) { +template void ProjectionCalculator::calcNode(MWNode &node) { MatrixXd exp_pts; node.getExpandedChildPts(exp_pts); assert(exp_pts.cols() == node.getNCoefs()); Coord r; - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < node.getNCoefs(); i++) { for (int d = 0; d < D; d++) { r[d] = scaling_factor[d] * exp_pts(d, i); } coefs[i] = this->func->evalf(r); } + node.cvTransform(Backward); node.mwTransform(Compression); node.setHasCoefs(); @@ -50,7 +51,7 @@ template void ProjectionCalculator::calcNode(MWNode &node) { /* Old interpolating version, somewhat faster template -void ProjectionCalculator::calcNode(MWNode &node) { +void ProjectionCalculator::calcNode(MWNode &node) { const ScalingBasis &sf = node.getMWTree().getMRA().getScalingBasis(); if (sf.getScalingType() != Interpol) { NOT_IMPLEMENTED_ABORT; @@ -104,8 +105,12 @@ void ProjectionCalculator::calcNode(MWNode &node) { } */ -template class ProjectionCalculator<1>; -template class ProjectionCalculator<2>; -template class ProjectionCalculator<3>; +template class ProjectionCalculator<1, double>; +template class ProjectionCalculator<2, double>; +template class ProjectionCalculator<3, double>; + +template class ProjectionCalculator<1, ComplexDouble>; +template class ProjectionCalculator<2, ComplexDouble>; +template class ProjectionCalculator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/ProjectionCalculator.h b/src/treebuilders/ProjectionCalculator.h index 2fbbb09fe..067c41422 100644 --- a/src/treebuilders/ProjectionCalculator.h +++ b/src/treebuilders/ProjectionCalculator.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class ProjectionCalculator final : public TreeCalculator { +template class ProjectionCalculator final : public TreeCalculator { public: - ProjectionCalculator(const RepresentableFunction &inp_func, const std::array &sf) + ProjectionCalculator(const RepresentableFunction &inp_func, const std::array &sf) : func(&inp_func) , scaling_factor(sf) {} private: - const RepresentableFunction *func; + const RepresentableFunction *func; const std::array scaling_factor; - void calcNode(MWNode &node) override; + void calcNode(MWNode &node) override; }; } // namespace mrcpp diff --git a/src/treebuilders/SplitAdaptor.h b/src/treebuilders/SplitAdaptor.h index b9d50fe8b..7e81bbe8b 100644 --- a/src/treebuilders/SplitAdaptor.h +++ b/src/treebuilders/SplitAdaptor.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class SplitAdaptor final : public TreeAdaptor { +template class SplitAdaptor final : public TreeAdaptor { public: SplitAdaptor(int ms, bool sp) - : TreeAdaptor(ms) + : TreeAdaptor(ms) , split(sp) {} private: bool split; - bool splitNode(const MWNode &node) const override { return this->split; } + bool splitNode(const MWNode &node) const override { return this->split; } }; } // namespace mrcpp diff --git a/src/treebuilders/SquareCalculator.h b/src/treebuilders/SquareCalculator.h index e9bb0f8d3..015b90f82 100644 --- a/src/treebuilders/SquareCalculator.h +++ b/src/treebuilders/SquareCalculator.h @@ -29,24 +29,42 @@ namespace mrcpp { -template class SquareCalculator final : public TreeCalculator { +template class SquareCalculator final : public TreeCalculator { public: - SquareCalculator(FunctionTree &inp) - : func(&inp) {} + SquareCalculator(FunctionTree &inp, bool conjugate = false) + : func(&inp) + , conj(conjugate) {} private: - FunctionTree *func; + FunctionTree *func; + bool conj; - void calcNode(MWNode &node_o) override { + void calcNode(MWNode &node_o) { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); - double *coefs_o = node_o.getCoefs(); + T *coefs_o = node_o.getCoefs(); // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + MWNode node_i = func->getNode(idx); // Copy node node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); - const double *coefs_i = node_i.getCoefs(); - for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * coefs_i[j]; } + const T *coefs_i = node_i.getCoefs(); + if constexpr (std::is_same::value) { + if (func->conjugate()) { + if (conj) { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::conj(coefs_i[j]) * coefs_i[j]; } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::conj(coefs_i[j]) * std::conj(coefs_i[j]); } + } + } else { + if (conj) { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * std::conj(coefs_i[j]); } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * coefs_i[j]; } + } + } + } else { + for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * coefs_i[j]; } + } node_o.cvTransform(Backward); node_o.mwTransform(Compression); node_o.setHasCoefs(); diff --git a/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.cpp b/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.cpp index a73f9a0d5..844f952d9 100644 --- a/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.cpp +++ b/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.cpp @@ -33,16 +33,14 @@ using Eigen::VectorXd; namespace mrcpp { - /** @param[in] node: ... * @details This will ... (work in progress) - * - * - * - * + * + * + * + * */ -void TimeEvolution_CrossCorrelationCalculator::calcNode(MWNode<2> &node) -{ +void TimeEvolution_CrossCorrelationCalculator::calcNode(MWNode<2> &node) { node.zeroCoefs(); int type = node.getMWTree().getMRA().getScalingBasis().getScalingType(); switch (type) { @@ -63,49 +61,43 @@ void TimeEvolution_CrossCorrelationCalculator::calcNode(MWNode<2> &node) node.calcNorms(); } - - /** @param[in] node: ... * @details This will ... (work in progress) - * - * - * - * + * + * + * + * */ -//template -void TimeEvolution_CrossCorrelationCalculator::applyCcc(MWNode<2> &node) -{ - //std::cout << node; - // The scale of J power integrals: - //int scale = node.getScale() + 1; //scale = n = (n - 1) + 1 - - int t_dim = node.getTDim(); //t_dim = 4 - int kp1_d = node.getKp1_d(); //kp1_d = (k + 1)^2 +// template +void TimeEvolution_CrossCorrelationCalculator::applyCcc(MWNode<2> &node) { + // std::cout << node; + // The scale of J power integrals: + // int scale = node.getScale() + 1; //scale = n = (n - 1) + 1 + + int t_dim = node.getTDim(); // t_dim = 4 + int kp1_d = node.getKp1_d(); // kp1_d = (k + 1)^2 VectorXd vec_o = VectorXd::Zero(t_dim * kp1_d); const NodeIndex<2> &idx = node.getNodeIndex(); - auto & J_power_inetgarls = *this->J_power_inetgarls[node.getScale() + 1]; - - for (int i = 0; i < t_dim; i++) - { + auto &J_power_inetgarls = *this->J_power_inetgarls[node.getScale() + 1]; + + for (int i = 0; i < t_dim; i++) { NodeIndex<2> l = idx.child(i); int l_b = l[1] - l[0]; int vec_o_segment_index = 0; - for( int p = 0; p <= node.getOrder(); p++ ) - for( int j = 0; j <= node.getOrder(); j++ ) - { - //std::min(M, N) could be used for breaking the following loop - //this->cross_correlation->Matrix.size() should be big enough a priori - for( int k = 0; 2*k + p + j < J_power_inetgarls[l_b].size(); k++ ) - { + for (int p = 0; p <= node.getOrder(); p++) + for (int j = 0; j <= node.getOrder(); j++) { + // std::min(M, N) could be used for breaking the following loop + // this->cross_correlation->Matrix.size() should be big enough a priori + for (int k = 0; 2 * k + p + j < J_power_inetgarls[l_b].size(); k++) { double J; - if( this->imaginary ) J = J_power_inetgarls[l_b][2*k + p + j].imag(); - else J = J_power_inetgarls[l_b][2*k + p + j].real(); - vec_o.segment(i * kp1_d, kp1_d)(vec_o_segment_index) - += - J * cross_correlation->Matrix[k](p, j); //by default eigen library reads a transpose matrix from a file + if (this->imaginary) + J = J_power_inetgarls[l_b][2 * k + p + j].imag(); + else + J = J_power_inetgarls[l_b][2 * k + p + j].real(); + vec_o.segment(i * kp1_d, kp1_d)(vec_o_segment_index) += J * cross_correlation->Matrix[k](p, j); // by default eigen library reads a transpose matrix from a file } vec_o_segment_index++; } @@ -113,9 +105,9 @@ void TimeEvolution_CrossCorrelationCalculator::applyCcc(MWNode<2> &node) double *coefs = node.getCoefs(); for (int i = 0; i < t_dim * kp1_d; i++) { - //auto scaling_factor = node.getMWTree().getMRA().getWorldBox().getScalingFactor(0); + // auto scaling_factor = node.getMWTree().getMRA().getWorldBox().getScalingFactor(0); coefs[i] = vec_o(i); - //std::cout<< "coefs[i] = " << coefs[i] << std::endl; + // std::cout<< "coefs[i] = " << coefs[i] << std::endl; } } diff --git a/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h b/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h index b2d6d0542..f2a68295f 100644 --- a/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h +++ b/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h @@ -30,38 +30,36 @@ #include "core/SchrodingerEvolution_CrossCorrelation.h" #include "functions/JpowerIntegrals.h" - namespace mrcpp { - /** @class TimeEvolution_CrossCorrelationCalculator * * @brief An efficient way to calculate ... (work in progress) * * @details An efficient way to calculate ... having the form * \f$ \ldots = \ldots \f$ - * - * - * + * + * + * */ -class TimeEvolution_CrossCorrelationCalculator final : public TreeCalculator<2> -{ +class TimeEvolution_CrossCorrelationCalculator final : public TreeCalculator<2> { public: - TimeEvolution_CrossCorrelationCalculator - (std::map & J, SchrodingerEvolution_CrossCorrelation *cross_correlation, bool imaginary) - : J_power_inetgarls(J), cross_correlation(cross_correlation), imaginary(imaginary){} -//private: + TimeEvolution_CrossCorrelationCalculator(std::map &J, SchrodingerEvolution_CrossCorrelation *cross_correlation, bool imaginary) + : J_power_inetgarls(J) + , cross_correlation(cross_correlation) + , imaginary(imaginary) {} + // private: std::map J_power_inetgarls; SchrodingerEvolution_CrossCorrelation *cross_correlation; - + /// @brief If False then the calculator is using th real part of integrals, otherwise - the imaginary part. bool imaginary; void calcNode(MWNode<2> &node) override; - //template + // template void applyCcc(MWNode<2> &node); - //template void applyCcc(MWNode<2> &node, CrossCorrelationCache &ccc); + // template void applyCcc(MWNode<2> &node, CrossCorrelationCache &ccc); }; } // namespace mrcpp diff --git a/src/treebuilders/TreeAdaptor.h b/src/treebuilders/TreeAdaptor.h index a46bab648..80cecb09e 100644 --- a/src/treebuilders/TreeAdaptor.h +++ b/src/treebuilders/TreeAdaptor.h @@ -30,7 +30,7 @@ namespace mrcpp { -template class TreeAdaptor { +template class TreeAdaptor { public: TreeAdaptor(int ms) : maxScale(ms) {} @@ -38,9 +38,9 @@ template class TreeAdaptor { void setMaxScale(int ms) { this->maxScale = ms; } - void splitNodeVector(MWNodeVector &out, MWNodeVector &inp) const { + void splitNodeVector(MWNodeVector &out, MWNodeVector &inp) const { for (int n = 0; n < inp.size(); n++) { - MWNode &node = *inp[n]; + MWNode &node = *inp[n]; // Can be BranchNode in operator application if (node.isBranchNode()) continue; if (node.getScale() + 2 > this->maxScale) continue; @@ -54,7 +54,7 @@ template class TreeAdaptor { protected: int maxScale; - virtual bool splitNode(const MWNode &node) const = 0; + virtual bool splitNode(const MWNode &node) const = 0; }; } // namespace mrcpp diff --git a/src/treebuilders/TreeBuilder.cpp b/src/treebuilders/TreeBuilder.cpp index 223d94794..ba0e5d973 100644 --- a/src/treebuilders/TreeBuilder.cpp +++ b/src/treebuilders/TreeBuilder.cpp @@ -35,13 +35,12 @@ namespace mrcpp { -template -void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const { +template void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const { Timer calc_t(false), split_t(false), norm_t(false); println(10, " == Building tree"); - MWNodeVector *newVec = nullptr; - MWNodeVector *workVec = calculator.getInitialWorkVector(tree); + MWNodeVector *newVec = nullptr; + MWNodeVector *workVec = calculator.getInitialWorkVector(tree); double sNorm = 0.0; double wNorm = 0.0; @@ -69,7 +68,7 @@ void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeA norm_t.stop(); split_t.resume(); - newVec = new MWNodeVector; + newVec = new MWNodeVector; if (iter >= maxIter and maxIter >= 0) workVec->clear(); adaptor.splitNodeVector(*newVec, *workVec); split_t.stop(); @@ -87,11 +86,11 @@ void TreeBuilder::build(MWTree &tree, TreeCalculator &calculator, TreeA print::time(10, "Time split", split_t); } -template void TreeBuilder::clear(MWTree &tree, TreeCalculator &calculator) const { +template void TreeBuilder::clear(MWTree &tree, TreeCalculator &calculator) const { println(10, " == Clearing tree"); Timer clean_t; - MWNodeVector nodeVec; + MWNodeVector nodeVec; tree_utils::make_node_table(tree, nodeVec); calculator.calcNodeVector(nodeVec); // clear all coefficients clean_t.stop(); @@ -104,16 +103,16 @@ template void TreeBuilder::clear(MWTree &tree, TreeCalculator & print::separator(10, ' '); } -template int TreeBuilder::split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const { +template int TreeBuilder::split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const { println(10, " == Refining tree"); Timer split_t; - MWNodeVector newVec; - MWNodeVector *workVec = tree.copyEndNodeTable(); + MWNodeVector newVec; + MWNodeVector *workVec = tree.copyEndNodeTable(); adaptor.splitNodeVector(newVec, *workVec); if (passCoefs) { for (int i = 0; i < workVec->size(); i++) { - MWNode &node = *(*workVec)[i]; + MWNode &node = *(*workVec)[i]; if (node.isBranchNode()) { node.giveChildrenCoefs(true); } } } @@ -131,11 +130,11 @@ template int TreeBuilder::split(MWTree &tree, TreeAdaptor &adap return newVec.size(); } -template void TreeBuilder::calc(MWTree &tree, TreeCalculator &calculator) const { +template void TreeBuilder::calc(MWTree &tree, TreeCalculator &calculator) const { println(10, " == Calculating tree"); Timer calc_t; - MWNodeVector *workVec = calculator.getInitialWorkVector(tree); + MWNodeVector *workVec = calculator.getInitialWorkVector(tree); calculator.calcNodeVector(*workVec); printout(10, " -- #" << std::setw(3) << 0 << ": Calculated "); printout(10, std::setw(6) << workVec->size() << " nodes "); @@ -148,26 +147,30 @@ template void TreeBuilder::calc(MWTree &tree, TreeCalculator &c print::time(10, "Time calc", calc_t); } -template double TreeBuilder::calcScalingNorm(const MWNodeVector &vec) const { +template double TreeBuilder::calcScalingNorm(const MWNodeVector &vec) const { double sNorm = 0.0; for (int i = 0; i < vec.size(); i++) { - const MWNode &node = *vec[i]; + const MWNode &node = *vec[i]; if (node.getDepth() >= 0) sNorm += node.getScalingNorm(); } return sNorm; } -template double TreeBuilder::calcWaveletNorm(const MWNodeVector &vec) const { +template double TreeBuilder::calcWaveletNorm(const MWNodeVector &vec) const { double wNorm = 0.0; for (int i = 0; i < vec.size(); i++) { - const MWNode &node = *vec[i]; + const MWNode &node = *vec[i]; if (node.getDepth() >= 0) wNorm += node.getWaveletNorm(); } return wNorm; } -template class TreeBuilder<1>; -template class TreeBuilder<2>; -template class TreeBuilder<3>; +template class TreeBuilder<1, double>; +template class TreeBuilder<2, double>; +template class TreeBuilder<3, double>; + +template class TreeBuilder<1, ComplexDouble>; +template class TreeBuilder<2, ComplexDouble>; +template class TreeBuilder<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/treebuilders/TreeBuilder.h b/src/treebuilders/TreeBuilder.h index 313e9f4f4..81c32afe6 100644 --- a/src/treebuilders/TreeBuilder.h +++ b/src/treebuilders/TreeBuilder.h @@ -29,16 +29,16 @@ namespace mrcpp { -template class TreeBuilder final { +template class TreeBuilder final { public: - void build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const; - void clear(MWTree &tree, TreeCalculator &calculator) const; - void calc(MWTree &tree, TreeCalculator &calculator) const; - int split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const; + void build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const; + void clear(MWTree &tree, TreeCalculator &calculator) const; + void calc(MWTree &tree, TreeCalculator &calculator) const; + int split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const; private: - double calcScalingNorm(const MWNodeVector &vec) const; - double calcWaveletNorm(const MWNodeVector &vec) const; + double calcScalingNorm(const MWNodeVector &vec) const; + double calcWaveletNorm(const MWNodeVector &vec) const; }; } // namespace mrcpp diff --git a/src/treebuilders/TreeCalculator.h b/src/treebuilders/TreeCalculator.h index acd9f00f8..1bf41f407 100644 --- a/src/treebuilders/TreeCalculator.h +++ b/src/treebuilders/TreeCalculator.h @@ -29,20 +29,20 @@ namespace mrcpp { -template class TreeCalculator { +template class TreeCalculator { public: TreeCalculator() = default; virtual ~TreeCalculator() = default; - virtual MWNodeVector *getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } + virtual MWNodeVector *getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } - virtual void calcNodeVector(MWNodeVector &nodeVec) { + virtual void calcNodeVector(MWNodeVector &nodeVec) { #pragma omp parallel shared(nodeVec) num_threads(mrcpp_get_num_threads()) { int nNodes = nodeVec.size(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &node = *nodeVec[n]; + MWNode &node = *nodeVec[n]; calcNode(node); } } @@ -50,7 +50,7 @@ template class TreeCalculator { } protected: - virtual void calcNode(MWNode &node) = 0; + virtual void calcNode(MWNode &node) = 0; virtual void postProcess() {} }; diff --git a/src/treebuilders/WaveletAdaptor.h b/src/treebuilders/WaveletAdaptor.h index 15da130e1..829039bf4 100644 --- a/src/treebuilders/WaveletAdaptor.h +++ b/src/treebuilders/WaveletAdaptor.h @@ -31,18 +31,16 @@ namespace mrcpp { -template class WaveletAdaptor : public TreeAdaptor { +template class WaveletAdaptor : public TreeAdaptor { public: WaveletAdaptor(double pr, int ms, bool ap = false, double sf = 1.0) - : TreeAdaptor(ms) + : TreeAdaptor(ms) , absPrec(ap) , prec(pr) , splitFac(sf) {} ~WaveletAdaptor() override = default; - void setPrecFunction(const std::function &idx)> &prec_func) { - this->precFunc = prec_func; - } + void setPrecFunction(const std::function &idx)> &prec_func) { this->precFunc = prec_func; } protected: bool absPrec; @@ -50,7 +48,7 @@ template class WaveletAdaptor : public TreeAdaptor { double splitFac; std::function &idx)> precFunc = [](const NodeIndex &idx) { return 1.0; }; - bool splitNode(const MWNode &node) const override { + bool splitNode(const MWNode &node) const override { auto precFac = this->precFunc(node.getNodeIndex()); // returns 1.0 by default return tree_utils::split_check(node, this->prec * precFac, this->splitFac, this->absPrec); } diff --git a/src/treebuilders/add.cpp b/src/treebuilders/add.cpp index 86b7f30a7..4ee28cff6 100644 --- a/src/treebuilders/add.cpp +++ b/src/treebuilders/add.cpp @@ -61,19 +61,11 @@ namespace mrcpp { * no coefs). * */ -template -void add(double prec, - FunctionTree &out, - double a, - FunctionTree &inp_a, - double b, - FunctionTree &inp_b, - int maxIter, - bool absPrec) { - FunctionTreeVector tmp_vec; +template void add(double prec, FunctionTree &out, T a, FunctionTree &inp_a, T b, FunctionTree &inp_b, int maxIter, bool absPrec, bool conjugate) { + FunctionTreeVector tmp_vec; tmp_vec.push_back(std::make_tuple(a, &inp_a)); tmp_vec.push_back(std::make_tuple(b, &inp_b)); - add(prec, out, tmp_vec, maxIter, absPrec); + add(prec, out, tmp_vec, maxIter, absPrec, conjugate); } /** @brief Addition of several MW function representations, adaptive grid @@ -98,25 +90,26 @@ void add(double prec, * no coefs). * */ -template void add(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter, bool absPrec) { +template void add(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter, bool absPrec, bool conjugate) { for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - AdditionCalculator calculator(inp); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + AdditionCalculator calculator(inp, conjugate); builder.build(out, calculator, adaptor, maxIter); Timer trans_t; out.mwTransform(BottomUp); out.calcSquareNorm(); + trans_t.stop(); Timer clean_t; for (int i = 0; i < inp.size(); i++) { - FunctionTree &tree = get_func(inp, i); + FunctionTree &tree = get_func(inp, i); tree.deleteGenerated(); } clean_t.stop(); @@ -126,67 +119,61 @@ template void add(double prec, FunctionTree &out, FunctionTreeVector< print::separator(10, ' '); } -template void add(double prec, FunctionTree &out, std::vector *> &inp, int maxIter, bool absPrec) { - FunctionTreeVector inp_vec; +template void add(double prec, FunctionTree &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate) { + FunctionTreeVector inp_vec; for (auto &t : inp) inp_vec.push_back({1.0, t}); - add(prec, out, inp_vec, maxIter, absPrec); + add(prec, out, inp_vec, maxIter, absPrec, conjugate); } -template void add<1>(double prec, - FunctionTree<1> &out, - double a, - FunctionTree<1> &tree_a, - double b, - FunctionTree<1> &tree_b, - int maxIter, - bool absPrec); -template void add<2>(double prec, - FunctionTree<2> &out, - double a, - FunctionTree<2> &tree_a, - double b, - FunctionTree<2> &tree_b, - int maxIter, - bool absPrec); -template void add<3>(double prec, - FunctionTree<3> &out, - double a, - FunctionTree<3> &tree_a, - double b, - FunctionTree<3> &tree_b, - int maxIter, - bool absPrec); - -template void add<1>(double prec, - FunctionTree<1> &out, - FunctionTreeVector<1> &inp, - int maxIter, - bool absPrec); -template void add<2>(double prec, - FunctionTree<2> &out, - FunctionTreeVector<2> &inp, - int maxIter, - bool absPrec); -template void add<3>(double prec, - FunctionTree<3> &out, - FunctionTreeVector<3> &inp, - int maxIter, - bool absPrec); - -template void add<1>(double prec, - FunctionTree<1> &out, - std::vector *> &inp, - int maxIter, - bool absPrec); -template void add<2>(double prec, - FunctionTree<2> &out, - std::vector *> &inp, - int maxIter, - bool absPrec); -template void add<3>(double prec, - FunctionTree<3> &out, - std::vector *> &inp, - int maxIter, - bool absPrec); +template void +add<1, double>(double prec, FunctionTree<1, double> &out, double a, FunctionTree<1, double> &tree_a, double b, FunctionTree<1, double> &tree_b, int maxIter, bool absPrec, bool conjugate); +template void +add<2, double>(double prec, FunctionTree<2, double> &out, double a, FunctionTree<2, double> &tree_a, double b, FunctionTree<2, double> &tree_b, int maxIter, bool absPrec, bool conjugate); +template void +add<3, double>(double prec, FunctionTree<3, double> &out, double a, FunctionTree<3, double> &tree_a, double b, FunctionTree<3, double> &tree_b, int maxIter, bool absPrec, bool conjugate); + +template void add<1, double>(double prec, FunctionTree<1, double> &out, FunctionTreeVector<1, double> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<2, double>(double prec, FunctionTree<2, double> &out, FunctionTreeVector<2, double> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<3, double>(double prec, FunctionTree<3, double> &out, FunctionTreeVector<3, double> &inp, int maxIter, bool absPrec, bool conjugate); + +template void add<1, double>(double prec, FunctionTree<1, double> &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<2, double>(double prec, FunctionTree<2, double> &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<3, double>(double prec, FunctionTree<3, double> &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate); + +template void add<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + ComplexDouble a, + FunctionTree<1, ComplexDouble> &tree_a, + ComplexDouble b, + FunctionTree<1, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool conjugate); +template void add<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + ComplexDouble a, + FunctionTree<2, ComplexDouble> &tree_a, + ComplexDouble b, + FunctionTree<2, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool conjugate); +template void add<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + ComplexDouble a, + FunctionTree<3, ComplexDouble> &tree_a, + ComplexDouble b, + FunctionTree<3, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool conjugate); + +template void add<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, FunctionTreeVector<1, ComplexDouble> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, FunctionTreeVector<2, ComplexDouble> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, FunctionTreeVector<3, ComplexDouble> &inp, int maxIter, bool absPrec, bool conjugate); + +template void add<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate); +template void add<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, std::vector *> &inp, int maxIter, bool absPrec, bool conjugate); } // namespace mrcpp diff --git a/src/treebuilders/add.h b/src/treebuilders/add.h index 105245829..1f94dbc9d 100644 --- a/src/treebuilders/add.h +++ b/src/treebuilders/add.h @@ -25,26 +25,11 @@ #pragma once - namespace mrcpp { -template void add(double prec, - FunctionTree &out, - double a, - FunctionTree &tree_a, - double b, - FunctionTree &tree_b, - int maxIter = -1, - bool absPrec = false); -template void add(double prec, - FunctionTree &out, - FunctionTreeVector &inp, - int maxIter = -1, - bool absPrec = false); -template void add(double prec, - FunctionTree &out, - std::vector *> &inp, - int maxIter = -1, - bool absPrec = false); +template +void add(double prec, FunctionTree &out, T a, FunctionTree &tree_a, T b, FunctionTree &tree_b, int maxIter = -1, bool absPrec = false, bool conjugate = false); +template void add(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); +template void add(double prec, FunctionTree &out, std::vector *> &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); } // namespace mrcpp diff --git a/src/treebuilders/apply.cpp b/src/treebuilders/apply.cpp index 3dc49de3c..cfe17b86f 100644 --- a/src/treebuilders/apply.cpp +++ b/src/treebuilders/apply.cpp @@ -38,10 +38,11 @@ #include "trees/FunctionTree.h" #include "utils/Printer.h" #include "utils/Timer.h" +#include namespace mrcpp { -template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec); +template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec); /** @brief Application of MW integral convolution operator * @@ -64,16 +65,16 @@ template void apply_on_unit_cell(bool inside, double prec, FunctionTree< * no coefs). * */ -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); Timer pre_t; oper.calcBandWidths(prec); int maxScale = out.getMRA().getMaxScale(); - WaveletAdaptor adaptor(prec, maxScale, absPrec); - ConvolutionCalculator calculator(prec, oper, inp); + WaveletAdaptor adaptor(prec, maxScale, absPrec); + ConvolutionCalculator calculator(prec, oper, inp); pre_t.stop(); - TreeBuilder builder; + TreeBuilder builder; builder.build(out, calculator, adaptor, maxIter); Timer post_t; @@ -91,6 +92,49 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat print::separator(10, ' '); } +/** @brief Application of MW integral convolution operator on Four component + * + * @param[in] prec: Build precision of output function + * @param[out] out: Output function to be built + * @param[in] oper: Convolution operator to apply + * @param[in] inp: Input function + * @param[in] metric: 4x4 array with coefficients that relates the in and out components + * @param[in] maxIter: Maximum number of refinement iterations in output tree, default -1 + * @param[in] absPrec: Build output tree based on absolute precision, default false + * + * @details The output function will be computed using the general algorithm: + * - For each input component apply the operator + * - Compute MW coefs on current grid + * - Refine grid where necessary based on `prec` + * - Repeat until convergence or `maxIter` is reached + * - `prec < 0` or `maxIter = 0` means NO refinement + * - `maxIter < 0` means no bound + * - After application multiply by metric coefficient, and put in relevant output component + * + * @note This algorithm will start at whatever grid is present in the `out` + * tree when the function is called (this grid should however be EMPTY, e.i. + * no coefs). + * + */ +template void apply(double prec, CompFunction &out, ConvolutionOperator &oper, const CompFunction &inp, ComplexDouble metric[4][4], int maxIter, bool absPrec) { + + for (int icomp = 0; icomp < inp.Ncomp(); icomp++) { + for (int ocomp = 0; ocomp < 4; ocomp++) { + if (std::norm(metric[icomp][ocomp]) > MachinePrec) { + if (inp.isreal()) { + if (out.CompD[ocomp] == nullptr) out.alloc_comp(ocomp); + apply(prec, *out.CompD[ocomp], oper, *inp.CompD[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompD[ocomp]->rescale(metric[icomp][ocomp].real()); } + } else { + if (out.CompC[ocomp] == nullptr) out.alloc_comp(ocomp); + apply(prec, *out.CompC[ocomp], oper, *inp.CompC[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompC[ocomp]->rescale(metric[icomp][ocomp]); } + } + } + } + } +} + /** @brief Application of MW integral convolution operator * * @param[in] inside: Use points inside (true) or outside (false) the unitcell @@ -113,18 +157,18 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat * no coefs). * */ -template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply_on_unit_cell(bool inside, double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); Timer pre_t; oper.calcBandWidths(prec); int maxScale = out.getMRA().getMaxScale(); - WaveletAdaptor adaptor(prec, maxScale, absPrec); - ConvolutionCalculator calculator(prec, oper, inp); + WaveletAdaptor adaptor(prec, maxScale, absPrec); + ConvolutionCalculator calculator(prec, oper, inp); calculator.startManipulateOperator(inside); pre_t.stop(); - TreeBuilder builder; + TreeBuilder builder; builder.build(out, calculator, adaptor, maxIter); Timer post_t; @@ -166,7 +210,7 @@ template void apply_on_unit_cell(bool inside, double prec, FunctionTree< * no coefs). * */ -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter, bool absPrec) { +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter, bool absPrec) { Timer pre_t; oper.calcBandWidths(prec); int maxScale = out.getMRA().getMaxScale(); @@ -183,13 +227,13 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat return 1.0 / maxNorm; }; - WaveletAdaptor adaptor(prec, maxScale, absPrec); + WaveletAdaptor adaptor(prec, maxScale, absPrec); adaptor.setPrecFunction(precFunc); - ConvolutionCalculator calculator(prec, oper, inp); + ConvolutionCalculator calculator(prec, oper, inp); calculator.setPrecFunction(precFunc); pre_t.stop(); - TreeBuilder builder; + TreeBuilder builder; builder.build(out, calculator, adaptor, maxIter); Timer post_t; @@ -205,6 +249,24 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat print::separator(10, ' '); } +template +void apply(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, FunctionTreeVector *precTrees, ComplexDouble metric[4][4], int maxIter, bool absPrec) { + + for (int icomp = 0; icomp < inp.Ncomp(); icomp++) { + for (int ocomp = 0; ocomp < 4; ocomp++) { + if (std::norm(metric[icomp][ocomp]) > MachinePrec) { + if (inp.isreal()) { + apply(prec, *out.CompD[ocomp], oper, *inp.CompD[icomp], precTrees[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompD[ocomp]->rescale(metric[icomp][ocomp]); } + } else { + apply(prec, *out.CompC[ocomp], oper, *inp.CompC[icomp], precTrees[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompC[ocomp]->rescale(metric[icomp][ocomp]); } + } + } + } + } +} + /** @brief Application of MW integral convolution operator on a periodic cell, excluding contributions inside the unit cell. * @@ -227,10 +289,29 @@ template void apply(double prec, FunctionTree &out, ConvolutionOperat * no coefs). * */ -template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { apply_on_unit_cell(false, prec, out, oper, inp, maxIter, absPrec); } +template void apply_far_field(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, ComplexDouble metric[4][4], int maxIter, bool absPrec) { + + for (int icomp = 0; icomp < 4; icomp++) { + if (inp.Comp[icomp] != nullptr) { + for (int ocomp = 0; ocomp < 4; ocomp++) { + if (std::norm(metric[icomp][ocomp]) > MachinePrec) { + if (inp.isreal()) { + apply_on_unit_cell(false, prec, *out.CompD[ocomp], oper, *inp.CompD[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompD[ocomp]->rescale(metric[icomp][ocomp]); } + } else { + apply_on_unit_cell(false, prec, *out.CompC[ocomp], oper, *inp.CompC[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompC[ocomp]->rescale(metric[icomp][ocomp]); } + } + } + } + } + } +} + /** @brief Application of MW integral convolution operator on a periodic cell, excluding contributions outside the unit cell. * @@ -253,10 +334,29 @@ template void apply_far_field(double prec, FunctionTree &out, Convolu * no coefs). * */ -template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { +template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter, bool absPrec) { apply_on_unit_cell(true, prec, out, oper, inp, maxIter, absPrec); } +template void apply_near_field(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, ComplexDouble metric[4][4], int maxIter, bool absPrec) { + + for (int icomp = 0; icomp < 4; icomp++) { + if (inp.Comp[icomp] != nullptr) { + for (int ocomp = 0; ocomp < 4; ocomp++) { + if (std::norm(metric[icomp][ocomp]) > MachinePrec) { + if (inp.isreal()) { + apply_on_unit_cell(true, prec, *out.CompD[ocomp], oper, *inp.CompD[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompD[ocomp]->rescale(metric[icomp][ocomp]); } + } else { + apply_on_unit_cell(true, prec, *out.CompC[ocomp], oper, *inp.CompC[icomp], maxIter, absPrec); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompC[ocomp]->rescale(metric[icomp][ocomp]); } + } + } + } + } + } +} + /** @brief Application of MW derivative operator * * @param[out] out: Output function to be built @@ -273,9 +373,9 @@ template void apply_near_field(double prec, FunctionTree &out, Convol * @note The output function should contain only empty root nodes at entry. * */ -template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir) { +template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); - TreeBuilder builder; + TreeBuilder builder; int maxScale = out.getMRA().getMaxScale(); int bw[D]; // Operator bandwidth in [x,y,z] @@ -285,14 +385,14 @@ template void apply(FunctionTree &out, DerivativeOperator &oper, F Timer pre_t; oper.calcBandWidths(1.0); // Fixed 0 or 1 for derivatives bw[dir] = oper.getMaxBandWidth(); - CopyAdaptor pre_adaptor(inp, maxScale, bw); - DefaultCalculator pre_calculator; + CopyAdaptor pre_adaptor(inp, maxScale, bw); + DefaultCalculator pre_calculator; builder.build(out, pre_calculator, pre_adaptor, -1); pre_t.stop(); // Apply operator on fixed expanded grid - SplitAdaptor apply_adaptor(maxScale, false); // Splits no nodes - DerivativeCalculator apply_calculator(dir, oper, inp); + SplitAdaptor apply_adaptor(maxScale, false); // Splits no nodes + DerivativeCalculator apply_calculator(dir, oper, inp); builder.build(out, apply_calculator, apply_adaptor, 0); if (out.isPeriodic()) out.rescale(std::pow(2.0, -oper.getOperatorRoot())); @@ -308,6 +408,37 @@ template void apply(FunctionTree &out, DerivativeOperator &oper, F print::separator(10, ' '); } +template void apply(CompFunction &out, DerivativeOperator &oper, CompFunction &inp, int dir, ComplexDouble metric[4][4]) { + // TODO: sums and not only each components independently, when concrete examples with non diagonal metric are tested + + for (int icomp = 0; icomp < inp.Ncomp(); icomp++) { + for (int ocomp = 0; ocomp < 4; ocomp++) { + if (std::norm(metric[icomp][ocomp]) > MachinePrec) { + if (inp.isreal() and (std::imag(metric[icomp][ocomp]) < MachinePrec or inp.Ncomp() == 1)) { + apply(*out.CompD[ocomp], oper, *inp.CompD[icomp], dir); + if (std::norm(metric[icomp][ocomp] - 1.0) > MachinePrec) { + if (std::imag(metric[icomp][ocomp]) < MachinePrec) + out.CompD[ocomp]->rescale(std::real(metric[icomp][ocomp])); + else + out.func_ptr->data.c1[ocomp] *= metric[icomp][ocomp]; // To consider: multiply c1 in rescale? + } + out.func_ptr->isreal = 1; + } else { + if (inp.isreal()) { + apply(*out.CompD[ocomp], oper, *inp.CompD[icomp], dir); + out.CompD[icomp]->CopyTreeToComplex(out.CompC[ocomp]); + out.func_ptr->isreal = 0; + out.func_ptr->iscomplex = 1; + } else { + apply(*out.CompC[ocomp], oper, *inp.CompC[icomp], dir); + } + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { out.CompC[ocomp]->rescale(metric[icomp][ocomp]); } + } + } + } + } +} + /** @brief Calculation of gradient vector of a function * * @param[in] oper: Derivative operator to apply @@ -320,16 +451,46 @@ template void apply(FunctionTree &out, DerivativeOperator &oper, F * @note The length of the output vector will be the template dimension D. * */ -template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp) { - FunctionTreeVector out; +template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp) { + FunctionTreeVector out; for (int d = 0; d < D; d++) { - auto *grad_d = new FunctionTree(inp.getMRA()); + auto *grad_d = new FunctionTree(inp.getMRA()); apply(*grad_d, oper, inp, d); out.push_back({1.0, grad_d}); } return out; } +std::vector *> gradient(DerivativeOperator<3> &oper, CompFunction<3> &inp, ComplexDouble metric[4][4]) { + std::vector *> out; + + for (int d = 0; d < 3; d++) { + CompFunction<3> *grad_d = new CompFunction<3>(); + for (int icomp = 0; icomp < inp.Ncomp(); icomp++) { + for (int ocomp = 0; ocomp < 4; ocomp++) { + if (std::norm(metric[icomp][ocomp]) > MachinePrec) { + grad_d->func_ptr->Ncomp = ocomp; + if (inp.isreal()) { + grad_d->func_ptr->isreal = 1; + grad_d->func_ptr->iscomplex = 0; + grad_d->CompD[ocomp] = new FunctionTree<3, double>(inp.CompD[0]->getMRA()); + apply(*(grad_d->CompD[ocomp]), oper, *inp.CompD[icomp], d); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { grad_d->CompD[ocomp]->rescale((metric[icomp][ocomp]).real()); } + } else { + grad_d->func_ptr->isreal = 0; + grad_d->func_ptr->iscomplex = 1; + grad_d->CompC[ocomp] = new FunctionTree<3, ComplexDouble>(inp.CompC[0]->getMRA()); + apply(*(grad_d->CompC[ocomp]), oper, *inp.CompC[icomp], d); + if (abs(metric[icomp][ocomp] - 1.0) > MachinePrec) { grad_d->CompC[ocomp]->rescale(metric[icomp][ocomp]); } + } + } + } + } + out.push_back(grad_d); + } + return out; +} + /** @brief Calculation of divergence of a function vector * * @param[out] out: Output function @@ -346,16 +507,16 @@ template FunctionTreeVector gradient(DerivativeOperator &oper, Fun * - The output function should contain only empty root nodes at entry. * */ -template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp) { +template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp) { if (inp.size() != D) MSG_ABORT("Dimension mismatch"); for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); - FunctionTreeVector tmp_vec; + FunctionTreeVector tmp_vec; for (int d = 0; d < D; d++) { - double coef_d = get_coef(inp, d); - FunctionTree &func_d = get_func(inp, d); - auto *out_d = new FunctionTree(func_d.getMRA()); + T coef_d = get_coef(inp, d); + FunctionTree &func_d = get_func(inp, d); + auto *out_d = new FunctionTree(func_d.getMRA()); apply(*out_d, oper, func_d, d); tmp_vec.push_back(std::make_tuple(coef_d, out_d)); } @@ -364,35 +525,94 @@ template void divergence(FunctionTree &out, DerivativeOperator &op clear(tmp_vec, true); } -template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp) { - FunctionTreeVector inp_vec; +template void divergence(CompFunction &out, DerivativeOperator &oper, FunctionTreeVector *inp, ComplexDouble metric[4][4]) { + MSG_ABORT("not implemented"); +} + +template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp) { + FunctionTreeVector inp_vec; for (auto &t : inp) inp_vec.push_back({1.0, t}); divergence(out, oper, inp_vec); } +template void divergence(CompFunction &out, DerivativeOperator &oper, std::vector *> *inp, ComplexDouble metric[4][4]) { + MSG_ABORT("not implemented"); +} -template void apply<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, int maxIter, bool absPrec); -template void apply<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, int maxIter, bool absPrec); -template void apply<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, int maxIter, bool absPrec); -template void apply<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, FunctionTreeVector<1> &precTrees, int maxIter, bool absPrec); -template void apply<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, FunctionTreeVector<2> &precTrees, int maxIter, bool absPrec); -template void apply<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, FunctionTreeVector<3> &precTrees, int maxIter, bool absPrec); -template void apply_far_field<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, int maxIter, bool absPrec); -template void apply_far_field<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, int maxIter, bool absPrec); -template void apply_far_field<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, int maxIter, bool absPrec); -template void apply_near_field<1>(double prec, FunctionTree<1> &out, ConvolutionOperator<1> &oper, FunctionTree<1> &inp, int maxIter, bool absPrec); -template void apply_near_field<2>(double prec, FunctionTree<2> &out, ConvolutionOperator<2> &oper, FunctionTree<2> &inp, int maxIter, bool absPrec); -template void apply_near_field<3>(double prec, FunctionTree<3> &out, ConvolutionOperator<3> &oper, FunctionTree<3> &inp, int maxIter, bool absPrec); -template void apply<1>(FunctionTree<1> &out, DerivativeOperator<1> &oper, FunctionTree<1> &inp, int dir); -template void apply<2>(FunctionTree<2> &out, DerivativeOperator<2> &oper, FunctionTree<2> &inp, int dir); -template void apply<3>(FunctionTree<3> &out, DerivativeOperator<3> &oper, FunctionTree<3> &inp, int dir); -template void divergence<1>(FunctionTree<1> &out, DerivativeOperator<1> &oper, FunctionTreeVector<1> &inp); -template void divergence<2>(FunctionTree<2> &out, DerivativeOperator<2> &oper, FunctionTreeVector<2> &inp); -template void divergence<3>(FunctionTree<3> &out, DerivativeOperator<3> &oper, FunctionTreeVector<3> &inp); -template void divergence<1>(FunctionTree<1> &out, DerivativeOperator<1> &oper, std::vector *> &inp); -template void divergence<2>(FunctionTree<2> &out, DerivativeOperator<2> &oper, std::vector *> &inp); -template void divergence<3>(FunctionTree<3> &out, DerivativeOperator<3> &oper, std::vector *> &inp); -template FunctionTreeVector<1> gradient<1>(DerivativeOperator<1> &oper, FunctionTree<1> &inp); -template FunctionTreeVector<2> gradient<2>(DerivativeOperator<2> &oper, FunctionTree<2> &inp); -template FunctionTreeVector<3> gradient<3>(DerivativeOperator<3> &oper, FunctionTree<3> &inp); +template void apply<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, int maxIter, bool absPrec); +template void apply<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, int maxIter, bool absPrec); +template void apply<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, int maxIter, bool absPrec); +template void apply<1>(double prec, CompFunction<1> &out, ConvolutionOperator<1> &oper, const CompFunction<1> &inp, ComplexDouble metric[4][4] = nullptr, int maxIter = -1, bool absPrec = false); +template void apply<2>(double prec, CompFunction<2> &out, ConvolutionOperator<2> &oper, const CompFunction<2> &inp, ComplexDouble metric[4][4] = nullptr, int maxIter = -1, bool absPrec = false); +template void apply<3>(double prec, CompFunction<3> &out, ConvolutionOperator<3> &oper, const CompFunction<3> &inp, ComplexDouble metric[4][4] = nullptr, int maxIter = -1, bool absPrec = false); +template void +apply<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, FunctionTreeVector<1, double> &precTrees, int maxIter, bool absPrec); +template void +apply<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, FunctionTreeVector<2, double> &precTrees, int maxIter, bool absPrec); +template void +apply<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, FunctionTreeVector<3, double> &precTrees, int maxIter, bool absPrec); +template void apply_far_field<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, int maxIter, bool absPrec); +template void apply_far_field<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, int maxIter, bool absPrec); +template void apply_far_field<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, int maxIter, bool absPrec); +template void apply_near_field<1, double>(double prec, FunctionTree<1, double> &out, ConvolutionOperator<1> &oper, FunctionTree<1, double> &inp, int maxIter, bool absPrec); +template void apply_near_field<2, double>(double prec, FunctionTree<2, double> &out, ConvolutionOperator<2> &oper, FunctionTree<2, double> &inp, int maxIter, bool absPrec); +template void apply_near_field<3, double>(double prec, FunctionTree<3, double> &out, ConvolutionOperator<3> &oper, FunctionTree<3, double> &inp, int maxIter, bool absPrec); +template void apply<1, double>(FunctionTree<1, double> &out, DerivativeOperator<1> &oper, FunctionTree<1, double> &inp, int dir); +template void apply<2, double>(FunctionTree<2, double> &out, DerivativeOperator<2> &oper, FunctionTree<2, double> &inp, int dir); +template void apply<3, double>(FunctionTree<3, double> &out, DerivativeOperator<3> &oper, FunctionTree<3, double> &inp, int dir); +template void divergence<1, double>(FunctionTree<1, double> &out, DerivativeOperator<1> &oper, FunctionTreeVector<1, double> &inp); +template void divergence<2, double>(FunctionTree<2, double> &out, DerivativeOperator<2> &oper, FunctionTreeVector<2, double> &inp); +template void divergence<3, double>(FunctionTree<3, double> &out, DerivativeOperator<3> &oper, FunctionTreeVector<3, double> &inp); +template void divergence<1, double>(FunctionTree<1, double> &out, DerivativeOperator<1> &oper, std::vector *> &inp); +template void divergence<2, double>(FunctionTree<2, double> &out, DerivativeOperator<2> &oper, std::vector *> &inp); +template void divergence<3, double>(FunctionTree<3, double> &out, DerivativeOperator<3> &oper, std::vector *> &inp); +template FunctionTreeVector<1, double> gradient<1>(DerivativeOperator<1> &oper, FunctionTree<1, double> &inp); +template FunctionTreeVector<2, double> gradient<2>(DerivativeOperator<2> &oper, FunctionTree<2, double> &inp); +template FunctionTreeVector<3, double> gradient<3>(DerivativeOperator<3> &oper, FunctionTree<3, double> &inp); + +template void apply<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int maxIter, bool absPrec); + +template void apply<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + ConvolutionOperator<1> &oper, + FunctionTree<1, ComplexDouble> &inp, + FunctionTreeVector<1, ComplexDouble> &precTrees, + int maxIter, + bool absPrec); +template void apply<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + ConvolutionOperator<2> &oper, + FunctionTree<2, ComplexDouble> &inp, + FunctionTreeVector<2, ComplexDouble> &precTrees, + int maxIter, + bool absPrec); +template void apply<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + ConvolutionOperator<3> &oper, + FunctionTree<3, ComplexDouble> &inp, + FunctionTreeVector<3, ComplexDouble> &precTrees, + int maxIter, + bool absPrec); +template void apply_far_field<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_far_field<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_far_field<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_near_field<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, ConvolutionOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_near_field<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, ConvolutionOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply_near_field<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, ConvolutionOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int maxIter, bool absPrec); +template void apply<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, DerivativeOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp, int dir); +template void apply<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, DerivativeOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp, int dir); +template void apply<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, DerivativeOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp, int dir); +template void divergence<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, DerivativeOperator<1> &oper, FunctionTreeVector<1, ComplexDouble> &inp); +template void divergence<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, DerivativeOperator<2> &oper, FunctionTreeVector<2, ComplexDouble> &inp); +template void divergence<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, DerivativeOperator<3> &oper, FunctionTreeVector<3, ComplexDouble> &inp); +template void divergence<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, DerivativeOperator<1> &oper, std::vector *> &inp); +template void divergence<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, DerivativeOperator<2> &oper, std::vector *> &inp); +template void divergence<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, DerivativeOperator<3> &oper, std::vector *> &inp); +template FunctionTreeVector<1, ComplexDouble> gradient<1>(DerivativeOperator<1> &oper, FunctionTree<1, ComplexDouble> &inp); +template FunctionTreeVector<2, ComplexDouble> gradient<2>(DerivativeOperator<2> &oper, FunctionTree<2, ComplexDouble> &inp); +template FunctionTreeVector<3, ComplexDouble> gradient<3>(DerivativeOperator<3> &oper, FunctionTree<3, ComplexDouble> &inp); + +template void apply(CompFunction<3> &out, DerivativeOperator<3> &oper, CompFunction<3> &inp, int dir = -1, ComplexDouble metric[4][4] = nullptr); } // namespace mrcpp diff --git a/src/treebuilders/apply.h b/src/treebuilders/apply.h index ae38e96ad..3bc9c8267 100644 --- a/src/treebuilders/apply.h +++ b/src/treebuilders/apply.h @@ -26,22 +26,35 @@ #pragma once #include "trees/FunctionTreeVector.h" +#include "utils/CompFunction.h" namespace mrcpp { // clang-format off -template class FunctionTree; + +template class FunctionTree; template class DerivativeOperator; template class ConvolutionOperator; -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter = -1, bool absPrec = false); -template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir = -1); -template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp); -template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp); -template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp); +const ComplexDouble defaultMetric [4][4] ={{1,0,0,0},{0,1,0,0},{0,0,1,0},{0,0,0,1}}; + +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); +template void apply(double prec, CompFunction &out, ConvolutionOperator &oper, const CompFunction &inp, ComplexDouble metric[4][4] = defaultMetric, int maxIter = -1, bool absPrec = false); +template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter = -1, bool absPrec = false); +template void apply(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, FunctionTreeVector *precTrees, ComplexDouble metric[4][4] = nullptr, int maxIter = -1, bool absPrec = false); +template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); +template void apply_far_field(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, ComplexDouble metric[4][4] = nullptr, int maxIter = -1, bool absPrec = false); +template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); +template void apply_near_field(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, ComplexDouble metric[4][4] = nullptr, int maxIter = -1, bool absPrec = false); +template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir = -1); +template void apply(CompFunction &out, DerivativeOperator &oper, CompFunction &inp, int dir = -1, ComplexDouble metric[4][4] = nullptr); +template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp); +template void divergence(CompFunction &out, DerivativeOperator &oper, FunctionTreeVector *inp, ComplexDouble metric[4][4]); +template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp); +template void divergence(CompFunction &out, DerivativeOperator &oper, std::vector *> *inp, ComplexDouble metric[4][4] = nullptr); +template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp); +// template +std::vector*> gradient(DerivativeOperator<3> &oper, CompFunction<3> &inp, ComplexDouble metric[4][4] = nullptr); // clang-format on } // namespace mrcpp diff --git a/src/treebuilders/complex_apply.cpp b/src/treebuilders/complex_apply.cpp index ab410244e..5cf0e3b08 100644 --- a/src/treebuilders/complex_apply.cpp +++ b/src/treebuilders/complex_apply.cpp @@ -24,7 +24,6 @@ */ #include "complex_apply.h" -#include "apply.h" #include "ConvolutionCalculator.h" #include "CopyAdaptor.h" #include "DefaultCalculator.h" @@ -33,6 +32,7 @@ #include "TreeBuilder.h" #include "WaveletAdaptor.h" #include "add.h" +#include "apply.h" #include "grid.h" #include "operators/ConvolutionOperator.h" #include "operators/DerivativeOperator.h" @@ -42,7 +42,6 @@ namespace mrcpp { - /** @brief Application of MW integral convolution operator (complex version) * * @param[in] prec: Build precision of output function @@ -58,7 +57,7 @@ namespace mrcpp { * - Repeat until convergence or `maxIter` is reached * - `prec < 0` or `maxIter = 0` means NO refinement * - `maxIter < 0` means no bound - * + * * The default is to work with relative precision * (stop when the wavelet coefficients are below a given (small) fraction of * function norm. @@ -74,43 +73,25 @@ namespace mrcpp { * tree when the function is called (this grid should however be EMPTY, e.i. * no coefs). * \todo !!! Here should be given a method for greed cleaning !!! - * + * * */ -template -void apply -( - double prec, ComplexObject< FunctionTree > &out, - ComplexObject< ConvolutionOperator > &oper, ComplexObject< FunctionTree > &inp, - int maxIter, bool absPrec -) -{ - FunctionTree temp1( inp.real->getMRA() ); - FunctionTree temp2( inp.real->getMRA() ); +template void apply(double prec, ComplexObject> &out, ComplexObject> &oper, ComplexObject> &inp, int maxIter, bool absPrec) { + FunctionTree temp1(inp.real->getMRA()); + FunctionTree temp2(inp.real->getMRA()); apply(prec, temp1, *oper.real, *inp.real, maxIter, absPrec); apply(prec, temp2, *oper.imaginary, *inp.imaginary, maxIter, absPrec); add(prec, *out.real, 1.0, temp1, -1.0, temp2); - //temp1.setZero(); - //temp2.setZero(); + // temp1.setZero(); + // temp2.setZero(); apply(prec, temp1, *oper.imaginary, *inp.real, maxIter, absPrec); apply(prec, temp2, *oper.real, *inp.imaginary, maxIter, absPrec); add(prec, *out.imaginary, 1.0, temp1, 1.0, temp2); } - - - - -template -void apply <1> -( - double prec, ComplexObject< FunctionTree<1> > &out, - ComplexObject< ConvolutionOperator<1> > &oper, ComplexObject< FunctionTree<1> > &inp, - int maxIter, bool absPrec -); - +template void apply<1>(double prec, ComplexObject> &out, ComplexObject> &oper, ComplexObject> &inp, int maxIter, bool absPrec); } // namespace mrcpp diff --git a/src/treebuilders/complex_apply.h b/src/treebuilders/complex_apply.h index 88aa96ee5..8ed9a0f17 100644 --- a/src/treebuilders/complex_apply.h +++ b/src/treebuilders/complex_apply.h @@ -30,15 +30,14 @@ namespace mrcpp { /// @brief Stores pointers to real and imaginary parts of tree objects. -/// @tparam MWClass -template -struct ComplexObject -{ - MWClass* real; - MWClass* imaginary; - - ComplexObject(MWClass& realPart, MWClass& imaginaryPart) - : real(&realPart), imaginary(&imaginaryPart) {} +/// @tparam MWClass +template struct ComplexObject { + MWClass *real; + MWClass *imaginary; + + ComplexObject(MWClass &realPart, MWClass &imaginaryPart) + : real(&realPart) + , imaginary(&imaginaryPart) {} }; // clang-format off @@ -54,5 +53,4 @@ void apply ); // clang-format on - } // namespace mrcpp diff --git a/src/treebuilders/grid.cpp b/src/treebuilders/grid.cpp index fb9e65a91..0e7fb968b 100644 --- a/src/treebuilders/grid.cpp +++ b/src/treebuilders/grid.cpp @@ -48,11 +48,11 @@ namespace mrcpp { * @note This algorithm will start at whatever grid is present in the `out` * tree when the function is called. */ -template void build_grid(FunctionTree &out, int scales) { +template void build_grid(FunctionTree &out, int scales) { auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - DefaultCalculator calculator; - SplitAdaptor adaptor(maxScale, true); // Splits all nodes + TreeBuilder builder; + DefaultCalculator calculator; + SplitAdaptor adaptor(maxScale, true); // Splits all nodes for (auto n = 0; n < scales; n++) builder.build(out, calculator, adaptor, 1); } @@ -75,11 +75,11 @@ template void build_grid(FunctionTree &out, int scales) { * particular `RepresentableFunction`. * */ -template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter) { +template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter) { auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - AnalyticAdaptor adaptor(inp, maxScale); - DefaultCalculator calculator; + TreeBuilder builder; + AnalyticAdaptor adaptor(inp, maxScale); + DefaultCalculator calculator; builder.build(out, calculator, adaptor, maxIter); print::separator(10, ' '); } @@ -142,12 +142,12 @@ template void build_grid(FunctionTree &out, const GaussExp &inp, i * but NOT vice versa. * */ -template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter) { +template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - CopyAdaptor adaptor(inp, maxScale, nullptr); - DefaultCalculator calculator; + TreeBuilder builder; + CopyAdaptor adaptor(inp, maxScale, nullptr); + DefaultCalculator calculator; builder.build(out, calculator, adaptor, maxIter); print::separator(10, ' '); } @@ -171,20 +171,20 @@ template void build_grid(FunctionTree &out, FunctionTree &inp, int * `maxIter` is reached). * */ -template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter) { +template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter) { for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - CopyAdaptor adaptor(inp, maxScale, nullptr); - DefaultCalculator calculator; + TreeBuilder builder; + CopyAdaptor adaptor(inp, maxScale, nullptr); + DefaultCalculator calculator; builder.build(out, calculator, adaptor, maxIter); print::separator(10, ' '); } -template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter) { - FunctionTreeVector inp_vec; +template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter) { + FunctionTreeVector inp_vec; for (auto *t : inp) inp_vec.push_back({1.0, t}); build_grid(out, inp_vec, maxIter); } @@ -202,8 +202,8 @@ template void build_grid(FunctionTree &out, std::vector void copy_func(FunctionTree &out, FunctionTree &inp) { - FunctionTreeVector tmp_vec; +template void copy_func(FunctionTree &out, FunctionTree &inp) { + FunctionTreeVector tmp_vec; tmp_vec.push_back(std::make_tuple(1.0, &inp)); add(-1.0, out, tmp_vec); } @@ -218,12 +218,32 @@ template void copy_func(FunctionTree &out, FunctionTree &inp) { * will _extend_ the existing grid. * */ -template void copy_grid(FunctionTree &out, FunctionTree &inp) { +template void copy_grid(FunctionTree &out, FunctionTree &inp) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA") out.clear(); build_grid(out, inp); } +/** @brief Build empty grid that is identical to another MW grid for every component + * + * @param[out] out: Output to be built + * @param[in] inp: Input + * + * @note The difference from the corresponding `build_grid` function is that + * this will first clear the grid of the `out` function, while `build_grid` + * will _extend_ the existing grid. + * + */ +template void copy_grid(CompFunction &out, CompFunction &inp) { + out.free(); + out.func_ptr->data = inp.func_ptr->data; + out.alloc(inp.Ncomp()); + for (int i = 0; i < inp.Ncomp(); i++) { + if (inp.isreal()) build_grid(*out.CompD[i], *inp.CompD[i]); + if (inp.iscomplex()) build_grid(*out.CompC[i], *inp.CompC[i]); + } +} + /** @brief Clear the MW coefficients of a function representation * * @param[in,out] out: Output function to be cleared @@ -233,9 +253,9 @@ template void copy_grid(FunctionTree &out, FunctionTree &inp) { * grid refinement as well. * */ -template void clear_grid(FunctionTree &out) { - TreeBuilder builder; - DefaultCalculator calculator; +template void clear_grid(FunctionTree &out) { + TreeBuilder builder; + DefaultCalculator calculator; builder.clear(out, calculator); } @@ -250,11 +270,11 @@ template void clear_grid(FunctionTree &out) { * the function representation unchanged, but on a larger grid. * */ -template int refine_grid(FunctionTree &out, int scales) { +template int refine_grid(FunctionTree &out, int scales) { auto nSplit = 0; auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - SplitAdaptor adaptor(maxScale, true); // Splits all nodes + TreeBuilder builder; + SplitAdaptor adaptor(maxScale, true); // Splits all nodes for (auto n = 0; n < scales; n++) { nSplit += builder.split(out, adaptor, true); // Transfers coefs to children } @@ -274,10 +294,10 @@ template int refine_grid(FunctionTree &out, int scales) { * unchanged, but (possibly) on a larger grid. * */ -template int refine_grid(FunctionTree &out, double prec, bool absPrec) { +template int refine_grid(FunctionTree &out, double prec, bool absPrec) { int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); int nSplit = builder.split(out, adaptor, true); return nSplit; } @@ -294,11 +314,11 @@ template int refine_grid(FunctionTree &out, double prec, bool absPrec * leaving the function representation unchanged, but on a larger grid. * */ -template int refine_grid(FunctionTree &out, FunctionTree &inp) { +template int refine_grid(FunctionTree &out, FunctionTree &inp) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA") auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - CopyAdaptor adaptor(inp, maxScale, nullptr); + TreeBuilder builder; + CopyAdaptor adaptor(inp, maxScale, nullptr); auto nSplit = builder.split(out, adaptor, true); return nSplit; } @@ -316,52 +336,93 @@ template int refine_grid(FunctionTree &out, FunctionTree &inp) { * is implemented in the particular `RepresentableFunction`. * */ -template int refine_grid(FunctionTree &out, const RepresentableFunction &inp) { +template int refine_grid(FunctionTree &out, const RepresentableFunction &inp) { auto maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - AnalyticAdaptor adaptor(inp, maxScale); + TreeBuilder builder; + AnalyticAdaptor adaptor(inp, maxScale); int nSplit = builder.split(out, adaptor, true); return nSplit; } -template void build_grid<1>(FunctionTree<1> &out, int scales); -template void build_grid<2>(FunctionTree<2> &out, int scales); -template void build_grid<3>(FunctionTree<3> &out, int scales); +template void copy_grid(CompFunction<1> &out, CompFunction<1> &inp); +template void copy_grid(CompFunction<2> &out, CompFunction<2> &inp); +template void copy_grid(CompFunction<3> &out, CompFunction<3> &inp); + +template void build_grid<1, double>(FunctionTree<1, double> &out, int scales); +template void build_grid<2, double>(FunctionTree<2, double> &out, int scales); +template void build_grid<3, double>(FunctionTree<3, double> &out, int scales); template void build_grid<1>(FunctionTree<1> &out, const GaussExp<1> &inp, int maxIter); template void build_grid<2>(FunctionTree<2> &out, const GaussExp<2> &inp, int maxIter); template void build_grid<3>(FunctionTree<3> &out, const GaussExp<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, const RepresentableFunction<1> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, const RepresentableFunction<2> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, const RepresentableFunction<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, FunctionTree<1> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, FunctionTree<2> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, FunctionTree<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, FunctionTreeVector<1> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, FunctionTreeVector<2> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, FunctionTreeVector<3> &inp, int maxIter); -template void build_grid<1>(FunctionTree<1> &out, std::vector *> &inp, int maxIter); -template void build_grid<2>(FunctionTree<2> &out, std::vector *> &inp, int maxIter); -template void build_grid<3>(FunctionTree<3> &out, std::vector *> &inp, int maxIter); -template void copy_func<1>(FunctionTree<1> &out, FunctionTree<1> &inp); -template void copy_func<2>(FunctionTree<2> &out, FunctionTree<2> &inp); -template void copy_func<3>(FunctionTree<3> &out, FunctionTree<3> &inp); -template void copy_grid<1>(FunctionTree<1> &out, FunctionTree<1> &inp); -template void copy_grid<2>(FunctionTree<2> &out, FunctionTree<2> &inp); -template void copy_grid<3>(FunctionTree<3> &out, FunctionTree<3> &inp); -template void clear_grid<1>(FunctionTree<1> &out); -template void clear_grid<2>(FunctionTree<2> &out); -template void clear_grid<3>(FunctionTree<3> &out); -template int refine_grid<1>(FunctionTree<1> &out, int scales); -template int refine_grid<2>(FunctionTree<2> &out, int scales); -template int refine_grid<3>(FunctionTree<3> &out, int scales); -template int refine_grid<1>(FunctionTree<1> &out, double prec, bool absPrec); -template int refine_grid<2>(FunctionTree<2> &out, double prec, bool absPrec); -template int refine_grid<3>(FunctionTree<3> &out, double prec, bool absPrec); -template int refine_grid<1>(FunctionTree<1> &out, FunctionTree<1> &inp); -template int refine_grid<2>(FunctionTree<2> &out, FunctionTree<2> &inp); -template int refine_grid<3>(FunctionTree<3> &out, FunctionTree<3> &inp); -template int refine_grid<1>(FunctionTree<1> &out, const RepresentableFunction<1> &inp); -template int refine_grid<2>(FunctionTree<2> &out, const RepresentableFunction<2> &inp); -template int refine_grid<3>(FunctionTree<3> &out, const RepresentableFunction<3> &inp); +template void build_grid<1, double>(FunctionTree<1, double> &out, const RepresentableFunction<1, double> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, const RepresentableFunction<2, double> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, const RepresentableFunction<3, double> &inp, int maxIter); +template void build_grid<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp, int maxIter); +template void build_grid<1, double>(FunctionTree<1, double> &out, FunctionTreeVector<1, double> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, FunctionTreeVector<2, double> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, FunctionTreeVector<3, double> &inp, int maxIter); +template void build_grid<1, double>(FunctionTree<1, double> &out, std::vector *> &inp, int maxIter); +template void build_grid<2, double>(FunctionTree<2, double> &out, std::vector *> &inp, int maxIter); +template void build_grid<3, double>(FunctionTree<3, double> &out, std::vector *> &inp, int maxIter); +template void copy_func<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp); +template void copy_func<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp); +template void copy_func<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp); +template void copy_grid<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp); +template void copy_grid<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp); +template void copy_grid<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp); +template void clear_grid<1, double>(FunctionTree<1, double> &out); +template void clear_grid<2, double>(FunctionTree<2, double> &out); +template void clear_grid<3, double>(FunctionTree<3, double> &out); +template int refine_grid<1, double>(FunctionTree<1, double> &out, int scales); +template int refine_grid<2, double>(FunctionTree<2, double> &out, int scales); +template int refine_grid<3, double>(FunctionTree<3, double> &out, int scales); +template int refine_grid<1, double>(FunctionTree<1, double> &out, double prec, bool absPrec); +template int refine_grid<2, double>(FunctionTree<2, double> &out, double prec, bool absPrec); +template int refine_grid<3, double>(FunctionTree<3, double> &out, double prec, bool absPrec); +template int refine_grid<1, double>(FunctionTree<1, double> &out, FunctionTree<1, double> &inp); +template int refine_grid<2, double>(FunctionTree<2, double> &out, FunctionTree<2, double> &inp); +template int refine_grid<3, double>(FunctionTree<3, double> &out, FunctionTree<3, double> &inp); +template int refine_grid<1, double>(FunctionTree<1, double> &out, const RepresentableFunction<1, double> &inp); +template int refine_grid<2, double>(FunctionTree<2, double> &out, const RepresentableFunction<2, double> &inp); +template int refine_grid<3, double>(FunctionTree<3, double> &out, const RepresentableFunction<3, double> &inp); + +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, int scales); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, int scales); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, int scales); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, const RepresentableFunction<1, ComplexDouble> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, const RepresentableFunction<2, ComplexDouble> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, const RepresentableFunction<3, ComplexDouble> &inp, int maxIter); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp, int maxIter); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTreeVector<1, ComplexDouble> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTreeVector<2, ComplexDouble> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTreeVector<3, ComplexDouble> &inp, int maxIter); +template void build_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, std::vector *> &inp, int maxIter); +template void build_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, std::vector *> &inp, int maxIter); +template void build_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, std::vector *> &inp, int maxIter); +template void copy_func<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp); +template void copy_func<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp); +template void copy_func<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp); +template void copy_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp); +template void copy_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp); +template void copy_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp); +template void clear_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out); +template void clear_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out); +template void clear_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, int scales); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, int scales); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, int scales); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, double prec, bool absPrec); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, double prec, bool absPrec); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, double prec, bool absPrec); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &inp); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &inp); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &inp); +template int refine_grid<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &out, const RepresentableFunction<1, ComplexDouble> &inp); +template int refine_grid<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &out, const RepresentableFunction<2, ComplexDouble> &inp); +template int refine_grid<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &out, const RepresentableFunction<3, ComplexDouble> &inp); } // namespace mrcpp diff --git a/src/treebuilders/grid.h b/src/treebuilders/grid.h index 1f4c3e4f5..1d7021f8b 100644 --- a/src/treebuilders/grid.h +++ b/src/treebuilders/grid.h @@ -28,19 +28,21 @@ #include "functions/RepresentableFunction.h" #include "trees/FunctionTree.h" #include "trees/FunctionTreeVector.h" +#include "utils/CompFunction.h" namespace mrcpp { -template void build_grid(FunctionTree &out, int scales); +template void build_grid(FunctionTree &out, int scales); template void build_grid(FunctionTree &out, const GaussExp &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter = -1); -template void copy_func(FunctionTree &out, FunctionTree &inp); -template void copy_grid(FunctionTree &out, FunctionTree &inp); -template void clear_grid(FunctionTree &out); -template int refine_grid(FunctionTree &out, int scales); -template int refine_grid(FunctionTree &out, double prec, bool absPrec = false); -template int refine_grid(FunctionTree &out, FunctionTree &inp); -template int refine_grid(FunctionTree &out, const RepresentableFunction &inp); +template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter = -1); +template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter = -1); +template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1); +template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter = -1); +template void copy_func(FunctionTree &out, FunctionTree &inp); +template void copy_grid(FunctionTree &out, FunctionTree &inp); +template void copy_grid(CompFunction &out, CompFunction &inp); +template void clear_grid(FunctionTree &out); +template int refine_grid(FunctionTree &out, int scales); +template int refine_grid(FunctionTree &out, double prec, bool absPrec = false); +template int refine_grid(FunctionTree &out, FunctionTree &inp); +template int refine_grid(FunctionTree &out, const RepresentableFunction &inp); } // namespace mrcpp diff --git a/src/treebuilders/map.cpp b/src/treebuilders/map.cpp index 98824d002..b363bf806 100644 --- a/src/treebuilders/map.cpp +++ b/src/treebuilders/map.cpp @@ -65,13 +65,12 @@ namespace mrcpp { * no coefs). * */ -template -void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter, bool absPrec) { +template void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter, bool absPrec) { int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - MapCalculator calculator(fmap, inp); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + MapCalculator calculator(fmap, inp); builder.build(out, calculator, adaptor, maxIter); @@ -89,8 +88,8 @@ void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int print::separator(10, ' '); } -template void map<1>(double prec, FunctionTree<1> &out, FunctionTree<1> &inp, FMap fmap, int maxIter, bool absPrec); -template void map<2>(double prec, FunctionTree<2> &out, FunctionTree<2> &inp, FMap fmap, int maxIter, bool absPrec); -template void map<3>(double prec, FunctionTree<3> &out, FunctionTree<3> &inp, FMap fmap, int maxIter, bool absPrec); +template void map<1>(double prec, FunctionTree<1, double> &out, FunctionTree<1, double> &inp, FMap fmap, int maxIter, bool absPrec); +template void map<2>(double prec, FunctionTree<2, double> &out, FunctionTree<2, double> &inp, FMap fmap, int maxIter, bool absPrec); +template void map<3>(double prec, FunctionTree<3, double> &out, FunctionTree<3, double> &inp, FMap fmap, int maxIter, bool absPrec); } // Namespace mrcpp diff --git a/src/treebuilders/map.h b/src/treebuilders/map.h index 1c54dac32..d1f86e201 100644 --- a/src/treebuilders/map.h +++ b/src/treebuilders/map.h @@ -28,10 +28,8 @@ #include "trees/FunctionTreeVector.h" namespace mrcpp { -template class RepresentableFunction; -template class FunctionTree; +template class FunctionTree; -template -void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter = -1, bool absPrec = false); +template void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter = -1, bool absPrec = false); } // namespace mrcpp diff --git a/src/treebuilders/multiply.cpp b/src/treebuilders/multiply.cpp index a21e539ab..4e046126e 100644 --- a/src/treebuilders/multiply.cpp +++ b/src/treebuilders/multiply.cpp @@ -62,25 +62,19 @@ namespace mrcpp { * - Repeat until convergence or `maxIter` is reached * - `prec < 0` or `maxIter = 0` means NO refinement * - `maxIter < 0` means no bound + * - conjugate is applied on inp_b * * @note This algorithm will start at whatever grid is present in the `out` * tree when the function is called (this grid should however be EMPTY, e.i. * no coefs). * */ -template -void multiply(double prec, - FunctionTree &out, - double c, - FunctionTree &inp_a, - FunctionTree &inp_b, - int maxIter, - bool absPrec, - bool useMaxNorms) { - FunctionTreeVector tmp_vec; +template +void multiply(double prec, FunctionTree &out, T c, FunctionTree &inp_a, FunctionTree &inp_b, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate) { + FunctionTreeVector tmp_vec; tmp_vec.push_back({c, &inp_a}); tmp_vec.push_back({1.0, &inp_b}); - multiply(prec, out, tmp_vec, maxIter, absPrec, useMaxNorms); + multiply(prec, out, tmp_vec, maxIter, absPrec, useMaxNorms, conjugate); } /** @brief Multiplication of several MW function representations, adaptive grid @@ -100,32 +94,27 @@ void multiply(double prec, * - Repeat until convergence or `maxIter` is reached * - `prec < 0` or `maxIter = 0` means NO refinement * - `maxIter < 0` means no bound + * - conjugate is applied on all the trees in inp, except the first * * @note This algorithm will start at whatever grid is present in the `out` * tree when the function is called (this grid should however be EMPTY, e.i. * no coefs). * */ -template -void multiply(double prec, - FunctionTree &out, - FunctionTreeVector &inp, - int maxIter, - bool absPrec, - bool useMaxNorms) { +template void multiply(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate) { for (auto i = 0; i < inp.size(); i++) if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - MultiplicationCalculator calculator(inp); + TreeBuilder builder; + MultiplicationCalculator calculator(inp, conjugate); if (useMaxNorms) { for (int i = 0; i < inp.size(); i++) get_func(inp, i).makeMaxSquareNorms(); - MultiplicationAdaptor adaptor(prec, maxScale, inp); + MultiplicationAdaptor adaptor(prec, maxScale, inp); builder.build(out, calculator, adaptor, maxIter); } else { - WaveletAdaptor adaptor(prec, maxScale, absPrec); + WaveletAdaptor adaptor(prec, maxScale, absPrec); builder.build(out, calculator, adaptor, maxIter); } @@ -136,7 +125,7 @@ void multiply(double prec, Timer clean_t; for (int i = 0; i < inp.size(); i++) { - FunctionTree &tree = get_func(inp, i); + FunctionTree &tree = get_func(inp, i); tree.deleteGenerated(); } clean_t.stop(); @@ -146,16 +135,10 @@ void multiply(double prec, print::separator(10, ' '); } -template -void multiply(double prec, - FunctionTree &out, - std::vector *> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms) { - FunctionTreeVector inp_vec; +template void multiply(double prec, FunctionTree &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate) { + FunctionTreeVector inp_vec; for (auto &t : inp) inp_vec.push_back({1.0, t}); - multiply(prec, out, inp_vec, maxIter, absPrec, useMaxNorms); + multiply(prec, out, inp_vec, maxIter, absPrec, useMaxNorms, conjugate); } /** @brief Out-of-place square of MW function representations, adaptive grid @@ -179,13 +162,13 @@ void multiply(double prec, * no coefs). * */ -template void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter, bool absPrec) { +template void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter, bool absPrec, bool conjugate) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - SquareCalculator calculator(inp); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + SquareCalculator calculator(inp, conjugate); builder.build(out, calculator, adaptor, maxIter); @@ -225,14 +208,14 @@ template void square(double prec, FunctionTree &out, FunctionTree * no coefs). * */ -template -void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter, bool absPrec) { +template void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter, bool absPrec) { if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); + if (inp.conjugate()) MSG_ABORT("Not implemented"); int maxScale = out.getMRA().getMaxScale(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); - PowerCalculator calculator(inp, p); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); + PowerCalculator calculator(inp, p); builder.build(out, calculator, adaptor, maxIter); @@ -267,24 +250,19 @@ void power(double prec, FunctionTree &out, FunctionTree &inp, double p, in * @note The length of the input vectors must be the same. * */ -template -void dot(double prec, - FunctionTree &out, - FunctionTreeVector &inp_a, - FunctionTreeVector &inp_b, - int maxIter, - bool absPrec) { +template void dot(double prec, FunctionTree &out, FunctionTreeVector &inp_a, FunctionTreeVector &inp_b, int maxIter, bool absPrec) { if (inp_a.size() != inp_b.size()) MSG_ABORT("Input length mismatch"); - FunctionTreeVector tmp_vec; + FunctionTreeVector tmp_vec; for (int d = 0; d < inp_a.size(); d++) { - double coef_a = get_coef(inp_a, d); - double coef_b = get_coef(inp_b, d); - FunctionTree &tree_a = get_func(inp_a, d); - FunctionTree &tree_b = get_func(inp_b, d); - auto *out_d = new FunctionTree(out.getMRA()); + T coef_a = get_coef(inp_a, d); + T coef_b = get_coef(inp_b, d); + FunctionTree &tree_a = get_func(inp_a, d); + FunctionTree &tree_b = get_func(inp_b, d); + auto *out_d = new FunctionTree(out.getMRA()); build_grid(*out_d, out); - multiply(prec, *out_d, 1.0, tree_a, tree_b, maxIter, absPrec); + T One = 1.0; + multiply(prec, *out_d, One, tree_a, tree_b, maxIter, absPrec, true); tmp_vec.push_back({coef_a * coef_b, out_d}); } build_grid(out, tmp_vec); @@ -305,19 +283,18 @@ void dot(double prec, * grids overlap. * */ -template double dot(FunctionTree &bra, FunctionTree &ket) { +template V dot(FunctionTree &bra, FunctionTree &ket) { if (bra.getMRA() != ket.getMRA()) MSG_ABORT("Trees not compatible"); - - MWNodeVector nodeTable; - TreeIterator it(bra); + MWNodeVector nodeTable; + TreeIterator it(bra); it.setReturnGenNodes(false); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); nodeTable.push_back(&node); } int nNodes = nodeTable.size(); - double result = 0.0; - double locResult = 0.0; + V result = 0.0; + V locResult = 0.0; // OMP is disabled in order to get EXACT results (to the very last digit), the // order of summation makes the result different beyond the 14th digit or so. // OMP does improve the performace, but its not worth it for the time being. @@ -326,17 +303,17 @@ template double dot(FunctionTree &bra, FunctionTree &ket) { // { //#pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - const auto &braNode = static_cast &>(*nodeTable[n]); - const MWNode *mwNode = ket.findNode(braNode.getNodeIndex()); + const auto &braNode = static_cast &>(*nodeTable[n]); + const MWNode *mwNode = ket.findNode(braNode.getNodeIndex()); if (mwNode == nullptr) continue; - const auto &ketNode = static_cast &>(*mwNode); + const auto &ketNode = static_cast &>(*mwNode); if (braNode.isRootNode()) locResult += dot_scaling(braNode, ketNode); locResult += dot_wavelet(braNode, ketNode); } //#pragma omp critical result += locResult; - // } + return result; } @@ -352,30 +329,30 @@ template double dot(FunctionTree &bra, FunctionTree &ket) { * distribution within the node. * If the product is zero, the functions are disjoints. */ -template double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact) { +template double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact) { if (bra.getMRA() != ket.getMRA()) MSG_ABORT("Incompatible MRA"); double result = 0.0; int ncoef = bra.getKp1_d() * bra.getTDim(); - double valA[ncoef]; - double valB[ncoef]; + T valA[ncoef]; + T valB[ncoef]; int nNodes = bra.getNEndNodes(); for (int n = 0; n < nNodes; n++) { - FunctionNode &node = bra.getEndFuncNode(n); + FunctionNode &node = bra.getEndFuncNode(n); const NodeIndex idx = node.getNodeIndex(); if (exact) { // convert to interpolating coef, take abs, convert back - FunctionNode *mwNode = static_cast *>(ket.findNode(idx)); + FunctionNode *mwNode = static_cast *>(ket.findNode(idx)); if (mwNode == nullptr) MSG_ABORT("Trees must have same grid"); node.getAbsCoefs(valA); mwNode->getAbsCoefs(valB); - for (int i = 0; i < ncoef; i++) result += valA[i] * valB[i]; + for (int i = 0; i < ncoef; i++) result += std::norm(valA[i] * valB[i]); } else { // approximate by product of node norms int rIdx = ket.getRootBox().getBoxIndex(idx); assert(rIdx >= 0); - const MWNode &root = ket.getRootBox().getNode(rIdx); + const MWNode &root = ket.getRootBox().getNode(rIdx); result += std::sqrt(node.getSquareNorm()) * root.getNodeNorm(idx); } } @@ -383,124 +360,107 @@ template double node_norm_dot(FunctionTree &bra, FunctionTree &ket return result; } -template void multiply<1>(double prec, - FunctionTree<1> &out, - double c, - FunctionTree<1> &tree_a, - FunctionTree<1> &tree_b, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<2>(double prec, - FunctionTree<2> &out, - double c, - FunctionTree<2> &tree_a, - FunctionTree<2> &tree_b, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<3>(double prec, - FunctionTree<3> &out, - double c, - FunctionTree<3> &tree_a, - FunctionTree<3> &tree_b, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<1>(double prec, - FunctionTree<1> &out, - FunctionTreeVector<1> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<2>(double prec, - FunctionTree<2> &out, - FunctionTreeVector<2> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<3>(double prec, - FunctionTree<3> &out, - FunctionTreeVector<3> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<1>(double prec, - FunctionTree<1> &out, - std::vector *> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<2>(double prec, - FunctionTree<2> &out, - std::vector *> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void multiply<3>(double prec, - FunctionTree<3> &out, - std::vector *> &inp, - int maxIter, - bool absPrec, - bool useMaxNorms); -template void power<1>(double prec, - FunctionTree<1> &out, - FunctionTree<1> &tree, - double pow, - int maxIter, - bool absPrec); -template void power<2>(double prec, - FunctionTree<2> &out, - FunctionTree<2> &tree, - double pow, - int maxIter, - bool absPrec); -template void power<3>(double prec, - FunctionTree<3> &out, - FunctionTree<3> &tree, - double pow, - int maxIter, - bool absPrec); -template void square<1>(double prec, - FunctionTree<1> &out, - FunctionTree<1> &tree, - int maxIter, - bool absPrec); -template void square<2>(double prec, - FunctionTree<2> &out, - FunctionTree<2> &tree, - int maxIter, - bool absPrec); -template void square<3>(double prec, - FunctionTree<3> &out, - FunctionTree<3> &tree, - int maxIter, - bool absPrec); -template void dot<1>(double prec, - FunctionTree<1> &out, - FunctionTreeVector<1> &inp_a, - FunctionTreeVector<1> &inp_b, - int maxIter, - bool absPrec); -template void dot<2>(double prec, - FunctionTree<2> &out, - FunctionTreeVector<2> &inp_a, - FunctionTreeVector<2> &inp_b, - int maxIter, - bool absPrec); -template void dot<3>(double prec, - FunctionTree<3> &out, - FunctionTreeVector<3> &inp_a, - FunctionTreeVector<3> &inp_b, - int maxIter, - bool absPrec); - -template double dot<1>(FunctionTree<1> &bra, FunctionTree<1> &ket); -template double dot<2>(FunctionTree<2> &bra, FunctionTree<2> &ket); -template double dot<3>(FunctionTree<3> &bra, FunctionTree<3> &ket); - -template double node_norm_dot<1>(FunctionTree<1> &bra, FunctionTree<1> &ket, bool exact); -template double node_norm_dot<2>(FunctionTree<2> &bra, FunctionTree<2> &ket, bool exact); -template double node_norm_dot<3>(FunctionTree<3> &bra, FunctionTree<3> &ket, bool exact); +template void +multiply<1, double>(double prec, FunctionTree<1, double> &out, double c, FunctionTree<1, double> &tree_a, FunctionTree<1, double> &tree_b, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void +multiply<2, double>(double prec, FunctionTree<2, double> &out, double c, FunctionTree<2, double> &tree_a, FunctionTree<2, double> &tree_b, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void +multiply<3, double>(double prec, FunctionTree<3, double> &out, double c, FunctionTree<3, double> &tree_a, FunctionTree<3, double> &tree_b, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<1, double>(double prec, FunctionTree<1, double> &out, FunctionTreeVector<1, double> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<2, double>(double prec, FunctionTree<2, double> &out, FunctionTreeVector<2, double> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<3, double>(double prec, FunctionTree<3, double> &out, FunctionTreeVector<3, double> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<1, double>(double prec, FunctionTree<1, double> &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<2, double>(double prec, FunctionTree<2, double> &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<3, double>(double prec, FunctionTree<3, double> &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void power<1, double>(double prec, FunctionTree<1, double> &out, FunctionTree<1, double> &tree, double pow, int maxIter, bool absPrec); +template void power<2, double>(double prec, FunctionTree<2, double> &out, FunctionTree<2, double> &tree, double pow, int maxIter, bool absPrec); +template void power<3, double>(double prec, FunctionTree<3, double> &out, FunctionTree<3, double> &tree, double pow, int maxIter, bool absPrec); +template void square<1, double>(double prec, FunctionTree<1, double> &out, FunctionTree<1, double> &tree, int maxIter, bool absPrec, bool conjugate); +template void square<2, double>(double prec, FunctionTree<2, double> &out, FunctionTree<2, double> &tree, int maxIter, bool absPrec, bool conjugate); +template void square<3, double>(double prec, FunctionTree<3, double> &out, FunctionTree<3, double> &tree, int maxIter, bool absPrec, bool conjugate); +template void dot<1, double>(double prec, FunctionTree<1, double> &out, FunctionTreeVector<1, double> &inp_a, FunctionTreeVector<1, double> &inp_b, int maxIter, bool absPrec); +template void dot<2, double>(double prec, FunctionTree<2, double> &out, FunctionTreeVector<2, double> &inp_a, FunctionTreeVector<2, double> &inp_b, int maxIter, bool absPrec); +template void dot<3, double>(double prec, FunctionTree<3, double> &out, FunctionTreeVector<3, double> &inp_a, FunctionTreeVector<3, double> &inp_b, int maxIter, bool absPrec); +template double node_norm_dot<1, double>(FunctionTree<1, double> &bra, FunctionTree<1, double> &ket, bool exact); +template double node_norm_dot<2, double>(FunctionTree<2, double> &bra, FunctionTree<2, double> &ket, bool exact); +template double node_norm_dot<3, double>(FunctionTree<3, double> &bra, FunctionTree<3, double> &ket, bool exact); + +template void multiply<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + ComplexDouble c, + FunctionTree<1, ComplexDouble> &tree_a, + FunctionTree<1, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool useMaxNorms, + bool conjugate); +template void multiply<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + ComplexDouble c, + FunctionTree<2, ComplexDouble> &tree_a, + FunctionTree<2, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool useMaxNorms, + bool conjugate); +template void multiply<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + ComplexDouble c, + FunctionTree<3, ComplexDouble> &tree_a, + FunctionTree<3, ComplexDouble> &tree_b, + int maxIter, + bool absPrec, + bool useMaxNorms, + bool conjugate); +template void multiply<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, FunctionTreeVector<1, ComplexDouble> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, FunctionTreeVector<2, ComplexDouble> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, FunctionTreeVector<3, ComplexDouble> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void +multiply<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void +multiply<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void +multiply<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, std::vector *> &inp, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate); +template void power<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &tree, double pow, int maxIter, bool absPrec); +template void power<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &tree, double pow, int maxIter, bool absPrec); +template void power<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &tree, double pow, int maxIter, bool absPrec); +template void square<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, FunctionTree<1, ComplexDouble> &tree, int maxIter, bool absPrec, bool conjugate); +template void square<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, FunctionTree<2, ComplexDouble> &tree, int maxIter, bool absPrec, bool conjugate); +template void square<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, FunctionTree<3, ComplexDouble> &tree, int maxIter, bool absPrec, bool conjugate); +template void dot<1, ComplexDouble>(double prec, + FunctionTree<1, ComplexDouble> &out, + FunctionTreeVector<1, ComplexDouble> &inp_a, + FunctionTreeVector<1, ComplexDouble> &inp_b, + int maxIter, + bool absPrec); +template void dot<2, ComplexDouble>(double prec, + FunctionTree<2, ComplexDouble> &out, + FunctionTreeVector<2, ComplexDouble> &inp_a, + FunctionTreeVector<2, ComplexDouble> &inp_b, + int maxIter, + bool absPrec); +template void dot<3, ComplexDouble>(double prec, + FunctionTree<3, ComplexDouble> &out, + FunctionTreeVector<3, ComplexDouble> &inp_a, + FunctionTreeVector<3, ComplexDouble> &inp_b, + int maxIter, + bool absPrec); + +template double dot<1, double, double>(FunctionTree<1, double> &bra, FunctionTree<1, double> &ket); +template double dot<2, double, double>(FunctionTree<2, double> &bra, FunctionTree<2, double> &ket); +template double dot<3, double, double>(FunctionTree<3, double> &bra, FunctionTree<3, double> &ket); +template ComplexDouble dot<1, ComplexDouble, double>(FunctionTree<1, ComplexDouble> &bra, FunctionTree<1, double> &ket); +template ComplexDouble dot<2, ComplexDouble, double>(FunctionTree<2, ComplexDouble> &bra, FunctionTree<2, double> &ket); +template ComplexDouble dot<3, ComplexDouble, double>(FunctionTree<3, ComplexDouble> &bra, FunctionTree<3, double> &ket); +template ComplexDouble dot<1, double, ComplexDouble>(FunctionTree<1, double> &bra, FunctionTree<1, ComplexDouble> &ket); +template ComplexDouble dot<2, double, ComplexDouble>(FunctionTree<2, double> &bra, FunctionTree<2, ComplexDouble> &ket); +template ComplexDouble dot<3, double, ComplexDouble>(FunctionTree<3, double> &bra, FunctionTree<3, ComplexDouble> &ket); +template ComplexDouble dot<1, ComplexDouble, ComplexDouble>(FunctionTree<1, ComplexDouble> &bra, FunctionTree<1, ComplexDouble> &ket); +template ComplexDouble dot<2, ComplexDouble, ComplexDouble>(FunctionTree<2, ComplexDouble> &bra, FunctionTree<2, ComplexDouble> &ket); +template ComplexDouble dot<3, ComplexDouble, ComplexDouble>(FunctionTree<3, ComplexDouble> &bra, FunctionTree<3, ComplexDouble> &ket); + +template double node_norm_dot<1, ComplexDouble>(FunctionTree<1, ComplexDouble> &bra, FunctionTree<1, ComplexDouble> &ket, bool exact); +template double node_norm_dot<2, ComplexDouble>(FunctionTree<2, ComplexDouble> &bra, FunctionTree<2, ComplexDouble> &ket, bool exact); +template double node_norm_dot<3, ComplexDouble>(FunctionTree<3, ComplexDouble> &bra, FunctionTree<3, ComplexDouble> &ket, bool exact); } // namespace mrcpp diff --git a/src/treebuilders/multiply.h b/src/treebuilders/multiply.h index 54947bf78..316066483 100644 --- a/src/treebuilders/multiply.h +++ b/src/treebuilders/multiply.h @@ -28,57 +28,34 @@ #include "trees/FunctionTreeVector.h" namespace mrcpp { -template class RepresentableFunction; -template class FunctionTree; +template class RepresentableFunction; +template class FunctionTree; -template void dot(double prec, - FunctionTree &out, - FunctionTreeVector &inp_a, - FunctionTreeVector &inp_b, - int maxIter = -1, - bool absPrec = false); +template () * std::declval())> V dot(FunctionTree &bra, FunctionTree &ket); -template double dot(FunctionTree &bra, - FunctionTree &ket); +template void dot(double prec, FunctionTree &out, FunctionTreeVector &inp_a, FunctionTreeVector &inp_b, int maxIter = -1, bool absPrec = false); -template double node_norm_dot(FunctionTree &bra, - FunctionTree &ket, - bool exact = false); +template double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact = false); -template void multiply(double prec, - FunctionTree &out, - double c, - FunctionTree &inp_a, - FunctionTree &inp_b, - int maxIter = -1, - bool absPrec = false, - bool useMaxNorms = false); +template +void multiply(double prec, + FunctionTree &out, + T c, + FunctionTree &inp_a, + FunctionTree &inp_b, + int maxIter = -1, + bool absPrec = false, + bool useMaxNorms = false, + bool conjugate = false); -template void multiply(double prec, - FunctionTree &out, - std::vector *> &inp, - int maxIter = -1, - bool absPrec = false, - bool useMaxNorms = false); +template +void multiply(double prec, FunctionTree &out, std::vector *> &inp, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); -template void multiply(double prec, - FunctionTree &out, - FunctionTreeVector &inp, - int maxIter = -1, - bool absPrec = false, - bool useMaxNorms = false); +template +void multiply(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); -template void power(double prec, - FunctionTree &out, - FunctionTree &inp, - double p, - int maxIter = -1, - bool absPrec = false); +template void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter = -1, bool absPrec = false); -template void square(double prec, - FunctionTree &out, - FunctionTree &inp, - int maxIter = -1, - bool absPrec = false); +template void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); } // namespace mrcpp diff --git a/src/treebuilders/project.cpp b/src/treebuilders/project.cpp index c22f22ec8..7eea89416 100644 --- a/src/treebuilders/project.cpp +++ b/src/treebuilders/project.cpp @@ -56,8 +56,9 @@ namespace mrcpp { * no coefs). * */ -template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter, bool absPrec) { - AnalyticFunction inp(func); +template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter, bool absPrec) { + AnalyticFunction inp(func); + mrcpp::project(prec, out, inp, maxIter, absPrec); } @@ -81,14 +82,13 @@ template void project(double prec, FunctionTree &out, std::function void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter, bool absPrec) { - +template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter, bool absPrec) { int maxScale = out.getMRA().getMaxScale(); const auto scaling_factor = out.getMRA().getWorldBox().getScalingFactors(); - TreeBuilder builder; - WaveletAdaptor adaptor(prec, maxScale, absPrec); + TreeBuilder builder; + WaveletAdaptor adaptor(prec, maxScale, absPrec); - ProjectionCalculator calculator(inp, scaling_factor); + ProjectionCalculator calculator(inp, scaling_factor); builder.build(out, calculator, adaptor, maxIter); @@ -121,19 +121,31 @@ template void project(double prec, FunctionTree &out, RepresentableFu * no coefs). * */ -template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter, bool absPrec) { +template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter, bool absPrec) { if (out.size() != func.size()) MSG_ABORT("Size mismatch"); for (auto j = 0; j < D; j++) mrcpp::project(prec, get_func(out, j), func[j], maxIter, absPrec); } -template void project<1>(double prec, FunctionTree<1> &out, RepresentableFunction<1> &inp, int maxIter, bool absPrec); -template void project<2>(double prec, FunctionTree<2> &out, RepresentableFunction<2> &inp, int maxIter, bool absPrec); -template void project<3>(double prec, FunctionTree<3> &out, RepresentableFunction<3> &inp, int maxIter, bool absPrec); +template void project<1, double>(double prec, FunctionTree<1, double> &out, RepresentableFunction<1, double> &inp, int maxIter, bool absPrec); +template void project<2, double>(double prec, FunctionTree<2, double> &out, RepresentableFunction<2, double> &inp, int maxIter, bool absPrec); +template void project<3, double>(double prec, FunctionTree<3, double> &out, RepresentableFunction<3, double> &inp, int maxIter, bool absPrec); + +template void project<1, double>(double prec, FunctionTree<1, double> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<2, double>(double prec, FunctionTree<2, double> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<3, double>(double prec, FunctionTree<3, double> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<1, double>(double prec, FunctionTreeVector<1, double> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<2, double>(double prec, FunctionTreeVector<2, double> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<3, double>(double prec, FunctionTreeVector<3, double> &out, std::vector &r)>> inp, int maxIter, bool absPrec); + +template void project<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, RepresentableFunction<1, ComplexDouble> &inp, int maxIter, bool absPrec); +template void project<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, RepresentableFunction<2, ComplexDouble> &inp, int maxIter, bool absPrec); +template void project<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, RepresentableFunction<3, ComplexDouble> &inp, int maxIter, bool absPrec); + +template void project<1, ComplexDouble>(double prec, FunctionTree<1, ComplexDouble> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<2, ComplexDouble>(double prec, FunctionTree<2, ComplexDouble> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<3, ComplexDouble>(double prec, FunctionTree<3, ComplexDouble> &out, std::function &r)> func, int maxIter, bool absPrec); +template void project<1, ComplexDouble>(double prec, FunctionTreeVector<1, ComplexDouble> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<2, ComplexDouble>(double prec, FunctionTreeVector<2, ComplexDouble> &out, std::vector &r)>> inp, int maxIter, bool absPrec); +template void project<3, ComplexDouble>(double prec, FunctionTreeVector<3, ComplexDouble> &out, std::vector &r)>> inp, int maxIter, bool absPrec); -template void project<1>(double prec, FunctionTree<1> &out, std::function &r)> func, int maxIter, bool absPrec); -template void project<2>(double prec, FunctionTree<2> &out, std::function &r)> func, int maxIter, bool absPrec); -template void project<3>(double prec, FunctionTree<3> &out, std::function &r)> func, int maxIter, bool absPrec); -template void project<1>(double prec, FunctionTreeVector<1> &out, std::vector &r)>> inp, int maxIter, bool absPrec); -template void project<2>(double prec, FunctionTreeVector<2> &out, std::vector &r)>> inp, int maxIter, bool absPrec); -template void project<3>(double prec, FunctionTreeVector<3> &out, std::vector &r)>> inp, int maxIter, bool absPrec); } // namespace mrcpp diff --git a/src/treebuilders/project.h b/src/treebuilders/project.h index 790914a4b..f9e070ef2 100644 --- a/src/treebuilders/project.h +++ b/src/treebuilders/project.h @@ -30,7 +30,7 @@ #include namespace mrcpp { -template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter = -1, bool absPrec = false); -template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter = -1, bool absPrec = false); -template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter = -1, bool absPrec = false); +template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter = -1, bool absPrec = false); +template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter = -1, bool absPrec = false); +template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter = -1, bool absPrec = false); } // namespace mrcpp diff --git a/src/trees/BandWidth.cpp b/src/trees/BandWidth.cpp index 530d738b8..a79814d2f 100644 --- a/src/trees/BandWidth.cpp +++ b/src/trees/BandWidth.cpp @@ -44,7 +44,6 @@ void BandWidth::setWidth(int depth, int index, int wd) { if (wd > this->widths(depth, 4)) { this->widths(depth, 4) = wd; } } - std::ostream &BandWidth::print(std::ostream &o) const { o << " *BandWidths:" << std::endl; o << " n T C B A | max " << std::endl; diff --git a/src/trees/BandWidth.h b/src/trees/BandWidth.h index 85a2d5a43..b4ee49e8d 100644 --- a/src/trees/BandWidth.h +++ b/src/trees/BandWidth.h @@ -51,7 +51,7 @@ class BandWidth final { int getMaxWidth(int depth) const { return (depth > getDepth()) ? -1 : this->widths(depth, 4); } int getWidth(int depth, int index) const { return (depth > getDepth()) ? -1 : this->widths(depth, index); } void setWidth(int depth, int index, int wd); - + friend std::ostream &operator<<(std::ostream &o, const BandWidth &bw) { return bw.print(o); } private: diff --git a/src/trees/CornerOperatorTree.cpp b/src/trees/CornerOperatorTree.cpp index 9b7ecb24b..6de235dd3 100644 --- a/src/trees/CornerOperatorTree.cpp +++ b/src/trees/CornerOperatorTree.cpp @@ -24,24 +24,23 @@ */ #include "CornerOperatorTree.h" +#include "BandWidth.h" #include "OperatorNode.h" #include "utils/Printer.h" -#include "BandWidth.h" using namespace Eigen; namespace mrcpp { - /** @brief Calculates band widths of the non-standard form matrices. * * @param[in] prec: Precision used for thresholding - * + * * @details It is starting from \f$ l = 2^n \f$ and updating the band width value each time we encounter * considerable value while keeping decreasing down to \f$ l = 0 \f$, that stands for the distance to the diagonal. * This procedure is repeated for each matrix \f$ A, B \f$ and \f$ C \f$. - * - */ + * + */ void CornerOperatorTree::calcBandWidth(double prec) { if (this->bandWidth == nullptr) clearBandWidth(); this->bandWidth = new BandWidth(getDepth()); @@ -50,11 +49,10 @@ void CornerOperatorTree::calcBandWidth(double prec) { getMaxTranslations(max_transl); if (prec < 0.0) prec = this->normPrec; - double thrs = std::max(MachinePrec, prec / 10.0); //should be enough due to oscillating behaviour of corner matrix elements (it's affected by polynomial order) - - for (int depth = 0; depth < this->getDepth(); depth++) - { - int l = (1<getDepth(); depth++) { + int l = (1 << depth) - 1; this->bandWidth->setWidth(depth, 0, l); bool done = false; @@ -62,7 +60,7 @@ void CornerOperatorTree::calcBandWidth(double prec) { done = true; MWNode<2> *node = findNode(NodeIndex<2>(depth, {l, 0})); for (int k = 1; k < 4; k++) { - if ( (node != nullptr) && (node->getComponentNorm(k) > thrs)) { + if ((node != nullptr) && (node->getComponentNorm(k) > thrs)) { this->bandWidth->setWidth(depth, k, l); done = false; } @@ -73,20 +71,17 @@ void CornerOperatorTree::calcBandWidth(double prec) { println(100, "\nOperator BandWidth" << *this->bandWidth); } - /** @brief Checks if the distance to diagonal is lesser than the operator band width. * * @param[in] oTransl: distance to diagonal * @param[in] o_depth: scaling order * @param[in] idx: index corresponding to one of the matrices \f$ A, B, C \f$ or \f$ T \f$. - * - * @returns True if \b oTransl is outside of the corner band (close to diagonal) and False otherwise. - * - */ -bool CornerOperatorTree::isOutsideBand(int oTransl, int o_depth, int idx) -{ + * + * @returns True if \b oTransl is outside of the corner band (close to diagonal) and False otherwise. + * + */ +bool CornerOperatorTree::isOutsideBand(int oTransl, int o_depth, int idx) { return abs(oTransl) < this->bandWidth->getWidth(o_depth, idx); } - } // namespace mrcpp diff --git a/src/trees/CornerOperatorTree.h b/src/trees/CornerOperatorTree.h index 06f6f6136..0ac2ad5bd 100644 --- a/src/trees/CornerOperatorTree.h +++ b/src/trees/CornerOperatorTree.h @@ -29,7 +29,6 @@ namespace mrcpp { - /** @class CornerOperatorTree * * @brief Special case of OperatorTree class diff --git a/src/trees/FunctionNode.cpp b/src/trees/FunctionNode.cpp index c839e2b57..ff23fb394 100644 --- a/src/trees/FunctionNode.cpp +++ b/src/trees/FunctionNode.cpp @@ -44,7 +44,7 @@ namespace mrcpp { /** Function evaluation. * Evaluate all polynomials defined on the node. */ -template double FunctionNode::evalf(Coord r) { +template T FunctionNode::evalf(Coord r) { if (not this->hasCoefs()) MSG_ERROR("Evaluating node without coefs"); // The 1.0 appearing in the if tests comes from the period is always 1.0 @@ -57,7 +57,7 @@ template double FunctionNode::evalf(Coord r) { return getFuncChild(cIdx).evalScaling(r); } -template double FunctionNode::evalScaling(const Coord &r) const { +template T FunctionNode::evalScaling(const Coord &r) const { if (not this->hasCoefs()) MSG_ERROR("Evaluating node without coefs"); double arg[D]; @@ -72,10 +72,10 @@ template double FunctionNode::evalScaling(const Coord &r) const { const ScalingBasis &basis = this->getMWTree().getMRA().getScalingBasis(); basis.evalf(arg, val); - double result = 0.0; + T result = 0.0; //#pragma omp parallel for shared(fact) reduction(+:result) num_threads(mrcpp_get_num_threads()) for (int i = 0; i < this->getKp1_d(); i++) { - double temp = this->coefs[i]; + T temp = this->coefs[i]; for (int j = 0; j < D; j++) { int k = (i % fact[j + 1]) / fact[j]; temp *= val(k, j); @@ -92,7 +92,7 @@ template double FunctionNode::evalScaling(const Coord &r) const { * Wrapper for function integration, that requires different methods depending * on scaling type. Integrates the function represented on the node on the * full support of the node. */ -template double FunctionNode::integrate() const { +template T FunctionNode::integrate() const { if (not this->hasCoefs()) { return 0.0; } switch (this->getScalingType()) { case Legendre: @@ -115,7 +115,7 @@ template double FunctionNode::integrate() const { * s_i = int f(x)phi_i(x)dx * and since the first Legendre function is the constant 1, the first * coefficient is simply the integral of f(x). */ -template double FunctionNode::integrateLegendre() const { +template T FunctionNode::integrateLegendre() const { double n = (D * this->getScale()) / 2.0; double two_n = std::pow(2.0, -n); return two_n * this->getCoefs()[0]; @@ -126,7 +126,7 @@ template double FunctionNode::integrateLegendre() const { * Integrates the function represented on the node on the full support of the * node. A bit more involved than in the Legendre basis, as is requires some * coupling of quadrature weights. */ -template double FunctionNode::integrateInterpolating() const { +template T FunctionNode::integrateInterpolating() const { int qOrder = this->getKp1(); getQuadratureCache(qc); const VectorXd &weights = qc.getWeights(qOrder); @@ -136,7 +136,7 @@ template double FunctionNode::integrateInterpolating() const { int kp1_p[D]; for (int i = 0; i < D; i++) kp1_p[i] = math_utils::ipow(qOrder, i); - VectorXd coefs; + Eigen::Matrix coefs; this->getCoefs(coefs); for (int p = 0; p < D; p++) { @@ -152,7 +152,7 @@ template double FunctionNode::integrateInterpolating() const { } double n = (D * this->getScale()) / 2.0; double two_n = std::pow(2.0, -n); - double sum = coefs.segment(0, this->getKp1_d()).sum(); + T sum = coefs.segment(0, this->getKp1_d()).sum(); return two_n * sum; } @@ -162,48 +162,53 @@ template double FunctionNode::integrateInterpolating() const { * Integrates the function represented on the node on the full support of the * node. A bit more involved than in the Legendre basis, as is requires some * coupling of quadrature weights. */ -template double FunctionNode::integrateValues() const { +template T FunctionNode::integrateValues() const { int qOrder = this->getKp1(); getQuadratureCache(qc); const VectorXd &weights = qc.getWeights(qOrder); - VectorXd coefs; + Eigen::Matrix coefs; this->getCoefs(coefs); int ncoefs = coefs.size(); - int ncoefChild = ncoefs/(1< 3) MSG_ABORT("Not Implemented") + T sum = 0.0; + if (D > 3) + MSG_ABORT("Not Implemented") else if (D == 3) { for (int i = 0; i < qOrder; i++) { - double sumj = 0.0; + T sumj = 0.0; for (int j = 0; j < qOrder; j++) { - double sumk = 0.0; + T sumk = 0.0; for (int k = 0; k < qOrder; k++) sumk += cc[nc++] * weights[k]; sumj += sumk * weights[j]; } sum += sumj * weights[i]; } - } else if (D==2) { + } else if (D == 2) { for (int j = 0; j < qOrder; j++) { - double sumk = 0.0; - for (int k = 0; k < qOrder; k++) sumk += cc[nc++] * weights[k]; - sum += sumk * weights[j]; + T sumk = 0.0; + for (int k = 0; k < qOrder; k++) sumk += cc[nc++] * weights[k]; + sum += sumk * weights[j]; } - } else if (D==1) for (int k = 0; k < qOrder; k++) sum += cc[nc++] * weights[k]; - - int n = D * (this->getScale() + 1) ; // NB: one extra scale - int two_n = (1<0)sum/=two_n; - else sum*=two_n; + } else if (D == 1) + for (int k = 0; k < qOrder; k++) sum += cc[nc++] * weights[k]; + + int n = D * (this->getScale() + 1); // NB: one extra scale + int two_n = (1 << abs(n)); // 2**n; + if (n > 0) + sum /= two_n; + else + sum *= two_n; return sum; } -template void FunctionNode::setValues(const VectorXd &vec) { +template void FunctionNode::setValues(const Matrix &vec) { this->zeroCoefs(); this->setCoefBlock(0, vec.size(), vec.data()); this->cvTransform(Backward); @@ -212,15 +217,15 @@ template void FunctionNode::setValues(const VectorXd &vec) { this->calcNorms(); } -template void FunctionNode::getValues(VectorXd &vec) { +template void FunctionNode::getValues(Matrix &vec) { if (this->isGenNode()) { - MWNode copy(*this); - vec = Eigen::VectorXd::Zero(copy.getNCoefs()); + MWNode copy(*this); + vec = Eigen::Matrix::Zero(copy.getNCoefs()); copy.mwTransform(Reconstruction); copy.cvTransform(Forward); for (int i = 0; i < this->n_coefs; i++) vec(i) = copy.getCoefs()[i]; } else { - vec = VectorXd::Zero(this->n_coefs); + vec = Eigen::Matrix::Zero(this->n_coefs); this->mwTransform(Reconstruction); this->cvTransform(Forward); for (int i = 0; i < this->n_coefs; i++) vec(i) = this->coefs[i]; @@ -231,20 +236,23 @@ template void FunctionNode::getValues(VectorXd &vec) { /** get coefficients corresponding to absolute value of function * - * Leaves the original coefficients unchanged. */ -template void FunctionNode::getAbsCoefs(double *absCoefs) { - double *coefsTmp = this->coefs; + * Leaves the original coefficients unchanged. + * Note that we mus use T and not double, even if the norms are double, because + * the transforms expect T types. + */ +template void FunctionNode::getAbsCoefs(T *absCoefs) { + T *coefsTmp = this->coefs; for (int i = 0; i < this->n_coefs; i++) absCoefs[i] = coefsTmp[i]; // copy this->coefs = absCoefs; // swap coefs this->mwTransform(Reconstruction); this->cvTransform(Forward); - for (int i = 0; i < this->n_coefs; i++) this->coefs[i] = std::abs(this->coefs[i]); + for (int i = 0; i < this->n_coefs; i++) this->coefs[i] = std::norm(this->coefs[i]); this->cvTransform(Backward); this->mwTransform(Compression); this->coefs = coefsTmp; // restore original array (same address) } -template void FunctionNode::createChildren(bool coefs) { +template void FunctionNode::createChildren(bool coefs) { if (this->isBranchNode()) MSG_ABORT("Node already has children"); auto &allocator = this->getFuncTree().getNodeAllocator(); @@ -258,7 +266,7 @@ template void FunctionNode::createChildren(bool coefs) { this->childSerialIx = sIdx; for (int cIdx = 0; cIdx < nChildren; cIdx++) { // construct into allocator memory - new (child_p) FunctionNode(this, cIdx); + new (child_p) FunctionNode(this, cIdx); this->children[cIdx] = child_p; child_p->serialIx = sIdx; @@ -282,7 +290,7 @@ template void FunctionNode::createChildren(bool coefs) { this->clearIsEndNode(); } -template void FunctionNode::genChildren() { +template void FunctionNode::genChildren() { if (this->isBranchNode()) MSG_ABORT("Node already has children"); auto &allocator = this->getFuncTree().getGenNodeAllocator(); @@ -296,7 +304,7 @@ template void FunctionNode::genChildren() { this->childSerialIx = sIdx; for (int cIdx = 0; cIdx < nChildren; cIdx++) { // construct into allocator memory - new (child_p) FunctionNode(this, cIdx); + new (child_p) FunctionNode(this, cIdx); this->children[cIdx] = child_p; child_p->serialIx = sIdx; @@ -319,7 +327,7 @@ template void FunctionNode::genChildren() { this->setIsBranchNode(); } -template void FunctionNode::genParent() { +template void FunctionNode::genParent() { if (this->parent != nullptr) MSG_ABORT("Node is not an orphan"); auto &allocator = this->getFuncTree().getNodeAllocator(); @@ -332,7 +340,7 @@ template void FunctionNode::genParent() { this->parentSerialIx = sIdx; // construct into allocator memory - new (parent_p) FunctionNode(this->tree, this->getNodeIndex().parent()); + new (parent_p) FunctionNode(this->tree, this->getNodeIndex().parent()); this->parent = parent_p; @@ -351,12 +359,12 @@ template void FunctionNode::genParent() { this->getMWTree().incrementNodeCount(parent_p->getScale()); } -template void FunctionNode::deleteChildren() { - MWNode::deleteChildren(); +template void FunctionNode::deleteChildren() { + MWNode::deleteChildren(); this->setIsEndNode(); } -template void FunctionNode::dealloc() { +template void FunctionNode::dealloc() { int sIdx = this->serialIx; this->serialIx = -1; this->parentSerialIx = -1; @@ -376,8 +384,8 @@ template void FunctionNode::dealloc() { /** Update the coefficients of the node by a mw transform of the scaling * coefficients of the children. Option to overwrite or add up existing * coefficients. Specialized for D=3 below. */ -template void FunctionNode::reCompress() { - MWNode::reCompress(); +template void FunctionNode::reCompress() { + MWNode::reCompress(); } template <> void FunctionNode<3>::reCompress() { @@ -405,8 +413,10 @@ template <> void FunctionNode<3>::reCompress() { * Integrates the product of the functions represented by the scaling basis on * the node on the full support of the nodes. The scaling basis is fully * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. */ -template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { assert(bra.hasCoefs()); assert(ket.hasCoefs()); @@ -423,13 +433,101 @@ template double dot_scaling(const FunctionNode &bra, const FunctionNo #endif } +/** Inner product of the functions represented by the scaling basis of the nodes. + * + * Integrates the product of the functions represented by the scaling basis on + * the node on the full support of the nodes. The scaling basis is fully + * orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { + assert(bra.hasCoefs()); + assert(ket.hasCoefs()); + + const ComplexDouble *a = bra.getCoefs(); + const ComplexDouble *b = ket.getCoefs(); + + int size = bra.getKp1_d(); + ComplexDouble result = 0.0; + // note that bra is conjugated by default + if (bra.getMWTree().conjugate()) { + if (ket.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += a[i] * std::conj(b[i]); + } else { + for (int i = 0; i < size; i++) result += a[i] * b[i]; + } + } else { + if (ket.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += std::conj(a[i]) * std::conj(b[i]); + } else { + for (int i = 0; i < size; i++) result += std::conj(a[i]) * b[i]; + } + } + return result; +} + +/** Inner product of the functions represented by the scaling basis of the nodes. + * + * Integrates the product of the functions represented by the scaling basis on + * the node on the full support of the nodes. The scaling basis is fully + * orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { + assert(bra.hasCoefs()); + assert(ket.hasCoefs()); + + const ComplexDouble *a = bra.getCoefs(); + const double *b = ket.getCoefs(); + + int size = bra.getKp1_d(); + ComplexDouble result = 0.0; + // note that bra is conjugated by default + if (bra.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += a[i] * b[i]; + } else { + for (int i = 0; i < size; i++) result += std::conj(a[i]) * b[i]; + } + return result; +} + +/** Inner product of the functions represented by the scaling basis of the nodes. + * + * Integrates the product of the functions represented by the scaling basis on + * the node on the full support of the nodes. The scaling basis is fully + * orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { + assert(bra.hasCoefs()); + assert(ket.hasCoefs()); + + const double *a = bra.getCoefs(); + const ComplexDouble *b = ket.getCoefs(); + + int size = bra.getKp1_d(); + ComplexDouble result = 0.0; + // note that bra is conjugated by default + if (ket.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += a[i] * std::conj(b[i]); + } else { + for (int i = 0; i < size; i++) result += a[i] * b[i]; + } + return result; +} + /** Inner product of the functions represented by the wavelet basis of the nodes. * * Integrates the product of the functions represented by the wavelet basis on * the node on the full support of the nodes. The wavelet basis is fully * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. */ -template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { if (bra.isGenNode() or ket.isGenNode()) return 0.0; assert(bra.hasCoefs()); @@ -449,15 +547,132 @@ template double dot_wavelet(const FunctionNode &bra, const FunctionNo #endif } -template double dot_scaling(const FunctionNode<1> &bra, const FunctionNode<1> &ket); -template double dot_scaling(const FunctionNode<2> &bra, const FunctionNode<2> &ket); -template double dot_scaling(const FunctionNode<3> &bra, const FunctionNode<3> &ket); -template double dot_wavelet(const FunctionNode<1> &bra, const FunctionNode<1> &ket); -template double dot_wavelet(const FunctionNode<2> &bra, const FunctionNode<2> &ket); -template double dot_wavelet(const FunctionNode<3> &bra, const FunctionNode<3> &ket); +/** Inner product of the functions represented by the wavelet basis of the nodes. + * + * Integrates the product of the functions represented by the wavelet basis on + * the node on the full support of the nodes. The wavelet basis is fully + * orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { + if (bra.isGenNode() or ket.isGenNode()) return 0.0; + + assert(bra.hasCoefs()); + assert(ket.hasCoefs()); + + const ComplexDouble *a = bra.getCoefs(); + const ComplexDouble *b = ket.getCoefs(); + + int start = bra.getKp1_d(); + int size = (bra.getTDim() - 1) * start; + ComplexDouble result = 0.0; + if (bra.getMWTree().conjugate()) { + if (ket.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += a[start + i] * std::conj(b[start + i]); + } else { + for (int i = 0; i < size; i++) result += a[start + i] * b[start + i]; + } + } else { + if (ket.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += std::conj(a[start + i]) * std::conj(b[start + i]); + } else { + for (int i = 0; i < size; i++) result += std::conj(a[start + i]) * b[start + i]; + } + } + return result; +} + +/** Inner product of the functions represented by the wavelet basis of the nodes. + * + * Integrates the product of the functions represented by the wavelet basis on + * the node on the full support of the nodes. The wavelet basis is fully + * orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { + if (bra.isGenNode() or ket.isGenNode()) return 0.0; + + assert(bra.hasCoefs()); + assert(ket.hasCoefs()); + + const ComplexDouble *a = bra.getCoefs(); + const double *b = ket.getCoefs(); + + int start = bra.getKp1_d(); + int size = (bra.getTDim() - 1) * start; + ComplexDouble result = 0.0; + if (bra.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += a[start + i] * b[start + i]; + } else { + for (int i = 0; i < size; i++) result += std::conj(a[start + i]) * b[start + i]; + } + return result; +} + +/** Inner product of the functions represented by the wavelet basis of the nodes. + * + * Integrates the product of the functions represented by the wavelet basis on + * the node on the full support of the nodes. The wavelet basis is fully + * orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * NB: will take conjugate of bra in case of complex values. + */ +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { + if (bra.isGenNode() or ket.isGenNode()) return 0.0; + + assert(bra.hasCoefs()); + assert(ket.hasCoefs()); + + const double *a = bra.getCoefs(); + const ComplexDouble *b = ket.getCoefs(); + + int start = bra.getKp1_d(); + int size = (bra.getTDim() - 1) * start; + ComplexDouble result = 0.0; + if (ket.getMWTree().conjugate()) { + for (int i = 0; i < size; i++) result += a[start + i] * std::conj(b[start + i]); + } else { + for (int i = 0; i < size; i++) result += a[start + i] * b[start + i]; + } + return result; +} -template class FunctionNode<1>; -template class FunctionNode<2>; -template class FunctionNode<3>; +template double dot_scaling(const FunctionNode<1, double> &bra, const FunctionNode<1, double> &ket); +template double dot_scaling(const FunctionNode<2, double> &bra, const FunctionNode<2, double> &ket); +template double dot_scaling(const FunctionNode<3, double> &bra, const FunctionNode<3, double> &ket); +template double dot_wavelet(const FunctionNode<1, double> &bra, const FunctionNode<1, double> &ket); +template double dot_wavelet(const FunctionNode<2, double> &bra, const FunctionNode<2, double> &ket); +template double dot_wavelet(const FunctionNode<3, double> &bra, const FunctionNode<3, double> &ket); + +template class FunctionNode<1, double>; +template class FunctionNode<2, double>; +template class FunctionNode<3, double>; + +template class FunctionNode<1, ComplexDouble>; +template class FunctionNode<2, ComplexDouble>; +template class FunctionNode<3, ComplexDouble>; + +template ComplexDouble dot_scaling(const FunctionNode<1, ComplexDouble> &bra, const FunctionNode<1, ComplexDouble> &ket); +template ComplexDouble dot_scaling(const FunctionNode<2, ComplexDouble> &bra, const FunctionNode<2, ComplexDouble> &ket); +template ComplexDouble dot_scaling(const FunctionNode<3, ComplexDouble> &bra, const FunctionNode<3, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<1, ComplexDouble> &bra, const FunctionNode<1, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<2, ComplexDouble> &bra, const FunctionNode<2, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<3, ComplexDouble> &bra, const FunctionNode<3, ComplexDouble> &ket); + +template ComplexDouble dot_scaling(const FunctionNode<1, double> &bra, const FunctionNode<1, ComplexDouble> &ket); +template ComplexDouble dot_scaling(const FunctionNode<2, double> &bra, const FunctionNode<2, ComplexDouble> &ket); +template ComplexDouble dot_scaling(const FunctionNode<3, double> &bra, const FunctionNode<3, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<1, double> &bra, const FunctionNode<1, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<2, double> &bra, const FunctionNode<2, ComplexDouble> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<3, double> &bra, const FunctionNode<3, ComplexDouble> &ket); + +template ComplexDouble dot_scaling(const FunctionNode<1, ComplexDouble> &bra, const FunctionNode<1, double> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<1, ComplexDouble> &bra, const FunctionNode<1, double> &ket); +template ComplexDouble dot_scaling(const FunctionNode<2, ComplexDouble> &bra, const FunctionNode<2, double> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<2, ComplexDouble> &bra, const FunctionNode<2, double> &ket); +template ComplexDouble dot_scaling(const FunctionNode<3, ComplexDouble> &bra, const FunctionNode<3, double> &ket); +template ComplexDouble dot_wavelet(const FunctionNode<3, ComplexDouble> &bra, const FunctionNode<3, double> &ket); } // namespace mrcpp diff --git a/src/trees/FunctionNode.h b/src/trees/FunctionNode.h index 97a3d74d3..d1bfaaa31 100644 --- a/src/trees/FunctionNode.h +++ b/src/trees/FunctionNode.h @@ -32,55 +32,63 @@ namespace mrcpp { -template class FunctionNode final : public MWNode { +template class FunctionNode final : public MWNode { public: - FunctionTree &getFuncTree() { return static_cast &>(*this->tree); } - FunctionNode &getFuncParent() { return static_cast &>(*this->parent); } - FunctionNode &getFuncChild(int i) { return static_cast &>(*this->children[i]); } + FunctionTree &getFuncTree() { return static_cast &>(*this->tree); } + FunctionNode &getFuncParent() { return static_cast &>(*this->parent); } + FunctionNode &getFuncChild(int i) { return static_cast &>(*this->children[i]); } - const FunctionTree &getFuncTree() const { return static_cast &>(*this->tree); } - const FunctionNode &getFuncParent() const { return static_cast &>(*this->parent); } - const FunctionNode &getFuncChild(int i) const { return static_cast &>(*this->children[i]); } + const FunctionTree &getFuncTree() const { return static_cast &>(*this->tree); } + const FunctionNode &getFuncParent() const { return static_cast &>(*this->parent); } + const FunctionNode &getFuncChild(int i) const { return static_cast &>(*this->children[i]); } void createChildren(bool coefs) override; void genChildren() override; void genParent() override; void deleteChildren() override; - double integrate() const; + T integrate() const; - void setValues(const Eigen::VectorXd &vec); - void getValues(Eigen::VectorXd &vec); - void getAbsCoefs(double *absCoefs); + void setValues(const Eigen::Matrix &vec); + void getValues(Eigen::Matrix &vec); + void getAbsCoefs(T *absCoefs); - friend class FunctionTree; - friend class NodeAllocator; + friend class FunctionTree; + friend class NodeAllocator; protected: FunctionNode() - : MWNode() {} - FunctionNode(MWTree *tree, int rIdx) - : MWNode(tree, rIdx) {} - FunctionNode(MWNode *parent, int cIdx) - : MWNode(parent, cIdx) {} - FunctionNode(MWTree *tree, const NodeIndex &idx) - : MWNode(tree, idx) {} - FunctionNode(const FunctionNode &node) = delete; - FunctionNode &operator=(const FunctionNode &node) = delete; + : MWNode() {} + FunctionNode(MWTree *tree, int rIdx) + : MWNode(tree, rIdx) {} + FunctionNode(MWNode *parent, int cIdx) + : MWNode(parent, cIdx) {} + FunctionNode(MWTree *tree, const NodeIndex &idx) + : MWNode(tree, idx) {} + FunctionNode(const FunctionNode &node) = delete; + FunctionNode &operator=(const FunctionNode &node) = delete; ~FunctionNode() = default; - double evalf(Coord r); - double evalScaling(const Coord &r) const; + T evalf(Coord r); + T evalScaling(const Coord &r) const; void dealloc() override; void reCompress() override; - double integrateLegendre() const; - double integrateInterpolating() const; - double integrateValues() const; + T integrateLegendre() const; + T integrateInterpolating() const; + T integrateValues() const; }; +template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket); +template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); -template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket); -template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); +template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket); +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); + +template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket); +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); + +template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket); +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); } // namespace mrcpp diff --git a/src/trees/FunctionTree.cpp b/src/trees/FunctionTree.cpp index 47614a933..a41581692 100644 --- a/src/trees/FunctionTree.cpp +++ b/src/trees/FunctionTree.cpp @@ -30,6 +30,7 @@ #include "FunctionNode.h" #include "NodeAllocator.h" +#include "treebuilders/grid.h" #include "utils/Bank.h" #include "utils/Printer.h" #include "utils/Timer.h" @@ -50,20 +51,21 @@ namespace mrcpp { * If a shared memory pointer is provided the tree will be allocated in this * shared memory window, otherwise it will be local to each MPI process. */ -template -FunctionTree::FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem, const std::string &name) - : MWTree(mra, name) - , RepresentableFunction(mra.getWorldBox().getLowerBounds().data(), mra.getWorldBox().getUpperBounds().data()) { +template +FunctionTree::FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem, const std::string &name) + : MWTree(mra, name) + , RepresentableFunction(mra.getWorldBox().getLowerBounds().data(), mra.getWorldBox().getUpperBounds().data()) { int nodesPerChunk = 2048; // Large chunks are required for not leading to memory fragmentation (32 MB on "Betzy" 2023) + // nodesPerChunk is same for real and complex trees: the size (in MB) of the complex chunks are twice as large int coefsGenNodes = this->getKp1_d(); int coefsRegNodes = this->getTDim() * this->getKp1_d(); - this->nodeAllocator_p = std::make_unique>(this, sh_mem, coefsRegNodes, nodesPerChunk); - this->genNodeAllocator_p = std::make_unique>(this, nullptr, coefsGenNodes, nodesPerChunk); + this->nodeAllocator_p = std::make_unique>(this, sh_mem, coefsRegNodes, nodesPerChunk); + this->genNodeAllocator_p = std::make_unique>(this, nullptr, coefsGenNodes, nodesPerChunk); this->allocRootNodes(); this->resetEndNodeTable(); } -template void FunctionTree::allocRootNodes() { +template void FunctionTree::allocRootNodes() { auto &allocator = this->getNodeAllocator(); auto &rootbox = this->getRootBox(); @@ -74,10 +76,10 @@ template void FunctionTree::allocRootNodes() { auto *coef_p = allocator.getCoef_p(sIdx); auto *root_p = allocator.getNode_p(sIdx); - MWNode **roots = rootbox.getNodes(); + MWNode **roots = rootbox.getNodes(); for (int rIdx = 0; rIdx < nRoots; rIdx++) { // construct into allocator memory - new (root_p) FunctionNode(this, rIdx); + new (root_p) FunctionNode(this, rIdx); roots[rIdx] = root_p; root_p->serialIx = sIdx; @@ -101,21 +103,266 @@ template void FunctionTree::allocRootNodes() { } // FunctionTree destructor -template FunctionTree::~FunctionTree() { - this->deleteRootNodes(); +template FunctionTree::~FunctionTree() { + if (this->getNNodes() > 0) this->deleteRootNodes(); } +/** @brief Read a previously stored tree assuming text/ASCII format, + * in a representation using MADNESS conventions for n, l and index order. + * @param[in] file: File name + * @note This tree must have the exact same MRA the one that was saved(?) + */ +template void FunctionTree::loadTreeTXT(const std::string &file) { + std::ifstream in(file); + int NDIM, k; + in >> NDIM; + if (NDIM != D) NOT_IMPLEMENTED_ABORT; + double coord[D][2]; + for (int d = 0; d < D; d++) in >> coord[d][0] >> coord[d][1]; + + int p = 1; + int rscale = this->getRootScale(); // root scale of target MRA (MRChem) . NB: negative + for (int i = rscale; i < 0; i++) p *= 2; + int L = p; // NB for now we assume the world as a cube going from -L to +L and L is a power of 2 + // We require that the world box size is identical and a power of 2 + double TXT_thres = 1.0e-14; // threshold for differences in scaling factors + for (int d = 0; d < D; d++) { + if (std::abs(coord[d][0] + L) > TXT_thres) std::cout << coord[d][0] << " " << L << std::endl; + if (std::abs(coord[d][0] + L) > TXT_thres) NOT_IMPLEMENTED_ABORT; + if (std::abs(coord[d][1] - L) > TXT_thres) std::cout << coord[d][1] << " " << L << std::endl; + if (std::abs(coord[d][1] - L) > TXT_thres) NOT_IMPLEMENTED_ABORT; + } + + int nChildren = 1; + for (int d = 0; d < D; d++) nChildren *= 2; + + int nmax = 0; // deppeset scale in TXT + in >> k; + if (k != this->getKp1()) NOT_IMPLEMENTED_ABORT; + k--; // MRChem defines k as highest polynomial order. MADNESS as number of polynomials + + int ncoefs = 1; // number of coefficents for one single node (not a full MRChem MWnode which stores 2**D of them) + for (int i = 0; i < D; i++) ncoefs *= k + 1; + + std::vector *>> NodeTable(50); // to store all the nodes pointers + std::map mp; // to store the number of children stored in each parent node + // MRChem and MADNESS do not use the same indices order for the qudrature points + // We read MADNESS convention (note that mapMRC[mapMRC[i]]=i for all i) + std::vector mapMRC; // mapping vector + int kx = k; + int ky = k; + int kz = k; + if (D < 3) kz = 0; + if (D < 2) ky = 0; + int kp1 = k + 1; + // MADNESS: zyx and i=k,k-1,k-2... MRChem: xyz, i=0,1,2,3 ... + for (int x = kx; x >= 0; x--) { + for (int y = ky; y >= 0; y--) { + for (int z = kz; z >= 0; z--) { mapMRC.push_back(z * kp1 * kp1 + y * kp1 + x); } + } + } + + MWNode **roots = this->getRootBox().getNodes(); + for (int rIdx = 0; rIdx < nChildren; rIdx++) { + roots[rIdx]->deleteChildren(); + roots[rIdx]->zeroCoefs(); + } + this->clearEndNodeTable(); + + int nread; // number of nodes to read + in >> nread; + while (nread-- > 0) { + // NB: MRChem stores quadrature points values in the PARENT node. 2**D nodes are stored in the same parent + int n; // TXT scale + int n_in; // MRChem scale + in >> n_in; + n = n_in + rscale - 1; // MRChem does not define root scale as zero. + + std::array l_in; // translation index TXT + std::array l; // translation index MRChem + std::array lp; // translation index MRChem, parent + + for (int i = 0; i < D; i++) in >> l_in[i]; + + // MRChem defines smallest l as -(2**n)*L , where -L is smallest world coordinate. + // note that root scale has 2**D nodes (if range is -L,L) + for (int i = 0; i < D; i++) { + l[i] = l_in[i] - std::pow(2, n) * L; + lp[i] = l_in[i] / 2 - std::pow(2, n - 1) * L; // for parent + } + NodeIndex idx_p(n - 1, lp); // index of parent node + MWNode *node = &this->getNode(idx_p, true); + // note that node is not necesssarily an endnode, but they children are always endnodes + // must find to which child of the parent node it corresponds + int c_ix = 0; // child index in the parent + int p = 1; + for (int i = 0; i < D; i++) { + if (abs(l[i]) % 2 == 1) c_ix += p; + p *= 2; + } + T *values = node->getCoefs(); + if (mp[node->getSerialIx()] == 0) { + // init to zero + node->zeroCoefs(); + if (not node->isRootNode()) { + // also set siblings to zero if not set yet + MWNode *parent = &node->getMWParent(); + for (int cIdx = 0; cIdx < nChildren; cIdx++) { + if (mp[parent->getMWChild(cIdx).getSerialIx()] == 0) parent->getMWChild(cIdx).zeroCoefs(); + } + } + } + values += c_ix * ncoefs; // repoint to the right child position (ncoefs is for one child only) + for (int i = 0; i < ncoefs; i++) in >> values[mapMRC[i]]; // the indice i is mapped + mp[node->getSerialIx()]++; // counts the number of children included + nmax = std::max(nmax, n_in); // deepest scale in TXT + if (mp[node->getSerialIx()] == 1) NodeTable[n_in].push_back(node); + } + in.close(); + // transform all nodes from quadrature point values to scaling coefficients + for (int n = nmax; n > -1; n--) { + for (int i = 0; i < NodeTable[n].size(); i++) { + MWNode *node = NodeTable[n][i]; + node->cvTransform(Backward); + node->calcNorms(); + } + } + // now tree has only scaling coefficients or zeros on end nodes + + // Transform into scaling and wavelets, starting by leaf nodes and copying scaling into parents + for (int n = nmax; n > -1; n--) { + for (int i = 0; i < NodeTable[n].size(); i++) { + MWNode *node = NodeTable[n][i]; + if (mp[node->getSerialIx()] == nChildren) { + // node complete: transform into scaling and wavelets + if (node->isEndNode()) { + node->mwTransform(Compression); + node->setHasCoefs(); + node->calcNorms(); + this->endNodeTable.push_back(node); + } else { + // MRCPP requires that all nodes that have no children are end nodes + // and all nodes are groups of 2**D siblings + T *pcoefs = node->getCoefs(); // parent coefficients + for (int cIdx = 0; cIdx < nChildren; cIdx++) { + MWNode *cnode = &node->getMWChild(cIdx); + if (mp[cnode->getSerialIx()] != nChildren) { + // This child is not defined. must take scaling from parent + if (mp[cnode->getSerialIx()] > 0) std::cout << "accounting error " << std::endl; + T *ccoefs = cnode->getCoefs(); // child coefficients + for (int j = 0; j < ncoefs; j++) ccoefs[j] = pcoefs[j + cIdx * ncoefs]; + for (int j = ncoefs; j < ncoefs * nChildren; j++) ccoefs[j] = 0.0; // the remainder are set to zero + this->endNodeTable.push_back(cnode); // add to the list of nodes + cnode->setHasCoefs(); + cnode->calcNorms(); + } + } + node->mwTransform(Compression); + node->setHasCoefs(); + node->calcNorms(); + } + if (not node->isRootNode()) { + // and copy the new scaling parts into parent + MWNode *parent = &node->getMWParent(); + // check if parent exist already, and put in the list if not. + if (mp[parent->getSerialIx()] == 0) NodeTable[n - 1].push_back(parent); + int my_ix = -1; + // find index among siblings + for (int cIdx = 0; cIdx < nChildren; cIdx++) { + if (&parent->getMWChild(cIdx) == node) my_ix = cIdx; + } + if (my_ix < 0) std::cout << " DID NOT FIND INDEX" << std::endl; + T *ccoefs = node->getCoefs(); + T *pcoefs = parent->getCoefs(); + for (int j = 0; j < ncoefs; j++) pcoefs[j + my_ix * ncoefs] = ccoefs[j]; + mp[parent->getSerialIx()]++; + } + } else { + std::cout << " WARNING: found incomplete node " << std::endl; + } + } + } + this->calcSquareNorm(); +} + +/** @brief Write the tree to disk in text/ASCII format in a representation + * using MADNESS conventions for n, l and index order. + * @param[in] file: File name + */ +template void FunctionTree::saveTreeTXT(const std::string &fname) { + int nRoots = this->getRootBox().size(); + MWNode **roots = this->getRootBox().getNodes(); + + std::ofstream out(fname); + out << std::setprecision(14); + out << D << std::endl; + int rscale = this->getRootScale(); + std::array sf = this->getMRA().getWorldBox().getScalingFactors(); + double LMRChem = 1.0; + for (int i = 0; i > rscale; i--) LMRChem *= 2; // we assume world is from -L to L, and a cube with 2 root nodes in each direction + for (int d = 0; d < D; d++) { out << -sf[d] * LMRChem << " " << sf[d] * LMRChem << std::endl; } + int kp1 = this->getKp1(); + out << kp1 << std::endl; + int ncoefs = 1; + for (int d = 0; d < D; d++) ncoefs *= kp1; + int Tdim = std::pow(2, D); + + int nout = this->endNodeTable.size(); + out << Tdim * nout << std::endl; // could output only scaling coeff? + + // MRChem and MADNESS do not use the same indices order for the qudrature points + // We write into MADNESS convention (note that mapMRC[mapMRC[i]]=i for all i) + std::vector mapMRC; // mapping vector + int kx = kp1 - 1; + int ky = kp1 - 1; + int kz = kp1 - 1; + if (D < 3) kz = 0; + if (D < 2) ky = 0; + // MADNESS: zyx and i=k,k-1,k-2... MRChem: xyz, i=0,1,2,3 ... + for (int x = kx; x >= 0; x--) { + for (int y = ky; y >= 0; y--) { + for (int z = kz; z >= 0; z--) { mapMRC.push_back(z * kp1 * kp1 + y * kp1 + x); } + } + } + + int L = std::pow(2, -rscale); + int count = -1; + while (++count < nout) { + std::array l; + NodeIndex idx = this->endNodeTable[count]->getNodeIndex(); + MWNode *node = &(this->getNode(idx, false)); + T *values = node->getCoefs(); + int n = idx.getScale(); + node->mwTransform(Reconstruction); + node->cvTransform(Forward); + // we write for each children nodes separately + for (int i = 0; i < D; i++) { + // l in interval [0, max], while in MRCPP it is defined in [-max/2, max/2-1] + l[i] = 2 * (idx.getTranslation(i) + std::pow(2, n) * L); // first child + } + for (int cix = 0; cix < Tdim; cix++) { + out << n - rscale + 2 << " "; // scales start at zero. NB: children are one scale larger than node + for (int i = 0; i < D; i++) { + int p = (cix >> i) & 1; // shift by one for odd child indices + out << l[i] + p << " "; + } + out << std::endl; + for (int i = 0; i < ncoefs; i++) out << values[cix * ncoefs + mapMRC[i]] << " "; + out << std::endl; + } + } + out.close(); +} /** @brief Write the tree structure to disk, for later use * @param[in] file: File name, will get ".tree" extension */ -template void FunctionTree::saveTree(const std::string &file) { +template void FunctionTree::saveTree(const std::string &file) { Timer t1; this->deleteGenerated(); auto &allocator = this->getNodeAllocator(); std::stringstream fname; fname << file << ".tree"; - std::fstream f; f.open(fname.str(), std::ios::out | std::ios::binary); if (not f.is_open()) MSG_ERROR("Unable to open file"); @@ -130,6 +377,7 @@ template void FunctionTree::saveTree(const std::string &file) { f.write((char *)allocator.getCoefChunk(iChunk), allocator.getCoefChunkSize()); } f.close(); + this->saveTreeTXT("MRC.dat"); print::time(10, "Time write", t1); } @@ -137,8 +385,9 @@ template void FunctionTree::saveTree(const std::string &file) { * @param[in] file: File name, will get ".tree" extension * @note This tree must have the exact same MRA the one that was saved */ -template void FunctionTree::loadTree(const std::string &file) { +template void FunctionTree::loadTree(const std::string &file) { Timer t1; + std::stringstream fname; fname << file << ".tree"; @@ -164,15 +413,16 @@ template void FunctionTree::loadTree(const std::string &file) { Timer t2; allocator.reassemble(); this->resetEndNodeTable(); + this->calcSquareNorm(true); print::time(10, "Time rewrite pointers", t2); } /** @returns Integral of the function over the entire computational domain */ -template double FunctionTree::integrate() const { +template T FunctionTree::integrate() const { - double result = 0.0; + T result = 0.0; for (int i = 0; i < this->rootBox.size(); i++) { - const FunctionNode &fNode = getRootFuncNode(i); + const FunctionNode &fNode = getRootFuncNode(i); result += fNode.integrate(); } @@ -186,11 +436,10 @@ template double FunctionTree::integrate() const { return jacobian * result; } - /** @returns Integral of a representable function over the grid given by the tree */ -template <> double FunctionTree<3>::integrateEndNodes(RepresentableFunction_M &f) { - //traverse tree, and treat end nodes only - std::vector *> stack; // node from this +template <> double FunctionTree<3, double>::integrateEndNodes(RepresentableFunction_M &f) { + // traverse tree, and treat end nodes only + std::vector *> stack; // node from this for (int i = 0; i < this->getRootBox().size(); i++) stack.push_back(&(this->getRootFuncNode(i))); int basis = getMRA().getScalingBasis().getScalingType(); double result = 0.0; @@ -200,13 +449,13 @@ template <> double FunctionTree<3>::integrateEndNodes(RepresentableFunction_M &f if (Node->getNChildren() > 0) { for (int i = 0; i < Node->getNChildren(); i++) stack.push_back(&(Node->getFuncChild(i))); } else { - //end nodes + // end nodes Eigen::MatrixXd fmat = f.evalf(Node->nodeIndex); double *coefs = Node->getCoefs(); // save position of coeff, but do not use them! // The data in fmat is not organized so that two consecutive points are stored after each other in memory, so needs to copy before mwtransform, cannot use memory adress directly. - int nc=fmat.cols(); + int nc = fmat.cols(); double cc[nc]; - for (int i = 0; i < nc; i++)cc[i]=fmat(0,i); + for (int i = 0; i < nc; i++) cc[i] = fmat(0, i); Node->attachCoefs(cc); result += Node->integrateValues(); Node->attachCoefs(coefs); // put back original coeff @@ -236,7 +485,7 @@ template <> double FunctionTree<3>::integrateEndNodes(RepresentableFunction_M &f * the MW grid by one level before evaluating, using * `mrcpp::refine_grid(tree, 1)` */ -template double FunctionTree::evalf(const Coord &r) const { +template T FunctionTree::evalf(const Coord &r) const { // Handle potential scaling const auto scaling_factor = this->getMRA().getWorldBox().getScalingFactors(); auto arg = r; @@ -249,8 +498,8 @@ template double FunctionTree::evalf(const Coord &r) const { // Function is zero outside the domain for non-periodic functions if (this->outOfBounds(arg) and not this->getRootBox().isPeriodic()) return 0.0; - const MWNode &mw_node = this->getNodeOrEndNode(arg); - auto &f_node = static_cast &>(mw_node); + const MWNode &mw_node = this->getNodeOrEndNode(arg); + auto &f_node = static_cast &>(mw_node); auto result = f_node.evalScaling(arg); // Adjust for scaling factor included in basis @@ -270,7 +519,7 @@ template double FunctionTree::evalf(const Coord &r) const { * need fast evaluation, use refine_grid(tree, 1) first, and then * evalf. */ -template double FunctionTree::evalf_precise(const Coord &r) { +template T FunctionTree::evalf_precise(const Coord &r) { // Handle potential scaling const auto scaling_factor = this->getMRA().getWorldBox().getScalingFactors(); auto arg = r; @@ -283,8 +532,8 @@ template double FunctionTree::evalf_precise(const Coord &r) { // Function is zero outside the domain for non-periodic functions if (this->outOfBounds(arg) and not this->getRootBox().isPeriodic()) return 0.0; - MWNode &mw_node = this->getNodeOrEndNode(arg); - auto &f_node = static_cast &>(mw_node); + MWNode &mw_node = this->getNodeOrEndNode(arg); + auto &f_node = static_cast &>(mw_node); auto result = f_node.evalf(arg); this->deleteGenerated(); @@ -301,7 +550,7 @@ template double FunctionTree::evalf_precise(const Coord &r) { * squared, no grid refinement. * */ -template void FunctionTree::square() { +template void FunctionTree::square() { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel num_threads(mrcpp_get_num_threads()) @@ -310,10 +559,10 @@ template void FunctionTree::square() { int nCoefs = this->getTDim() * this->getKp1_d(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &node = *this->endNodeTable[n]; + MWNode &node = *this->endNodeTable[n]; node.mwTransform(Reconstruction); node.cvTransform(Forward); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < nCoefs; i++) { coefs[i] *= coefs[i]; } node.cvTransform(Backward); node.mwTransform(Compression); @@ -332,7 +581,7 @@ template void FunctionTree::square() { * to the given power, no grid refinement. * */ -template void FunctionTree::power(double p) { +template void FunctionTree::power(double p) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel num_threads(mrcpp_get_num_threads()) @@ -341,10 +590,10 @@ template void FunctionTree::power(double p) { int nCoefs = this->getTDim() * this->getKp1_d(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &node = *this->endNodeTable[n]; + MWNode &node = *this->endNodeTable[n]; node.mwTransform(Reconstruction); node.cvTransform(Forward); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < nCoefs; i++) { coefs[i] = std::pow(coefs[i], p); } node.cvTransform(Backward); node.mwTransform(Compression); @@ -363,7 +612,7 @@ template void FunctionTree::power(double p) { * in-place multiplied by the given coefficient, no grid refinement. * */ -template void FunctionTree::rescale(double c) { +template void FunctionTree::rescale(T c) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) num_threads(mrcpp_get_num_threads()) { @@ -371,9 +620,9 @@ template void FunctionTree::rescale(double c) { int nCoefs = this->getTDim() * this->getKp1_d(); #pragma omp for schedule(guided) for (int i = 0; i < nNodes; i++) { - MWNode &node = *this->endNodeTable[i]; + MWNode &node = *this->endNodeTable[i]; if (not node.hasCoefs()) MSG_ABORT("No coefs"); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int j = 0; j < nCoefs; j++) { coefs[j] *= c; } node.calcNorms(); } @@ -383,7 +632,7 @@ template void FunctionTree::rescale(double c) { } /** @brief In-place rescaling by a function norm \f$ ||f||^{-1} \f$, fixed grid */ -template void FunctionTree::normalize() { +template void FunctionTree::normalize() { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); double sq_norm = this->getSquareNorm(); if (sq_norm < 0.0) MSG_ERROR("Normalizing uninitialized function"); @@ -399,7 +648,7 @@ template void FunctionTree::normalize() { * the function, i.e. no further grid refinement. * */ -template void FunctionTree::add(double c, FunctionTree &inp) { +template void FunctionTree::add(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) @@ -407,10 +656,10 @@ template void FunctionTree::add(double c, FunctionTree &inp) { int nNodes = this->getNEndNodes(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &out_node = *this->endNodeTable[n]; - MWNode &inp_node = inp.getNode(out_node.getNodeIndex()); - double *out_coefs = out_node.getCoefs(); - const double *inp_coefs = inp_node.getCoefs(); + MWNode &out_node = *this->endNodeTable[n]; + MWNode &inp_node = inp.getNode(out_node.getNodeIndex()); + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] += c * inp_coefs[i]; } out_node.calcNorms(); } @@ -419,6 +668,37 @@ template void FunctionTree::add(double c, FunctionTree &inp) { this->calcSquareNorm(); inp.deleteGenerated(); } +/** @brief In-place addition with MW function representations, fixed grid + * + * @param[in] c: Numerical coefficient of input function + * @param[in] inp: Input function to add + * + * @details The input function will be added to the union of the current grid of + * and input the function grid. + * + */ +template void FunctionTree::add_inplace(T c, FunctionTree &inp) { + if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); + if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); + while (refine_grid(*this, inp)) {}; +#pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) + { + int nNodes = this->getNEndNodes(); +#pragma omp for schedule(guided) + for (int n = 0; n < nNodes; n++) { + MWNode &out_node = *this->endNodeTable[n]; + MWNode &inp_node = inp.getNode(out_node.getNodeIndex()); + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); + for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] += c * inp_coefs[i]; } + out_node.calcNorms(); + } + } + this->mwTransform(BottomUp); + this->calcSquareNorm(); + inp.deleteGenerated(); +} + /** @brief In-place addition of absolute values of MW function representations * * @param[in] c Numerical coefficient of input function @@ -428,22 +708,22 @@ template void FunctionTree::add(double c, FunctionTree &inp) { * function, i.e. no further grid refinement. * */ -template void FunctionTree::absadd(double c, FunctionTree &inp) { +template void FunctionTree::absadd(T c, FunctionTree &inp) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) { int nNodes = this->getNEndNodes(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &out_node = *this->endNodeTable[n]; - MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy + MWNode &out_node = *this->endNodeTable[n]; + MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy out_node.mwTransform(Reconstruction); out_node.cvTransform(Forward); inp_node.mwTransform(Reconstruction); inp_node.cvTransform(Forward); - double *out_coefs = out_node.getCoefs(); - const double *inp_coefs = inp_node.getCoefs(); - for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = abs(out_coefs[i]) + c * abs(inp_coefs[i]); } + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); + for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = std::norm(out_coefs[i]) + std::norm(c * inp_coefs[i]); } out_node.cvTransform(Backward); out_node.mwTransform(Compression); out_node.calcNorms(); @@ -463,7 +743,7 @@ template void FunctionTree::absadd(double c, FunctionTree &inp) { * of the function, i.e. no further grid refinement. * */ -template void FunctionTree::multiply(double c, FunctionTree &inp) { +template void FunctionTree::multiply(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) @@ -471,14 +751,14 @@ template void FunctionTree::multiply(double c, FunctionTree &inp) int nNodes = this->getNEndNodes(); #pragma omp for schedule(guided) for (int n = 0; n < nNodes; n++) { - MWNode &out_node = *this->endNodeTable[n]; - MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy + MWNode &out_node = *this->endNodeTable[n]; + MWNode inp_node = inp.getNode(out_node.getNodeIndex()); // Full copy out_node.mwTransform(Reconstruction); out_node.cvTransform(Forward); inp_node.mwTransform(Reconstruction); inp_node.cvTransform(Forward); - double *out_coefs = out_node.getCoefs(); - const double *inp_coefs = inp_node.getCoefs(); + T *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] *= c * inp_coefs[i]; } out_node.cvTransform(Backward); out_node.mwTransform(Compression); @@ -498,16 +778,16 @@ template void FunctionTree::multiply(double c, FunctionTree &inp) * of the function, i.e. no further grid refinement. * */ -template void FunctionTree::map(FMap fmap) { +template void FunctionTree::map(FMap fmap) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); { int nNodes = this->getNEndNodes(); #pragma omp parallel for schedule(guided) num_threads(mrcpp_get_num_threads()) for (int n = 0; n < nNodes; n++) { - MWNode &node = *this->endNodeTable[n]; + MWNode &node = *this->endNodeTable[n]; node.mwTransform(Reconstruction); node.cvTransform(Forward); - double *coefs = node.getCoefs(); + T *coefs = node.getCoefs(); for (int i = 0; i < node.getNCoefs(); i++) { coefs[i] = fmap(coefs[i]); } node.cvTransform(Backward); node.mwTransform(Compression); @@ -518,29 +798,29 @@ template void FunctionTree::map(FMap fmap) { this->calcSquareNorm(); } -template void FunctionTree::getEndValues(VectorXd &data) { +template void FunctionTree::getEndValues(Eigen::Matrix &data) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); int nNodes = this->getNEndNodes(); int nCoefs = this->getTDim() * this->getKp1_d(); data = VectorXd::Zero(nNodes * nCoefs); for (int n = 0; n < nNodes; n++) { - MWNode &node = getEndFuncNode(n); + MWNode &node = getEndFuncNode(n); node.mwTransform(Reconstruction); node.cvTransform(Forward); - const double *c = node.getCoefs(); + const T *c = node.getCoefs(); for (int i = 0; i < nCoefs; i++) { data(n * nCoefs + i) = c[i]; } node.cvTransform(Backward); node.mwTransform(Compression); } } -template void FunctionTree::setEndValues(VectorXd &data) { +template void FunctionTree::setEndValues(Eigen::Matrix &data) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); int nNodes = this->getNEndNodes(); int nCoefs = this->getTDim() * this->getKp1_d(); for (int i = 0; i < nNodes; i++) { - MWNode &node = getEndFuncNode(i); - const double *c = data.segment(i * nCoefs, nCoefs).data(); + MWNode &node = getEndFuncNode(i); + const T *c = data.segment(i * nCoefs, nCoefs).data(); node.setCoefBlock(0, nCoefs, c); node.cvTransform(Backward); node.mwTransform(Compression); @@ -551,10 +831,10 @@ template void FunctionTree::setEndValues(VectorXd &data) { this->calcSquareNorm(); } -template std::ostream &FunctionTree::print(std::ostream &o) const { +template std::ostream &FunctionTree::print(std::ostream &o) const { o << std::endl << "*FunctionTree: " << this->name << std::endl; o << " genNodes: " << getNGenNodes() << std::endl; - return MWTree::print(o); + return MWTree::print(o); } /** @brief Reduce the precision of the tree by deleting nodes @@ -571,9 +851,9 @@ template std::ostream &FunctionTree::print(std::ostream &o) const { * \f$ ||w|| < 2^{-sn/2} ||f|| \epsilon \f$. In principal, `s` should be equal * to the dimension; in practice, it is set to `s=1`. */ -template int FunctionTree::crop(double prec, double splitFac, bool absPrec) { +template int FunctionTree::crop(double prec, double splitFac, bool absPrec) { for (int i = 0; i < this->rootBox.size(); i++) { - MWNode &root = this->getRootMWNode(i); + MWNode &root = this->getRootMWNode(i); root.crop(prec, splitFac, absPrec); } int nChunks = this->getNodeAllocator().compress(); @@ -586,22 +866,22 @@ template int FunctionTree::crop(double prec, double splitFac, bool ab * Also returns an array with the corresponding indices defined as the * values of serialIx in refTree, and an array with the indices of the parent. * Set index -1 for nodes that are not present in refTree */ -template -void FunctionTree::makeCoeffVector(std::vector &coefs, - std::vector &indices, - std::vector &parent_indices, - std::vector &scalefac, - int &max_index, - MWTree &refTree, - std::vector *> *refNodes) { +template +void FunctionTree::makeCoeffVector(std::vector &coefs, + std::vector &indices, + std::vector &parent_indices, + std::vector &scalefac, + int &max_index, + MWTree &refTree, + std::vector *> *refNodes) { coefs.clear(); indices.clear(); parent_indices.clear(); max_index = 0; int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - std::vector *> refstack; // nodes from refTree - std::vector *> thisstack; // nodes from this Tree + std::vector *> refstack; // nodes from refTree + std::vector *> thisstack; // nodes from this Tree for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { refstack.push_back(refTree.getRootBox().getNodes()[rIdx]); thisstack.push_back(this->getRootBox().getNodes()[rIdx]); @@ -609,8 +889,8 @@ void FunctionTree::makeCoeffVector(std::vector &coefs, int stack_p = 0; while (thisstack.size() > stack_p) { // refNode and thisNode are the same node in space, but on different trees - MWNode *thisNode = thisstack[stack_p]; - MWNode *refNode = refstack[stack_p++]; + MWNode *thisNode = thisstack[stack_p]; + MWNode *refNode = refstack[stack_p++]; coefs.push_back(thisNode->getCoefs()); if (refNodes != nullptr) refNodes->push_back(refNode); if (refNode != nullptr) { @@ -640,26 +920,26 @@ void FunctionTree::makeCoeffVector(std::vector &coefs, * reference tree and a list of coefficients. * It is the reference tree (refTree) which is traversed, but one does not descend * into children if the norm of the tree is smaller than absPrec. */ -template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode) { - std::vector *> stack; - std::map *> ix2node; // gives the nodes in this tree for a given ix +template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode) { + std::vector *> stack; + std::map *> ix2node; // gives the nodes in this tree for a given ix int sizecoef = (1 << this->getDim()) * this->getKp1_d(); int sizecoefW = ((1 << this->getDim()) - 1) * this->getKp1_d(); this->squareNorm = 0.0; this->clearEndNodeTable(); for (int rIdx = 0; rIdx < refTree.getRootBox().size(); rIdx++) { - MWNode *refNode = refTree.getRootBox().getNodes()[rIdx]; + MWNode *refNode = refTree.getRootBox().getNodes()[rIdx]; stack.push_back(refNode); int ix = ix2coef[refNode->getSerialIx()]; ix2node[ix] = this->getRootBox().getNodes()[rIdx]; } while (stack.size() > 0) { - MWNode *refNode = stack.back(); // node in the reference tree refTree + MWNode *refNode = stack.back(); // node in the reference tree refTree stack.pop_back(); assert(ix2coef.count(refNode->getSerialIx()) > 0); int ix = ix2coef[refNode->getSerialIx()]; - MWNode *node = ix2node[ix]; // corresponding node in this tree + MWNode *node = ix2node[ix]; // corresponding node in this tree // copy coefficients into this tree int size = sizecoefW; if (refNode->isRootNode() or mode == "copy") { @@ -701,8 +981,8 @@ template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std } else if ((absPrec < 0 or tree_utils::split_check(*node, absPrec, 1.0, true)) and refNode->getNChildren() > 0) { // include children in tree node->createChildren(true); - double *inp = node->getCoefs(); - double *out = node->getMWChild(0).getCoefs(); + T *inp = node->getCoefs(); + T *out = node->getMWChild(0).getCoefs(); tree_utils::mw_transform(*this, inp, out, false, sizecoef, true); // make the scaling part for (int i = 0; i < refNode->getNChildren(); i++) { stack.push_back(refNode->children[i]); // means we continue to traverse the reference tree @@ -716,10 +996,51 @@ template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std } } +/** Traverse tree using DFS and append same nodes as another tree, without coefficients + * Note that we do not use coefficients, so it does not matter what is real or complex + */ +template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { + std::vector *> instack; // node from inTree + std::vector *> thisstack; // node from this Tree + this->clearEndNodeTable(); + for (int rIdx = 0; rIdx < inTree.getRootBox().size(); rIdx++) { + instack.push_back(inTree.getRootBox().getNodes()[rIdx]); + thisstack.push_back(this->getRootBox().getNodes()[rIdx]); + } + while (thisstack.size() > 0) { + // inNode and thisNode are the same node in space, but on different trees + MWNode *thisNode = thisstack.back(); + thisstack.pop_back(); + MWNode *inNode = instack.back(); + instack.pop_back(); + if (inNode->getNChildren() > 0) { + thisNode->clearIsEndNode(); + if (thisNode->getNChildren() < inNode->getNChildren()) thisNode->createChildren(false); + for (int i = 0; i < inNode->getNChildren(); i++) { + instack.push_back(inNode->children[i]); + thisstack.push_back(thisNode->children[i]); + } + } else { + // construct EndNodeTable for "This", starting from this branch + // This could be done more efficiently, if it proves to be time consuming + std::vector *> branchstack; // local stack starting from this branch + branchstack.push_back(thisNode); + while (branchstack.size() > 0) { + MWNode *branchNode = branchstack.back(); + branchstack.pop_back(); + if (branchNode->getNChildren() > 0) { + for (int i = 0; i < branchNode->getNChildren(); i++) { branchstack.push_back(branchNode->children[i]); } + } else + this->endNodeTable.push_back(branchNode); + } + } + } +} + /** Traverse tree using DFS and append same nodes as another tree, without coefficients */ -template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { - std::vector *> instack; // node from inTree - std::vector *> thisstack; // node from this Tree +template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { + std::vector *> instack; // node from inTree + std::vector *> thisstack; // node from this Tree this->clearEndNodeTable(); for (int rIdx = 0; rIdx < inTree.getRootBox().size(); rIdx++) { instack.push_back(inTree.getRootBox().getNodes()[rIdx]); @@ -727,9 +1048,9 @@ template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { } while (thisstack.size() > 0) { // inNode and thisNode are the same node in space, but on different trees - MWNode *thisNode = thisstack.back(); + MWNode *thisNode = thisstack.back(); thisstack.pop_back(); - MWNode *inNode = instack.back(); + MWNode *inNode = instack.back(); instack.pop_back(); if (inNode->getNChildren() > 0) { thisNode->clearIsEndNode(); @@ -741,10 +1062,10 @@ template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { } else { // construct EndNodeTable for "This", starting from this branch // This could be done more efficiently, if it proves to be time consuming - std::vector *> branchstack; // local stack starting from this branch + std::vector *> branchstack; // local stack starting from this branch branchstack.push_back(thisNode); while (branchstack.size() > 0) { - MWNode *branchNode = branchstack.back(); + MWNode *branchNode = branchstack.back(); branchstack.pop_back(); if (branchNode->getNChildren() > 0) { for (int i = 0; i < branchNode->getNChildren(); i++) { branchstack.push_back(branchNode->children[i]); } @@ -755,24 +1076,24 @@ template void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { } } -template void FunctionTree::deleteGenerated() { +template void FunctionTree::deleteGenerated() { for (int n = 0; n < this->getNEndNodes(); n++) this->getEndMWNode(n).deleteGenerated(); } -template void FunctionTree::deleteGeneratedParents() { +template void FunctionTree::deleteGeneratedParents() { for (int n = 0; n < this->getRootBox().size(); n++) this->getRootMWNode(n).deleteParent(); } -template <> int FunctionTree<3>::saveNodesAndRmCoeff() { +template <> int FunctionTree<3, double>::saveNodesAndRmCoeff() { if (this->isLocal) MSG_INFO("Tree is already in local representation"); NodesCoeff = new BankAccount; // NB: must be a collective call! int stack_p = 0; if (mpi::wrk_rank == 0) { int sizecoeff = (1 << 3) * this->getKp1_d(); - std::vector *> stack; // nodes from this Tree + std::vector *> stack; // nodes from this Tree for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); } while (stack.size() > stack_p) { - MWNode<3> *Node = stack[stack_p++]; + MWNode<3, double> *Node = stack[stack_p++]; int id = 0; NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs()); for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); } @@ -785,8 +1106,254 @@ template <> int FunctionTree<3>::saveNodesAndRmCoeff() { return this->NodeIndex2serialIx.size(); } -template class FunctionTree<1>; -template class FunctionTree<2>; -template class FunctionTree<3>; +template <> int FunctionTree<3, ComplexDouble>::saveNodesAndRmCoeff() { + if (this->isLocal) MSG_INFO("Tree is already in local representation"); + NodesCoeff = new BankAccount; // NB: must be a collective call! + int stack_p = 0; + if (mpi::wrk_rank == 0) { + int sizecoeff = (1 << 3) * this->getKp1_d(); + sizecoeff *= 2; // double->ComplexDouble. Saved as twice as many doubles + std::vector *> stack; // nodes from this Tree + for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); } + while (stack.size() > stack_p) { + MWNode<3, ComplexDouble> *Node = stack[stack_p++]; + int id = 0; + NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs()); + for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); } + } + } + this->nodeAllocator_p->deallocAllCoeff(); + mpi::broadcast_Tree_noCoeff(*this, mpi::comm_wrk); + this->isLocal = true; + assert(this->NodeIndex2serialIx.size() == getNNodes()); + return this->NodeIndex2serialIx.size(); +} + +/** @brief Deep copy of tree + * + * @details Exact copy without any binding between old and new tree + */ +template void FunctionTree::deep_copy(FunctionTree *out) { + copy_grid(*out, *this); + copy_func(*out, *this); +} + +/** @brief New tree with only real part + */ +template FunctionTree *FunctionTree::Real() { + FunctionTree *out = new FunctionTree(this->getMRA(), this->getName()); +#pragma omp parallel num_threads(mrcpp_get_num_threads()) + { + int nNodes = this->getNEndNodes(); +#pragma omp for schedule(guided) + for (int n = 0; n < nNodes; n++) { + MWNode &inp_node = *this->endNodeTable[n]; + MWNode out_node = out->getNode(out_node.getNodeIndex()); // Full copy + double *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); + for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = std::real(inp_coefs[i]); } + out_node.calcNorms(); + } + } + out->mwTransform(BottomUp); + out->calcSquareNorm(); + return out; +} + +/** @brief New tree with only imaginary part + */ +template FunctionTree *FunctionTree::Imag() { + FunctionTree *out = new FunctionTree(this->getMRA(), this->getName()); +#pragma omp parallel num_threads(mrcpp_get_num_threads()) + { + int nNodes = this->getNEndNodes(); +#pragma omp for schedule(guided) + for (int n = 0; n < nNodes; n++) { + MWNode &inp_node = *this->endNodeTable[n]; + MWNode out_node = out->getNode(out_node.getNodeIndex()); // Full copy + double *out_coefs = out_node.getCoefs(); + const T *inp_coefs = inp_node.getCoefs(); + for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = std::imag(inp_coefs[i]); } + out_node.calcNorms(); + } + } + out->mwTransform(BottomUp); + out->calcSquareNorm(); + return out; +} + +/* + * From real to complex tree. Copy everything, and convert double to ComplexDouble for the coefficents. + * Should use a deep_copy if generalized in the future. + */ + +template <> void FunctionTree<3, double>::CopyTreeToComplex(FunctionTree<3, ComplexDouble> *&outTree) { + delete outTree; + double ref = 0.0; + outTree = new FunctionTree<3, ComplexDouble>(this->getMRA()); + std::vector *> instack; // node from this + std::vector *> outstack; // node from outTree + outTree->clearEndNodeTable(); + for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { + instack.push_back(this->getRootBox().getNodes()[rIdx]); + outstack.push_back(outTree->getRootBox().getNodes()[rIdx]); + } + int nNodes = std::min(this->getNNodes(), this->getNodeAllocator().getMaxNodesPerChunk()); + int ncoefs = this->getNodeAllocator().getNCoefs(); + while (instack.size() > 0) { + // inNode and outNode are the same node in space, but on different trees + MWNode<3, ComplexDouble> *outNode = outstack.back(); + outstack.pop_back(); + MWNode<3, double> *inNode = instack.back(); + instack.pop_back(); + // copy coefficients: + double *incoefs = inNode->getCoefs(); + ComplexDouble *outcoefs = outNode->getCoefs(); + for (int i = 0; i < ncoefs; i++) outcoefs[i] = incoefs[i]; + outNode->setHasCoefs(); + outNode->calcNorms(); + + if (inNode->getNChildren() > 0) { + if (outNode->getNChildren() < inNode->getNChildren()) outNode->createChildren(true); + for (int i = 0; i < inNode->getNChildren(); i++) { + instack.push_back(inNode->children[i]); + outstack.push_back(outNode->children[i]); + } + } else { + outTree->endNodeTable.push_back(outNode); + } + } + outTree->calcSquareNorm(); + outTree->calcSquareNorm(true); +} + +template <> void FunctionTree<2, double>::CopyTreeToComplex(FunctionTree<2, ComplexDouble> *&outTree) { + delete outTree; + double ref = 0.0; + outTree = new FunctionTree<2, ComplexDouble>(this->getMRA()); + std::vector *> instack; // node from this + std::vector *> outstack; // node from outTree + outTree->clearEndNodeTable(); + for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { + instack.push_back(this->getRootBox().getNodes()[rIdx]); + outstack.push_back(outTree->getRootBox().getNodes()[rIdx]); + } + int nNodes = std::min(this->getNNodes(), this->getNodeAllocator().getMaxNodesPerChunk()); + int ncoefs = this->getNodeAllocator().getNCoefs(); + while (instack.size() > 0) { + // inNode and outNode are the same node in space, but on different trees + MWNode<2, ComplexDouble> *outNode = outstack.back(); + outstack.pop_back(); + MWNode<2, double> *inNode = instack.back(); + instack.pop_back(); + // copy coefficients: + double *incoefs = inNode->getCoefs(); + ComplexDouble *outcoefs = outNode->getCoefs(); + for (int i = 0; i < ncoefs; i++) outcoefs[i] = incoefs[i]; + outNode->setHasCoefs(); + outNode->calcNorms(); + + if (inNode->getNChildren() > 0) { + if (outNode->getNChildren() < inNode->getNChildren()) outNode->createChildren(true); + for (int i = 0; i < inNode->getNChildren(); i++) { + instack.push_back(inNode->children[i]); + outstack.push_back(outNode->children[i]); + } + } else { + outTree->endNodeTable.push_back(outNode); + } + } + outTree->calcSquareNorm(); + outTree->calcSquareNorm(true); +} + +template <> void FunctionTree<1, double>::CopyTreeToComplex(FunctionTree<1, ComplexDouble> *&outTree) { + delete outTree; + double ref = 0.0; + outTree = new FunctionTree<1, ComplexDouble>(this->getMRA()); + std::vector *> instack; // node from this + std::vector *> outstack; // node from outTree + outTree->clearEndNodeTable(); + for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { + instack.push_back(this->getRootBox().getNodes()[rIdx]); + outstack.push_back(outTree->getRootBox().getNodes()[rIdx]); + } + int nNodes = std::min(this->getNNodes(), this->getNodeAllocator().getMaxNodesPerChunk()); + int ncoefs = this->getNodeAllocator().getNCoefs(); + while (instack.size() > 0) { + // inNode and outNode are the same node in space, but on different trees + MWNode<1, ComplexDouble> *outNode = outstack.back(); + outstack.pop_back(); + MWNode<1, double> *inNode = instack.back(); + instack.pop_back(); + // copy coefficients: + double *incoefs = inNode->getCoefs(); + ComplexDouble *outcoefs = outNode->getCoefs(); + for (int i = 0; i < ncoefs; i++) outcoefs[i] = incoefs[i]; + outNode->setHasCoefs(); + outNode->calcNorms(); + + if (inNode->getNChildren() > 0) { + if (outNode->getNChildren() < inNode->getNChildren()) outNode->createChildren(true); + for (int i = 0; i < inNode->getNChildren(); i++) { + instack.push_back(inNode->children[i]); + outstack.push_back(outNode->children[i]); + } + } else { + outTree->endNodeTable.push_back(outNode); + } + } + outTree->calcSquareNorm(); + outTree->calcSquareNorm(true); +} + +// for testing +template <> void FunctionTree<3, double>::CopyTreeToReal(FunctionTree<3, double> *&outTree) { + delete outTree; + double ref = 0.0; + // FunctionTree<3, double>* inTree = this; + outTree = new FunctionTree<3, double>(this->getMRA()); + std::vector *> instack; // node from this + std::vector *> outstack; // node from outTree + outTree->clearEndNodeTable(); + for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { + instack.push_back(this->getRootBox().getNodes()[rIdx]); + outstack.push_back(outTree->getRootBox().getNodes()[rIdx]); + } + int nNodes = std::min(this->getNNodes(), this->getNodeAllocator().getMaxNodesPerChunk()); + int ncoefs = this->getNodeAllocator().getNCoefs(); + while (instack.size() > 0) { + // inNode and outNode are the same node in space, but on different trees + MWNode<3, double> *outNode = outstack.back(); + outstack.pop_back(); + MWNode<3, double> *inNode = instack.back(); + instack.pop_back(); + // copy coefficients: + double *incoefs = inNode->getCoefs(); + double *outcoefs = outNode->getCoefs(); + for (int i = 0; i < ncoefs; i++) outcoefs[i] = incoefs[i]; + outNode->setHasCoefs(); + outNode->calcNorms(); + + if (inNode->getNChildren() > 0) { + outNode->clearIsEndNode(); + if (outNode->getNChildren() < inNode->getNChildren()) outNode->createChildren(true); + for (int i = 0; i < inNode->getNChildren(); i++) { + instack.push_back(inNode->children[i]); + outstack.push_back(outNode->children[i]); + } + } else { + outTree->endNodeTable.push_back(outNode); + } + } +} + +template class FunctionTree<1, double>; +template class FunctionTree<2, double>; +template class FunctionTree<3, double>; + +template class FunctionTree<1, ComplexDouble>; +template class FunctionTree<2, ComplexDouble>; +template class FunctionTree<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/FunctionTree.h b/src/trees/FunctionTree.h index 0be9563ea..9d976d6be 100644 --- a/src/trees/FunctionTree.h +++ b/src/trees/FunctionTree.h @@ -52,69 +52,81 @@ namespace mrcpp { * uninitialized, and its square norm will be negative (minus one). */ -template class FunctionTree final : public MWTree, public RepresentableFunction { +template class FunctionTree final : public MWTree, public RepresentableFunction { public: FunctionTree(const MultiResolutionAnalysis &mra, const std::string &name) : FunctionTree(mra, nullptr, name) {} - FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem = nullptr, const std::string &name = "nn"); - FunctionTree(const FunctionTree &tree) = delete; - FunctionTree &operator=(const FunctionTree &tree) = delete; + FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem = nullptr, const std::string &name = "nn"); + FunctionTree(const FunctionTree &tree) = delete; + FunctionTree &operator=(const FunctionTree &tree) = delete; ~FunctionTree() override; - double integrate() const; + T integrate() const; double integrateEndNodes(RepresentableFunction_M &f); - double evalf_precise(const Coord &r); - double evalf(const Coord &r) const override; + T evalf_precise(const Coord &r); + T evalf(const Coord &r) const override; int getNGenNodes() const { return getGenNodeAllocator().getNNodes(); } - void getEndValues(Eigen::VectorXd &data); - void setEndValues(Eigen::VectorXd &data); + void getEndValues(Eigen::Matrix &data); + void setEndValues(Eigen::Matrix &data); void saveTree(const std::string &file); + void saveTreeTXT(const std::string &file); void loadTree(const std::string &file); + void loadTreeTXT(const std::string &file); // In place operations void square(); void power(double p); - void rescale(double c); + void rescale(T c); void normalize(); - void add(double c, FunctionTree &inp); - void absadd(double c, FunctionTree &inp); - void multiply(double c, FunctionTree &inp); - void map(FMap fmap); + void add(T c, FunctionTree &inp); + void add_inplace(T c, FunctionTree &inp); + void absadd(T c, FunctionTree &inp); + void multiply(T c, FunctionTree &inp); + void map(FMap fmap); int getNChunks() { return this->getNodeAllocator().getNChunks(); } int getNChunksUsed() { return this->getNodeAllocator().getNChunksUsed(); } int crop(double prec, double splitFac = 1.0, bool absPrec = true); - FunctionNode &getEndFuncNode(int i) { return static_cast &>(this->getEndMWNode(i)); } - FunctionNode &getRootFuncNode(int i) { return static_cast &>(this->rootBox.getNode(i)); } + FunctionNode &getEndFuncNode(int i) { return static_cast &>(this->getEndMWNode(i)); } + FunctionNode &getRootFuncNode(int i) { return static_cast &>(this->rootBox.getNode(i)); } - NodeAllocator &getGenNodeAllocator() { return *this->genNodeAllocator_p; } - const NodeAllocator &getGenNodeAllocator() const { return *this->genNodeAllocator_p; } + NodeAllocator &getGenNodeAllocator() { return *this->genNodeAllocator_p; } + const NodeAllocator &getGenNodeAllocator() const { return *this->genNodeAllocator_p; } - const FunctionNode &getEndFuncNode(int i) const { return static_cast &>(this->getEndMWNode(i)); } - const FunctionNode &getRootFuncNode(int i) const { return static_cast &>(this->rootBox.getNode(i)); } + const FunctionNode &getEndFuncNode(int i) const { return static_cast &>(this->getEndMWNode(i)); } + const FunctionNode &getRootFuncNode(int i) const { return static_cast &>(this->rootBox.getNode(i)); } void deleteGenerated(); void deleteGeneratedParents(); - void makeCoeffVector(std::vector &coefs, + void makeCoeffVector(std::vector &coefs, std::vector &indices, std::vector &parent_indices, std::vector &scalefac, int &max_index, - MWTree &refTree, - std::vector *> *refNodes = nullptr); - void makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode = "adaptive"); - void appendTreeNoCoeff(MWTree &inTree); - + MWTree &refTree, + std::vector *> *refNodes = nullptr); + void makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode = "adaptive"); + void appendTreeNoCoeff(MWTree &inTree); + void appendTreeNoCoeff(MWTree &inTree); + void CopyTree(FunctionTree &inTree); // tools for use of local (nodes are stored in Bank) representation int saveNodesAndRmCoeff(); // put all nodes coefficients in Bank and delete all coefficients + void deep_copy(FunctionTree *out); + FunctionTree *Real(); + FunctionTree *Imag(); + void CopyTreeToComplex(FunctionTree<3, ComplexDouble> *&out); + void CopyTreeToComplex(FunctionTree<2, ComplexDouble> *&out); + void CopyTreeToComplex(FunctionTree<1, ComplexDouble> *&out); + void CopyTreeToReal(FunctionTree<3, double> *&out); // for testing + protected: - std::unique_ptr> genNodeAllocator_p{nullptr}; + std::unique_ptr> genNodeAllocator_p{nullptr}; std::ostream &print(std::ostream &o) const override; void allocRootNodes(); diff --git a/src/trees/FunctionTreeVector.h b/src/trees/FunctionTreeVector.h index d73005cd8..142113e1f 100644 --- a/src/trees/FunctionTreeVector.h +++ b/src/trees/FunctionTreeVector.h @@ -32,18 +32,19 @@ namespace mrcpp { -template using CoefsFunctionTree = std::tuple *>; -template using FunctionTreeVector = std::vector>; +template using CoefsFunctionTree = std::tuple *>; +template using FunctionTreeVector = std::vector>; /** @brief Remove all entries in the vector * @param[in] fs: Vector to clear * @param[in] dealloc: Option to free FunctionTree pointer before clearing */ -template void clear(FunctionTreeVector &fs, bool dealloc = false) { +template void clear(FunctionTreeVector &fs, bool dealloc = false) { if (dealloc) { for (auto &t : fs) { auto f = std::get<1>(t); if (f != nullptr) delete f; + f = nullptr; } } fs.clear(); @@ -52,7 +53,7 @@ template void clear(FunctionTreeVector &fs, bool dealloc = false) { /** @returns Total number of nodes of all trees in the vector * @param[in] fs: Vector to fetch from */ -template int get_n_nodes(const FunctionTreeVector &fs) { +template int get_n_nodes(const FunctionTreeVector &fs) { int nNodes = 0; for (const auto &t : fs) { auto f = std::get<1>(t); @@ -64,7 +65,7 @@ template int get_n_nodes(const FunctionTreeVector &fs) { /** @returns Total size of all trees in the vector, in kB * @param[in] fs: Vector to fetch from */ -template int get_size_nodes(const FunctionTreeVector &fs) { +template int get_size_nodes(const FunctionTreeVector &fs) { int sNodes = 0; for (const auto &t : fs) { auto f = std::get<1>(t); @@ -77,7 +78,7 @@ template int get_size_nodes(const FunctionTreeVector &fs) { * @param[in] fs: Vector to fetch from * @param[in] i: Position in vector */ -template double get_coef(const FunctionTreeVector &fs, int i) { +template T get_coef(const FunctionTreeVector &fs, int i) { return std::get<0>(fs[i]); } @@ -85,7 +86,7 @@ template double get_coef(const FunctionTreeVector &fs, int i) { * @param[in] fs: Vector to fetch from * @param[in] i: Position in vector */ -template FunctionTree &get_func(FunctionTreeVector &fs, int i) { +template FunctionTree &get_func(FunctionTreeVector &fs, int i) { return *(std::get<1>(fs[i])); } @@ -93,7 +94,7 @@ template FunctionTree &get_func(FunctionTreeVector &fs, int i) { * @param[in] fs: Vector to fetch from * @param[in] i: Position in vector */ -template const FunctionTree &get_func(const FunctionTreeVector &fs, int i) { +template const FunctionTree &get_func(const FunctionTreeVector &fs, int i) { return *(std::get<1>(fs[i])); } } // namespace mrcpp diff --git a/src/trees/MWNode.cpp b/src/trees/MWNode.cpp index d15c0939f..2d521b468 100644 --- a/src/trees/MWNode.cpp +++ b/src/trees/MWNode.cpp @@ -45,8 +45,8 @@ namespace mrcpp { * * @details Should be used only by NodeAllocator to obtain * virtual table pointers for the derived classes. */ -template -MWNode::MWNode() +template +MWNode::MWNode() : tree(nullptr) , parent(nullptr) , nodeIndex() @@ -66,8 +66,8 @@ MWNode::MWNode() * * @details Constructor for an empty node, given the corresponding MWTree and NodeIndex */ -template -MWNode::MWNode(MWTree *tree, const NodeIndex &idx) +template +MWNode::MWNode(MWTree *tree, const NodeIndex &idx) : tree(tree) , parent(nullptr) , nodeIndex(idx) @@ -87,8 +87,8 @@ MWNode::MWNode(MWTree *tree, const NodeIndex &idx) * @details Constructor for root nodes. It requires the corresponding * MWTree and an integer to fetch the right NodeIndex */ -template -MWNode::MWNode(MWTree *tree, int rIdx) +template +MWNode::MWNode(MWTree *tree, int rIdx) : tree(tree) , parent(nullptr) , nodeIndex(tree->getRootBox().getNodeIndex(rIdx)) @@ -108,8 +108,8 @@ MWNode::MWNode(MWTree *tree, int rIdx) * @details Constructor for leaf nodes. It requires the corresponding * parent and an integer to identify the correct child. */ -template -MWNode::MWNode(MWNode *parent, int cIdx) +template +MWNode::MWNode(MWNode *parent, int cIdx) : tree(parent->tree) , parent(parent) , nodeIndex(parent->getNodeIndex().child(cIdx)) @@ -130,8 +130,8 @@ MWNode::MWNode(MWNode *parent, int cIdx) * does not "belong" to the tree: it cannot be accessed by traversing * the tree. */ -template -MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) +template +MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) : tree(node.tree) , parent(nullptr) , nodeIndex(node.nodeIndex) @@ -163,7 +163,7 @@ MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) * * @details Recursive deallocation of a node and all its decendants */ -template MWNode::~MWNode() { +template MWNode::~MWNode() { if (this->isLooseNode()) this->freeCoefs(); MRCPP_DESTROY_OMP_LOCK(); } @@ -174,7 +174,7 @@ template MWNode::~MWNode() { * called (derived classes must implement their own version). This was * to avoid having pure virtual methods in the base class. */ -template void MWNode::dealloc() { +template void MWNode::dealloc() { NOT_REACHED_ABORT; } @@ -184,13 +184,13 @@ template void MWNode::dealloc() { * are not treated by the NodeAllocator class. * */ -template void MWNode::allocCoefs(int n_blocks, int block_size) { +template void MWNode::allocCoefs(int n_blocks, int block_size) { if (this->n_coefs != 0) MSG_ABORT("n_coefs should be zero"); if (this->isAllocated()) MSG_ABORT("Coefs already allocated"); if (not this->isLooseNode()) MSG_ABORT("Only loose nodes here!"); this->n_coefs = n_blocks * block_size; - this->coefs = new double[this->n_coefs]; + this->coefs = new T[this->n_coefs]; this->clearHasCoefs(); this->setIsAllocated(); @@ -202,7 +202,7 @@ template void MWNode::allocCoefs(int n_blocks, int block_size) { * are not treated by the NodeAllocator class. * */ -template void MWNode::freeCoefs() { +template void MWNode::freeCoefs() { if (not this->isLooseNode()) MSG_ABORT("Only loose nodes here!"); if (this->coefs != nullptr) delete[] this->coefs; @@ -216,7 +216,7 @@ template void MWNode::freeCoefs() { /** @brief Printout of node coefficients */ -template void MWNode::printCoefs() const { +template void MWNode::printCoefs() const { if (not this->isAllocated()) MSG_ABORT("Node is not allocated"); println(0, "\nMW coefs"); int kp1_d = this->getKp1_d(); @@ -228,18 +228,18 @@ template void MWNode::printCoefs() const { /** @brief wraps the MW coefficients into an eigen vector object */ -template void MWNode::getCoefs(Eigen::VectorXd &c) const { +template void MWNode::getCoefs(Eigen::Matrix &c) const { if (not this->isAllocated()) MSG_ABORT("Node is not allocated"); if (not this->hasCoefs()) MSG_ABORT("Node has no coefs"); if (this->n_coefs == 0) MSG_ABORT("ncoefs == 0"); - c = VectorXd::Map(this->coefs, this->n_coefs); + c = Eigen::Matrix::Map(this->coefs, this->n_coefs); } /** @brief sets all MW coefficients and the norms to zero * */ -template void MWNode::zeroCoefs() { +template void MWNode::zeroCoefs() { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated " << *this); for (int i = 0; i < this->n_coefs; i++) { this->coefs[i] = 0.0; } @@ -249,7 +249,7 @@ template void MWNode::zeroCoefs() { /** @brief Attach a set of coefs to this node. Only used locally (the tree is not aware of this). */ -template void MWNode::attachCoefs(double *coefs) { +template void MWNode::attachCoefs(T *coefs) { this->coefs = coefs; this->setHasCoefs(); } @@ -264,7 +264,7 @@ template void MWNode::attachCoefs(double *coefs) { * (given scaling/wavelet in each direction). Its size is then \f$ * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. */ -template void MWNode::setCoefBlock(int block, int block_size, const double *c) { +template void MWNode::setCoefBlock(int block, int block_size, const T *c) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] = c[i]; } } @@ -279,7 +279,7 @@ template void MWNode::setCoefBlock(int block, int block_size, const d * (given scaling/wavelet in each direction). Its size is then \f$ * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. */ -template void MWNode::addCoefBlock(int block, int block_size, const double *c) { +template void MWNode::addCoefBlock(int block, int block_size, const T *c) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] += c[i]; } } @@ -293,7 +293,7 @@ template void MWNode::addCoefBlock(int block, int block_size, const d * (given scaling/wavelet in each direction). Its size is then \f$ * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. */ -template void MWNode::zeroCoefBlock(int block, int block_size) { +template void MWNode::zeroCoefBlock(int block, int block_size) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] = 0.0; } } @@ -309,7 +309,7 @@ template void MWNode::zeroCoefBlock(int block, int block_size) { * already be present and its memory allocated for this to work * properly. */ -template void MWNode::giveChildrenCoefs(bool overwrite) { +template void MWNode::giveChildrenCoefs(bool overwrite) { assert(this->isBranchNode()); if (not this->isAllocated()) MSG_ABORT("Not allocated!"); if (not this->hasCoefs()) MSG_ABORT("No coefficients!"); @@ -320,8 +320,8 @@ template void MWNode::giveChildrenCoefs(bool overwrite) { // coeff of child should be have been allocated already here int stride = getMWChild(0).getNCoefs(); - double *inp = getCoefs(); - double *out = getMWChild(0).getCoefs(); + T *inp = getCoefs(); + T *out = getMWChild(0).getCoefs(); bool readOnlyScaling = false; if (this->isGenNode()) readOnlyScaling = true; @@ -345,9 +345,9 @@ template void MWNode::giveChildrenCoefs(bool overwrite) { * node. The scaling coefficients of the selected child are then * copied/summed in the correct child node. */ -template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { +template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { - MWNode node_i = *this; + MWNode node_i = *this; node_i.mwTransform(Reconstruction); @@ -355,7 +355,7 @@ template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { int nChildren = this->getTDim(); if (this->children[cIdx] == nullptr) MSG_ABORT("Child does not exist!"); - MWNode &child = getMWChild(cIdx); + MWNode &child = getMWChild(cIdx); if (overwrite) { child.setCoefBlock(0, kp1_d, &node_i.getCoefs()[cIdx * kp1_d]); } else { @@ -371,12 +371,12 @@ template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { * * \warning This routine is only used in connection with Periodic Boundary Conditions */ -template void MWNode::giveParentCoefs(bool overwrite) { - MWNode node = *this; - MWNode &parent = getMWParent(); +template void MWNode::giveParentCoefs(bool overwrite) { + MWNode node = *this; + MWNode &parent = getMWParent(); int kp1_d = this->getKp1_d(); if (node.getScale() == 0) { - NodeBox &box = this->getMWTree().getRootBox(); + NodeBox &box = this->getMWTree().getRootBox(); auto reverse = getTDim() - 1; for (auto i = 0; i < getTDim(); i++) { parent.setCoefBlock(i, kp1_d, &box.getNode(reverse - i).getCoefs()[0]); } } else { @@ -393,17 +393,17 @@ template void MWNode::giveParentCoefs(bool overwrite) { * them consecutively in the corresponding block of the parent, * following the usual bitwise notation. */ -template void MWNode::copyCoefsFromChildren() { +template void MWNode::copyCoefsFromChildren() { int kp1_d = this->getKp1_d(); int nChildren = this->getTDim(); for (int cIdx = 0; cIdx < nChildren; cIdx++) { - MWNode &child = getMWChild(cIdx); + MWNode &child = getMWChild(cIdx); if (not child.hasCoefs()) MSG_ABORT("Child has no coefs"); setCoefBlock(cIdx, kp1_d, child.getCoefs()); } } -/** @brief Generates scaling cofficients of children +/** @brief Generates scaling coefficients of children * * @details If the node is a leafNode, it takes the scaling&wavelet * coefficients of the parent and it generates the scaling @@ -411,7 +411,7 @@ template void MWNode::copyCoefsFromChildren() { * them consecutively in the corresponding block of the parent, * following the usual bitwise notation. */ -template void MWNode::threadSafeGenChildren() { +template void MWNode::threadSafeGenChildren() { if (tree->isLocal) { NOT_IMPLEMENTED_ABORT; } MRCPP_SET_OMP_LOCK(); if (isLeafNode()) { @@ -421,6 +421,24 @@ template void MWNode::threadSafeGenChildren() { MRCPP_UNSET_OMP_LOCK(); } +/** @brief Creates scaling coefficients of children + * + * @details If the node is a leafNode, it takes the scaling&wavelet + * coefficients of the parent and it generates the scaling + * coefficients for the children and stores + * them consecutively in the corresponding block of the parent, + * following the usual bitwise notation. The new node is permanently added to the tree. + */ +template void MWNode::threadSafeCreateChildren() { + if (tree->isLocal) { NOT_IMPLEMENTED_ABORT; } + MRCPP_SET_OMP_LOCK(); + if (isLeafNode()) { + createChildren(true); + giveChildrenCoefs(); + } + MRCPP_UNSET_OMP_LOCK(); +} + /** @brief Coefficient-Value transform * * @details This routine transforms the scaling coefficients of the node to the @@ -431,7 +449,7 @@ template void MWNode::threadSafeGenChildren() { * NOTE: this routine assumes a 0/1 (scaling on child 0 and 1) * representation, instead of s/d (scaling and wavelet). */ -template void MWNode::cvTransform(int operation) { +template void MWNode::cvTransform(int operation, bool firstchild) { int kp1 = this->getKp1(); int kp1_dm1 = math_utils::ipow(kp1, D - 1); int kp1_d = this->getKp1_d(); @@ -439,17 +457,19 @@ template void MWNode::cvTransform(int operation) { auto sb = this->getMWTree().getMRA().getScalingBasis(); const MatrixXd &S = sb.getCVMap(operation); - double o_vec[nCoefs]; - double *out_vec = o_vec; - double *in_vec = this->coefs; + T o_vec[nCoefs]; + T *out_vec = o_vec; + T *in_vec = this->coefs; + int nChildren = this->getTDim(); + if (firstchild) nChildren = 1; for (int i = 0; i < D; i++) { - for (int t = 0; t < this->getTDim(); t++) { - double *out = out_vec + t * kp1_d; - double *in = in_vec + t * kp1_d; + for (int t = 0; t < nChildren; t++) { + T *out = out_vec + t * kp1_d; + T *in = in_vec + t * kp1_d; math_utils::apply_filter(out, in, S, kp1, kp1_dm1, 0.0); } - double *tmp = in_vec; + T *tmp = in_vec; in_vec = out_vec; out_vec = tmp; } @@ -473,8 +493,8 @@ template void MWNode::cvTransform(int operation) { } } /* Old interpolating version, somewhat faster -template -void MWNode::cvTransform(int operation) { +template +void MWNode::cvTransform(int operation) { const ScalingBasis &sf = this->getMWTree().getMRA().getScalingBasis(); if (sf.getScalingType() != Interpol) { NOT_IMPLEMENTED_ABORT; @@ -520,25 +540,25 @@ void MWNode::cvTransform(int operation) { */ /** @brief Multiwavelet transform - * - * @details Application of the filters on one node to pass from a 0/1 (scaling - * on child 0 and 1) representation to an s/d (scaling and - * wavelet) representation. Bit manipulation is used in order to - * determine the correct filters and whether to apply them or just - * pass to the next couple of indexes. The starting coefficients are - * preserved until the application is terminated, then they are - * overwritten. With minor modifications this code can also be used - * for the inverse mw transform (just use the transpose filters) or - * for the application of an operator (using A, B, C and T parts of an - * operator instead of G1, G0, H1, H0). This is the version where the - * three directions are operated one after the other. Although this - * is formally faster than the other algorithm, the separation of the - * three dimensions prevent the possibility to use the norm of the - * operator in order to discard a priori negligible contributions. - * - * * @param[in] operation: compression (s0,s1->s,d) or reconstruction (s,d->s0,s1). - */ -template void MWNode::mwTransform(int operation) { + * + * @details Application of the filters on one node to pass from a 0/1 (scaling + * on child 0 and 1) representation to an s/d (scaling and + * wavelet) representation. Bit manipulation is used in order to + * determine the correct filters and whether to apply them or just + * pass to the next couple of indexes. The starting coefficients are + * preserved until the application is terminated, then they are + * overwritten. With minor modifications this code can also be used + * for the inverse mw transform (just use the transpose filters) or + * for the application of an operator (using A, B, C and T parts of an + * operator instead of G1, G0, H1, H0). This is the version where the + * three directions are operated one after the other. Although this + * is formally faster than the other algorithm, the separation of the + * three dimensions prevent the possibility to use the norm of the + * operator in order to discard a priori negligible contributions. + * + * * @param[in] operation: compression (s0,s1->s,d) or reconstruction (s,d->s0,s1). + */ +template void MWNode::mwTransform(int operation) { int kp1 = this->getKp1(); int kp1_dm1 = math_utils::ipow(kp1, D - 1); int kp1_d = this->getKp1_d(); @@ -546,20 +566,20 @@ template void MWNode::mwTransform(int operation) { const MWFilter &filter = getMWTree().getMRA().getFilter(); double overwrite = 0.0; - double o_vec[nCoefs]; - double *out_vec = o_vec; - double *in_vec = this->coefs; + T o_vec[nCoefs]; + T *out_vec = o_vec; + T *in_vec = this->coefs; for (int i = 0; i < D; i++) { int mask = 1 << i; for (int gt = 0; gt < this->getTDim(); gt++) { - double *out = out_vec + gt * kp1_d; + T *out = out_vec + gt * kp1_d; for (int ft = 0; ft < this->getTDim(); ft++) { /* Operate in direction i only if the bits along other * directions are identical. The bit of the direction we * operate on determines the appropriate filter/operator */ if ((gt | mask) == (ft | mask)) { - double *in = in_vec + ft * kp1_d; + T *in = in_vec + ft * kp1_d; int fIdx = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const MatrixXd &oper = filter.getSubFilter(fIdx, operation); math_utils::apply_filter(out, in, oper, kp1, kp1_dm1, overwrite); @@ -568,7 +588,7 @@ template void MWNode::mwTransform(int operation) { } overwrite = 0.0; } - double *tmp = in_vec; + T *tmp = in_vec; in_vec = out_vec; out_vec = tmp; } @@ -578,19 +598,19 @@ template void MWNode::mwTransform(int operation) { } /** @brief Set all norms to Undefined. */ -template void MWNode::clearNorms() { +template void MWNode::clearNorms() { this->squareNorm = -1.0; for (int i = 0; i < this->getTDim(); i++) { this->componentNorms[i] = -1.0; } } /** @brief Set all norms to zero. */ -template void MWNode::zeroNorms() { +template void MWNode::zeroNorms() { this->squareNorm = 0.0; for (int i = 0; i < this->getTDim(); i++) { this->componentNorms[i] = 0.0; } } /** @brief Calculate and store square norm and component norms, if allocated. */ -template void MWNode::calcNorms() { +template void MWNode::calcNorms() { this->squareNorm = 0.0; for (int i = 0; i < this->getTDim(); i++) { double norm_i = calcComponentNorm(i); @@ -600,7 +620,7 @@ template void MWNode::calcNorms() { } /** @brief Calculate and return the squared scaling norm. */ -template double MWNode::getScalingNorm() const { +template double MWNode::getScalingNorm() const { double sNorm = this->getComponentNorm(0); if (sNorm >= 0.0) { return sNorm * sNorm; @@ -610,7 +630,7 @@ template double MWNode::getScalingNorm() const { } /** @brief Calculate and return the squared wavelet norm. */ -template double MWNode::getWaveletNorm() const { +template double MWNode::getWaveletNorm() const { double wNorm = 0.0; for (int i = 1; i < this->getTDim(); i++) { double norm_i = this->getComponentNorm(i); @@ -624,28 +644,24 @@ template double MWNode::getWaveletNorm() const { } /** @brief Calculate the norm of one component (NOT the squared norm!). */ -template double MWNode::calcComponentNorm(int i) const { +template double MWNode::calcComponentNorm(int i) const { if (this->isGenNode() and i != 0) return 0.0; assert(this->isAllocated()); assert(this->hasCoefs()); - const double *c = this->getCoefs(); + const T *c = this->getCoefs(); int size = this->getKp1_d(); int start = i * size; double sq_norm = 0.0; -#ifdef HAVE_BLAS - sq_norm = cblas_ddot(size, &c[start], 1, &c[start], 1); -#else - for (int i = start; i < start + size; i++) { sq_norm += c[i] * c[i]; } -#endif + for (int i = start; i < start + size; i++) { sq_norm += std::norm(c[i]); } return std::sqrt(sq_norm); } /** @brief Update the coefficients of the node by a mw transform of the scaling * coefficients of the children. */ -template void MWNode::reCompress() { +template void MWNode::reCompress() { if (this->isGenNode()) NOT_IMPLEMENTED_ABORT; if (this->isBranchNode()) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); @@ -662,12 +678,12 @@ template void MWNode::reCompress() { * @param[in] splitFac: factor used in the split check (larger factor means tighter threshold for finer nodes) * @param[in] absPrec: flag to switch from relative (false) to absolute (true) precision. */ -template bool MWNode::crop(double prec, double splitFac, bool absPrec) { +template bool MWNode::crop(double prec, double splitFac, bool absPrec) { if (this->isEndNode()) { return true; } else { for (int i = 0; i < this->getTDim(); i++) { - MWNode &child = *this->children[i]; + MWNode &child = *this->children[i]; if (child.crop(prec, splitFac, absPrec)) { if (tree_utils::split_check(*this, prec, splitFac, absPrec) == false) { this->deleteChildren(); @@ -679,15 +695,15 @@ template bool MWNode::crop(double prec, double splitFac, bool absPrec return false; } -template void MWNode::createChildren(bool coefs) { +template void MWNode::createChildren(bool coefs) { NOT_REACHED_ABORT; } -template void MWNode::genChildren() { +template void MWNode::genChildren() { NOT_REACHED_ABORT; } -template void MWNode::genParent() { +template void MWNode::genParent() { NOT_REACHED_ABORT; } @@ -696,11 +712,11 @@ template void MWNode::genParent() { * @details * Leaves node as LeafNode and children[] as null pointer. */ -template void MWNode::deleteChildren() { +template void MWNode::deleteChildren() { if (this->isLeafNode()) return; for (int cIdx = 0; cIdx < getTDim(); cIdx++) { if (this->children[cIdx] != nullptr) { - MWNode &child = getMWChild(cIdx); + MWNode &child = getMWChild(cIdx); child.deleteChildren(); child.dealloc(); } @@ -711,18 +727,17 @@ template void MWNode::deleteChildren() { } /** @brief Recursive deallocation of parent and all their forefathers. */ -template void MWNode::deleteParent() { +template void MWNode::deleteParent() { if (this->parent == nullptr) return; - MWNode &parent = getMWParent(); + MWNode &parent = getMWParent(); parent.deleteParent(); parent.dealloc(); this->parentSerialIx = -1; this->parent = nullptr; } - /** @brief Deallocation of all generated nodes . */ -template void MWNode::deleteGenerated() { +template void MWNode::deleteGenerated() { if (this->isBranchNode()) { if (this->isEndNode()) { this->deleteChildren(); @@ -733,7 +748,7 @@ template void MWNode::deleteGenerated() { } /** @brief returns the coordinates of the centre of the node */ -template Coord MWNode::getCenter() const { +template Coord MWNode::getCenter() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); auto &l = getNodeIndex(); @@ -743,7 +758,7 @@ template Coord MWNode::getCenter() const { } /** @brief returns the upper bounds of the D-interval defining the node */ -template Coord MWNode::getUpperBounds() const { +template Coord MWNode::getUpperBounds() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); auto &l = getNodeIndex(); @@ -753,7 +768,7 @@ template Coord MWNode::getUpperBounds() const { } /** @brief returns the lower bounds of the D-interval defining the node */ -template Coord MWNode::getLowerBounds() const { +template Coord MWNode::getLowerBounds() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); auto &l = getNodeIndex(); @@ -770,7 +785,7 @@ template Coord MWNode::getLowerBounds() const { * to be followed at the current scale in oder to get to the requested * node at the final scale. The result is the index of the child needed. * The index is obtained by bit manipulation of of the translation indices. */ -template int MWNode::getChildIndex(const NodeIndex &nIdx) const { +template int MWNode::getChildIndex(const NodeIndex &nIdx) const { assert(isAncestor(nIdx)); int cIdx = 0; int diffScale = nIdx.getScale() - getScale() - 1; @@ -790,7 +805,7 @@ template int MWNode::getChildIndex(const NodeIndex &nIdx) const { * * @detailsGiven a point in space, determines which child should be followed * to get to the corresponding terminal node. */ -template int MWNode::getChildIndex(const Coord &r) const { +template int MWNode::getChildIndex(const Coord &r) const { assert(hasCoord(r)); int cIdx = 0; double sFac = std::pow(2.0, -getScale()); @@ -815,7 +830,7 @@ template int MWNode::getChildIndex(const Coord &r) const { * grid of quadrature points. * */ -template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { +template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { int kp1 = this->getKp1(); pts = MatrixXd::Zero(D, kp1); @@ -840,7 +855,7 @@ template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { * nodes. * */ -template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { +template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { int kp1 = this->getKp1(); pts = MatrixXd::Zero(D, 2 * kp1); @@ -865,7 +880,7 @@ template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { * vectors of quadrature points. * */ -template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const { +template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const { MatrixXd prim_pts; getPrimitiveQuadPts(prim_pts); @@ -881,7 +896,7 @@ template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const /** @brief Returns the quadrature points in a given node * - * @param[in,out] pts: expanded quadrature points in a \f$ d \times + * @param[in,out] pts: expanded quadrature points in a \f$ d \times * 2^d(k+1)^d \f$ matrix form. * * @details The primitive quadrature points of the children are used to obtain a @@ -889,7 +904,7 @@ template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const * vectors of quadrature points. * */ -template void MWNode::getExpandedChildPts(MatrixXd &pts) const { +template void MWNode::getExpandedChildPts(MatrixXd &pts) const { MatrixXd prim_pts; getPrimitiveChildPts(prim_pts); @@ -923,7 +938,7 @@ template void MWNode::getExpandedChildPts(MatrixXd &pts) const { * the node does not exist, or if it is a GenNode. Recursion starts at at this * node and ASSUMES the requested node is in fact decending from this node. */ -template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) const { +template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) const { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -947,7 +962,7 @@ template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) { +template MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -973,7 +988,7 @@ template MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx * this node and ASSUMES the requested node is in fact decending from * this node. */ -template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) const { +template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) const { if (getDepth() == depth or this->isEndNode()) { return this; } int cIdx = getChildIndex(r); assert(this->children[cIdx] != nullptr); @@ -992,7 +1007,7 @@ template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) { +template MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) { if (getDepth() == depth or this->isEndNode()) { return this; } int cIdx = getChildIndex(r); assert(this->children[cIdx] != nullptr); @@ -1010,7 +1025,7 @@ template MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, * this node and ASSUMES the requested node is in fact decending from * this node. */ -template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) const { +template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) const { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -1036,7 +1051,7 @@ template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeInd * this node and ASSUMES the requested node is in fact decending from * this node. */ -template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) { +template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return this; @@ -1061,7 +1076,7 @@ template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex * that does not exist. Recursion starts at this node and ASSUMES the * requested node is in fact decending from this node. */ -template MWNode *MWNode::retrieveNode(const Coord &r, int depth) { +template MWNode *MWNode::retrieveNode(const Coord &r, int depth) { if (depth < 0) MSG_ABORT("Invalid argument"); if (getDepth() == depth) { return this; } @@ -1081,24 +1096,31 @@ template MWNode *MWNode::retrieveNode(const Coord &r, int depth * routine always returns the appropriate node, and will generate nodes that * does not exist. Recursion starts at this node and ASSUMES the requested * node is in fact descending from this node. + * If create = true, the nodes are permanently added to the tree. */ -template MWNode *MWNode::retrieveNode(const NodeIndex &idx) { +template MWNode *MWNode::retrieveNode(const NodeIndex &idx, bool create) { if (getScale() == idx.getScale()) { // we're done if (tree->isLocal) { + NOT_IMPLEMENTED_ABORT; // has to fetch coeff in Bank. NOT USED YET - int ncoefs = (1 << D) * this->getKp1_d(); - coefs = new double[ncoefs]; // TODO must be cleaned at some stage - tree->getNodeCoeff(idx, coefs); + // int ncoefs = (1 << D) * this->getKp1_d(); + // coefs = new double[ncoefs]; // TODO must be cleaned at some stage + // coefs = new double[ncoefs]; // TODO must be cleaned at some stage + // tree->getNodeCoeff(idx, coefs); } assert(getNodeIndex() == idx); return this; } assert(isAncestor(idx)); - threadSafeGenChildren(); + if (create) { + threadSafeCreateChildren(); + } else { + threadSafeGenChildren(); + } int cIdx = getChildIndex(idx); assert(this->children[cIdx] != nullptr); - return this->children[cIdx]->retrieveNode(idx); + return this->children[cIdx]->retrieveNode(idx, create); } /** Node retriever that ALWAYS returns the requested node. @@ -1113,7 +1135,7 @@ template MWNode *MWNode::retrieveNode(const NodeIndex &idx) { * does not exist. Recursion starts at this node and ASSUMES the requested * node is in fact related to this node. */ -template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { +template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { if (getScale() < idx.getScale()) MSG_ABORT("Scale error") if (getScale() == idx.getScale()) return this; if (this->parent == nullptr) { @@ -1132,7 +1154,7 @@ template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { * found, do not generate any new node, but rather give the value of the norm * assuming the function is uniformly distributed within the node. */ -template double MWNode::getNodeNorm(const NodeIndex &idx) const { +template double MWNode::getNodeNorm(const NodeIndex &idx) const { if (this->getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); return std::sqrt(this->squareNorm); @@ -1150,7 +1172,7 @@ template double MWNode::getNodeNorm(const NodeIndex &idx) const { * * @param[in] r: point coordinates */ -template bool MWNode::hasCoord(const Coord &r) const { +template bool MWNode::hasCoord(const Coord &r) const { double sFac = std::pow(2.0, -getScale()); const NodeIndex &l = getNodeIndex(); // println(1, "[" << r[0] << "," << r[1] << "," << r[2] << "]"); @@ -1168,7 +1190,7 @@ template bool MWNode::hasCoord(const Coord &r) const { /** Testing if nodes are compatible wrt NodeIndex and Tree (order, rootScale, * relPrec, etc). */ -template bool MWNode::isCompatible(const MWNode &node) { +template bool MWNode::isCompatible(const MWNode &node) { NOT_IMPLEMENTED_ABORT; // if (nodeIndex != node.nodeIndex) { // println(0, "nodeIndex mismatch" << std::endl); @@ -1186,7 +1208,7 @@ template bool MWNode::isCompatible(const MWNode &node) { * * @param[in] idx: the NodeIndex of the requested node */ -template bool MWNode::isAncestor(const NodeIndex &idx) const { +template bool MWNode::isAncestor(const NodeIndex &idx) const { int relScale = idx.getScale() - getScale(); if (relScale < 0) return false; const NodeIndex &l = getNodeIndex(); @@ -1197,7 +1219,7 @@ template bool MWNode::isAncestor(const NodeIndex &idx) const { return true; } -template bool MWNode::isDecendant(const NodeIndex &idx) const { +template bool MWNode::isDecendant(const NodeIndex &idx) const { NOT_IMPLEMENTED_ABORT; } @@ -1205,7 +1227,7 @@ template bool MWNode::isDecendant(const NodeIndex &idx) const { * * @param[in] o: the output stream */ -template std::ostream &MWNode::print(std::ostream &o) const { +template std::ostream &MWNode::print(std::ostream &o) const { std::string flags = " "; o << getNodeIndex(); if (isRootNode()) flags[0] = 'R'; @@ -1236,14 +1258,14 @@ template std::ostream &MWNode::print(std::ostream &o) const { * normalization is such that a constant function gives constant value, * i.e. *not* same normalization as a squareNorm */ -template void MWNode::setMaxSquareNorm() { +template void MWNode::setMaxSquareNorm() { auto n = this->getScale(); this->maxWSquareNorm = calcScaledWSquareNorm(); this->maxSquareNorm = calcScaledSquareNorm(); if (not this->isEndNode()) { for (int i = 0; i < this->getTDim(); i++) { - MWNode &child = *this->children[i]; + MWNode &child = *this->children[i]; child.setMaxSquareNorm(); this->maxSquareNorm = std::max(this->maxSquareNorm, child.maxSquareNorm); this->maxWSquareNorm = std::max(this->maxWSquareNorm, child.maxWSquareNorm); @@ -1252,20 +1274,23 @@ template void MWNode::setMaxSquareNorm() { } /** @brief recursively reset maxSquaredNorm and maxWSquareNorm of parent and descendants to value -1 */ -template void MWNode::resetMaxSquareNorm() { +template void MWNode::resetMaxSquareNorm() { auto n = this->getScale(); this->maxSquareNorm = -1.0; this->maxWSquareNorm = -1.0; if (not this->isEndNode()) { for (int i = 0; i < this->getTDim(); i++) { - MWNode &child = *this->children[i]; + MWNode &child = *this->children[i]; child.resetMaxSquareNorm(); } } } -template class MWNode<1>; -template class MWNode<2>; -template class MWNode<3>; +template class MWNode<1, double>; +template class MWNode<2, double>; +template class MWNode<3, double>; +template class MWNode<1, ComplexDouble>; +template class MWNode<2, ComplexDouble>; +template class MWNode<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/MWNode.h b/src/trees/MWNode.h index d7d7a18a7..f86313846 100644 --- a/src/trees/MWNode.h +++ b/src/trees/MWNode.h @@ -23,12 +23,12 @@ * */ - #pragma once #include #include "MRCPP/macros.h" +#include "utils/math_utils.h" #include "utils/omp_utils.h" #include "HilbertPath.h" @@ -49,12 +49,12 @@ namespace mrcpp { * translation index, the norm, pointers to parent node and child * nodes, pointer to the corresponding MWTree etc... See member and * data descriptions for details. - * + * */ -template class MWNode { +template class MWNode { public: - MWNode(const MWNode &node, bool allocCoef = true, bool SetCoef = true); - MWNode &operator=(const MWNode &node) = delete; + MWNode(const MWNode &node, bool allocCoef = true, bool SetCoef = true); + MWNode &operator=(const MWNode &node) = delete; virtual ~MWNode(); int getKp1() const { return getMWTree().getKp1(); } @@ -76,7 +76,7 @@ template class MWNode { Coord getLowerBounds() const; bool hasCoord(const Coord &r) const; - bool isCompatible(const MWNode &node); + bool isCompatible(const MWNode &node); bool isAncestor(const NodeIndex &idx) const; bool isDecendant(const NodeIndex &idx) const; @@ -89,30 +89,30 @@ template class MWNode { double getComponentNorm(int i) const { return this->componentNorms[i]; } int getNCoefs() const { return this->n_coefs; } - void getCoefs(Eigen::VectorXd &c) const; + void getCoefs(Eigen::Matrix &c) const; void printCoefs() const; - double *getCoefs() { return this->coefs; } - const double *getCoefs() const { return this->coefs; } + T *getCoefs() { return this->coefs; } + const T *getCoefs() const { return this->coefs; } void getPrimitiveQuadPts(Eigen::MatrixXd &pts) const; void getPrimitiveChildPts(Eigen::MatrixXd &pts) const; void getExpandedQuadPts(Eigen::MatrixXd &pts) const; void getExpandedChildPts(Eigen::MatrixXd &pts) const; - MWTree &getMWTree() { return static_cast &>(*this->tree); } - MWNode &getMWParent() { return static_cast &>(*this->parent); } - MWNode &getMWChild(int i) { return static_cast &>(*this->children[i]); } + MWTree &getMWTree() { return static_cast &>(*this->tree); } + MWNode &getMWParent() { return static_cast &>(*this->parent); } + MWNode &getMWChild(int i) { return static_cast &>(*this->children[i]); } - const MWTree &getMWTree() const { return static_cast &>(*this->tree); } - const MWNode &getMWParent() const { return static_cast &>(*this->parent); } - const MWNode &getMWChild(int i) const { return static_cast &>(*this->children[i]); } + const MWTree &getMWTree() const { return static_cast &>(*this->tree); } + const MWNode &getMWParent() const { return static_cast &>(*this->parent); } + const MWNode &getMWChild(int i) const { return static_cast &>(*this->children[i]); } void zeroCoefs(); - void setCoefBlock(int block, int block_size, const double *c); - void addCoefBlock(int block, int block_size, const double *c); + void setCoefBlock(int block, int block_size, const T *c); + void addCoefBlock(int block, int block_size, const T *c); void zeroCoefBlock(int block, int block_size); - void attachCoefs(double *coefs); + void attachCoefs(T *coefs); void calcNorms(); void zeroNorms(); @@ -124,7 +124,7 @@ template class MWNode { virtual void deleteChildren(); virtual void deleteParent(); - virtual void cvTransform(int kind); + virtual void cvTransform(int kind, bool firstchild = false); virtual void mwTransform(int kind); double getNodeNorm(const NodeIndex &idx) const; @@ -154,47 +154,50 @@ template class MWNode { void clearIsRootNode() { CLEAR_BITS(status, FlagRootNode); } void clearIsAllocated() { CLEAR_BITS(status, FlagAllocated); } - friend std::ostream &operator<<(std::ostream &o, const MWNode &nd) { return nd.print(o); } + friend std::ostream &operator<<(std::ostream &o, const MWNode &nd) { return nd.print(o); } - friend class TreeBuilder; - friend class MultiplicationCalculator; - friend class NodeAllocator; - friend class MWTree; - friend class FunctionTree; + friend class TreeBuilder; + friend class MultiplicationCalculator; + friend class NodeAllocator; + friend class MWTree; + friend class FunctionTree; friend class OperatorTree; - friend class FunctionNode; + friend class FunctionNode; friend class OperatorNode; - friend class DerivativeCalculator; + friend class DerivativeCalculator; + bool isComplex = false; // TODO put as one of the flags + friend class FunctionTree; // required if a ComplexDouble tree access a double node from another tree! + friend class FunctionTree; + int childSerialIx{-1}; ///< index of first child in serial Tree, or -1 for leafnodes/endnodes protected: - MWTree *tree{nullptr}; ///< Tree the node belongs to - MWNode *parent{nullptr}; ///< Parent node - MWNode *children[1 << D]; ///< 2^D children + MWTree *tree{nullptr}; ///< Tree the node belongs to + MWNode *parent{nullptr}; ///< Parent node + MWNode *children[1 << D]; ///< 2^D children double squareNorm{-1.0}; ///< Squared norm of all 2^D (k+1)^D coefficients double componentNorms[1 << D]; ///< Squared norms of the separeted 2^D components double maxSquareNorm{-1.0}; ///< Largest squared norm among itself and descendants. double maxWSquareNorm{-1.0}; ///< Largest wavelet squared norm among itself and descendants. ///< NB: must be set before used. - double *coefs{nullptr}; ///< the 2^D (k+1)^D MW coefficients - ///< For example, in case of a one dimensional function \f$ f \f$ - ///< this array equals \f$ s_0, \ldots, s_k, d_0, \ldots, d_k \f$, - ///< where scaling coefficients \f$ s_j = s_{jl}^n(f) \f$ - ///< and wavelet coefficients \f$ d_j = d_{jl}^n(f) \f$. - ///< Here \f$ n, l \f$ are unique for every node. + T *coefs{nullptr}; ///< the 2^D (k+1)^D MW coefficients + ///< For example, in case of a one dimensional function \f$ f \f$ + ///< this array equals \f$ s_0, \ldots, s_k, d_0, \ldots, d_k \f$, + ///< where scaling coefficients \f$ s_j = s_{jl}^n(f) \f$ + ///< and wavelet coefficients \f$ d_j = d_{jl}^n(f) \f$. + ///< Here \f$ n, l \f$ are unique for every node. int n_coefs{0}; int serialIx{-1}; ///< index in serial Tree int parentSerialIx{-1}; ///< index of parent in serial Tree, or -1 for roots - int childSerialIx{-1}; ///< index of first child in serial Tree, or -1 for leafnodes/endnodes NodeIndex nodeIndex; ///< Scale and translation of the node HilbertPath hilbertPath; ///< To be documented MWNode(); - MWNode(MWTree *tree, int rIdx); - MWNode(MWTree *tree, const NodeIndex &idx); - MWNode(MWNode *parent, int cIdx); + MWNode(MWTree *tree, int rIdx); + MWNode(MWTree *tree, const NodeIndex &idx); + MWNode(MWNode *parent, int cIdx); virtual void dealloc(); @@ -219,21 +222,22 @@ template class MWNode { int getChildIndex(const NodeIndex &nIdx) const; int getChildIndex(const Coord &r) const; - bool diffBranch(const MWNode &rhs) const; + bool diffBranch(const MWNode &rhs) const; - MWNode *retrieveNode(const Coord &r, int depth); - MWNode *retrieveNode(const NodeIndex &idx); - MWNode *retrieveParent(const NodeIndex &idx); + MWNode *retrieveNode(const Coord &r, int depth); + MWNode *retrieveNode(const NodeIndex &idx, bool create = false); + MWNode *retrieveParent(const NodeIndex &idx); - const MWNode *retrieveNodeNoGen(const NodeIndex &idx) const; - MWNode *retrieveNodeNoGen(const NodeIndex &idx); + const MWNode *retrieveNodeNoGen(const NodeIndex &idx) const; + MWNode *retrieveNodeNoGen(const NodeIndex &idx); - const MWNode *retrieveNodeOrEndNode(const Coord &r, int depth) const; - MWNode *retrieveNodeOrEndNode(const Coord &r, int depth); + const MWNode *retrieveNodeOrEndNode(const Coord &r, int depth) const; + MWNode *retrieveNodeOrEndNode(const Coord &r, int depth); - const MWNode *retrieveNodeOrEndNode(const NodeIndex &idx) const; - MWNode *retrieveNodeOrEndNode(const NodeIndex &idx); + const MWNode *retrieveNodeOrEndNode(const NodeIndex &idx) const; + MWNode *retrieveNodeOrEndNode(const NodeIndex &idx); + void threadSafeCreateChildren(); void threadSafeGenChildren(); void deleteGenerated(); diff --git a/src/trees/MWTree.cpp b/src/trees/MWTree.cpp index 583fb1fc1..6a646d33f 100644 --- a/src/trees/MWTree.cpp +++ b/src/trees/MWTree.cpp @@ -26,10 +26,10 @@ #include "MWTree.h" #include "MWNode.h" -#include "NodeIndex.h" -#include "TreeIterator.h" #include "MultiResolutionAnalysis.h" #include "NodeAllocator.h" +#include "NodeIndex.h" +#include "TreeIterator.h" #include "utils/Bank.h" #include "utils/Printer.h" #include "utils/math_utils.h" @@ -49,11 +49,11 @@ namespace mrcpp { * root nodes. The information for the root node configuration to use * is in the mra object which is passed to the constructor. */ -template -MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) +template +MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) : MRA(mra) - , order(mra.getOrder()) /// polynomial order - , kp1_d(math_utils::ipow(mra.getOrder() + 1, D)) ///nr of scaling coefficients \f$ (k+1)^D \f$ + , order(mra.getOrder()) /// polynomial order + , kp1_d(math_utils::ipow(mra.getOrder() + 1, D)) /// nr of scaling coefficients \f$ (k+1)^D \f$ , name(n) , squareNorm(-1.0) , rootBox(mra.getWorldBox()) { @@ -61,21 +61,21 @@ MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) } /** @brief MWTree destructor. */ -template MWTree::~MWTree() { +template MWTree::~MWTree() { this->endNodeTable.clear(); if (this->nodesAtDepth.size() != 1) MSG_ERROR("Nodes at depth != 1 -> " << this->nodesAtDepth.size()); if (this->nodesAtDepth[0] != 0) MSG_ERROR("Nodes at depth 0 != 0 -> " << this->nodesAtDepth[0]); } /** @brief Deletes all the nodes in the tree - * - * @details This method will recursively delete all the nodes, - * including the root nodes. Derived classes will call this method - * when the object is deleted. - */ -template void MWTree::deleteRootNodes() { + * + * @details This method will recursively delete all the nodes, + * including the root nodes. Derived classes will call this method + * when the object is deleted. + */ +template void MWTree::deleteRootNodes() { for (int i = 0; i < this->rootBox.size(); i++) { - MWNode &root = this->getRootMWNode(i); + MWNode &root = this->getRootMWNode(i); root.deleteChildren(); root.dealloc(); this->rootBox.clearNode(i); @@ -90,9 +90,9 @@ template void MWTree::deleteRootNodes() { * nodes, (nodeChunks in NodeAllocator) is NOT released, but is * immediately available to the new function. */ -template void MWTree::clear() { +template void MWTree::clear() { for (int i = 0; i < this->rootBox.size(); i++) { - MWNode &root = this->getRootMWNode(i); + MWNode &root = this->getRootMWNode(i); root.deleteChildren(); root.clearHasCoefs(); root.clearNorms(); @@ -106,10 +106,11 @@ template void MWTree::clear() { * @details The norm is calculated using endNodes only. The specific * type of norm which is computed will depend on the derived class */ -template void MWTree::calcSquareNorm() { +template void MWTree::calcSquareNorm(bool deep) { double treeNorm = 0.0; for (int n = 0; n < this->getNEndNodes(); n++) { - const MWNode &node = getEndMWNode(n); + MWNode &node = getEndMWNode(n); + if (deep) node.calcNorms(); assert(node.hasCoefs()); treeNorm += node.getSquareNorm(); } @@ -126,9 +127,9 @@ template void MWTree::calcSquareNorm() { * @details It performs a Multiwavlet transform of the whole tree. The * input parameters will specify the direction (upwards or downwards) * and whether the result is added to the coefficients or it - * overwrites them. See the documentation for the #mwTransformUp + * overwrites them. See the documentation for the #mwTransformUp * and #mwTransformDown for details. - * \f[ + * \f[ * \pmatrix{ * s_{nl}\\ * d_{nl} @@ -139,7 +140,7 @@ template void MWTree::calcSquareNorm() { * } * \f] */ -template void MWTree::mwTransform(int type, bool overwrite) { +template void MWTree::mwTransform(int type, bool overwrite) { switch (type) { case TopDown: mwTransformDown(overwrite); @@ -157,13 +158,13 @@ template void MWTree::mwTransform(int type, bool overwrite) { * * @details It starts at the bottom of the tree (scaling coefficients * of the leaf nodes) and it generates the scaling and wavelet - * coefficients if the parent node. It then proceeds recursively all the + * coefficients of the parent node. It then proceeds recursively all the * way up to the root nodes. This is generally used after a function * projection to purify the coefficients obtained by quadrature at * coarser scales which are therefore not precise enough. */ -template void MWTree::mwTransformUp() { - std::vector> nodeTable; +template void MWTree::mwTransformUp() { + std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); #pragma omp parallel shared(nodeTable) num_threads(mrcpp_get_num_threads()) { @@ -172,7 +173,7 @@ template void MWTree::mwTransformUp() { int nNodes = nodeTable[n].size(); #pragma omp for schedule(guided) for (int i = 0; i < nNodes; i++) { - MWNode &node = *nodeTable[n][i]; + MWNode &node = *nodeTable[n][i]; if (node.isBranchNode()) { node.reCompress(); } } } @@ -190,8 +191,8 @@ template void MWTree::mwTransformUp() { * operation is generally used after the operator application. * */ -template void MWTree::mwTransformDown(bool overwrite) { - std::vector> nodeTable; +template void MWTree::mwTransformDown(bool overwrite) { + std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); #pragma omp parallel shared(nodeTable) num_threads(mrcpp_get_num_threads()) { @@ -199,7 +200,7 @@ template void MWTree::mwTransformDown(bool overwrite) { int n_nodes = nodeTable[n].size(); #pragma omp for schedule(guided) for (int i = 0; i < n_nodes; i++) { - MWNode &node = *nodeTable[n][i]; + MWNode &node = *nodeTable[n][i]; if (node.isBranchNode()) { if (this->getRootScale() > node.getScale()) { int reverse = n_nodes - 1; @@ -215,15 +216,15 @@ template void MWTree::mwTransformDown(bool overwrite) { } /** @brief Set the MW coefficients to zero, keeping the same tree structure - * + * * @details Keeps the node structure of the tree, even though the zero * function is representable at depth zero. One should then use \ref cropTree to remove * unnecessary nodes. */ -template void MWTree::setZero() { - TreeIterator it(*this); +template void MWTree::setZero() { + TreeIterator it(*this); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); node.zeroCoefs(); } this->squareNorm = 0.0; @@ -236,7 +237,7 @@ template void MWTree::setZero() { * safe, and must NEVER be called outside a critical region in parallel. * It's way. way too expensive to lock the tree, so don't even think * about it. */ -template void MWTree::incrementNodeCount(int scale) { +template void MWTree::incrementNodeCount(int scale) { int depth = scale - getRootScale(); if (depth < 0) { int n = this->nodesAtNegativeDepth.size(); @@ -261,7 +262,7 @@ template void MWTree::incrementNodeCount(int scale) { * It's way. way too expensive to lock the tree, so don't even think * about it. */ -template void MWTree::decrementNodeCount(int scale) { +template void MWTree::decrementNodeCount(int scale) { int depth = scale - getRootScale(); if (depth < 0) { assert(-depth - 1 < this->nodesAtNegativeDepth.size()); @@ -280,7 +281,7 @@ template void MWTree::decrementNodeCount(int scale) { * * @param[in] depth: Tree depth (0 depth is the coarsest scale) to count. */ -template int MWTree::getNNodesAtDepth(int depth) const { +template int MWTree::getNNodesAtDepth(int depth) const { int N = 0; if (depth < 0) { if (this->nodesAtNegativeDepth.size() >= -depth) N = this->nodesAtNegativeDepth[-depth]; @@ -291,9 +292,9 @@ template int MWTree::getNNodesAtDepth(int depth) const { } /** @returns Size of all MW coefs in the tree, in kB */ -template int MWTree::getSizeNodes() const { +template int MWTree::getSizeNodes() const { auto nCoefs = 1ll * getNNodes() * getTDim() * getKp1_d(); - return sizeof(double) * nCoefs / 1024; + return sizeof(T) * nCoefs / 1024; } /** @brief Finds and returns the node pointer with the given \ref NodeIndex, const version. @@ -303,11 +304,11 @@ template int MWTree::getSizeNodes() const { * pointer if the node does not exist, or if it is a * GenNode. Recursion starts at the appropriate rootNode. */ -template const MWNode *MWTree::findNode(NodeIndex idx) const { +template const MWNode *MWTree::findNode(NodeIndex idx) const { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } int rIdx = getRootBox().getBoxIndex(idx); if (rIdx < 0) return nullptr; - const MWNode &root = this->rootBox.getNode(rIdx); + const MWNode &root = this->rootBox.getNode(rIdx); assert(root.isAncestor(idx)); return root.retrieveNodeNoGen(idx); } @@ -319,11 +320,11 @@ template const MWNode *MWTree::findNode(NodeIndex idx) const { * pointer if the node does not exist, or if it is a * GenNode. Recursion starts at the appropriate rootNode. */ -template MWNode *MWTree::findNode(NodeIndex idx) { +template MWNode *MWTree::findNode(NodeIndex idx) { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } int rIdx = getRootBox().getBoxIndex(idx); if (rIdx < 0) return nullptr; - MWNode &root = this->rootBox.getNode(rIdx); + MWNode &root = this->rootBox.getNode(rIdx); assert(root.isAncestor(idx)); return root.retrieveNodeNoGen(idx); } @@ -334,17 +335,18 @@ template MWNode *MWTree::findNode(NodeIndex idx) { * node does not exist, it will be generated by MW * transform. Recursion starts at the appropriate rootNode and descends * from this. + * The nodes are permanently added to the tree if create = true */ -template MWNode &MWTree::getNode(NodeIndex idx) { +template MWNode &MWTree::getNode(NodeIndex idx, bool create) { if (getRootBox().isPeriodic()) periodic::index_manipulation(idx, getRootBox().getPeriodic()); - MWNode *out = nullptr; - MWNode &root = getRootBox().getNode(idx); + MWNode *out = nullptr; + MWNode &root = getRootBox().getNode(idx); if (idx.getScale() < getRootScale()) { #pragma omp critical(gen_parent) out = root.retrieveParent(idx); } else { - out = root.retrieveNode(idx); + out = root.retrieveNode(idx, create); } return *out; } @@ -357,9 +359,9 @@ template MWNode &MWTree::getNode(NodeIndex idx) { * GenNodes. Recursion starts at the appropriate rootNode and decends * from this. */ -template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { +template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } - MWNode &root = getRootBox().getNode(idx); + MWNode &root = getRootBox().getNode(idx); assert(root.isAncestor(idx)); return *root.retrieveNodeOrEndNode(idx); } @@ -371,9 +373,9 @@ template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { * transform. Recursion starts at the appropriate rootNode and decends * from this. */ -template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) const { +template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) const { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } - const MWNode &root = getRootBox().getNode(idx); + const MWNode &root = getRootBox().getNode(idx); assert(root.isAncestor(idx)); return *root.retrieveNodeOrEndNode(idx); } @@ -387,8 +389,8 @@ template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) * generate nodes that do not exist. Recursion starts at the * appropriate rootNode and decends from this. */ -template MWNode &MWTree::getNode(Coord r, int depth) { - MWNode &root = getRootBox().getNode(r); +template MWNode &MWTree::getNode(Coord r, int depth) { + MWNode &root = getRootBox().getNode(r); if (depth >= 0) { return *root.retrieveNode(r, depth); } else { @@ -405,11 +407,11 @@ template MWNode &MWTree::getNode(Coord r, int depth) { * the path to the requested node, and will never create or return GenNodes. * Recursion starts at the appropriate rootNode and decends from this. */ -template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { +template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { if (getRootBox().isPeriodic()) { periodic::coord_manipulation(r, getRootBox().getPeriodic()); } - MWNode &root = getRootBox().getNode(r); + MWNode &root = getRootBox().getNode(r); return *root.retrieveNodeOrEndNode(r, depth); } @@ -422,22 +424,22 @@ template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { * the path to the requested node, and will never create or return GenNodes. * Recursion starts at the appropriate rootNode and decends from this. */ -template const MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) const { +template const MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) const { if (getRootBox().isPeriodic()) { periodic::coord_manipulation(r, getRootBox().getPeriodic()); } - const MWNode &root = getRootBox().getNode(r); + const MWNode &root = getRootBox().getNode(r); return *root.retrieveNodeOrEndNode(r, depth); } /** @brief Returns the list of all EndNodes * * @details copies the list of all EndNode pointers into a new vector - * and retunrs it. + * and returns it. */ -template MWNodeVector *MWTree::copyEndNodeTable() { - auto *nVec = new MWNodeVector; +template MWNodeVector *MWTree::copyEndNodeTable() { + auto *nVec = new MWNodeVector; for (int n = 0; n < getNEndNodes(); n++) { - MWNode &node = getEndMWNode(n); + MWNode &node = getEndMWNode(n); nVec->push_back(&node); } return nVec; @@ -447,29 +449,28 @@ template MWNodeVector *MWTree::copyEndNodeTable() { * * @details the endNodeTable is first deleted and then rebuilt from * scratch. It makes use of the TreeIterator to traverse the tree. - * + * */ -template void MWTree::resetEndNodeTable() { +template void MWTree::resetEndNodeTable() { clearEndNodeTable(); - TreeIterator it(*this, TopDown, Hilbert); + TreeIterator it(*this, TopDown, Hilbert); it.setReturnGenNodes(false); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); if (node.isEndNode()) { this->endNodeTable.push_back(&node); } } } - -template int MWTree::countBranchNodes(int depth) { +template int MWTree::countBranchNodes(int depth) { NOT_IMPLEMENTED_ABORT; } -template int MWTree::countLeafNodes(int depth) { +template int MWTree::countLeafNodes(int depth) { NOT_IMPLEMENTED_ABORT; // int nNodes = 0; - // TreeIterator it(*this); + // TreeIterator it(*this); // while (it.next()) { - // MWNode &node = it.getNode(); + // MWNode &node = it.getNode(); // if (node.getDepth() == depth or depth < 0) { // if (node.isLeafNode()) { // nNodes++; @@ -480,12 +481,12 @@ template int MWTree::countLeafNodes(int depth) { } /* Traverse tree and count nodes belonging to this rank. */ -template int MWTree::countNodes(int depth) { +template int MWTree::countNodes(int depth) { NOT_IMPLEMENTED_ABORT; - // TreeIterator it(*this); + // TreeIterator it(*this); // int count = 0; // while (it.next()) { - // MWNode &node = it.getNode(); + // MWNode &node = it.getNode(); // if (node.isGenNode()) { // continue; // } @@ -497,12 +498,12 @@ template int MWTree::countNodes(int depth) { } /* Traverse tree and count nodes with allocated coefficients. */ -template int MWTree::countAllocNodes(int depth) { +template int MWTree::countAllocNodes(int depth) { NOT_IMPLEMENTED_ABORT; - // TreeIterator it(*this); + // TreeIterator it(*this); // int count = 0; // while (it.next()) { - // MWNode &node = it.getNode(); + // MWNode &node = it.getNode(); // if (node.isGenNode()) { // continue; // } @@ -515,7 +516,7 @@ template int MWTree::countAllocNodes(int depth) { /** @brief Prints a summary of the tree structure on the output file */ -template std::ostream &MWTree::print(std::ostream &o) const { +template std::ostream &MWTree::print(std::ostream &o) const { o << " square norm: " << this->squareNorm << std::endl; o << " root scale: " << this->getRootScale() << std::endl; o << " order: " << this->order << std::endl; @@ -532,9 +533,9 @@ template std::ostream &MWTree::print(std::ostream &o) const { * @details it defines the upper bound of the squared norm \f$ * ||f||^2_{\ldots} \f$ in this node or its descendents */ -template void MWTree::makeMaxSquareNorms() { - NodeBox &rBox = this->getRootBox(); - MWNode **roots = rBox.getNodes(); +template void MWTree::makeMaxSquareNorms() { + NodeBox &rBox = this->getRootBox(); + MWNode **roots = rBox.getNodes(); for (int rIdx = 0; rIdx < rBox.size(); rIdx++) { // recursively set value of children and descendants roots[rIdx]->setMaxSquareNorm(); @@ -543,15 +544,18 @@ template void MWTree::makeMaxSquareNorms() { /** @brief gives serialIx of a node from its NodeIndex * - * @details Peter will document this! + * @details gives a unique integer for each nodes corresponding to the position + * of the node in the serialized representation */ -template int MWTree::getIx(NodeIndex nIdx) { +template int MWTree::getIx(NodeIndex nIdx) { if (this->isLocal == false) MSG_ERROR("getIx only implemented in local representation"); - if(NodeIndex2serialIx.count(nIdx) == 0) return -1; - else return NodeIndex2serialIx[nIdx]; + if (NodeIndex2serialIx.count(nIdx) == 0) + return -1; + else + return NodeIndex2serialIx[nIdx]; } -template void MWTree::getNodeCoeff(NodeIndex nIdx, double *data) { +template void MWTree::getNodeCoeff(NodeIndex nIdx, T *data) { assert(this->isLocal); int size = (1 << D) * kp1_d; int id = 0; @@ -559,8 +563,12 @@ template void MWTree::getNodeCoeff(NodeIndex nIdx, double *data) { this->NodesCoeff->get_data(id, size, data); } -template class MWTree<1>; -template class MWTree<2>; -template class MWTree<3>; +template class MWTree<1, double>; +template class MWTree<2, double>; +template class MWTree<3, double>; + +template class MWTree<1, ComplexDouble>; +template class MWTree<2, ComplexDouble>; +template class MWTree<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/MWTree.h b/src/trees/MWTree.h index 51cfe3eed..b0261aca6 100644 --- a/src/trees/MWTree.h +++ b/src/trees/MWTree.h @@ -26,8 +26,8 @@ #pragma once #include -#include #include +#include #include "MRCPP/mrcpp_declarations.h" #include "utils/omp_utils.h" @@ -61,11 +61,11 @@ class BankAccount; * present. See specific methods for details. * */ -template class MWTree { +template class MWTree { public: MWTree(const MultiResolutionAnalysis &mra, const std::string &n); - MWTree(const MWTree &tree) = delete; - MWTree &operator=(const MWTree &tree) = delete; + MWTree(const MWTree &tree) = delete; + MWTree &operator=(const MWTree &tree) = delete; virtual ~MWTree(); void setZero(); @@ -73,7 +73,7 @@ template class MWTree { /** @returns Squared L2 norm of the function */ double getSquareNorm() const { return this->squareNorm; } - void calcSquareNorm(); + void calcSquareNorm(bool deep = false); void clearSquareNorm() { this->squareNorm = -1.0; } int getOrder() const { return this->order; } @@ -90,8 +90,8 @@ template class MWTree { int getSizeNodes() const; /** @returns */ - NodeBox &getRootBox() { return this->rootBox; } - const NodeBox &getRootBox() const { return this->rootBox; } + NodeBox &getRootBox() { return this->rootBox; } + const NodeBox &getRootBox() const { return this->rootBox; } const MultiResolutionAnalysis &getMRA() const { return this->MRA; } void mwTransform(int type, bool overwrite = true); @@ -102,28 +102,28 @@ template class MWTree { int getRootIndex(Coord r) const { return this->rootBox.getBoxIndex(r); } int getRootIndex(NodeIndex nIdx) const { return this->rootBox.getBoxIndex(nIdx); } - MWNode *findNode(NodeIndex nIdx); - const MWNode *findNode(NodeIndex nIdx) const; + MWNode *findNode(NodeIndex nIdx); + const MWNode *findNode(NodeIndex nIdx) const; - MWNode &getNode(NodeIndex nIdx); - MWNode &getNodeOrEndNode(NodeIndex nIdx); - const MWNode &getNodeOrEndNode(NodeIndex nIdx) const; + MWNode &getNode(NodeIndex nIdx, bool create = false); + MWNode &getNodeOrEndNode(NodeIndex nIdx); + const MWNode &getNodeOrEndNode(NodeIndex nIdx) const; - MWNode &getNode(Coord r, int depth = -1); - MWNode &getNodeOrEndNode(Coord r, int depth = -1); - const MWNode &getNodeOrEndNode(Coord r, int depth = -1) const; + MWNode &getNode(Coord r, int depth = -1); + MWNode &getNodeOrEndNode(Coord r, int depth = -1); + const MWNode &getNodeOrEndNode(Coord r, int depth = -1) const; int getNEndNodes() const { return this->endNodeTable.size(); } int getNRootNodes() const { return this->rootBox.size(); } - MWNode &getEndMWNode(int i) { return *this->endNodeTable[i]; } - MWNode &getRootMWNode(int i) { return this->rootBox.getNode(i); } - const MWNode &getEndMWNode(int i) const { return *this->endNodeTable[i]; } - const MWNode &getRootMWNode(int i) const { return this->rootBox.getNode(i); } + MWNode &getEndMWNode(int i) { return *this->endNodeTable[i]; } + MWNode &getRootMWNode(int i) { return this->rootBox.getNode(i); } + const MWNode &getEndMWNode(int i) const { return *this->endNodeTable[i]; } + const MWNode &getRootMWNode(int i) const { return this->rootBox.getNode(i); } bool isPeriodic() const { return this->MRA.getWorldBox().isPeriodic(); } - MWNodeVector *copyEndNodeTable(); - MWNodeVector *getEndNodeTable() { return &this->endNodeTable; } + MWNodeVector *copyEndNodeTable(); + MWNodeVector *getEndNodeTable() { return &this->endNodeTable; } void deleteRootNodes(); void resetEndNodeTable(); @@ -133,24 +133,26 @@ template class MWTree { int countLeafNodes(int depth = -1); int countAllocNodes(int depth = -1); int countNodes(int depth = -1); - bool isLocal = false; // to know whether the tree coeffcients are stored in the Bank + bool isLocal = false; // to know whether the tree coeffcients are stored in the Bank int getIx(NodeIndex nIdx); // gives serialIx of a stored node from its NodeIndex if isLocal void makeMaxSquareNorms(); // sets values for maxSquareNorm and maxWSquareNorm in all nodes - NodeAllocator &getNodeAllocator() { return *this->nodeAllocator_p; } - const NodeAllocator &getNodeAllocator() const { return *this->nodeAllocator_p; } - MWNodeVector endNodeTable; ///< Final projected nodes + NodeAllocator &getNodeAllocator() { return *this->nodeAllocator_p; } + const NodeAllocator &getNodeAllocator() const { return *this->nodeAllocator_p; } + MWNodeVector endNodeTable; ///< Final projected nodes - void getNodeCoeff(NodeIndex nIdx, double *data); // fetch coefficient from a specific node stored in Bank + void getNodeCoeff(NodeIndex nIdx, T *data); // fetch coefficient from a specific node stored in Bank + bool conjugate() const { return this->conj; } + void setConjugate(bool conjug) { this->conj = conjug; } - friend std::ostream &operator<<(std::ostream &o, const MWTree &tree) { return tree.print(o); } + friend std::ostream &operator<<(std::ostream &o, const MWTree &tree) { return tree.print(o); } - friend class MWNode; - friend class FunctionNode; + friend class MWNode; + friend class FunctionNode; friend class OperatorNode; - friend class TreeBuilder; - friend class NodeAllocator; + friend class TreeBuilder; + friend class NodeAllocator; protected: // Parameters that are set in construction and should never change @@ -165,11 +167,11 @@ template class MWTree { // Parameters that are dynamic and can be set by user std::string name; - std::unique_ptr> nodeAllocator_p{nullptr}; + std::unique_ptr> nodeAllocator_p{nullptr}; // Tree data double squareNorm; - NodeBox rootBox; ///< The actual container of nodes + NodeBox rootBox; ///< The actual container of nodes std::vector nodesAtDepth; ///< Node counter std::vector nodesAtNegativeDepth; ///< Node counter @@ -180,8 +182,8 @@ template class MWTree { void decrementNodeCount(int scale); BankAccount *NodesCoeff = nullptr; + bool conj{false}; virtual std::ostream &print(std::ostream &o) const; }; - } // namespace mrcpp diff --git a/src/trees/MultiResolutionAnalysis.cpp b/src/trees/MultiResolutionAnalysis.cpp index 6eaabb120..43b39c32d 100644 --- a/src/trees/MultiResolutionAnalysis.cpp +++ b/src/trees/MultiResolutionAnalysis.cpp @@ -106,7 +106,8 @@ MultiResolutionAnalysis::MultiResolutionAnalysis(const MultiResolutionAnalysi * @param[in] sb: Polynomial basis (MW) as a ScalingBasis object * @param[in] depth: Maximum allowed resolution depth, relative to root scale * - * @details Creates a MRA object from pre-existing BoundingBox and ScalingBasis objects. These objects are taken as reference. For more details about the constructor itself, see the first constructor. + * @details Creates a MRA object from pre-existing BoundingBox and ScalingBasis objects. These objects are taken as reference. For more details about the constructor itself, see the first + * constructor. */ template MultiResolutionAnalysis::MultiResolutionAnalysis(const BoundingBox &bb, const ScalingBasis &sb, int depth) @@ -124,9 +125,9 @@ MultiResolutionAnalysis::MultiResolutionAnalysis(const BoundingBox &bb, co * * @param[in] mra: MRA object, taken by constant reference * - * @details Equality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis represented by a BoundingBox object, computational domain (ScalingBasis object) and maximum depth (integer), and false otherwise. - * Computations on different MRA cannot be combined, this operator can be used to make sure that the multiple MRAs are compatible. - * For more information about the meaning of equality for BoundingBox and ScalingBasis objets, see their respective classes. + * @details Equality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis represented by a BoundingBox object, computational domain (ScalingBasis + * object) and maximum depth (integer), and false otherwise. Computations on different MRA cannot be combined, this operator can be used to make sure that the multiple MRAs are compatible. For more + * information about the meaning of equality for BoundingBox and ScalingBasis objets, see their respective classes. */ template bool MultiResolutionAnalysis::operator==(const MultiResolutionAnalysis &mra) const { if (this->basis != mra.basis) return false; @@ -141,12 +142,19 @@ template bool MultiResolutionAnalysis::operator==(const MultiResoluti * * @param[in] mra: MRA object, taken by constant reference * - * @details Inequality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis represented by a BoundingBox object, computational domain (ScalingBasis object) and maximum depth (integer), and false otherwise. - * Opposite of the == operator. - * For more information about the meaning of equality for BoundingBox and ScalingBasis objets, see their respective classes. + * @details Inequality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis represented by a BoundingBox object, computational domain (ScalingBasis + * object) and maximum depth (integer), and false otherwise. Opposite of the == operator. For more information about the meaning of equality for BoundingBox and ScalingBasis objets, see their + * respective classes. */ template bool MultiResolutionAnalysis::operator!=(const MultiResolutionAnalysis &mra) const { - return !(*this == mra); + if (this->basis != mra.basis) return true; + if (this->world != mra.world) + std::cout << "diff world " << this->world << std::endl + << "and " + << " " << mra.world << std::endl; + if (this->world != mra.world) return true; + if (this->maxDepth != mra.maxDepth) return true; + return false; } /** diff --git a/src/trees/NodeAllocator.cpp b/src/trees/NodeAllocator.cpp index 9ca79f0b4..5079459aa 100644 --- a/src/trees/NodeAllocator.cpp +++ b/src/trees/NodeAllocator.cpp @@ -27,18 +27,19 @@ #include -#include "MWNode.h" -#include "FunctionTree.h" #include "FunctionNode.h" -#include "OperatorTree.h" +#include "FunctionTree.h" +#include "MWNode.h" #include "OperatorNode.h" +#include "OperatorTree.h" #include "utils/Printer.h" #include "utils/mpi_utils.h" namespace mrcpp { -template NodeAllocator::NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) +template +NodeAllocator::NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) : coefsPerNode(coefsPerNode) , maxNodesPerChunk(nodesPerChunk) , tree_p(tree) @@ -47,14 +48,15 @@ template NodeAllocator::NodeAllocator(FunctionTree *tree, SharedMe this->nodeChunks.reserve(100); this->coefChunks.reserve(100); - FunctionNode tmp; + FunctionNode tmp; this->cvptr = *(char **)(&tmp); - this->sizeOfNode = sizeof(FunctionNode); + this->sizeOfNode = sizeof(FunctionNode); MRCPP_INIT_OMP_LOCK(); } -template <> NodeAllocator<2>::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) +template <> +NodeAllocator<2>::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) : coefsPerNode(coefsPerNode) , maxNodesPerChunk(nodesPerChunk) , tree_p(tree) @@ -70,11 +72,11 @@ template <> NodeAllocator<2>::NodeAllocator(OperatorTree *tree, SharedMemory *me MRCPP_INIT_OMP_LOCK(); } -template NodeAllocator::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) { +template NodeAllocator::NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk) { NOT_REACHED_ABORT; } -template NodeAllocator::~NodeAllocator() { +template NodeAllocator::~NodeAllocator() { for (auto &chunk : this->nodeChunks) delete[](char *) chunk; if (not isShared()) // if the data is shared, it must be freed by MPI_Win_free for (auto &chunk : this->coefChunks) delete[] chunk; @@ -82,38 +84,37 @@ template NodeAllocator::~NodeAllocator() { MRCPP_DESTROY_OMP_LOCK(); } -template MWNode * NodeAllocator::getNode_p(int sIdx) { +template MWNode *NodeAllocator::getNode_p(int sIdx) { MRCPP_SET_OMP_LOCK(); auto *node = getNodeNoLock(sIdx); MRCPP_UNSET_OMP_LOCK(); return node; } -template double * NodeAllocator::getCoef_p(int sIdx) { +template T *NodeAllocator::getCoef_p(int sIdx) { MRCPP_SET_OMP_LOCK(); auto *coefs = getCoefNoLock(sIdx); MRCPP_UNSET_OMP_LOCK(); return coefs; } -template MWNode * NodeAllocator::getNodeNoLock(int sIdx) { +template MWNode *NodeAllocator::getNodeNoLock(int sIdx) { if (sIdx < 0 or sIdx >= this->stackStatus.size()) return nullptr; int chunk = sIdx / this->maxNodesPerChunk; // which chunk int cIdx = sIdx % this->maxNodesPerChunk; // position in chunk return this->nodeChunks[chunk] + cIdx; } -template double * NodeAllocator::getCoefNoLock(int sIdx) { +template T *NodeAllocator::getCoefNoLock(int sIdx) { if (sIdx < 0 or sIdx >= this->stackStatus.size()) return nullptr; int chunk = sIdx / this->maxNodesPerChunk; // which chunk int idx = sIdx % this->maxNodesPerChunk; // position in chunk return this->coefChunks[chunk] + idx * this->coefsPerNode; } -template int NodeAllocator::alloc(int nNodes, bool coefs) { +template int NodeAllocator::alloc(int nNodes, bool coefs) { MRCPP_SET_OMP_LOCK(); if (nNodes <= 0 or nNodes > this->maxNodesPerChunk) MSG_ABORT("Cannot allocate " << nNodes << " nodes"); - // move topstack to start of next chunk if current chunk is too small int cIdx = this->topStack % (this->maxNodesPerChunk); bool chunkOverflow = ((cIdx + nNodes) > this->maxNodesPerChunk); @@ -127,6 +128,10 @@ template int NodeAllocator::alloc(int nNodes, bool coefs) { // return value is index of first new node auto sIdx = this->topStack; + // we require that the index for first child is a multiple of 2**D + // so that we can find the sibling rank using rank=sIdx%(2**D) + if (sIdx % nNodes != 0) MSG_ERROR(" node allocate error"); + // fill stack status auto &status = this->stackStatus; for (int i = sIdx; i < sIdx + nNodes; i++) { @@ -143,13 +148,13 @@ template int NodeAllocator::alloc(int nNodes, bool coefs) { return sIdx; } -template void NodeAllocator::dealloc(int sIdx) { +template void NodeAllocator::dealloc(int sIdx) { MRCPP_SET_OMP_LOCK(); if (sIdx < 0 or sIdx >= this->stackStatus.size()) MSG_ABORT("Invalid serial index: " << sIdx); auto *node_p = getNodeNoLock(sIdx); node_p->~MWNode(); - this->stackStatus[sIdx] = 0; // mark as available - if (sIdx == this->topStack - 1) { // top of stack + this->stackStatus[sIdx] = 0; // mark as available + if (sIdx == this->topStack - 1) { // top of stack while (this->stackStatus[this->topStack - 1] == 0) { this->topStack--; if (this->topStack < 1) break; @@ -161,16 +166,16 @@ template void NodeAllocator::dealloc(int sIdx) { MRCPP_UNSET_OMP_LOCK(); } -template void NodeAllocator::deallocAllCoeff() { +template void NodeAllocator::deallocAllCoeff() { if (not this->isShared()) for (auto &chunk : this->coefChunks) delete[] chunk; - else delete this->shmem_p; + else + delete this->shmem_p; this->shmem_p = nullptr; this->coefChunks.clear(); - } -template void NodeAllocator::init(int nChunks, bool coefs) { +template void NodeAllocator::init(int nChunks, bool coefs) { MRCPP_SET_OMP_LOCK(); if (nChunks <= 0) MSG_ABORT("Invalid number of chunks: " << nChunks); for (int i = getNChunks(); i < nChunks; i++) appendChunk(coefs); @@ -182,10 +187,10 @@ template void NodeAllocator::init(int nChunks, bool coefs) { MRCPP_UNSET_OMP_LOCK(); } -template void NodeAllocator::appendChunk(bool coefs) { +template void NodeAllocator::appendChunk(bool coefs) { // make coeff chunk if (coefs) { - double *c_chunk = nullptr; + T *c_chunk = nullptr; if (this->isShared()) { // for coefficients, take from the shared memory block c_chunk = this->shmem_p->sh_end_ptr; @@ -193,13 +198,13 @@ template void NodeAllocator::appendChunk(bool coefs) { // may increase size dynamically in the future if (this->shmem_p->sh_max_ptr < this->shmem_p->sh_end_ptr) MSG_ABORT("Shared block too small"); } else { - c_chunk = new double[getCoefChunkSize() / sizeof(double)]; + c_chunk = new T[getCoefChunkSize() / sizeof(T)]; } this->coefChunks.push_back(c_chunk); } // make node chunk - auto n_chunk = (MWNode *)new char[getNodeChunkSize()]; + auto n_chunk = (MWNode *)new char[getNodeChunkSize()]; for (int i = 0; i < this->maxNodesPerChunk; i++) { n_chunk[i].serialIx = -1; n_chunk[i].parentSerialIx = -1; @@ -215,11 +220,10 @@ template void NodeAllocator::appendChunk(bool coefs) { } /** Fill all holes in the chunks with occupied nodes, then remove all empty chunks */ -template int NodeAllocator::compress() { +template int NodeAllocator::compress() { MRCPP_SET_OMP_LOCK(); int nNodes = (1 << D); - if (this->maxNodesPerChunk * this->nodeChunks.size() <= - getTree().getNNodes() + this->maxNodesPerChunk + nNodes - 1) { + if (this->maxNodesPerChunk * this->nodeChunks.size() <= getTree().getNNodes() + this->maxNodesPerChunk + nNodes - 1) { MRCPP_UNSET_OMP_LOCK(); return 0; // nothing to compress } @@ -249,17 +253,17 @@ template int NodeAllocator::compress() { return nChunksDeleted; } -template int NodeAllocator::deleteUnusedChunks() { +template int NodeAllocator::deleteUnusedChunks() { // number of occupied chunks int nChunksTotal = getNChunks(); int nChunksUsed = getNChunksUsed(); - if(nChunksTotal == nChunksUsed) return 0; // no unused chunks + if (nChunksTotal == nChunksUsed) return 0; // no unused chunks assert(nChunksTotal >= nChunksUsed); for (int i = nChunksUsed; i < nChunksTotal; i++) delete[](char *)(this->nodeChunks[i]); if (isShared()) { // shared coefficients cannot be fully deallocated, only pointer is moved. - getMemory().sh_end_ptr -= (nChunksTotal - nChunksUsed) * this->coefsPerNode * this->maxNodesPerChunk; + getMemory().sh_end_ptr -= (nChunksTotal - nChunksUsed) * this->coefsPerNode * this->maxNodesPerChunk; } else { for (int i = nChunksUsed; i < nChunksTotal; i++) delete[] this->coefChunks[i]; } @@ -271,7 +275,7 @@ template int NodeAllocator::deleteUnusedChunks() { return nChunksTotal - nChunksUsed; } -template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int dstIdx) { +template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int dstIdx) { assert(nNodes > 0); assert(nNodes <= this->maxNodesPerChunk); @@ -288,7 +292,7 @@ template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int ds for (int i = 0; i < nNodes * this->sizeOfNode; i++) ((char *)dstNode)[i] = ((char *)srcNode)[i]; // coefs have new adresses - double *coefs_p = getCoefNoLock(dstIdx); + T *coefs_p = getCoefNoLock(dstIdx); if (coefs_p == nullptr) NOT_IMPLEMENTED_ABORT; // Nodes without coefs not handled atm for (int i = 0; i < nNodes; i++) (dstNode + i)->coefs = coefs_p + i * getNCoefs(); @@ -325,7 +329,7 @@ template void NodeAllocator::moveNodes(int nNodes, int srcIdx, int ds } // Last positions on a chunk cannot be used if there is no place for nNodes siblings on the same chunk -template int NodeAllocator::findNextAvailable(int sIdx, int nNodes) const { +template int NodeAllocator::findNextAvailable(int sIdx, int nNodes) const { assert(sIdx >= 0); assert(sIdx < this->stackStatus.size()); assert(nNodes >= 0); @@ -343,7 +347,7 @@ template int NodeAllocator::findNextAvailable(int sIdx, int nNodes) c return sIdx; } -template int NodeAllocator::findNextOccupied(int sIdx) const { +template int NodeAllocator::findNextOccupied(int sIdx) const { assert(sIdx >= 0); assert(sIdx < this->stackStatus.size()); bool endOfStack = (sIdx >= this->topStack); @@ -359,17 +363,17 @@ template int NodeAllocator::findNextOccupied(int sIdx) const { } /** Traverse tree and redefine pointer, counter and tables. */ -template void NodeAllocator::reassemble() { +template void NodeAllocator::reassemble() { MRCPP_SET_OMP_LOCK(); this->nNodes = 0; getTree().nodesAtDepth.clear(); getTree().squareNorm = 0.0; getTree().clearEndNodeTable(); - NodeBox &rootbox = getTree().getRootBox(); - MWNode **roots = rootbox.getNodes(); + NodeBox &rootbox = getTree().getRootBox(); + MWNode **roots = rootbox.getNodes(); - std::stack *> stack; + std::stack *> stack; for (int rIdx = 0; rIdx < rootbox.size(); rIdx++) { auto *root_p = getNodeNoLock(rIdx); assert(root_p != nullptr); @@ -414,7 +418,7 @@ template void NodeAllocator::reassemble() { MRCPP_UNSET_OMP_LOCK(); } -template void NodeAllocator::print() const { +template void NodeAllocator::print() const { int n = 0; for (int iChunk = 0; iChunk < getNChunks(); iChunk++) { int iShift = iChunk * this->maxNodesPerChunk; @@ -436,8 +440,12 @@ template void NodeAllocator::print() const { } } -template class NodeAllocator<1>; -template class NodeAllocator<2>; -template class NodeAllocator<3>; +template class NodeAllocator<1, double>; +template class NodeAllocator<2, double>; +template class NodeAllocator<3, double>; + +template class NodeAllocator<1, ComplexDouble>; +template class NodeAllocator<2, ComplexDouble>; +template class NodeAllocator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/NodeAllocator.h b/src/trees/NodeAllocator.h index 38e4ba7eb..7e33b7e21 100644 --- a/src/trees/NodeAllocator.h +++ b/src/trees/NodeAllocator.h @@ -40,12 +40,12 @@ namespace mrcpp { -template class NodeAllocator final { +template class NodeAllocator final { public: - NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); - NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); - NodeAllocator(const NodeAllocator &tree) = delete; - NodeAllocator &operator=(const NodeAllocator &tree) = delete; + NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + NodeAllocator(const NodeAllocator &tree) = delete; + NodeAllocator &operator=(const NodeAllocator &tree) = delete; ~NodeAllocator(); int alloc(int nNodes, bool coefs = true); @@ -63,38 +63,39 @@ template class NodeAllocator final { int getNChunks() const { return this->nodeChunks.size(); } int getNChunksUsed() const { return (this->topStack + this->maxNodesPerChunk - 1) / this->maxNodesPerChunk; } int getNodeChunkSize() const { return this->maxNodesPerChunk * this->sizeOfNode; } - int getCoefChunkSize() const { return this->maxNodesPerChunk * this->coefsPerNode * sizeof(double); } + int getCoefChunkSize() const { return this->maxNodesPerChunk * this->coefsPerNode * sizeof(T); } + int getMaxNodesPerChunk() const { return this->maxNodesPerChunk; } - double * getCoef_p(int sIdx); - MWNode * getNode_p(int sIdx); + T *getCoef_p(int sIdx); + MWNode *getNode_p(int sIdx); - double * getCoefChunk(int i) { return this->coefChunks[i]; } - MWNode * getNodeChunk(int i) { return this->nodeChunks[i]; } + T *getCoefChunk(int i) { return this->coefChunks[i]; } + MWNode *getNodeChunk(int i) { return this->nodeChunks[i]; } void print() const; protected: - int nNodes{0}; // number of nodes actually in use - int topStack{0}; // index of last node on stack - int sizeOfNode{0}; // sizeof(NodeType) - int coefsPerNode{0}; // number of coef for one node - int maxNodesPerChunk{0}; // max number of nodes per allocation + int nNodes{0}; // number of nodes actually in use + int topStack{0}; // index of last node on stack + int sizeOfNode{0}; // sizeof(NodeType) + int coefsPerNode{0}; // number of coef for one node + int maxNodesPerChunk{0}; // max number of nodes per allocation std::vector stackStatus{}; - std::vector coefChunks{}; - std::vector *> nodeChunks{}; + std::vector coefChunks{}; + std::vector *> nodeChunks{}; - char *cvptr{nullptr}; // pointer to virtual table - MWNode *last_p{nullptr}; // pointer just after the last active node, i.e. where to put next node - MWTree *tree_p{nullptr}; // pointer to external object - SharedMemory *shmem_p{nullptr}; // pointer to external object + char *cvptr{nullptr}; // pointer to virtual table + MWNode *last_p{nullptr}; // pointer just after the last active node, i.e. where to put next node + MWTree *tree_p{nullptr}; // pointer to external object + SharedMemory *shmem_p{nullptr}; // pointer to external object bool isShared() const { return (this->shmem_p != nullptr); } - MWTree &getTree() { return *this->tree_p; } - SharedMemory &getMemory() { return *this->shmem_p; } + MWTree &getTree() { return *this->tree_p; } + SharedMemory &getMemory() { return *this->shmem_p; } - double * getCoefNoLock(int sIdx); - MWNode * getNodeNoLock(int sIdx); + T *getCoefNoLock(int sIdx); + MWNode *getNodeNoLock(int sIdx); void moveNodes(int nNodes, int srcIdx, int dstIdx); void appendChunk(bool coefs); @@ -107,5 +108,4 @@ template class NodeAllocator final { #endif }; - } // namespace mrcpp diff --git a/src/trees/NodeBox.cpp b/src/trees/NodeBox.cpp index cc247d58e..bf747c4fc 100644 --- a/src/trees/NodeBox.cpp +++ b/src/trees/NodeBox.cpp @@ -36,50 +36,50 @@ namespace mrcpp { -template -NodeBox::NodeBox(const NodeIndex &idx, const std::array &nb) +template +NodeBox::NodeBox(const NodeIndex &idx, const std::array &nb) : BoundingBox(idx, nb) , nOccupied(0) , nodes(nullptr) { allocNodePointers(); } -template -NodeBox::NodeBox(const BoundingBox &box) +template +NodeBox::NodeBox(const BoundingBox &box) : BoundingBox(box) , nOccupied(0) , nodes(nullptr) { allocNodePointers(); } -template -NodeBox::NodeBox(const NodeBox &box) +template +NodeBox::NodeBox(const NodeBox &box) : BoundingBox(box) , nOccupied(0) , nodes(nullptr) { allocNodePointers(); } -template void NodeBox::allocNodePointers() { +template void NodeBox::allocNodePointers() { assert(this->nodes == nullptr); int nNodes = this->size(); - this->nodes = new MWNode *[nNodes]; + this->nodes = new MWNode *[nNodes]; for (int n = 0; n < nNodes; n++) { this->nodes[n] = nullptr; } this->nOccupied = 0; } -template NodeBox::~NodeBox() { +template NodeBox::~NodeBox() { deleteNodes(); } -template void NodeBox::deleteNodes() { +template void NodeBox::deleteNodes() { if (this->nodes == nullptr) { return; } for (int n = 0; n < this->size(); n++) { clearNode(n); } delete[] this->nodes; this->nodes = nullptr; } -template void NodeBox::setNode(int bIdx, MWNode **node) { +template void NodeBox::setNode(int bIdx, MWNode **node) { assert(bIdx >= 0); assert(bIdx < this->totBoxes); clearNode(bIdx); @@ -89,44 +89,48 @@ template void NodeBox::setNode(int bIdx, MWNode **node) { *node = nullptr; } -template MWNode &NodeBox::getNode(NodeIndex nIdx) { +template MWNode &NodeBox::getNode(NodeIndex nIdx) { int bIdx = this->getBoxIndex(nIdx); return getNode(bIdx); } -template MWNode &NodeBox::getNode(Coord r) { +template MWNode &NodeBox::getNode(Coord r) { int bIdx = this->getBoxIndex(r); if (bIdx < 0) MSG_ERROR("Coord out of bounds"); return getNode(bIdx); } -template MWNode &NodeBox::getNode(int bIdx) { +template MWNode &NodeBox::getNode(int bIdx) { assert(bIdx >= 0); assert(bIdx < this->totBoxes); assert(this->nodes[bIdx] != nullptr); return *this->nodes[bIdx]; } -template const MWNode &NodeBox::getNode(NodeIndex nIdx) const { +template const MWNode &NodeBox::getNode(NodeIndex nIdx) const { int bIdx = this->getBoxIndex(nIdx); return getNode(bIdx); } -template const MWNode &NodeBox::getNode(Coord r) const { +template const MWNode &NodeBox::getNode(Coord r) const { int bIdx = this->getBoxIndex(r); if (bIdx < 0) MSG_ERROR("Coord out of bounds"); return getNode(bIdx); } -template const MWNode &NodeBox::getNode(int bIdx) const { +template const MWNode &NodeBox::getNode(int bIdx) const { assert(bIdx >= 0); assert(bIdx < this->totBoxes); assert(this->nodes[bIdx] != nullptr); return *this->nodes[bIdx]; } -template class NodeBox<1>; -template class NodeBox<2>; -template class NodeBox<3>; +template class NodeBox<1, double>; +template class NodeBox<2, double>; +template class NodeBox<3, double>; + +template class NodeBox<1, ComplexDouble>; +template class NodeBox<2, ComplexDouble>; +template class NodeBox<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/NodeBox.h b/src/trees/NodeBox.h index dfb0dc20c..7a7fc086e 100644 --- a/src/trees/NodeBox.h +++ b/src/trees/NodeBox.h @@ -30,31 +30,31 @@ namespace mrcpp { -template class NodeBox final : public BoundingBox { +template class NodeBox final : public BoundingBox { public: NodeBox(const NodeIndex &idx, const std::array &nb = {}); - NodeBox(const NodeBox &box); + NodeBox(const NodeBox &box); NodeBox(const BoundingBox &box); - NodeBox &operator=(const NodeBox &box) = delete; + NodeBox &operator=(const NodeBox &box) = delete; ~NodeBox() override; - void setNode(int idx, MWNode **node); + void setNode(int idx, MWNode **node); void clearNode(int idx) { this->nodes[idx] = nullptr; } - MWNode &getNode(NodeIndex idx); - MWNode &getNode(Coord r); - MWNode &getNode(int i = 0); + MWNode &getNode(NodeIndex idx); + MWNode &getNode(Coord r); + MWNode &getNode(int i = 0); - const MWNode &getNode(NodeIndex idx) const; - const MWNode &getNode(Coord r) const; - const MWNode &getNode(int i = 0) const; + const MWNode &getNode(NodeIndex idx) const; + const MWNode &getNode(Coord r) const; + const MWNode &getNode(int i = 0) const; int getNOccupied() const { return this->nOccupied; } - MWNode **getNodes() { return this->nodes; } + MWNode **getNodes() { return this->nodes; } protected: - int nOccupied; ///< Number of non-zero pointers in box - MWNode **nodes; ///< Container of nodes + int nOccupied; ///< Number of non-zero pointers in box + MWNode **nodes; ///< Container of nodes void allocNodePointers(); void deleteNodes(); diff --git a/src/trees/NodeIndex.h b/src/trees/NodeIndex.h index 866f3bdb2..f73ded001 100644 --- a/src/trees/NodeIndex.h +++ b/src/trees/NodeIndex.h @@ -31,8 +31,8 @@ #pragma once -#include #include +#include namespace mrcpp { @@ -92,7 +92,7 @@ template class NodeIndex final { } private: - short int N{0}; ///< Length scale index 2^N + short int N{0}; ///< Length scale index 2^N std::array L{}; ///< Translation index [x,y,z,...] }; diff --git a/src/trees/OperatorNode.cpp b/src/trees/OperatorNode.cpp index a0e09aac5..37f576eac 100644 --- a/src/trees/OperatorNode.cpp +++ b/src/trees/OperatorNode.cpp @@ -42,16 +42,16 @@ void OperatorNode::dealloc() { this->tree->getNodeAllocator().dealloc(sIdx); } -/** +/** * @brief Calculate one specific component norm of the OperatorNode (TODO: needs to be specified more). - * + * * @param[in] i: TODO: deens to be specified * * @details OperatorNorms are defined as matrix 2-norms that are expensive to calculate. * Thus we calculate some cheaper upper bounds for this norm for thresholding. * First a simple vector norm, then a product of the 1- and infinity-norm. * (TODO: needs to be more presiced). - * + * */ double OperatorNode::calcComponentNorm(int i) const { int depth = getDepth(); @@ -64,7 +64,7 @@ double OperatorNode::calcComponentNorm(int i) const { int kp1 = this->getKp1(); int kp1_d = this->getKp1_d(); const VectorXd &comp_vec = coef_vec.segment(i * kp1_d, kp1_d); - const MatrixXd comp_mat = MatrixXd::Map(comp_vec.data(), kp1, kp1); //one can use MatrixXd OperatorNode::getComponent(int i) + const MatrixXd comp_mat = MatrixXd::Map(comp_vec.data(), kp1, kp1); // one can use MatrixXd OperatorNode::getComponent(int i) double norm = 0.0; double vecNorm = comp_vec.norm(); @@ -79,7 +79,6 @@ double OperatorNode::calcComponentNorm(int i) const { return norm; } - /** @brief Matrix elements of the non-standard form. * * @param[in] i: Index enumerating the matrix type in the non-standard form. @@ -92,10 +91,9 @@ double OperatorNode::calcComponentNorm(int i) const { * One of these matrices is returned by the method according to the choice of the index parameter * \f$ i = 0, 1, 2, 3 \f$, respectively. * For example, \f$ \alpha_l^n = \text{getComponent}(3) \f$. - * + * */ -MatrixXd OperatorNode::getComponent(int i) -{ +MatrixXd OperatorNode::getComponent(int i) { int depth = getDepth(); double prec = getOperTree().getNormPrecision(); double thrs = std::max(MachinePrec, prec / (8.0 * (1 << depth))); diff --git a/src/trees/OperatorTree.cpp b/src/trees/OperatorTree.cpp index 44963d465..890f2677c 100644 --- a/src/trees/OperatorTree.cpp +++ b/src/trees/OperatorTree.cpp @@ -25,9 +25,9 @@ #include "OperatorTree.h" #include "BandWidth.h" -#include "TreeIterator.h" #include "NodeAllocator.h" #include "OperatorNode.h" +#include "TreeIterator.h" #include "utils/Printer.h" #include "utils/tree_utils.h" @@ -98,15 +98,14 @@ void OperatorTree::clearBandWidth() { this->bandWidth = nullptr; } - /** @brief Calculates band widths of the non-standard form matrices. * * @param[in] prec: Precision used for thresholding - * + * * @details It is starting from \f$ l = 0 \f$ and updating the band width value each time we encounter - * considerable value while keeping increasing \f$ l \f$, that stands for the distance to the diagonal. - * - */ + * considerable value while keeping increasing \f$ l \f$, that stands for the distance to the diagonal. + * + */ void OperatorTree::calcBandWidth(double prec) { if (this->bandWidth == nullptr) clearBandWidth(); this->bandWidth = new BandWidth(getDepth()); @@ -134,61 +133,45 @@ void OperatorTree::calcBandWidth(double prec) { println(100, "\nOperator BandWidth" << *this->bandWidth); } - /** @brief Checks if the distance to diagonal is bigger than the operator band width. * * @param[in] oTransl: distance to diagonal * @param[in] o_depth: scaling order * @param[in] idx: index corresponding to one of the matrices \f$ A, B, C \f$ or \f$ T \f$. - * - * @returns True if \b oTransl is outside of the band and False otherwise. - * - */ -bool OperatorTree::isOutsideBand(int oTransl, int o_depth, int idx) -{ + * + * @returns True if \b oTransl is outside of the band and False otherwise. + * + */ +bool OperatorTree::isOutsideBand(int oTransl, int o_depth, int idx) { return abs(oTransl) > this->bandWidth->getWidth(o_depth, idx); } - /** @brief Cleans up end nodes. * * @param[in] trust_scale: there is no cleaning down below \b trust_scale (it speeds up operator building). - * + * * @details Traverses the tree and rewrites end nodes having branch node twins, * i. e. identical with respect to scale and translation. * This method is very handy, when an adaptive operator construction * can make a significunt noise at low scaling depth. * Its need comes from the fact that mwTransform up cannot override * rubbish that can potentially stick to end nodes at a particular level, - * and as a result spread further up to the root with mwTransform. - * + * and as a result spread further up to the root with mwTransform. + * */ -void OperatorTree::removeRoughScaleNoise(int trust_scale) -{ - MWNode<2> *p_rubbish; //possibly inexact end node - MWNode<2> *p_counterpart; //exact branch node - for( int n = (this->getDepth() - 2 < trust_scale) ? this->getDepth() - 2 : trust_scale; n > this->getRootScale(); n--) - { - int N = 1<findNode( NodeIndex<2>(n, {m, l}) ); - if( p_rubbish != nullptr && p_rubbish->isEndNode() ) - { - for( int m1 = 0; m1 < N; m1++ ) - for( int l1 = 0; l1 < N; l1++ ) - if - ( - (m1 - l1 == m - l) - && - ( p_counterpart = this->findNode( NodeIndex<2>(n, {m1, l1}) ) ) != nullptr - && - p_counterpart->isBranchNode() - ) - { - for(int i = 0; i < p_counterpart->n_coefs; i++) - p_rubbish->coefs[i] = p_counterpart->coefs[i]; +void OperatorTree::removeRoughScaleNoise(int trust_scale) { + MWNode<2> *p_rubbish; // possibly inexact end node + MWNode<2> *p_counterpart; // exact branch node + for (int n = (this->getDepth() - 2 < trust_scale) ? this->getDepth() - 2 : trust_scale; n > this->getRootScale(); n--) { + int N = 1 << n; + for (int m = 0; m < N; m++) + for (int l = 0; l < N; l++) { + p_rubbish = this->findNode(NodeIndex<2>(n, {m, l})); + if (p_rubbish != nullptr && p_rubbish->isEndNode()) { + for (int m1 = 0; m1 < N; m1++) + for (int l1 = 0; l1 < N; l1++) + if ((m1 - l1 == m - l) && (p_counterpart = this->findNode(NodeIndex<2>(n, {m1, l1}))) != nullptr && p_counterpart->isBranchNode()) { + for (int i = 0; i < p_counterpart->n_coefs; i++) p_rubbish->coefs[i] = p_counterpart->coefs[i]; } } } @@ -196,8 +179,6 @@ void OperatorTree::removeRoughScaleNoise(int trust_scale) } } - - void OperatorTree::getMaxTranslations(VectorXi &maxTransl) { int nScales = this->nodesAtDepth.size(); maxTransl = VectorXi::Zero(nScales); diff --git a/src/trees/OperatorTree.h b/src/trees/OperatorTree.h index dcc2e09a8..83be4789a 100644 --- a/src/trees/OperatorTree.h +++ b/src/trees/OperatorTree.h @@ -50,8 +50,10 @@ class OperatorTree : public MWTree<2> { BandWidth &getBandWidth() { return *this->bandWidth; } const BandWidth &getBandWidth() const { return *this->bandWidth; } - OperatorNode &getNode(int n, int l) { return *nodePtrAccess[n][l]; } ///< TODO: It has to be specified more. - ///< \b l is distance to the diagonal. + OperatorNode &getNode(int n, int l) { + return *nodePtrAccess[n][l]; + } ///< TODO: It has to be specified more. + ///< \b l is distance to the diagonal. const OperatorNode &getNode(int n, int l) const { return *nodePtrAccess[n][l]; } void mwTransformDown(bool overwrite) override; diff --git a/src/trees/TreeIterator.cpp b/src/trees/TreeIterator.cpp index 9bf9fb054..b9d4aee85 100644 --- a/src/trees/TreeIterator.cpp +++ b/src/trees/TreeIterator.cpp @@ -29,7 +29,8 @@ namespace mrcpp { -template TreeIterator::TreeIterator(int traverse, int iterator) +template +TreeIterator::TreeIterator(int traverse, int iterator) : root(0) , nRoots(0) , mode(traverse) @@ -38,7 +39,8 @@ template TreeIterator::TreeIterator(int traverse, int iterator) , state(nullptr) , initialState(nullptr) {} -template TreeIterator::TreeIterator(MWTree &tree, int traverse, int iterator) +template +TreeIterator::TreeIterator(MWTree &tree, int traverse, int iterator) : root(0) , nRoots(0) , mode(traverse) @@ -49,23 +51,23 @@ template TreeIterator::TreeIterator(MWTree &tree, int traverse, in init(tree); } -template TreeIterator::~TreeIterator() { +template TreeIterator::~TreeIterator() { if (this->initialState != nullptr) delete this->initialState; } -template int TreeIterator::getChildIndex(int i) const { - const MWNode &node = *this->state->node; +template int TreeIterator::getChildIndex(int i) const { + const MWNode &node = *this->state->node; const HilbertPath &h = node.getHilbertPath(); // Legesgue type returns i, Hilbert type returns Hilbert index return (this->type == Hilbert) ? h.getZIndex(i) : i; } -template bool TreeIterator::next() { +template bool TreeIterator::next() { if (not this->state) return false; if (this->mode == TopDown) { if (this->tryNode()) return true; } - MWNode &node = *this->state->node; + MWNode &node = *this->state->node; if (checkDepth(node) and checkGenerated(node)) { const int nChildren = 1 << D; for (int i = 0; i < nChildren; i++) { @@ -80,12 +82,12 @@ template bool TreeIterator::next() { this->removeState(); return next(); } -template bool TreeIterator::nextParent() { +template bool TreeIterator::nextParent() { if (not this->state) return false; if (this->mode == BottomUp) { if (this->tryNode()) return true; } - MWNode &node = *this->state->node; + MWNode &node = *this->state->node; if (this->tryNextRootParent()) return true; if (checkDepth(node)) { if (this->tryParent()) return true; @@ -97,73 +99,73 @@ template bool TreeIterator::nextParent() { return nextParent(); } -template void TreeIterator::init(MWTree &tree) { +template void TreeIterator::init(MWTree &tree) { this->root = 0; this->maxDepth = -1; this->nRoots = tree.getRootBox().size(); - this->state = new IteratorNode(&tree.getRootBox().getNode(this->root)); + this->state = new IteratorNode(&tree.getRootBox().getNode(this->root)); // Save the first state so it can be properly deleted later this->initialState = this->state; } -template bool TreeIterator::tryNode() { +template bool TreeIterator::tryNode() { if (not this->state) { return false; } if (this->state->doneNode) { return false; } this->state->doneNode = true; return true; } -template bool TreeIterator::tryChild(int i) { +template bool TreeIterator::tryChild(int i) { if (not this->state) { return false; } if (this->state->doneChild[i]) { return false; } this->state->doneChild[i] = true; if (this->state->node->isLeafNode()) { return false; } - MWNode *child = &this->state->node->getMWChild(i); - this->state = new IteratorNode(child, this->state); + MWNode *child = &this->state->node->getMWChild(i); + this->state = new IteratorNode(child, this->state); return next(); } -template bool TreeIterator::tryParent() { +template bool TreeIterator::tryParent() { if (not this->state) return false; if (this->state->doneParent) return false; this->state->doneParent = true; if (not this->state->node->hasParent()) return false; - MWNode *parent = &this->state->node->getMWParent(); - this->state = new IteratorNode(parent, this->state); + MWNode *parent = &this->state->node->getMWParent(); + this->state = new IteratorNode(parent, this->state); return nextParent(); } -template bool TreeIterator::tryNextRoot() { +template bool TreeIterator::tryNextRoot() { if (not this->state) { return false; } if (not this->state->node->isRootNode()) { return false; } this->root++; if (this->root >= this->nRoots) { return false; } - MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); - this->state = new IteratorNode(nextRoot, this->state); + MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); + this->state = new IteratorNode(nextRoot, this->state); return next(); } -template bool TreeIterator::tryNextRootParent() { +template bool TreeIterator::tryNextRootParent() { if (not this->state) { return false; } if (not this->state->node->isRootNode()) { return false; } this->root++; if (this->root >= this->nRoots) { return false; } - MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); - this->state = new IteratorNode(nextRoot, this->state); + MWNode *nextRoot = &state->node->getMWTree().getRootBox().getNode(root); + this->state = new IteratorNode(nextRoot, this->state); return nextParent(); } -template void TreeIterator::removeState() { +template void TreeIterator::removeState() { if (this->state == this->initialState) { this->initialState = nullptr; } if (this->state != nullptr) { - IteratorNode *spare = this->state; + IteratorNode *spare = this->state; this->state = spare->next; spare->next = nullptr; delete spare; } } -template void TreeIterator::setTraverse(int traverse) { +template void TreeIterator::setTraverse(int traverse) { switch (traverse) { case TopDown: this->mode = TopDown; @@ -177,7 +179,7 @@ template void TreeIterator::setTraverse(int traverse) { } } -template void TreeIterator::setIterator(int iterator) { +template void TreeIterator::setIterator(int iterator) { switch (iterator) { case Lebesgue: this->type = Lebesgue; @@ -191,7 +193,7 @@ template void TreeIterator::setIterator(int iterator) { } } -template bool TreeIterator::checkDepth(const MWNode &node) const { +template bool TreeIterator::checkDepth(const MWNode &node) const { if (this->maxDepth < 0) { return true; } else if (node.getDepth() < this->maxDepth) { @@ -201,7 +203,7 @@ template bool TreeIterator::checkDepth(const MWNode &node) const { } } -template bool TreeIterator::checkGenerated(const MWNode &node) const { +template bool TreeIterator::checkGenerated(const MWNode &node) const { if (node.isEndNode() and not this->returnGenNodes) { return false; } else { @@ -209,8 +211,8 @@ template bool TreeIterator::checkGenerated(const MWNode &node) con } } -template -IteratorNode::IteratorNode(MWNode *nd, IteratorNode *nx) +template +IteratorNode::IteratorNode(MWNode *nd, IteratorNode *nx) : node(nd) , next(nx) , doneNode(false) @@ -219,8 +221,12 @@ IteratorNode::IteratorNode(MWNode *nd, IteratorNode *nx) for (int i = 0; i < nChildren; i++) { this->doneChild[i] = false; } } -template class TreeIterator<1>; -template class TreeIterator<2>; -template class TreeIterator<3>; +template class TreeIterator<1, double>; +template class TreeIterator<2, double>; +template class TreeIterator<3, double>; + +template class TreeIterator<1, ComplexDouble>; +template class TreeIterator<2, ComplexDouble>; +template class TreeIterator<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/trees/TreeIterator.h b/src/trees/TreeIterator.h index d82db82a0..82ea49eb9 100644 --- a/src/trees/TreeIterator.h +++ b/src/trees/TreeIterator.h @@ -30,10 +30,10 @@ namespace mrcpp { -template class TreeIterator { +template class TreeIterator { public: TreeIterator(int traverse = TopDown, int iterator = Lebesgue); - TreeIterator(MWTree &tree, int traverse = TopDown, int iterator = Lebesgue); + TreeIterator(MWTree &tree, int traverse = TopDown, int iterator = Lebesgue); virtual ~TreeIterator(); void setReturnGenNodes(bool i = true) { this->returnGenNodes = i; } @@ -41,12 +41,12 @@ template class TreeIterator { void setTraverse(int traverse); void setIterator(int iterator); - void init(MWTree &tree); + void init(MWTree &tree); bool next(); bool nextParent(); - MWNode &getNode() { return *this->state->node; } + MWNode &getNode() { return *this->state->node; } - friend class IteratorNode; + friend class IteratorNode; protected: int root; @@ -55,8 +55,8 @@ template class TreeIterator { int type; int maxDepth; bool returnGenNodes{true}; - IteratorNode *state; - IteratorNode *initialState; + IteratorNode *state; + IteratorNode *initialState; int getChildIndex(int i) const; @@ -66,19 +66,19 @@ template class TreeIterator { bool tryNextRoot(); bool tryNextRootParent(); void removeState(); - bool checkDepth(const MWNode &node) const; - bool checkGenerated(const MWNode &node) const; + bool checkDepth(const MWNode &node) const; + bool checkGenerated(const MWNode &node) const; }; -template class IteratorNode final { +template class IteratorNode final { public: - MWNode *node; - IteratorNode *next; + MWNode *node; + IteratorNode *next; bool doneNode; bool doneParent; bool doneChild[1 << D]; - IteratorNode(MWNode *nd, IteratorNode *nx = nullptr); + IteratorNode(MWNode *nd, IteratorNode *nx = nullptr); ~IteratorNode() { delete this->next; } }; diff --git a/src/utils/Bank.cpp b/src/utils/Bank.cpp index a774c44ff..f8c111a53 100644 --- a/src/utils/Bank.cpp +++ b/src/utils/Bank.cpp @@ -17,7 +17,7 @@ Bank::~Bank() { struct Blockdata_struct { std::vector data; // to store the incoming data. One column for each orbital on the same node. - int N_rows = 0; // the number of coefficients in one column of the block. + int N_rows = 0; // the number of coefficients in one column of the block. std::map id2data; // internal index of the data in the block std::vector id; // the id of each column. Either nodeid, or orbid }; @@ -29,33 +29,33 @@ struct OrbBlock_struct { }; struct mem_struct { std::vector chunk_p; // vector with allocated chunks - int p = -1; // position of next available memory (not allocated if < 0) - //on Betzy 1024*1024*4 ok, 1024*1024*2 NOT ok: leads to memory fragmentation (on "Betzy" 2023) - int chunk_size = 1024*1024*4; // chunksize (in number of doubles). data_p[i]+chunk_size is end of chunk i - int account=-1; - double * get_mem(int size){ - if(p<0 or size > chunk_size or p + size > chunk_size){ //allocate new chunk of memory - if(size > 1024*1024){ - //make a special chunk just for this - double * m_p = new double[size]; + int p = -1; // position of next available memory (not allocated if < 0) + // on Betzy 1024*1024*4 ok, 1024*1024*2 NOT ok: leads to memory fragmentation (on "Betzy" 2023) + int chunk_size = 1024 * 1024 * 4; // chunksize (in number of doubles). data_p[i]+chunk_size is end of chunk i + int account = -1; + double *get_mem(int size) { + if (p < 0 or size > chunk_size or p + size > chunk_size) { // allocate new chunk of memory + if (size > 1024 * 1024) { + // make a special chunk just for this + double *m_p = new double[size]; chunk_p.push_back(m_p); - p=-1; + p = -1; return m_p; } else { - double * m_p = new double[chunk_size]; + double *m_p = new double[chunk_size]; chunk_p.push_back(m_p); - p=0; + p = 0; } } - double * m_p = chunk_p[chunk_p.size()-1] + p; + double *m_p = chunk_p[chunk_p.size() - 1] + p; p += size; return m_p; } }; std::map *> get_nodeid2block; // to get block from its nodeid (all coeff for one node) -std::map *> get_orbid2block; // to get block from its orbid +std::map *> get_orbid2block; // to get block from its orbid -std::map mem; +std::map mem; int const MIN_SCALE = -999; // Smaller than smallest scale int naccounts = 0; @@ -115,7 +115,7 @@ void Bank::open() { get_readytasks[account] = new std::map>; currentsize[account] = 0; mem[account] = new mem_struct; - mem[account]->account=account; + mem[account]->account = account; MPI_Send(&account, 1, MPI_INT, status.MPI_SOURCE, 1, comm_bank); continue; } @@ -153,8 +153,8 @@ void Bank::open() { this->clear_bank(); for (auto const &block : nodeid2block) { if (block.second.data.size() > 0) { - currentsize[account] -= block.second.N_rows * block.second.data.size()/ 128; // converted into kB - totcurrentsize -= block.second.N_rows * block.second.data.size()/ 128; // converted into kB + currentsize[account] -= block.second.N_rows * block.second.data.size() / 128; // converted into kB + totcurrentsize -= block.second.N_rows * block.second.data.size() / 128; // converted into kB } } nodeid2block.clear(); @@ -171,9 +171,9 @@ void Bank::open() { int dataindex = 0; // internal index of the data in the block int size = 0; if (message == GET_NODEDATA) { - int orbid = messages[3]; // which part of the block to fetch + int orbid = messages[3]; // which part of the block to fetch dataindex = block.id2data[orbid]; // column of the data in the block - size = block.N_rows; // number of doubles to fetch + size = block.N_rows; // number of doubles to fetch if (size != messages[4]) std::cout << "ERROR nodedata has wrong size" << std::endl; double *data_p = block.data[dataindex]; if (size > 0) MPI_Send(data_p, size, MPI_DOUBLE, status.MPI_SOURCE, 3, comm_bank); @@ -186,12 +186,12 @@ void Bank::open() { if (printinfo) std::cout << " rewrite into superblock " << block.data.size() << " " << block.N_rows << " nodeid " << nodeid << std::endl; for (int j = 0; j < block.data.size(); j++) { for (int i = 0; i < block.N_rows; i++) { DataBlock(i, j) = block.data[j][i]; } - } + } dataindex = 0; // start from first column // send info about the size of the superblock - metadata_block[0] = nodeid; // nodeid + metadata_block[0] = nodeid; // nodeid metadata_block[1] = block.data.size(); // number of columns - metadata_block[2] = size; // total size = rows*columns + metadata_block[2] = size; // total size = rows*columns MPI_Send(metadata_block, size_metadata, MPI_INT, status.MPI_SOURCE, 1, comm_bank); // send info about the id of each column MPI_Send(block.id.data(), metadata_block[1], MPI_INT, status.MPI_SOURCE, 2, comm_bank); @@ -242,7 +242,7 @@ void Bank::open() { // send info about the size of the superblock metadata_block[0] = orbid; metadata_block[1] = block.data.size(); // number of columns - metadata_block[2] = size; // total size = rows*columns + metadata_block[2] = size; // total size = rows*columns MPI_Send(metadata_block, size_metadata, MPI_INT, status.MPI_SOURCE, 1, comm_bank); MPI_Send(block.id.data(), metadata_block[1], MPI_INT, status.MPI_SOURCE, 2, comm_bank); MPI_Send(coeff.data(), size, MPI_DOUBLE, status.MPI_SOURCE, 3, comm_bank); @@ -301,9 +301,9 @@ void Bank::open() { } send_function(*deposits[ix].orb, status.MPI_SOURCE, 1, comm_bank); if (message == GET_FUNCTION_AND_DELETE) { - currentsize[account] -= deposits[ix].orb->getSizeNodes(NUMBER::Total); - totcurrentsize -= deposits[ix].orb->getSizeNodes(NUMBER::Total); - deposits[ix].orb->free(NUMBER::Total); + currentsize[account] -= deposits[ix].orb->getSizeNodes(); + totcurrentsize -= deposits[ix].orb->getSizeNodes(); + deposits[ix].orb->free(); id2ix[id] = 0; } } @@ -319,20 +319,20 @@ void Bank::open() { // append the incoming data Blockdata_struct &block = nodeid2block[nodeid]; block.id2data[orbid] = nodeid2block[nodeid].data.size(); // internal index of the data in the block - double *data_p = mem[account]->get_mem(size);//new double[size]; - currentsize[account] += size / 128; // converted into kB - totcurrentsize += size / 128; // converted into kB + double *data_p = mem[account]->get_mem(size); // new double[size]; + currentsize[account] += size / 128; // converted into kB + totcurrentsize += size / 128; // converted into kB this->maxsize = std::max(totcurrentsize, this->maxsize); block.data.push_back(data_p); block.id.push_back(orbid); - if (block.N_rows > 0 and block.N_rows != size) cout<<" ERROR block size incompatible " < 0 and block.N_rows != size) cout << " ERROR block size incompatible " << block.N_rows << " " << size << endl; block.N_rows = size; OrbBlock_struct &orbblock = orbid2block[orbid]; orbblock.id2data[nodeid] = orbblock.data.size(); // internal index of the data in the block orbblock.data.push_back(data_p); orbblock.id.push_back(nodeid); - //orbblock.N_rows.push_back(size); + // orbblock.N_rows.push_back(size); MPI_Recv(data_p, size, MPI_DOUBLE, status.MPI_SOURCE, 1, comm_bank, &status); if (printinfo) std::cout << " written block " << nodeid << " id " << orbid << " subblocks " << nodeid2block[nodeid].data.size() << std::endl; @@ -370,12 +370,12 @@ void Bank::open() { } else { ix = deposits.size(); // NB: ix is now index of last element + 1 deposits.resize(ix + 1); - if (message == SAVE_FUNCTION) deposits[ix].orb = new ComplexFunction(0); + if (message == SAVE_FUNCTION) deposits[ix].orb = new CompFunction<3>(0); if (message == SAVE_DATA) { datasize = messages[3]; - deposits[ix].data = mem[account]->get_mem(datasize);//new double[datasize]; - currentsize[account] += datasize / 128; // converted into kB - totcurrentsize += datasize / 128; // converted into kB + deposits[ix].data = mem[account]->get_mem(datasize); // new double[datasize]; + currentsize[account] += datasize / 128; // converted into kB + totcurrentsize += datasize / 128; // converted into kB this->maxsize = std::max(totcurrentsize, this->maxsize); deposits[ix].hasdata = true; } @@ -385,10 +385,9 @@ void Bank::open() { deposits[ix].source = status.MPI_SOURCE; if (message == SAVE_FUNCTION) { recv_function(*deposits[ix].orb, deposits[ix].source, 1, comm_bank); - cout<<"recv ORB size "<getSizeNodes(NUMBER::Total)<getSizeNodes(NUMBER::Total); - totcurrentsize += deposits[ix].orb->getSizeNodes(NUMBER::Total); + currentsize[account] += deposits[ix].orb->getSizeNodes(); + totcurrentsize += deposits[ix].orb->getSizeNodes(); this->maxsize = std::max(totcurrentsize, this->maxsize); } } @@ -480,13 +479,13 @@ void Bank::remove_account(int account) { } std::vector &deposits = *get_deposits[account]; for (int ix = 1; ix < deposits.size(); ix++) { - if (deposits[ix].orb != nullptr) deposits[ix].orb->free(NUMBER::Total); - if (deposits[ix].hasdata) { - currentsize[account] -= deposits[ix].datasize / 128; - totcurrentsize -= deposits[ix].datasize / 128; - } - if (deposits[ix].hasdata) (*get_id2ix[account])[deposits[ix].id] = 0; // indicate that it does not exist - deposits[ix].hasdata = false; + if (deposits[ix].orb != nullptr) deposits[ix].orb->free(); + if (deposits[ix].hasdata) { + currentsize[account] -= deposits[ix].datasize / 128; + totcurrentsize -= deposits[ix].datasize / 128; + } + if (deposits[ix].hasdata) (*get_id2ix[account])[deposits[ix].id] = 0; // indicate that it does not exist + deposits[ix].hasdata = false; } deposits.clear(); get_deposits.erase(account); @@ -503,8 +502,8 @@ void Bank::remove_account(int account) { std::map &orbid2block = *get_orbid2block[account]; for (auto const &block : nodeid2block) { - currentsize[account] -= block.second.N_rows * block.second.data.size()/ 128; // converted into kB - totcurrentsize -= block.second.N_rows * block.second.data.size()/ 128; // converted into kB + currentsize[account] -= block.second.N_rows * block.second.data.size() / 128; // converted into kB + totcurrentsize -= block.second.N_rows * block.second.data.size() / 128; // converted into kB } nodeid2block.clear(); orbid2block.clear(); @@ -512,7 +511,7 @@ void Bank::remove_account(int account) { get_nodeid2block.erase(account); get_orbid2block.erase(account); - for (double* c_p : mem[account]->chunk_p) delete [] c_p; + for (double *c_p : mem[account]->chunk_p) delete[] c_p; mem.erase(account); currentsize.erase(account); #endif @@ -642,7 +641,7 @@ std::vector Bank::get_totalsize() { // get orbital with identity id. // If wait=0, return immediately with value zero if not available (default) // else, wait until available -int BankAccount::get_func(int id, ComplexFunction &func, int wait) { +int BankAccount::get_func(int id, CompFunction<3> &func, int wait) { #ifdef MRCPP_HAS_MPI MPI_Status status; int messages[message_size]; @@ -670,7 +669,7 @@ int BankAccount::get_func(int id, ComplexFunction &func, int wait) { // get orbital with identity id, and delete from bank. // return immediately with value zero if not available -int BankAccount::get_func_del(int id, ComplexFunction &orb) { +int BankAccount::get_func_del(int id, CompFunction<3> &orb) { #ifdef MRCPP_HAS_MPI MPI_Status status; int messages[message_size]; @@ -691,7 +690,7 @@ int BankAccount::get_func_del(int id, ComplexFunction &orb) { } // save function in Bank with identity id -int BankAccount::put_func(int id, ComplexFunction &func) { +int BankAccount::put_func(int id, CompFunction<3> &func) { #ifdef MRCPP_HAS_MPI // for now we distribute according to id int messages[message_size]; @@ -721,6 +720,23 @@ int BankAccount::put_data(int id, int size, double *data) { return 1; } +// save data in Bank with identity id . datasize MUST have been set already. NB:not tested +int BankAccount::put_data(int id, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + // for now we distribute according to id + int messages[message_size]; + + messages[0] = SAVE_DATA; + messages[1] = account_id; + messages[2] = id; + messages[3] = size * 2; // save as twice as many doubles + messages[4] = MIN_SCALE; // to indicate that it is defined by id + MPI_Send(messages, 5, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + MPI_Send(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank); +#endif + return 1; +} + // save data in Bank with identity nIdx. datasize MUST have been set already. NB:not tested int BankAccount::put_data(NodeIndex<3> nIdx, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -740,6 +756,25 @@ int BankAccount::put_data(NodeIndex<3> nIdx, int size, double *data) { return 1; } +// save data in Bank with identity nIdx. datasize MUST have been set already. NB:not tested +int BankAccount::put_data(NodeIndex<3> nIdx, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + // for now we distribute according to id + int messages[message_size]; + messages[0] = SAVE_DATA; + messages[1] = account_id; + messages[2] = nIdx.getTranslation(0); + messages[3] = size * 2; // save as twice as many doubles + messages[4] = nIdx.getScale(); + messages[5] = nIdx.getTranslation(1); + messages[6] = nIdx.getTranslation(2); + int id = std::abs(nIdx.getTranslation(0) + nIdx.getTranslation(1) + nIdx.getTranslation(2)); + MPI_Send(messages, 7, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + MPI_Send(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank); +#endif + return 1; +} + // get data with identity id int BankAccount::get_data(int id, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -755,6 +790,22 @@ int BankAccount::get_data(int id, int size, double *data) { return 1; } +// get data with identity id +int BankAccount::get_data(int id, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + int messages[message_size]; + messages[0] = GET_DATA; + messages[1] = account_id; + messages[2] = id; + messages[3] = MIN_SCALE; + MPI_Send(messages, 4, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + // fetch as twice as many doubles + MPI_Recv(data, size * 2, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank, &status); +#endif + return 1; +} + // get data with identity id int BankAccount::get_data(NodeIndex<3> nIdx, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -774,6 +825,26 @@ int BankAccount::get_data(NodeIndex<3> nIdx, int size, double *data) { return 1; } +// get data with identity id +int BankAccount::get_data(NodeIndex<3> nIdx, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + int messages[message_size]; + int id = std::abs(nIdx.getTranslation(0) + nIdx.getTranslation(1) + nIdx.getTranslation(2)); + messages[0] = GET_DATA; + messages[1] = account_id; + messages[2] = id; + messages[3] = nIdx.getScale(); + messages[4] = nIdx.getTranslation(0); + messages[5] = nIdx.getTranslation(1); + messages[6] = nIdx.getTranslation(2); + MPI_Send(messages, 7, MPI_INT, bankmaster[id % bank_size], 0, comm_bank); + // fetch as twice as many doubles + MPI_Recv(data, size * 2, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank, &status); +#endif + return 1; +} + // save data in Bank with identity id as part of block with identity nodeid. int BankAccount::put_nodedata(int id, int nodeid, int size, double *data) { #ifdef MRCPP_HAS_MPI @@ -790,6 +861,23 @@ int BankAccount::put_nodedata(int id, int nodeid, int size, double *data) { return 1; } +// save data in Bank with identity id as part of block with identity nodeid. +// NB: Complex is stored as two doubles +int BankAccount::put_nodedata(int id, int nodeid, int size, ComplexDouble *data) { +#ifdef MRCPP_HAS_MPI + // for now we distribute according to nodeid + int messages[message_size]; + messages[0] = SAVE_NODEDATA; + messages[1] = account_id; + messages[2] = nodeid; // which block + messages[3] = id; // id within block + messages[4] = 2 * size; // size of this data + MPI_Send(messages, 5, MPI_INT, bankmaster[nodeid % bank_size], 0, comm_bank); + MPI_Send(data, 2 * size, MPI_DOUBLE, bankmaster[nodeid % bank_size], 1, comm_bank); +#endif + return 1; +} + // get data with identity id int BankAccount::get_nodedata(int id, int nodeid, int size, double *data, std::vector &idVec) { #ifdef MRCPP_HAS_MPI @@ -807,6 +895,23 @@ int BankAccount::get_nodedata(int id, int nodeid, int size, double *data, std::v return 1; } +// get data with identity id +int BankAccount::get_nodedata(int id, int nodeid, int size, ComplexDouble *data, std::vector &idVec) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + // get the column with identity id + int messages[message_size]; + messages[0] = GET_NODEDATA; + messages[1] = account_id; + messages[2] = nodeid; // which block + messages[3] = id; // id within block. + messages[4] = size; // expected size of data + MPI_Send(messages, 5, MPI_INT, bankmaster[nodeid % bank_size], 0, comm_bank); + MPI_Recv(data, size, MPI_DOUBLE, bankmaster[nodeid % bank_size], 3, comm_bank, &status); +#endif + return 1; +} + // get all data for nodeid (same nodeid, different orbitals) int BankAccount::get_nodeblock(int nodeid, double *data, std::vector &idVec) { #ifdef MRCPP_HAS_MPI @@ -827,6 +932,26 @@ int BankAccount::get_nodeblock(int nodeid, double *data, std::vector &idVec return 1; } +// get all data for nodeid (same nodeid, different orbitals) +int BankAccount::get_nodeblock(int nodeid, ComplexDouble *data, std::vector &idVec) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + // get the entire superblock and also the id of each column + int messages[message_size]; + messages[0] = GET_NODEBLOCK; + messages[1] = account_id; + messages[2] = nodeid; + + MPI_Send(messages, 3, MPI_INT, bankmaster[nodeid % bank_size], 0, comm_bank); + MPI_Recv(metadata_block, size_metadata, MPI_INT, bankmaster[nodeid % bank_size], 1, comm_bank, &status); + idVec.resize(metadata_block[1]); + int size = metadata_block[2]; + if (size > 0) MPI_Recv(idVec.data(), metadata_block[1], MPI_INT, bankmaster[nodeid % bank_size], 2, comm_bank, &status); + if (size > 0) MPI_Recv(data, size, MPI_DOUBLE, bankmaster[nodeid % bank_size], 3, comm_bank, &status); +#endif + return 1; +} + // get all data with identity orbid (same orbital, different nodes) int BankAccount::get_orbblock(int orbid, double *&data, std::vector &nodeidVec, int bankstart) { #ifdef MRCPP_HAS_MPI @@ -848,6 +973,27 @@ int BankAccount::get_orbblock(int orbid, double *&data, std::vector &nodeid return 1; } +// get all data with identity orbid (same orbital, different nodes) +int BankAccount::get_orbblock(int orbid, ComplexDouble *&data, std::vector &nodeidVec, int bankstart) { +#ifdef MRCPP_HAS_MPI + MPI_Status status; + int nodeid = wrk_rank + bankstart; + // get the entire superblock and also the nodeid of each column + int messages[message_size]; + messages[0] = GET_ORBBLOCK; + messages[1] = account_id; + messages[2] = orbid; + MPI_Send(messages, 3, MPI_INT, bankmaster[nodeid % bank_size], 0, comm_bank); + MPI_Recv(metadata_block, size_metadata, MPI_INT, bankmaster[nodeid % bank_size], 1, comm_bank, &status); + nodeidVec.resize(metadata_block[1]); + int totsize = metadata_block[2]; + if (totsize > 0) MPI_Recv(nodeidVec.data(), metadata_block[1], MPI_INT, bankmaster[nodeid % bank_size], 2, comm_bank, &status); + data = new ComplexDouble[totsize / 2]; + if (totsize > 0) MPI_Recv(data, totsize, MPI_DOUBLE, bankmaster[nodeid % bank_size], 3, comm_bank, &status); +#endif + return 1; +} + // creator. NB: collective BankAccount::BankAccount(int iclient, MPI_Comm comm) { this->account_id = dataBank.openAccount(iclient, comm); diff --git a/src/utils/Bank.h b/src/utils/Bank.h index 501faa7a0..69719c530 100644 --- a/src/utils/Bank.h +++ b/src/utils/Bank.h @@ -1,6 +1,6 @@ #pragma once -#include "ComplexFunction.h" +#include "CompFunction.h" #include "parallel.h" #include "trees/NodeIndex.h" @@ -9,7 +9,7 @@ namespace mrcpp { using namespace mpi; struct deposit { - ComplexFunction *orb; + CompFunction<3> *orb; double *data; // for pure data arrays bool hasdata; int datasize; @@ -96,17 +96,25 @@ class BankAccount { void clear(int i = wrk_rank, MPI_Comm comm = comm_wrk); // int put_orb(int id, ComplexFunction &orb); // int get_orb(int id, ComplexFunction &orb, int wait = 0); - int get_func_del(int id, ComplexFunction &orb); - int put_func(int id, ComplexFunction &func); - int get_func(int id, ComplexFunction &func, int wait = 0); + int get_func_del(int id, CompFunction<3> &orb); + int put_func(int id, CompFunction<3> &func); + int get_func(int id, CompFunction<3> &func, int wait = 0); int put_data(int id, int size, double *data); + int put_data(int id, int size, ComplexDouble *data); int get_data(int id, int size, double *data); + int get_data(int id, int size, ComplexDouble *data); int put_data(NodeIndex<3> nIdx, int size, double *data); + int put_data(NodeIndex<3> nIdx, int size, ComplexDouble *data); int get_data(NodeIndex<3> nIdx, int size, double *data); + int get_data(NodeIndex<3> nIdx, int size, ComplexDouble *data); int put_nodedata(int id, int nodeid, int size, double *data); + int put_nodedata(int id, int nodeid, int size, ComplexDouble *data); int get_nodedata(int id, int nodeid, int size, double *data, std::vector &idVec); + int get_nodedata(int id, int nodeid, int size, ComplexDouble *data, std::vector &idVec); int get_nodeblock(int nodeid, double *data, std::vector &idVec); + int get_nodeblock(int nodeid, ComplexDouble *data, std::vector &idVec); int get_orbblock(int orbid, double *&data, std::vector &nodeidVec, int bankstart); + int get_orbblock(int orbid, ComplexDouble *&data, std::vector &nodeidVec, int bankstart); }; class TaskManager { diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index bfaa4e0ba..c3238abec 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -11,7 +11,7 @@ target_sources(mrcpp ${CMAKE_CURRENT_SOURCE_DIR}/tree_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/Bank.cpp ${CMAKE_CURRENT_SOURCE_DIR}/parallel.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ComplexFunction.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/CompFunction.cpp ) get_filename_component(_dirname ${CMAKE_CURRENT_LIST_DIR} NAME) @@ -28,7 +28,7 @@ list(APPEND ${_dirname}_h ${CMAKE_CURRENT_SOURCE_DIR}/tree_utils.h ${CMAKE_CURRENT_SOURCE_DIR}/Bank.h ${CMAKE_CURRENT_SOURCE_DIR}/parallel.h - ${CMAKE_CURRENT_SOURCE_DIR}/ComplexFunction.h + ${CMAKE_CURRENT_SOURCE_DIR}/CompFunction.h ) # Sets install directory for all the headers in the list diff --git a/src/utils/CompFunction.cpp b/src/utils/CompFunction.cpp new file mode 100644 index 000000000..a8ce70799 --- /dev/null +++ b/src/utils/CompFunction.cpp @@ -0,0 +1,2771 @@ +#include "CompFunction.h" +#include "Bank.h" +#include "Printer.h" +#include "parallel.h" +#include "treebuilders/add.h" +#include "treebuilders/grid.h" +#include "treebuilders/multiply.h" +#include "treebuilders/project.h" +#include "trees/FunctionNode.h" +#include + +/* Some rules for CompFunction: + * NComp is the number of components. If Ncomp>0, the corresponding trees must exist (can be only empty roots). + * The other trees should be set to nullptr. + * The trees and data can be shared among several CompFunction; this is managed automatically by "std::make_shared" + * Normally the CompFunction must be eiher real or complex (or none if noe is defined anyway). + * Though it is allowed in some cases to have both and the code should preferably allow this. (It is used temporary + * when we need a Complex type, but the trees are real: the tree is then copied as a complex tree in the same CompFunction). + * TreePtr (aka func_ptr) is the part potentially shared with others with "std::make_shared". It contains the pointers to the trees. + * The static data (number of components, real/complex, conjugaison, integers used for spin etc.) are store in func_ptr.data. + */ + +namespace mrcpp { + +template MultiResolutionAnalysis *defaultCompMRA = nullptr; // Global MRA + +template CompFunction::CompFunction(MultiResolutionAnalysis &mra) { + defaultCompMRA = &mra; + func_ptr = std::make_shared>(false); + CompD = func_ptr->real; + CompC = func_ptr->cplx; + for (int i = 0; i < 4; i++) CompD[i] = nullptr; + for (int i = 0; i < 4; i++) CompC[i] = nullptr; +} + +template CompFunction::CompFunction() { + func_ptr = std::make_shared>(false); + CompD = func_ptr->real; + CompC = func_ptr->cplx; + for (int i = 0; i < 4; i++) CompD[i] = nullptr; + for (int i = 0; i < 4; i++) CompC[i] = nullptr; +} + +/* + * Empty functions (no components defined) + */ +template CompFunction::CompFunction(int n1) { + func_ptr = std::make_shared>(false); + CompD = func_ptr->real; + CompC = func_ptr->cplx; + for (int i = 0; i < 4; i++) CompD[i] = nullptr; + for (int i = 0; i < 4; i++) CompC[i] = nullptr; + func_ptr->data.n1[0] = n1; + func_ptr->data.n2[0] = -1; + func_ptr->data.n3[0] = 0; + func_ptr->rank = 0; + func_ptr->isreal = 1; + func_ptr->iscomplex = 0; + func_ptr->data.shared = false; +} + +/* + * Empty functions (no components defined) + */ +template CompFunction::CompFunction(int n1, bool share) { + func_ptr = std::make_shared>(share); + CompD = func_ptr->real; + CompC = func_ptr->cplx; + for (int i = 0; i < 4; i++) CompD[i] = nullptr; + for (int i = 0; i < 4; i++) CompC[i] = nullptr; + func_ptr->data.n1[0] = n1; + func_ptr->data.n2[0] = -1; + func_ptr->data.n3[0] = 0; + func_ptr->rank = 0; + func_ptr->isreal = 1; + func_ptr->iscomplex = 0; + func_ptr->data.shared = share; +} + +/* + * Empty functions (trees defined but zero) + */ +template CompFunction::CompFunction(const CompFunctionData &indata, bool alloc) { + func_ptr = std::make_shared>(indata.shared); + func_ptr->data = indata; + CompD = func_ptr->real; + CompC = func_ptr->cplx; + if (alloc) + this->alloc(Ncomp()); + else + this->free(); +} + +/** @brief Copy constructor + * + * Shallow copy: meta data is copied along with the component pointers, + * NO transfer of ownership. + */ +template CompFunction::CompFunction(const CompFunction &compfunc) { + func_ptr = compfunc.func_ptr; + CompD = func_ptr->real; + CompC = func_ptr->cplx; +} + +/** @brief Copy constructor + * + * Shallow copy: meta data is copied along with the component pointers, + * NO transfer of ownership. + */ +template CompFunction &CompFunction::operator=(const CompFunction &compfunc) { + if (this != &compfunc) { + func_ptr = compfunc.func_ptr; + CompD = func_ptr->real; + CompC = func_ptr->cplx; + } + return *this; +} + +template +/** @brief Parameter copy + * + * Returns a copy without defined trees. + */ +CompFunction CompFunction::paramCopy(bool alloc) const { + return CompFunction(func_ptr->data, alloc); +} + +template void CompFunction::flushMRAData() { + const auto &box = defaultCompMRA<3>->getWorldBox(); + func_ptr->data.type = defaultCompMRA<3>->getScalingBasis().getScalingType(); + func_ptr->data.order = defaultCompMRA<3>->getOrder(); + func_ptr->data.depth = defaultCompMRA<3>->getMaxDepth(); + func_ptr->data.scale = box.getScale(); + func_ptr->data.boxes[0] = box.size(0); + func_ptr->data.boxes[1] = box.size(1); + func_ptr->data.boxes[2] = box.size(2); + func_ptr->data.corner[0] = box.getCornerIndex().getTranslation(0); + func_ptr->data.corner[1] = box.getCornerIndex().getTranslation(1); + func_ptr->data.corner[2] = box.getCornerIndex().getTranslation(2); +} + +template void CompFunction::flushFuncData() { + if (D == 3) flushMRAData(); + for (int i = 0; i < Ncomp(); i++) { + if (isreal()) { + func_ptr->Nchunks[i] = CompD[i]->getNChunksUsed(); + } else { + func_ptr->Nchunks[i] = CompC[i]->getNChunksUsed(); + } + } + for (int i = Ncomp(); i < 4; i++) func_ptr->Nchunks[i] = 0; +} + +template CompFunctionData CompFunction::getFuncData() const { + CompFunctionData outdata; + const auto &box = defaultCompMRA<3>->getWorldBox(); + outdata.type = defaultCompMRA<3>->getScalingBasis().getScalingType(); + outdata.order = defaultCompMRA<3>->getOrder(); + outdata.depth = defaultCompMRA<3>->getMaxDepth(); + outdata.scale = box.getScale(); + outdata.boxes[0] = box.size(0); + outdata.boxes[1] = box.size(1); + outdata.boxes[2] = box.size(2); + outdata.corner[0] = box.getCornerIndex().getTranslation(0); + outdata.corner[1] = box.getCornerIndex().getTranslation(1); + outdata.corner[2] = box.getCornerIndex().getTranslation(2); + for (int i = 0; i < Ncomp(); i++) { + if (isreal()) { + outdata.Nchunks[i] = CompD[i]->getNChunksUsed(); + } else { + outdata.Nchunks[i] = CompC[i]->getNChunksUsed(); + } + } + for (int i = Ncomp(); i < 4; i++) outdata.Nchunks[i] = 0; + return outdata; +} + +template ComplexDouble CompFunction::integrate() const { + ComplexDouble integral; + if (isreal()) + integral = CompD[0]->integrate(); + else + integral = CompC[0]->integrate(); + return integral; +} + +template double CompFunction::norm() const { + double norm = getSquareNorm(); + if (norm > 0.0) norm = std::sqrt(norm); + return norm; +} +template double CompFunction::getSquareNorm() const { + double norm = 0.0; + for (int i = 0; i < Ncomp(); i++) { + if (isreal() and CompD[i] != nullptr) { + norm += CompD[i]->getSquareNorm(); + } else if (iscomplex() and CompC[i] != nullptr) { + norm += CompC[i]->getSquareNorm(); + } + } + return norm; +} + +// Allocate empty trees. The tree must be defined as real or complex already. +// Allocates all ialloc trees, with indices 0,...ialloc-1 +// nalloc is the number of components allocated. ialloc=1 allocates one tree. +// deletes all old trees if found. +template void CompFunction::alloc(int nalloc, bool zero) { + if (defaultCompMRA == nullptr) MSG_ABORT("Default MRA not yet defined"); + if (isreal() == 0 and iscomplex() == 0) MSG_ABORT("Function must be defined either real or complex"); + for (int i = 0; i < nalloc; i++) { + delete CompD[i]; + delete CompC[i]; + CompD[i] = nullptr; + CompC[i] = nullptr; + if (isreal()) { + CompD[i] = new FunctionTree(*defaultCompMRA, func_ptr->shared_mem_real); + if (zero) CompD[i]->setZero(); + } + if (iscomplex()) { + CompC[i] = new FunctionTree(*defaultCompMRA, func_ptr->shared_mem_cplx); + if (zero) CompC[i]->setZero(); + } + func_ptr->Ncomp = std::max(Ncomp(), i + 1); + } + for (int i = nalloc; i < Ncomp(); i++) { + // delete possible remaining components + delete CompD[i]; + delete CompC[i]; + CompD[i] = nullptr; + CompC[i] = nullptr; + } +} + +// Allocate one empty trees for one specific component. +// The tree must be defined as real or complex already. +// ialloc is index allocated. ialloc=0 allocates the tree with index zero. +// deletes old tree if found. +template void CompFunction::alloc_comp(int ialloc) { + if (defaultCompMRA == nullptr) MSG_ABORT("Default MRA not yet defined"); + if (isreal() == 0 and iscomplex() == 0) MSG_ABORT("Function must be defined either real or complex"); + int i = ialloc; + delete CompD[i]; + delete CompC[i]; + CompD[i] = nullptr; + CompC[i] = nullptr; + if (isreal()) { + CompD[i] = new FunctionTree(*defaultCompMRA, func_ptr->shared_mem_real); + CompD[i]->setZero(); + } + if (iscomplex()) { + CompC[i] = new FunctionTree(*defaultCompMRA, func_ptr->shared_mem_cplx); + CompC[i]->setZero(); + } + func_ptr->Ncomp = std::max(Ncomp(), i + 1); +} + +template void CompFunction::free() { + for (int i = 0; i < Ncomp(); i++) { + if (CompD[i] != nullptr) delete CompD[i]; + if (CompC[i] != nullptr) delete CompC[i]; + CompD[i] = nullptr; + CompC[i] = nullptr; + } + if (this->func_ptr->shared_mem_real) this->func_ptr->shared_mem_real->clear(); + if (this->func_ptr->shared_mem_cplx) this->func_ptr->shared_mem_cplx->clear(); + func_ptr->Ncomp = 0; +} + +template int CompFunction::getSizeNodes() const { + int size_mb = 0; // Memory size in kB + for (int i = 0; i < Ncomp(); i++) { + if (isreal() and CompD[i] != nullptr) size_mb += CompD[i]->getSizeNodes(); + if (iscomplex() and CompC[i] != nullptr) size_mb += CompC[i]->getSizeNodes(); + } + return size_mb; +} + +template int CompFunction::getNNodes() const { + int nNodes = 0; + for (int i = 0; i < Ncomp(); i++) { + if (isreal() and CompD[i] != nullptr) nNodes += CompD[i]->getNNodes(); + if (iscomplex() and CompC[i] != nullptr) nNodes += CompC[i]->getNNodes(); + } + return nNodes; +} + +/** @brief Soft complex conjugate + * + * Will use complex conjugate in operations (add, multiply etc.) + * Does change the state (conj flag), but does not actively change all coefficients. + */ +template void CompFunction::dagger() { + func_ptr->data.conj = not(func_ptr->data.conj); + for (int i = 0; i < Ncomp(); i++) { + if (CompC[i] != nullptr) CompC[i]->setConjugate(func_ptr->data.conj); + } +} + +template FunctionTree &CompFunction::real(int i) { + if (!isreal()) MSG_ABORT("not real function"); + if (CompD[i] == nullptr) alloc_comp(i); + return *CompD[i]; +} +template // NB: should return CompC in the future +FunctionTree &CompFunction::imag(int i) { + MSG_ABORT("Must choose real or complex"); + if (!iscomplex()) MSG_ABORT("not complex function"); + return *CompD[i]; +} + +template FunctionTree &CompFunction::complex(int i) { + if (!iscomplex()) MSG_ABORT("not marked as a complex function"); + if (CompC[i] == nullptr) alloc_comp(i); + return *CompC[i]; +} + +template const FunctionTree &CompFunction::real(int i) const { + if (!isreal()) MSG_ABORT("not real function"); + return *CompD[i]; +} +template // NB: should use complex or real +const FunctionTree &CompFunction::imag(int i) const { + MSG_ABORT("Must choose real or complex"); + if (!iscomplex()) MSG_ABORT("not complex function"); + return *CompD[i]; +} +template const FunctionTree &CompFunction::complex(int i) const { + if (!iscomplex()) MSG_ABORT("not marked as a complex function"); + return *CompC[i]; +} + +/* for backwards compatibility */ +template void CompFunction::setReal(FunctionTree *tree, int i) { + func_ptr->isreal = 1; + // if (CompD[i] != nullptr) delete CompD[i]; + CompD[i] = tree; + if (tree != nullptr) { + func_ptr->Ncomp = std::max(Ncomp(), i + 1); + } else { + func_ptr->Ncomp = std::min(Ncomp(), i); + } +} + +template void CompFunction::setCplx(FunctionTree *tree, int i) { + func_ptr->iscomplex = 1; + // if (CompC[i] != nullptr) delete CompC[i]; + CompC[i] = tree; + if (tree != nullptr) { + func_ptr->Ncomp = std::max(Ncomp(), i + 1); + } else { + func_ptr->Ncomp = std::min(Ncomp(), i); + } +} + +/** @brief In place addition. + * + * Output is extended to union grid. + * + */ +template void CompFunction::add(ComplexDouble c, CompFunction inp) { + + if (Ncomp() < inp.Ncomp()) { + func_ptr->data = inp.func_ptr->data; + alloc(inp.Ncomp(), true); + } + + for (int i = 0; i < inp.Ncomp(); i++) { + if (inp.isreal() and c.imag() < MachineZero) { + CompD[i]->add_inplace(c.real(), *inp.CompD[i]); + } else { + if (this->isreal()) { + CompD[i]->CopyTreeToComplex(CompC[i]); + delete CompD[i]; + CompD[i] = nullptr; + func_ptr->iscomplex = true; + func_ptr->isreal = false; + } + CompC[i]->add_inplace(c, *inp.CompC[i]); + } + } +} + +template int CompFunction::crop(double prec) { + if (prec < 0.0) return 0; + int nChunksremoved = 0; + for (int i = 0; i < Ncomp(); i++) { + if (isreal()) { + nChunksremoved += CompD[i]->crop(prec, 1.0, false); + } else { + nChunksremoved += CompC[i]->crop(prec, 1.0, false); + } + } + return nChunksremoved; +} + +/** @brief In place multiply with scalar. Fully in-place.*/ +template void CompFunction::rescale(ComplexDouble c) { + bool need_to_rescale = not(isShared()) or mpi::share_master(); + if (need_to_rescale) { + for (int i = 0; i < Ncomp(); i++) { + if (iscomplex()) { + CompC[i]->rescale(c); + } else { + if (abs(c.imag()) > MachineZero) { // works only only for NComp==1) + CompD[i]->CopyTreeToComplex(CompC[i]); + delete CompD[i]; + CompD[i] = nullptr; + func_ptr->iscomplex = true; + func_ptr->isreal = false; + CompC[i]->rescale(c); + } else { + CompD[i]->rescale(c.real()); + } + } + } + } else + MSG_ABORT("Not implemented"); +} + +template class MultiResolutionAnalysis<1>; +template class MultiResolutionAnalysis<2>; +template class MultiResolutionAnalysis<3>; +template class CompFunction<1>; +template class CompFunction<2>; +template class CompFunction<3>; + +/** @brief Deep copy + * + * Deep copy: meta data is copied along with the content of each component. + */ +template void deep_copy(CompFunction *out, const CompFunction &inp) { + out->func_ptr->data = inp.func_ptr->data; + out->alloc(inp.Ncomp()); + for (int i = 0; i < inp.Ncomp(); i++) { + if (inp.isreal()) { + inp.CompD[i]->deep_copy(out->CompD[i]); + } else { + inp.CompC[i]->deep_copy(out->CompC[i]); + } + } +} + +/** @brief Deep copy + * + * Deep copy: meta func_ptr->data is copied along with the content of each component. + */ +template void deep_copy(CompFunction &out, const CompFunction &inp) { + out.func_ptr->data = inp.func_ptr->data; + out.alloc(inp.Ncomp()); + for (int i = 0; i < inp.Ncomp(); i++) { + if (inp.isreal()) { + inp.CompD[i]->deep_copy(out.CompD[i]); + } else { + inp.CompC[i]->deep_copy(out.CompC[i]); + } + } +} + +/** @brief out = a*inp_a + b*inp_b + * + * Recast into linear_combination. + * + */ +template void add(CompFunction &out, ComplexDouble a, CompFunction inp_a, ComplexDouble b, CompFunction inp_b, double prec, bool conjugate) { + std::vector coefs(2); + coefs[0] = a; + coefs[1] = b; + + std::vector> funcs; // NB: not a CompFunctionVector, because not run in parallel! + funcs.push_back(inp_a); + funcs.push_back(inp_b); + + linear_combination(out, coefs, funcs, prec, conjugate); +} + +/** @brief out = c_0*inp_0 + c_1*inp_1 + ... + c_N*inp_N + * + * OMP parallel, but not MPI parallel + */ +template void linear_combination(CompFunction &out, const std::vector &c, std::vector> &inp, double prec, bool conjugate) { + double thrs = MachineZero; + bool need_to_add = not(out.isShared()) or mpi::share_master(); + bool share = out.isShared(); + out.func_ptr->data = inp[0].func_ptr->data; + out.func_ptr->data.shared = share; // we don' inherit the shareness + bool iscomplex = false; + for (int i = 0; i < inp.size(); i++) + if (inp[i].iscomplex() or c[i].imag() > MachineZero) iscomplex = true; + if (iscomplex) { + out.func_ptr->data.iscomplex = 1; + out.func_ptr->data.isreal = 0; + } + out.alloc(out.Ncomp()); + for (int comp = 0; comp < inp[0].Ncomp(); comp++) { + if (not iscomplex) { + FunctionTreeVector fvec; // one component vector + for (int i = 0; i < inp.size(); i++) { + if (std::norm(c[i]) < thrs) continue; + if (inp[i].getNNodes() == 0 or inp[i].CompD[comp]->getSquareNorm() < thrs) continue; + fvec.push_back(std::make_tuple(c[i].real(), inp[i].CompD[comp])); + } + if (need_to_add) { + if (fvec.size() > 0) { + if (prec < 0.0) { + build_grid(*out.CompD[comp], fvec); + mrcpp::add(prec, *out.CompD[comp], fvec, 0); + } else { + mrcpp::add(prec, *out.CompD[comp], fvec); + } + } else if (out.isreal()) { + out.CompD[comp]->setZero(); + } + } + } else { + FunctionTreeVector fvec; // one component vector + for (int i = 0; i < inp.size(); i++) { + if (inp[i].isreal()) { + inp[i].CompD[comp]->CopyTreeToComplex(inp[i].CompC[comp]); + delete inp[i].CompD[comp]; + inp[i].CompD[comp] = nullptr; + inp[i].func_ptr->iscomplex = true; + inp[i].func_ptr->isreal = false; + } + if (std::norm(c[i]) < thrs) continue; + if (inp[i].getNNodes() == 0 or inp[i].CompC[comp]->getSquareNorm() < thrs) continue; + fvec.push_back(std::make_tuple(c[i], inp[i].CompC[comp])); + } + if (need_to_add) { + if (fvec.size() > 0) { + if (prec < 0.0) { + build_grid(*out.CompC[comp], fvec); + mrcpp::add(prec, *out.CompC[comp], fvec, 0, false, conjugate); + } else { + mrcpp::add(prec, *out.CompC[comp], fvec, -1, false, conjugate); + } + } else if (out.iscomplex()) { + out.CompC[comp]->setZero(); + } + } + } + mpi::share_function(out, 0, 9911, mpi::comm_share); + } +} + +/** @brief out = inp_a * inp_b + * + */ +template void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, double prec, bool absPrec, bool useMaxNorms, bool conjugate) { + multiply(prec, out, 1.0, inp_a, inp_b, -1, absPrec, useMaxNorms, conjugate); +} + +/** @brief out = inp_a * inp_b + * Takes conjugate of inp_a if conjugate=true + * In case of mixed real/complex inputs, the real functions are converted into complex functions. + */ +template void multiply(double prec, CompFunction &out, double coef, CompFunction inp_a, CompFunction inp_b, int maxIter, bool absPrec, bool useMaxNorms, bool conjugate) { + if (inp_b.func_ptr->conj) MSG_ABORT("Not implemented"); + if (inp_a.func_ptr->conj) conjugate = (not conjugate); + bool need_to_multiply = not(out.isShared()) or mpi::share_master(); + bool out_allocated = true; + if (out.Ncomp() == 0) out_allocated = false; + bool share = out.isShared(); + out.func_ptr->data = inp_a.func_ptr->data; + out.func_ptr->data.shared = share; // we don't inherit the shareness + out.func_ptr->conj = false; // we don't inherit conjugaison + for (int comp = 0; comp < inp_a.Ncomp(); comp++) { + out.func_ptr->data.c1[comp] = inp_a.func_ptr->data.c1[comp] * inp_b.func_ptr->data.c1[comp]; // we could put this is coef if everything is real? + if (inp_a.isreal() and inp_b.isreal()) { + if (need_to_multiply) { + if (!out_allocated) out.alloc(out.Ncomp()); + if (prec < 0.0) { + // Union grid + build_grid(*out.CompD[comp], *inp_a.CompD[comp]); + build_grid(*out.CompD[comp], *inp_b.CompD[comp]); + mrcpp::multiply(prec, *out.CompD[comp], coef, *inp_a.CompD[comp], *inp_b.CompD[comp], 0, false, false, conjugate); + } else { + // Adaptive grid + mrcpp::multiply(prec, *out.CompD[comp], coef, *inp_a.CompD[comp], *inp_b.CompD[comp], maxIter, absPrec, useMaxNorms, conjugate); + } + } + } else { + // if one of the input is real, we simply make a new complex copy of it + bool inp_aisReal = inp_a.isreal(); + bool inp_bisReal = inp_b.isreal(); + if (inp_aisReal) { + inp_a.CompD[comp]->CopyTreeToComplex(inp_a.CompC[comp]); + inp_a.func_ptr->iscomplex = true; + inp_a.func_ptr->isreal = false; + } + if (inp_bisReal) { + inp_b.CompD[comp]->CopyTreeToComplex(inp_b.CompC[comp]); + inp_b.func_ptr->iscomplex = true; + inp_b.func_ptr->isreal = false; + } + ComplexDouble coef = 1.0; + if (need_to_multiply) { + if (prec < 0.0) { + // Union grid + out.func_ptr->iscomplex = 1; + out.func_ptr->isreal = 0; + delete out.CompD[comp]; + delete out.CompC[comp]; + if (!out_allocated) out.alloc(out.Ncomp()); + build_grid(*out.CompC[comp], *inp_a.CompC[comp]); + build_grid(*out.CompC[comp], *inp_b.CompC[comp]); + mrcpp::multiply(prec, *out.CompC[comp], coef, *inp_a.CompC[comp], *inp_b.CompC[comp], 0, false, false, conjugate); + } else { // note that this assumes Ncomp=1 + // Adaptive grid + if (out.CompD[comp] != nullptr) { // NB: func_ptr has alreadybeen overwritten! + if (out.CompD[comp]->getNNodes() > 0) { + out.CompD[comp]->CopyTreeToComplex(out.CompC[comp]); + out.func_ptr->iscomplex = 1; + out.func_ptr->isreal = 0; + delete out.CompD[comp]; + out.CompD[comp] = nullptr; + } else { + out.func_ptr->iscomplex = 1; + out.func_ptr->isreal = 0; + out.alloc(out.Ncomp()); + } + } else { + out.func_ptr->iscomplex = 1; + out.func_ptr->isreal = 0; + if (!out_allocated) out.alloc(out.Ncomp()); + } + mrcpp::multiply(prec, *out.CompC[comp], coef, *inp_a.CompC[comp], *inp_b.CompC[comp], maxIter, absPrec, useMaxNorms, conjugate); + } + } + // restore original tree + if (inp_aisReal) { + delete inp_a.CompC[comp]; + inp_a.CompC[comp] = nullptr; + inp_a.func_ptr->iscomplex = false; + inp_a.func_ptr->isreal = true; + } + if (inp_bisReal) { + delete inp_b.CompC[comp]; + inp_b.CompC[comp] = nullptr; + inp_b.func_ptr->iscomplex = false; + inp_b.func_ptr->isreal = true; + } + } + } + mpi::share_function(out, 0, 9911, mpi::comm_share); +} + +/** @brief out = inp_a * f + * + * Only one component is multiplied + */ +template void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, double prec, int nrefine, bool conjugate) { + if (inp_a.Ncomp() > 1) MSG_ABORT("Not implemented"); + if (inp_a.isreal() != 1) MSG_ABORT("Not implemented"); + if (conjugate) MSG_ABORT("Not implemented"); + CompFunctionVector CompVec; // Should use vector? + CompVec.push_back(inp_a); + CompFunctionVector CompVecOut; + CompVecOut = multiply(CompVec, f, prec, nullptr, nrefine, true); + out = CompVecOut[0]; + // multiply(out, *inp_a.CompD[0], f, prec, nrefine, conjugate); +} + +/** @brief out = inp_a * f + * + * Only one component is multiplied + */ +template void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, double prec, int nrefine, bool conjugate) { + MSG_ABORT("Not implemented"); + if (inp_a.Ncomp() > 1) MSG_ABORT("Not implemented"); + if (inp_a.iscomplex() != 1) MSG_ABORT("Not implemented"); + if (conjugate) MSG_ABORT("Not implemented"); + CompFunctionVector CompVec; // Should use vector? + CompVec.push_back(inp_a); + CompFunctionVector CompVecOut; + // CompVecOut = multiply(CompVec, f, prec, nrefine, true); + out = CompVecOut[0]; +} + +/** @brief out = inp_a * f + * + */ +template void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, double prec, int nrefine, bool conjugate) { + CompFunction func_a; + func_a.func_ptr->isreal = 1; + func_a.func_ptr->iscomplex = 0; + func_a.alloc(1); + func_a.CompD[0] = &inp_a; + multiply(out, func_a, f, prec, nrefine, conjugate); + func_a.CompD[0] = nullptr; +} +template void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, double prec, int nrefine, bool conjugate) { + CompFunction func_a(1); + func_a.func_ptr->isreal = 0; + func_a.func_ptr->iscomplex = 1; + func_a.CompC[0] = &inp_a; + multiply(out, func_a, f, prec, nrefine, conjugate); + func_a.CompC[0] = nullptr; +} + +/** @brief Compute = int bra^\dag(r) * ket(r) dr. + * + * Sum of component dots. + * Notice that the ComplexDouble dot(CompFunction bra, CompFunction ket) { + if (bra.func_ptr->conj or ket.func_ptr->conj) MSG_ABORT("Not implemented"); + ComplexDouble dotprodtot = 0.0; + for (int comp = 0; comp < bra.Ncomp(); comp++) { + ComplexDouble dotprod = 0.0; + if (bra.func_ptr->data.n1[0] != ket.func_ptr->data.n1[0] and bra.func_ptr->data.n1[0] != 0 and ket.func_ptr->data.n1[0] != 0) continue; + if (bra.isreal() and ket.isreal()) { + dotprod += mrcpp::dot(*bra.CompD[comp], *ket.CompD[comp]); + } else if (bra.isreal() and ket.iscomplex()) { + dotprod += mrcpp::dot(*bra.CompD[comp], *ket.CompC[comp]); + } else if (bra.iscomplex() and ket.isreal()) { + dotprod += mrcpp::dot(*bra.CompC[comp], *ket.CompD[comp]); + } else { + dotprod += mrcpp::dot(*bra.CompC[comp], *ket.CompC[comp]); + } + dotprod *= bra.func_ptr->data.c1[comp] * ket.func_ptr->data.c1[comp]; + dotprodtot += dotprod; + } + if (bra.isreal() and ket.isreal()) { + return dotprodtot.real(); + } else { + return dotprodtot; + } +} + +/** @brief Compute = int |bra^\dag(r)| * |ket(r)| dr. + * + * sum of components + */ +template double node_norm_dot(CompFunction bra, CompFunction ket) { + double dotprodtot = 0.0; + for (int comp = 0; comp < bra.Ncomp(); comp++) { + double dotprod = 0.0; + if (bra.isreal() and ket.isreal()) { + dotprod += mrcpp::node_norm_dot(*bra.CompD[comp], *ket.CompD[comp]); + } else if (bra.isreal() and ket.iscomplex()) { + MSG_ABORT("Not implemented"); + } else if (bra.iscomplex() and ket.isreal()) { + MSG_ABORT("Not implemented"); + } else { + dotprod += mrcpp::node_norm_dot(*bra.CompC[comp], *ket.CompC[comp]); + } + dotprod *= std::norm(bra.func_ptr->data.c1[comp]) * std::norm(ket.func_ptr->data.c1[comp]); // for fully complex values this does not really give the norm + dotprodtot += dotprod; + } + return dotprodtot; +} + +void project(CompFunction<3> &out, std::function &r)> f, double prec) { + bool need_to_project = not(out.isShared()) or mpi::share_master(); + out.func_ptr->isreal = 1; + out.func_ptr->iscomplex = 0; + if (out.Ncomp() < 1) out.alloc(1); + if (need_to_project) mrcpp::project<3>(prec, *out.CompD[0], f); + mpi::share_function(out, 0, 123123, mpi::comm_share); +} + +// template +void project(CompFunction<3> &out, std::function &r)> f, double prec) { + bool need_to_project = not(out.isShared()) or mpi::share_master(); + out.func_ptr->isreal = 0; + out.func_ptr->iscomplex = 1; + if (out.Ncomp() < 1) out.alloc(1); + if (need_to_project) mrcpp::project<3>(prec, *out.CompC[0], f); + mpi::share_function(out, 0, 123123, mpi::comm_share); +} + +template void project(CompFunction &out, RepresentableFunction &f, double prec) { + bool need_to_project = not(out.isShared()) or mpi::share_master(); + out.func_ptr->isreal = 1; + out.func_ptr->iscomplex = 0; + if (out.Ncomp() < 1) out.alloc(1); + if (need_to_project) mrcpp::project(prec, *out.CompD[0], f); + mpi::share_function(out, 0, 132231, mpi::comm_share); +} +template void project(CompFunction &out, RepresentableFunction &f, double prec) { + bool need_to_project = not(out.isShared()) or mpi::share_master(); + out.func_ptr->isreal = 0; + out.func_ptr->iscomplex = 1; + if (out.Ncomp() < 1) out.alloc(1); + if (need_to_project) mrcpp::project(prec, *out.CompC[0], f); + mpi::share_function(out, 0, 132231, mpi::comm_share); +} + +// CompFunctionVector + +CompFunctionVector::CompFunctionVector(int N) + : std::vector>(N) { + for (int i = 0; i < N; i++) (*this)[i].func_ptr->rank = i; + vecMRA = defaultCompMRA<3>; +} +void CompFunctionVector::distribute() { + for (int i = 0; i < this->size(); i++) (*this)[i].func_ptr->rank = i; +} + +/** @brief Make a linear combination of functions + * + * Uses "local" representation: treats one node at a time. + * For each node, all functions are transformed simultaneously + * by a dense matrix multiplication. + * Phi input functions, Psi output functions + * Phi and Psi are complex. + */ +void rotate_cplx(CompFunctionVector &Phi, const ComplexMatrix &U, CompFunctionVector &Psi, double prec) { + + // The principle of this routine is that nodes for all orbitals are rotated one by one using matrix multiplication. + // The routine does avoid when possible to move data, but uses pointers and indices manipulation. + // MPI version does not use OMP yet, Serial version uses OMP + // size of input is N, size of output is M + bool serial = mpi::wrk_size == 1; // flag for serial/MPI switch + int N = Phi.size(); + int M = Psi.size(); + for (int i = 0; i < M; i++) { + for (int j; j < 4; j++) delete Psi[i].CompD[j]; + Psi[i].func_ptr->isreal = 0; + Psi[i].func_ptr->iscomplex = 1; + } + for (int i = 0; i < N; i++) { + if (Phi[i].func_ptr->conj) MSG_ABORT("Conjugaison not implemneted for rotations"); + } + if (U.rows() < N) MSG_ABORT("Incompatible number of rows for U matrix"); + if (U.cols() < M) MSG_ABORT("Incompatible number of columns for U matrix"); + + // 1) make union tree without coefficients. Note that the ref tree is always real (in fact it has no coeff) + FunctionTree<3> refTree(*Phi.vecMRA); + mpi::allreduce_Tree_noCoeff(refTree, Phi, mpi::comm_wrk); + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + std::vector scalefac_ref; + std::vector coeffVec_ref; // not used! + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + int max_ix; + // get a list of all nodes in union tree, identified by their serialIx indices + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac_ref, max_ix, refTree); + int max_n = indexVec_ref.size(); + + for (int j = 0; j < N; j++) { + if (!mpi::my_func(j)) continue; + if (Phi[j].isreal()) MSG_ABORT("This function only use complex input"); + } + + for (int i = 0; i < M; i++) { + Psi[i].func_ptr->data.isreal = 0; + Psi[i].func_ptr->data.iscomplex = 1; + } + + // 3) In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + + BankAccount nodesPhi; // to put the original nodes + BankAccount nodesRotated; // to put the rotated nodes + + // used for serial only: + std::vector> coeffVec(N); + std::vector> indexVec(N); // serialIx of the nodes + std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2node(N); // for a given orbital and a given node, gives the node index in the + // orbital given the node index in the reference tree + if (serial) { + // make list of all coefficients (coeffVec), and their reference indices (indexVec) + std::vector parindexVec; // serialIx of the parent nodes + std::vector scalefac; + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + Phi[j].complex().makeCoeffVector(coeffVec[j], indexVec[j], parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec[j]) { + orb2node[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j); + } + } + } else { // MPI case + // send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(Phi, refTree, nodesPhi); + mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. + } + + // 4) rotate all the nodes + IntMatrix split_serial; // in the serial case all split are stored in one array + std::vector> coeffpVec(M); // to put pointers to the rotated coefficient for each orbital in serial case + std::vector> ix2coef(M); // to find the index in for example rotCoeffVec[] corresponding to a serialIx + int csize; // size of the current coefficients (different for roots and branches) + std::vector rotatedCoeffVec; // just to ensure that the data from rotatedCoeff is not deleted, since we point to it. + // j indices are for unrotated orbitals, i indices are for rotated orbitals + if (serial) { + std::map ix2coef_ref; // to find the index n corresponding to a serialIx + split_serial.resize(M, max_n); // not use in the MPI case + for (int n = 0; n < max_n; n++) { + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + ix2coef_ref[node_ix] = n; + for (int i = 0; i < M; i++) split_serial(i, n) = 1; + } + std::vector nodeReady(max_n, 0); // To indicate to OMP threads that the parent is ready (for splits) + // assumes the nodes are ordered such that parent are treated before children. BFS or DFS ok. + // NB: the n must be traversed approximately in right order: Thread n may have to wait until som other preceding + // n is finished. +#pragma omp parallel for schedule(dynamic) + for (int n = 0; n < max_n; n++) { + int csize; + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + // 4a) make a dense contiguous matrix with the coefficient from all the orbitals using node n + std::vector orbjVec; // to remember which orbital correspond to each orbVec.size(); + if (node2orbVec[node_ix].size() <= 0) continue; + csize = sizecoeffW; + if (parindexVec_ref[n] < 0) csize = sizecoeff; // for root nodes we include scaling coeff + + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + if (parindexVec_ref[n] < 0) shift = 0; + ComplexMatrix coeffBlock(csize, node2orbVec[node_ix].size()); + for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2node[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlock(k, orbjVec.size()) = coeffVec[j][orb_node_ix][k + shift]; + orbjVec.push_back(j); + } + + // 4b) make a list of rotated orbitals needed for this node + // OMP must wait until parent is ready + while (parindexVec_ref[n] >= 0 and nodeReady[ix2coef_ref[parindexVec_ref[n]]] == 0) { +#pragma omp flush + }; + + std::vector orbiVec; + for (int i = 0; i < M; i++) { // loop over all rotated orbitals + if (parindexVec_ref[n] >= 0 and split_serial(i, ix2coef_ref[parindexVec_ref[n]]) == 0) continue; // parent node has too small wavelets + orbiVec.push_back(i); + } + + // 4c) rotate this node + ComplexMatrix Un(orbjVec.size(), orbiVec.size()); // chunk of U, with reorganized indices + for (int i = 0; i < orbiVec.size(); i++) { // loop over rotated orbitals + for (int j = 0; j < orbjVec.size(); j++) { Un(j, i) = U(orbjVec[j], orbiVec[i]); } + } + ComplexMatrix rotatedCoeff(csize, orbiVec.size()); + // HERE IT HAPPENS! + // TODO: conjugaison + rotatedCoeff.noalias() = coeffBlock * Un; // Matrix mutiplication + + // 4d) store and make rotated node pointers + // for now we allocate in buffer, in future could be directly allocated in the final trees + double thres = prec * prec * scalefac_ref[n] * scalefac_ref[n]; + // make all norms: + for (int i = 0; i < orbiVec.size(); i++) { + // check if parent must be split + if (parindexVec_ref[n] == -1 or split_serial(orbiVec[i], ix2coef_ref[parindexVec_ref[n]])) { + // mark this node for this orbital for later split +#pragma omp critical + { + ix2coef[orbiVec[i]][node_ix] = coeffpVec[orbiVec[i]].size(); + coeffpVec[orbiVec[i]].push_back(&(rotatedCoeff(0, i))); // list of coefficient pointers + } + // check norms for split + double wnorm = 0.0; // rotatedCoeff(k, i) is already in cache here + int kstart = 0; + if (parindexVec_ref[n] < 0) kstart = sizecoeff - sizecoeffW; // do not include scaling, even for roots + for (int k = kstart; k < csize; k++) wnorm += std::real(rotatedCoeff(k, i) * std::conj(rotatedCoeff(k, i))); + if (thres < wnorm or prec < 0) + split_serial(orbiVec[i], n) = 1; + else + split_serial(orbiVec[i], n) = 0; + } else { + ix2coef[orbiVec[i]][node_ix] = max_n + 1; // should not be used + split_serial(orbiVec[i], n) = 0; // do not split if parent does not need to be split + } + } + nodeReady[n] = 1; +#pragma omp critical + { + // this ensures that rotatedCoeff is not deleted, when getting out of scope + rotatedCoeffVec.push_back(std::move(rotatedCoeff)); + } + } + } else { // MPI case + + // TODO? rotate in bank, so that we do not get and put. Requires clever handling of splits. + std::vector split(M, -1.0); // which orbitals need splitting (at a given node). For now double for compatibilty with bank + std::vector needsplit(M, 1.0); // which orbitals need splitting + BankAccount nodeSplits; + mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. + + ComplexMatrix coeffBlock(sizecoeff, N); + max_ix++; // largest node index + 1. to store rotated orbitals with different id + TaskManager tasks(max_n); + for (int nn = 0; nn < max_n; nn++) { + int n = tasks.next_task(); + if (n < 0) break; + double thres = prec * prec * scalefac_ref[n] * scalefac_ref[n]; + // 4a) make list of orbitals that should split the parent node, i.e. include this node + int parentid = parindexVec_ref[n]; + if (parentid == -1) { + // root node, split if output needed + for (int i = 0; i < M; i++) { split[i] = 1.0; } + csize = sizecoeff; + } else { + // note that it will wait until data is available + nodeSplits.get_data(parentid, M, split.data()); + csize = sizecoeffW; + } + std::vector orbiVec; + std::vector orbjVec; + for (int i = 0; i < M; i++) { // loop over rotated orbitals + if (split[i] < 0.0) continue; // parent node has too small wavelets + orbiVec.push_back(i); + } + + // 4b) rotate this node + ComplexMatrix coeffBlock(csize, N); // largest possible used size + nodesPhi.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbjVec); + coeffBlock.conservativeResize(Eigen::NoChange, orbjVec.size()); // keep only used part + + // chunk of U, with reorganized indices and separate blocks for real and imag: + ComplexMatrix Un(orbjVec.size(), orbiVec.size()); + ComplexMatrix rotatedCoeff(csize, orbiVec.size()); + + for (int i = 0; i < orbiVec.size(); i++) { // loop over included rotated real and imag part of orbitals + for (int j = 0; j < orbjVec.size(); j++) { // loop over input orbital, possibly imaginary parts + Un(j, i) = U(orbjVec[j], orbiVec[i]); + } + } + + // HERE IT HAPPENS + // TODO conjugaison + rotatedCoeff.noalias() = coeffBlock * Un; // Matrix mutiplication + + // 3c) find which orbitals need to further refine this node, and store rotated node (after each other while + // in cache). + for (int i = 0; i < orbiVec.size(); i++) { // loop over rotated orbitals + needsplit[orbiVec[i]] = -1.0; // default, do not split + // check if this node/orbital needs further refinement + double wnorm = 0.0; + int kwstart = csize - sizecoeffW; // do not include scaling + for (int k = kwstart; k < csize; k++) wnorm += std::real(rotatedCoeff.col(i)[k] * std::conj(rotatedCoeff.col(i)[k])); + if (thres < wnorm or prec < 0) needsplit[orbiVec[i]] = 1.0; + nodesRotated.put_nodedata(orbiVec[i], indexVec_ref[n] + max_ix, csize, rotatedCoeff.col(i).data()); + } + nodeSplits.put_data(indexVec_ref[n], M, needsplit.data()); + } + mpi::barrier(mpi::comm_wrk); // wait until all rotated nodes are ready + } + + // 5) reconstruct trees using rotated nodes. + + // only serial case can use OMP, because MPI cannot be used by threads + if (serial) { + // OMP parallelized, but does not scale well, because the total memory bandwidth is a bottleneck. (the main + // operation is writing the coefficient into the tree) + +#pragma omp parallel for schedule(static) + for (int j = 0; j < M; j++) { + if (coeffpVec[j].size() == 0) continue; + Psi[j].alloc(1); // All data is stored in coeffpVec[j] + Psi[j].complex().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], prec); + } + } else { // MPI case + for (int j = 0; j < M; j++) { + if (not mpi::my_func(j)) continue; + // traverse possible nodes, and stop descending when norm is zero (leaf in out[j]) + std::vector coeffpVec; // + std::map ix2coef; // to find the index in coeffVec[] corresponding to a serialIx + int ix = 0; + std::vector pointerstodelete; // list of temporary arrays to clean up + for (int ibank = 0; ibank < mpi::bank_size; ibank++) { + std::vector nodeidVec; + ComplexDouble *dataVec; // will be allocated by bank + nodesRotated.get_orbblock(j, dataVec, nodeidVec, ibank); + if (nodeidVec.size() > 0) pointerstodelete.push_back(dataVec); + int shift = 0; + for (int n = 0; n < nodeidVec.size(); n++) { + assert(nodeidVec[n] - max_ix >= 0); // unrotated nodes have been deleted + assert(ix2coef.count(nodeidVec[n] - max_ix) == 0); // each nodeid treated once + ix2coef[nodeidVec[n] - max_ix] = ix++; + csize = sizecoeffW; + if (parindexVec_ref[nodeidVec[n] - max_ix] < 0) csize = sizecoeff; + coeffpVec.push_back(&dataVec[shift]); // list of coeff pointers + shift += csize; + } + } + + Psi[j].alloc(1); + Psi[j].complex().makeTreefromCoeff(refTree, coeffpVec, ix2coef, prec); + + for (ComplexDouble *p : pointerstodelete) delete[] p; + pointerstodelete.clear(); + } + } +} + +/** @brief Make a linear combination of functions + * + * Uses "local" representation: treats one node at a time. + * For each node, all functions are transformed simultaneously + * by a dense matrix multiplication. + * Phi input functions, Psi output functions + * + */ +void rotate(CompFunctionVector &Phi, const ComplexMatrix &U, CompFunctionVector &Psi, double prec) { + + if (Phi[0].iscomplex()) { + rotate_cplx(Phi, U, Psi, prec); + return; + } + + // The principle of this routine is that nodes are rotated one by one using matrix multiplication. + // The routine does avoid when possible to move data, but uses pointers and indices manipulation. + // MPI version does not use OMP yet, Serial version uses OMP + // size of input is N, size of output is M + int N = Phi.size(); + int M = Psi.size(); + if (U.rows() < N) MSG_ABORT("Incompatible number of rows for U matrix"); + if (U.cols() < M) MSG_ABORT("Incompatible number of columns for U matrix"); + + // 1) make union tree without coefficients. Note that the ref tree is always real (in fact it has no coeff) + FunctionTree<3> refTree(*Phi.vecMRA); + mpi::allreduce_Tree_noCoeff(refTree, Phi, mpi::comm_wrk); + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + std::vector scalefac_ref; + std::vector coeffVec_ref; // not used! + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + int max_ix; + // get a list of all nodes in union tree, identified by their serialIx indices + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac_ref, max_ix, refTree); + int max_n = indexVec_ref.size(); + for (int i = 0; i < M; i++) { + Psi[i].func_ptr->data.isreal = 1; + Psi[i].func_ptr->data.iscomplex = 0; + } + + // 3) In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + + bool serial = mpi::wrk_size == 1; // flag for serial/MPI switch + BankAccount nodesPhi; // to put the original nodes + BankAccount nodesRotated; // to put the rotated nodes + + // used for serial only: + std::vector> coeffVec(N); + std::vector> indexVec(N); // serialIx of the nodes + std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2node(N); // for a given orbital and a given node, gives the node index in the + // orbital given the node index in the reference tree + if (serial) { + + // make list of all coefficients (coeffVec), and their reference indices (indexVec) + std::vector parindexVec; // serialIx of the parent nodes + std::vector scalefac; + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + Phi[j].real().makeCoeffVector(coeffVec[j], indexVec[j], parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec[j]) { + orb2node[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j); + } + } + } else { // MPI case + // send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(Phi, refTree, nodesPhi); + mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. + } + + // 4) rotate all the nodes + IntMatrix split_serial; // in the serial case all split are stored in one array + std::vector> coeffpVec(M); // to put pointers to the rotated coefficient for each orbital in serial case + std::vector> ix2coef(M); // to find the index in for example rotCoeffVec[] corresponding to a serialIx + int csize; // size of the current coefficients (different for roots and branches) + std::vector rotatedCoeffVec; // just to ensure that the data from rotatedCoeff is not deleted, since we point to it. + // j indices are for unrotated orbitals, i indices are for rotated orbitals + if (serial) { + std::map ix2coef_ref; // to find the index n corresponding to a serialIx + split_serial.resize(M, max_n); // not use in the MPI case + for (int n = 0; n < max_n; n++) { + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + ix2coef_ref[node_ix] = n; + for (int i = 0; i < M; i++) split_serial(i, n) = 1; + } + + std::vector nodeReady(max_n, 0); // To indicate to OMP threads that the parent is ready (for splits) + + // assumes the nodes are ordered such that parent are treated before children. BFS or DFS ok. + // NB: the n must be traversed approximately in right order: Thread n may have to wait until som other preceding + // n is finished. +#pragma omp parallel for schedule(dynamic) + for (int n = 0; n < max_n; n++) { + int csize; + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + // 4a) make a dense contiguous matrix with the coefficient from all the orbitals using node n + std::vector orbjVec; // to remember which orbital correspond to each orbVec.size(); + if (node2orbVec[node_ix].size() <= 0) continue; + csize = sizecoeffW; + if (parindexVec_ref[n] < 0) csize = sizecoeff; // for root nodes we include scaling coeff + + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + if (parindexVec_ref[n] < 0) shift = 0; + DoubleMatrix coeffBlock(csize, node2orbVec[node_ix].size()); + for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2node[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlock(k, orbjVec.size()) = coeffVec[j][orb_node_ix][k + shift]; + orbjVec.push_back(j); + } + + // 4b) make a list of rotated orbitals needed for this node + // OMP must wait until parent is ready + while (parindexVec_ref[n] >= 0 and nodeReady[ix2coef_ref[parindexVec_ref[n]]] == 0) { +#pragma omp flush + }; + + std::vector orbiVec; + for (int i = 0; i < M; i++) { // loop over all rotated orbitals + if (parindexVec_ref[n] >= 0 and split_serial(i, ix2coef_ref[parindexVec_ref[n]]) == 0) continue; // parent node has too small wavelets + orbiVec.push_back(i); + } + + // 4c) rotate this node + DoubleMatrix Un(orbjVec.size(), orbiVec.size()); // chunk of U, with reorganized indices + for (int i = 0; i < orbiVec.size(); i++) { // loop over rotated orbitals + for (int j = 0; j < orbjVec.size(); j++) { Un(j, i) = std::real(U(orbjVec[j], orbiVec[i])); } + } + DoubleMatrix rotatedCoeff(csize, orbiVec.size()); + // HERE IT HAPPENS! + rotatedCoeff.noalias() = coeffBlock * Un; // Matrix mutiplication + + // 4d) store and make rotated node pointers + // for now we allocate in buffer, in future could be directly allocated in the final trees + double thres = prec * prec * scalefac_ref[n] * scalefac_ref[n]; + // make all norms: + for (int i = 0; i < orbiVec.size(); i++) { + // check if parent must be split + if (parindexVec_ref[n] == -1 or split_serial(orbiVec[i], ix2coef_ref[parindexVec_ref[n]])) { + // mark this node for this orbital for later split +#pragma omp critical + { + ix2coef[orbiVec[i]][node_ix] = coeffpVec[orbiVec[i]].size(); + coeffpVec[orbiVec[i]].push_back(&(rotatedCoeff(0, i))); // list of coefficient pointers + } + // check norms for split + double wnorm = 0.0; // rotatedCoeff(k, i) is already in cache here + int kstart = 0; + if (parindexVec_ref[n] < 0) kstart = sizecoeff - sizecoeffW; // do not include scaling, even for roots + for (int k = kstart; k < csize; k++) wnorm += rotatedCoeff(k, i) * rotatedCoeff(k, i); + if (thres < wnorm or prec < 0) + split_serial(orbiVec[i], n) = 1; + else + split_serial(orbiVec[i], n) = 0; + } else { + ix2coef[orbiVec[i]][node_ix] = max_n + 1; // should not be used + split_serial(orbiVec[i], n) = 0; // do not split if parent does not need to be split + } + } + nodeReady[n] = 1; +#pragma omp critical + { + // this ensures that rotatedCoeff is not deleted, when getting out of scope + rotatedCoeffVec.push_back(std::move(rotatedCoeff)); + } + } + } else { // MPI case + + // TODO? rotate in bank, so that we do not get and put. Requires clever handling of splits. + std::vector split(M, -1.0); // which orbitals need splitting (at a given node). For now double for compatibilty with bank + std::vector needsplit(M, 1.0); // which orbitals need splitting + BankAccount nodeSplits; + mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. + + DoubleMatrix coeffBlock(sizecoeff, N); + max_ix++; // largest node index + 1. to store rotated orbitals with different id + TaskManager tasks(max_n); + for (int nn = 0; nn < max_n; nn++) { + int n = tasks.next_task(); + if (n < 0) break; + double thres = prec * prec * scalefac_ref[n] * scalefac_ref[n]; + // 4a) make list of orbitals that should split the parent node, i.e. include this node + int parentid = parindexVec_ref[n]; + if (parentid == -1) { + // root node, split if output needed + for (int i = 0; i < M; i++) { split[i] = 1.0; } + csize = sizecoeff; + } else { + // note that it will wait until data is available + nodeSplits.get_data(parentid, M, split.data()); + csize = sizecoeffW; + } + std::vector orbiVec; + std::vector orbjVec; + for (int i = 0; i < M; i++) { // loop over rotated orbitals + if (split[i] < 0.0) continue; // parent node has too small wavelets + orbiVec.push_back(i); + } + + // 4b) rotate this node + DoubleMatrix coeffBlock(csize, N); // largest possible used size + nodesPhi.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbjVec); + coeffBlock.conservativeResize(Eigen::NoChange, orbjVec.size()); // keep only used part + + // chunk of U, with reorganized indices and separate blocks for real and imag: + DoubleMatrix Un(orbjVec.size(), orbiVec.size()); + DoubleMatrix rotatedCoeff(csize, orbiVec.size()); + + for (int i = 0; i < orbiVec.size(); i++) { // loop over included rotated real and imag part of orbitals + for (int j = 0; j < orbjVec.size(); j++) { // loop over input orbital, possibly imaginary parts + Un(j, i) = std::real(U(orbjVec[j], orbiVec[i])); + } + } + + // HERE IT HAPPENS + rotatedCoeff.noalias() = coeffBlock * Un; // Matrix mutiplication + + // 3c) find which orbitals need to further refine this node, and store rotated node (after each other while + // in cache). + for (int i = 0; i < orbiVec.size(); i++) { // loop over rotated orbitals + needsplit[orbiVec[i]] = -1.0; // default, do not split + // check if this node/orbital needs further refinement + double wnorm = 0.0; + int kwstart = csize - sizecoeffW; // do not include scaling + for (int k = kwstart; k < csize; k++) wnorm += rotatedCoeff.col(i)[k] * rotatedCoeff.col(i)[k]; + if (thres < wnorm or prec < 0) needsplit[orbiVec[i]] = 1.0; + nodesRotated.put_nodedata(orbiVec[i], indexVec_ref[n] + max_ix, csize, rotatedCoeff.col(i).data()); + } + nodeSplits.put_data(indexVec_ref[n], M, needsplit.data()); + } + mpi::barrier(mpi::comm_wrk); // wait until all rotated nodes are ready + } + + // 5) reconstruct trees using rotated nodes. + + // only serial case can use OMP, because MPI cannot be used by threads + if (serial) { + // OMP parallelized, but does not scale well, because the total memory bandwidth is a bottleneck. (the main + // operation is writing the coefficient into the tree) + +#pragma omp parallel for schedule(static) + for (int j = 0; j < M; j++) { + if (coeffpVec[j].size() == 0) continue; + Psi[j].alloc(1); + Psi[j].real().clear(); + Psi[j].real().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], prec); + } + + } else { // MPI case + + for (int j = 0; j < M; j++) { + if (not mpi::my_func(j)) continue; + // traverse possible nodes, and stop descending when norm is zero (leaf in out[j]) + std::vector coeffpVec; // + std::map ix2coef; // to find the index in coeffVec[] corresponding to a serialIx + int ix = 0; + std::vector pointerstodelete; // list of temporary arrays to clean up + for (int ibank = 0; ibank < mpi::bank_size; ibank++) { + std::vector nodeidVec; + double *dataVec; // will be allocated by bank + nodesRotated.get_orbblock(j, dataVec, nodeidVec, ibank); + if (nodeidVec.size() > 0) pointerstodelete.push_back(dataVec); + int shift = 0; + for (int n = 0; n < nodeidVec.size(); n++) { + assert(nodeidVec[n] - max_ix >= 0); // unrotated nodes have been deleted + assert(ix2coef.count(nodeidVec[n] - max_ix) == 0); // each nodeid treated once + ix2coef[nodeidVec[n] - max_ix] = ix++; + csize = sizecoeffW; + if (parindexVec_ref[nodeidVec[n] - max_ix] < 0) csize = sizecoeff; + coeffpVec.push_back(&dataVec[shift]); // list of coeff pointers + shift += csize; + } + } + Psi[j].alloc(1); + Psi[j].real().makeTreefromCoeff(refTree, coeffpVec, ix2coef, prec); + + for (double *p : pointerstodelete) delete[] p; + pointerstodelete.clear(); + } + } +} + +void rotate(CompFunctionVector &Phi, const ComplexMatrix &U, double prec) { + rotate(Phi, U, Phi, prec); + return; +} + +/** @brief Save all nodes in bank; identify them using serialIx from refTree + * shift is a shift applied in the id + */ +void save_nodes(CompFunctionVector &Phi, FunctionTree<3> &refTree, BankAccount &account, int sizes) { + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + int max_nNodes = refTree.getNNodes(); + std::vector coeffVec; + std::vector coeffVec_cplx; + std::vector scalefac; + std::vector indexVec; // SerialIx of the node in refOrb + std::vector parindexVec; // SerialIx of the parent node + int N = Phi.size(); + int max_ix; + for (int j = 0; j < N; j++) { + if (not mpi::my_func(j)) continue; + // make vector with all coef address and their index in the union grid + if (Phi[j].isreal()) { + Phi[j].real().makeCoeffVector(coeffVec, indexVec, parindexVec, scalefac, max_ix, refTree); + int max_n = indexVec.size(); + // send node coefs from Phi[j] to bank + // except for the root nodes, only wavelets are sent + for (int i = 0; i < max_n; i++) { + if (indexVec[i] < 0) continue; // nodes that are not in refOrb + int csize = sizecoeffW; + if (parindexVec[i] < 0) csize = sizecoeff; + if (sizes > 0) { // fixed size + account.put_nodedata(j, indexVec[i], sizes, coeffVec[i]); + } else { + account.put_nodedata(j, indexVec[i], csize, &(coeffVec[i][sizecoeff - csize])); + } + } + } + // Complex components + if (Phi[j].iscomplex()) { + Phi[j].complex().makeCoeffVector(coeffVec_cplx, indexVec, parindexVec, scalefac, max_ix, refTree); + int max_n = indexVec.size(); + // send node coefs from Phi[j] to bank + for (int i = 0; i < max_n; i++) { + if (indexVec[i] < 0) continue; // nodes that are not in refOrb + // NB: the identifier (indexVec[i]) must be shifted for not colliding with the nodes from the real part + int csize = sizecoeffW; + if (parindexVec[i] < 0) csize = sizecoeff; + if (sizes > 0) { // fixed size + account.put_nodedata(j, indexVec[i], sizes, coeffVec_cplx[i]); + } else { + account.put_nodedata(j, indexVec[i], csize, &(coeffVec_cplx[i][sizecoeff - csize])); + } + } + } + } +} + +/** @brief Multiply all orbitals with a function + * + * @param Phi: orbitals to multiply + * @param f : function to multiply + * + * Computes the product of each orbital with a function + * in parallel using a local representation. + * Input trees are extended by one scale at most. + */ +CompFunctionVector multiply(CompFunctionVector &Phi, RepresentableFunction<3> &f, double prec, CompFunction<3> *Func, int nrefine, bool all) { + int N = Phi.size(); + const int D = 3; + bool serial = mpi::wrk_size == 1; // flag for serial/MPI switch + // 1a) extend grid where f is large (around nuclei) + // TODO: do it in save_nodes + refTree, only saving the extra nodes, without keeping them permanently. Or refine refTree? + + for (int i = 0; i < N; i++) { + if (!mpi::my_func(i)) continue; + int irefine = 0; + while (Phi[i].isreal() and irefine < nrefine and refine_grid(Phi[i].real(), f) > 0) irefine++; + if (Phi[i].iscomplex()) MSG_ABORT("Not yet implemented"); + irefine = 0; + // while (Phi[i].iscomplex() and irefine < nrefine and refine_grid(Phi[i].complex(), f) > 0) irefine++; + } + + // 1b) make union tree without coefficients + FunctionTree refTree(*Phi.vecMRA); + // refine_grid(refTree, f); //to test + mpi::allreduce_Tree_noCoeff(refTree, Phi, mpi::comm_wrk); + + int kp1 = refTree.getKp1(); + int kp1_d = refTree.getKp1_d(); + int nCoefs = refTree.getTDim() * kp1_d; + + IntVector PsihasReIm = IntVector::Zero(2); + for (int i = 0; i < N; i++) { + if (!mpi::my_func(i)) continue; + PsihasReIm[0] = (Phi[i].hasReal()) ? 1 : 0; + PsihasReIm[1] = (Phi[i].hasImag()) ? 1 : 0; + } + mpi::allreduce_vector(PsihasReIm, mpi::comm_wrk); + CompFunctionVector out(N); + for (int i = 0; i < N; i++) { out[0] = Phi[i].paramCopy(); } + if (not PsihasReIm[0] and not PsihasReIm[1]) { + return out; // do nothing + } + + std::vector scalefac_ref; + std::vector coeffVec_ref; // not used! + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + std::vector *> refNodes; // pointers to nodes + int max_ix; + // get a list of all nodes in union tree, identified by their serialIx indices + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac_ref, max_ix, refTree, &refNodes); + int max_n = indexVec_ref.size(); + std::map ix2n; // for a given serialIx, give index in vectors + for (int nn = 0; nn < max_n; nn++) ix2n[indexVec_ref[nn]] = nn; + + // 2a) send own nodes to bank, identifying them through the serialIx of refTree + BankAccount nodesPhi; // to put the original nodes + BankAccount nodesMultiplied; // to put the multiplied nodes + + // used for serial only: + std::vector> coeffVec(N); + std::vector> indexVec(N); // serialIx of the nodes + std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2node(N); // for a given orbital and a given node, gives the node index in the + // orbital given the node index in the reference tree + if (serial) { + // make list of all coefficients (coeffVec), and their reference indices (indexVec) + std::vector parindexVec; // serialIx of the parent nodes + std::vector scalefac; + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + if (Phi[j].hasReal()) { + Phi[j].real().makeCoeffVector(coeffVec[j], indexVec[j], parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec[j]) { + orb2node[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j); + } + } + if (Phi[j].hasImag()) { + Phi[j].imag().makeCoeffVector(coeffVec[j + N], indexVec[j + N], parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec[j + N]) { + orb2node[j + N][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j + N); + } + } + } + } else { + save_nodes(Phi, refTree, nodesPhi, nCoefs); + mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. + } + + // 2b) save Func in bank and remove its coefficients + if (Func != nullptr and !serial) { + // put Func in local representation if not already done + if (!Func->real().isLocal) { Func->real().saveNodesAndRmCoeff(); } + } + + // 3) mutiply for each node + std::vector> coeffpVec(N); // to put pointers to the multiplied coefficient for each orbital in serial case + std::vector multipliedCoeffVec; // just to ensure that the data from multipliedCoeff is not deleted, since we point to it. + std::vector> ix2coef(N); // to find the index in for example rotCoeffVec[] corresponding to a serialIx + DoubleVector NODEP = DoubleVector::Zero(nCoefs); + DoubleVector NODEF = DoubleVector::Zero(nCoefs); + + if (serial) { +#pragma omp parallel for schedule(dynamic) + for (int n = 0; n < max_n; n++) { + MWNode node(*(refNodes[n]), false); + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + + // 3a) make values for f at this node + // 3a1) get coordinates of quadrature points for this node + Eigen::MatrixXd pts; // Eigen::Zero(D, nCoefs); + double fval[nCoefs]; + Coord r; + double *originalCoef = nullptr; + MWNode<3> *Fnode = nullptr; + if (Func == nullptr) { + node.getExpandedChildPts(pts); // TODO: use getPrimitiveChildPts (less cache). + for (int j = 0; j < nCoefs; j++) { + for (int d = 0; d < D; d++) r[d] = pts(d, j); //*scaling_factor[d]? + fval[j] = f.evalf(r); + } + } else { + Fnode = Func->real().findNode(node.getNodeIndex()); + if (Fnode == nullptr) { + node.getExpandedChildPts(pts); // TODO: use getPrimitiveChildPts (less cache). + for (int j = 0; j < nCoefs; j++) { + for (int d = 0; d < D; d++) r[d] = pts(d, j); //*scaling_factor[d]? + fval[j] = f.evalf(r); + } + } else { + originalCoef = Fnode->getCoefs(); + for (int j = 0; j < nCoefs; j++) fval[j] = originalCoef[j]; + Fnode->attachCoefs(fval); // note that each thread has its own copy + Fnode->mwTransform(Reconstruction); + Fnode->cvTransform(Forward); + } + } + DoubleMatrix multipliedCoeff(nCoefs, node2orbVec[node_ix].size()); + int i = 0; + // 3b) fetch all orbitals at this node + std::vector orbjVec; // to remember which orbital correspond to each orbVec.size(); + for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2node[j][node_ix]; + orbjVec.push_back(j); + for (int k = 0; k < nCoefs; k++) multipliedCoeff(k, i) = coeffVec[j][orb_node_ix][k]; + // 3c) transform to grid + node.attachCoefs(&(multipliedCoeff(0, i))); + node.mwTransform(Reconstruction); + node.cvTransform(Forward); + // 3d) multiply + for (int k = 0; k < nCoefs; k++) multipliedCoeff(k, i) *= fval[k]; // replace by Matrix vector multiplication? + // 3e) transform back to mw + node.cvTransform(Backward); + node.mwTransform(Compression); + i++; + } + if (Func != nullptr and originalCoef != nullptr) { + // restablish original values + Fnode->attachCoefs(originalCoef); + } + + // 3f) save multiplied nodes + for (int i = 0; i < orbjVec.size(); i++) { +#pragma omp critical + { + ix2coef[orbjVec[i]][node_ix] = coeffpVec[orbjVec[i]].size(); + coeffpVec[orbjVec[i]].push_back(&(multipliedCoeff(0, i))); // list of coefficient pointers + } + } +#pragma omp critical + { + // this ensures that multipliedCoeff is not deleted, when getting out of scope + multipliedCoeffVec.push_back(std::move(multipliedCoeff)); + } + node.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor + } + } else { + // MPI + int count1 = 0; + int count2 = 0; + TaskManager tasks(max_n); + for (int nn = 0; nn < max_n; nn++) { + int n = tasks.next_task(); + if (n < 0) break; + MWNode node(*(refNodes[n]), false); + // 3a) make values for f + // 3a1) get coordinates of quadrature points for this node + Eigen::MatrixXd pts; // Eigen::Zero(D, nCoefs); + node.getExpandedChildPts(pts); // TODO: use getPrimitiveChildPts (less cache). + double fval[nCoefs]; + Coord r; + MWNode Fnode(*(refNodes[n]), false); + if (Func == nullptr) { + for (int j = 0; j < nCoefs; j++) { + for (int d = 0; d < D; d++) r[d] = pts(d, j); //*scaling_factor[d]? + fval[j] = f.evalf(r); + } + } else { + int nIdx = Func->real().getIx(node.getNodeIndex()); + count1++; + if (nIdx < 0) { + // use the function f instead of Func + count2++; + for (int j = 0; j < nCoefs; j++) { + for (int d = 0; d < D; d++) r[d] = pts(d, j); + fval[j] = f.evalf(r); + } + } else { + Func->real().getNodeCoeff(nIdx, fval); // fetch coef from Bank + Fnode.attachCoefs(fval); + Fnode.mwTransform(Reconstruction); + Fnode.cvTransform(Forward); + } + } + + // 3b) fetch all orbitals at this node + DoubleMatrix coeffBlock(nCoefs, N); // largest possible used size + std::vector orbjVec; + nodesPhi.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbjVec); + coeffBlock.conservativeResize(Eigen::NoChange, orbjVec.size()); // keep only used part + DoubleMatrix MultipliedCoeff(nCoefs, orbjVec.size()); + // 3c) transform to grid + for (int j = 0; j < orbjVec.size(); j++) { // TODO: transform all j at once ? + // TODO: select only nodes that are end nodes? + node.attachCoefs(coeffBlock.col(j).data()); + node.mwTransform(Reconstruction); + node.cvTransform(Forward); + // 3d) multiply + double *coefs = node.getCoefs(); + for (int i = 0; i < nCoefs; i++) coefs[i] *= fval[i]; + // 3e) transform back to mw + node.cvTransform(Backward); + node.mwTransform(Compression); + // 3f) save multiplied nodes + nodesMultiplied.put_nodedata(orbjVec[j], indexVec_ref[n] + max_ix, nCoefs, coefs); + } + node.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor + Fnode.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor + } + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! + } + + // 5) reconstruct trees using multiplied nodes. + + // only serial case can use OMP, because MPI cannot be used by threads + if (serial) { + // OMP parallelized, but does not scale well, because the total memory bandwidth is a bottleneck. (the main + // operation is writing the coefficient into the tree) + +#pragma omp parallel for schedule(static) + for (int j = 0; j < N; j++) { + if (j < N) { + if (Phi[j].hasReal()) { + out[j].alloc(1); + out[j].real().clear(); + out[j].real().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], -1.0, "copy"); + // 6) reconstruct trees from end nodes + out[j].real().mwTransform(BottomUp); + out[j].real().calcSquareNorm(); + } + } else { + if (Phi[j].hasImag()) { + out[j].alloc(1); + out[j].imag().clear(); + out[j].imag().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], -1.0, "copy"); + out[j].imag().mwTransform(BottomUp); + out[j].imag().calcSquareNorm(); + } + } + } + } else { + for (int j = 0; j < N; j++) { + if (not mpi::my_func(j) and not all) continue; + // traverse possible nodes, and stop descending when norm is zero (leaf in out[j]) + std::vector coeffpVec; // + std::map ix2coef; // to find the index in coeffVec[] corresponding to a serialIx in refTree + int ix = 0; + std::vector pointerstodelete; // list of temporary arrays to clean up + + for (int ibank = 0; ibank < mpi::bank_size; ibank++) { + std::vector nodeidVec; + double *dataVec; // will be allocated by bank + nodesMultiplied.get_orbblock(j, dataVec, nodeidVec, ibank); + if (nodeidVec.size() > 0) pointerstodelete.push_back(dataVec); + int shift = 0; + for (int n = 0; n < nodeidVec.size(); n++) { + assert(nodeidVec[n] - max_ix >= 0); // unmultiplied nodes have been deleted + assert(ix2coef.count(nodeidVec[n] - max_ix) == 0); // each nodeid treated once + ix2coef[nodeidVec[n] - max_ix] = ix++; + coeffpVec.push_back(&dataVec[shift]); // list of coeff pointers + shift += nCoefs; + } + } + if (j < N) { + if (Phi[j].hasReal()) { + out[j].alloc(1); + out[j].real().clear(); + out[j].real().makeTreefromCoeff(refTree, coeffpVec, ix2coef, -1.0, "copy"); + // 6) reconstruct trees from end nodes + out[j].real().mwTransform(BottomUp); + out[j].real().calcSquareNorm(); + out[j].real().resetEndNodeTable(); + // out[j].real().crop(prec, 1.0, false); //bad convergence if out is cropped + if (nrefine > 0) Phi[j].real().crop(prec, 1.0, false); // restablishes original Phi + } + } else { + if (Phi[j].hasImag()) { + out[j].alloc(1); + out[j].imag().clear(); + out[j].imag().makeTreefromCoeff(refTree, coeffpVec, ix2coef, -1.0, "copy"); + out[j].imag().mwTransform(BottomUp); + out[j].imag().calcSquareNorm(); + // out[j].imag().crop(prec, 1.0, false); + if (nrefine > 0) Phi[j].imag().crop(prec, 1.0, false); + } + } + + for (double *p : pointerstodelete) delete[] p; + pointerstodelete.clear(); + } + } + return out; +} + +void SetdefaultMRA(MultiResolutionAnalysis<3> *MRA) { + defaultCompMRA<3> = MRA; +} + +ComplexVector dot(CompFunctionVector &Bra, CompFunctionVector &Ket) { + int N = Bra.size(); + ComplexVector result = ComplexVector::Zero(N); + for (int i = 0; i < N; i++) { + // The bra is sent to the owner of the ket + if (my_func(Bra[i]) != my_func(Ket[i])) { MSG_ABORT("same indices should have same ownership"); } + result[i] = dot(Bra[i], Ket[i]); + if (not mrcpp::mpi::my_func(i)) Bra[i].free(); + } + mrcpp::mpi::allreduce_vector(result, mrcpp::mpi::comm_wrk); + return result; +} + +/** @brief Compute Löwdin orthonormalization matrix + * + * @param Phi: orbitals to orthonomalize + * + * Computes the inverse square root of the orbital overlap matrix S^(-1/2) + */ +ComplexMatrix calc_lowdin_matrix(CompFunctionVector &Phi) { + ComplexMatrix S_tilde = calc_overlap_matrix(Phi); + ComplexMatrix S_m12 = math_utils::hermitian_matrix_pow(S_tilde, -1.0 / 2.0); + return S_m12; +} + +/** @brief Orbital transformation out_j = sum_i inp_i*U_ij + * + * NOTE: OrbitalVector is considered a ROW vector, so rotation + * means matrix multiplication from the right + * + * MPI: Rank distribution of output vector is the same as input vector + * + */ +ComplexMatrix calc_overlap_matrix_cplx(CompFunctionVector &BraKet) { + int N = BraKet.size(); + ComplexMatrix S = ComplexMatrix::Zero(N, N); + DoubleMatrix Sreal = S.real(); + MultiResolutionAnalysis<3> *mra = BraKet.vecMRA; + + // 1) make union tree without coefficients + mrcpp::FunctionTree<3> refTree(*mra); + mpi::allreduce_Tree_noCoeff(refTree, BraKet, mpi::comm_wrk); + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + + // get a list of all nodes in union grid, as defined by their indices + std::vector scalefac; + std::vector coeffVec_ref; + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + int max_ix; // largest index value (not used here) + + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); + int max_n = indexVec_ref.size(); + + // only used for serial case: + std::vector> coeffVec(N); + std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2node(N); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + + bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch + mrcpp::BankAccount nodesBraKet; + + // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + if (serial) { + // 2) make list of all coefficients, and their reference indices + // for different orbitals, indexVec will give the same index for the same node in space + std::vector parindexVec; // serialIx of the parent nodes + std::vector indexVec; // serialIx of the nodes + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + BraKet[j].complex().makeCoeffVector(coeffVec[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2node[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j); + } + } + } else { // MPI case + // 2) send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(BraKet, refTree, nodesBraKet); + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! + } + + // 3) make dot product for all the nodes and accumulate into S + int ibank = 0; +#pragma omp parallel if (serial) + { + ComplexMatrix S_omp = ComplexMatrix::Zero(N, N); // copy for each thread + +#pragma omp for schedule(dynamic) + for (int n = 0; n < max_n; n++) { + if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; + int csize; + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + std::vector orbVec; // identifies which orbitals use this node + if (serial and node2orbVec[node_ix].size() <= 0) continue; + if (parindexVec_ref[n] < 0) + csize = sizecoeff; + else + csize = sizecoeffW; + + // In the serial case we copy the coeff coeffBlock. In the mpi case coeffBlock is provided by the bank + if (serial) { + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + if (parindexVec_ref[n] < 0) shift = 0; + ComplexMatrix coeffBlock(csize, node2orbVec[node_ix].size()); + for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2node[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlock(k, orbVec.size()) = coeffVec[j][orb_node_ix][k + shift]; + orbVec.push_back(j); + } + if (orbVec.size() > 0) { + ComplexMatrix S_temp(orbVec.size(), orbVec.size()); + S_temp.noalias() = coeffBlock.transpose().conjugate() * coeffBlock; + for (int i = 0; i < orbVec.size(); i++) { + for (int j = 0; j < orbVec.size(); j++) { + if (BraKet[orbVec[i]].func_ptr->data.n1[0] != BraKet[orbVec[j]].func_ptr->data.n1[0] and BraKet[orbVec[i]].func_ptr->data.n1[0] != 0 and + BraKet[orbVec[j]].func_ptr->data.n1[0] != 0) + continue; + S_omp(orbVec[i], orbVec[j]) += S_temp(i, j); + } + } + } + } else { // MPI case + ComplexMatrix coeffBlock(csize, N); + nodesBraKet.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbVec); + + if (orbVec.size() > 0) { + ComplexMatrix S_temp(orbVec.size(), orbVec.size()); + coeffBlock.conservativeResize(Eigen::NoChange, orbVec.size()); + S_temp.noalias() = coeffBlock.transpose().conjugate() * coeffBlock; + for (int i = 0; i < orbVec.size(); i++) { + for (int j = 0; j < orbVec.size(); j++) { + if (BraKet[orbVec[i]].func_ptr->data.n1[0] != BraKet[orbVec[j]].func_ptr->data.n1[0] and BraKet[orbVec[i]].func_ptr->data.n1[0] != 0 and + BraKet[orbVec[j]].func_ptr->data.n1[0] != 0) + continue; + S_omp(orbVec[i], orbVec[j]) += S_temp(i, j); + } + } + } + } + } + if (serial) { +#pragma omp critical + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { S(i, j) += S_omp(i, j); } + } + } + } + + for (int i = 0; i < N; i++) { + for (int j = 0; j <= i; j++) { + if (i != j) S(j, i) = std::conj(S(i, j)); // ensure exact symmetri + } + } + + // Assumes linearity: result is sum of all nodes contributions + mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); + // multiply by CompFunction multiplicative factor + + ComplexVector Fac = ComplexVector::Zero(N); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(BraKet[i])) continue; + Fac[i] = BraKet[i].func_ptr->data.c1[0]; + } + + mrcpp::mpi::allreduce_vector(Fac, mrcpp::mpi::comm_wrk); + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { S(i, j) *= std::conj(Fac[i]) * Fac[j]; } + } + + return S; +} +ComplexMatrix calc_overlap_matrix(CompFunctionVector &BraKet) { + // NB: should be spinseparated at this point! + if (BraKet[0].iscomplex()) { return calc_overlap_matrix_cplx(BraKet); } + + int N = BraKet.size(); + ComplexMatrix S = ComplexMatrix::Zero(N, N); + + MultiResolutionAnalysis<3> *mra = BraKet.vecMRA; + + // 1) make union tree without coefficients + mrcpp::FunctionTree<3> refTree(*mra); + mpi::allreduce_Tree_noCoeff(refTree, BraKet, mpi::comm_wrk); + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + + // get a list of all nodes in union grid, as defined by their indices + std::vector scalefac; + std::vector coeffVec_ref; + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + int max_ix; // largest index value (not used here) + + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); + int max_n = indexVec_ref.size(); + + // only used for serial case: + std::vector> coeffVec(N); + std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2node(N); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + + bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch + mrcpp::BankAccount nodesBraKet; + + // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + if (serial) { + // 2) make list of all coefficients, and their reference indices + // for different orbitals, indexVec will give the same index for the same node in space + std::vector parindexVec; // serialIx of the parent nodes + std::vector indexVec; // serialIx of the nodes + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + BraKet[j].real().makeCoeffVector(coeffVec[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2node[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j); + } + } + } else { // MPI case + // 2) send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(BraKet, refTree, nodesBraKet); + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! + } + + // 3) make dot product for all the nodes and accumulate into S + int ibank = 0; +#pragma omp parallel if (serial) + { + ComplexMatrix S_omp = ComplexMatrix::Zero(N, N); // copy for each thread + +#pragma omp for schedule(dynamic) + for (int n = 0; n < max_n; n++) { + if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; + int csize; + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + std::vector orbVec; // identifies which orbitals use this node + if (serial and node2orbVec[node_ix].size() <= 0) continue; + if (parindexVec_ref[n] < 0) + csize = sizecoeff; + else + csize = sizecoeffW; + + // In the serial case we copy the coeff coeffBlock. In the mpi case coeffBlock is provided by the bank + if (serial) { + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + if (parindexVec_ref[n] < 0) shift = 0; + DoubleMatrix coeffBlock(csize, node2orbVec[node_ix].size()); + for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2node[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlock(k, orbVec.size()) = coeffVec[j][orb_node_ix][k + shift]; + orbVec.push_back(j); + } + if (orbVec.size() > 0) { + ComplexMatrix S_temp(orbVec.size(), orbVec.size()); + S_temp.noalias() = coeffBlock.transpose() * coeffBlock; + for (int i = 0; i < orbVec.size(); i++) { + for (int j = 0; j < orbVec.size(); j++) { + if (BraKet[orbVec[i]].func_ptr->data.n1[0] != BraKet[orbVec[j]].func_ptr->data.n1[0] and BraKet[orbVec[i]].func_ptr->data.n1[0] != 0 and + BraKet[orbVec[j]].func_ptr->data.n1[0] != 0) + continue; + S_omp(orbVec[i], orbVec[j]) += S_temp(i, j); + } + } + } + } else { // MPI case + DoubleMatrix coeffBlock(csize, N); + nodesBraKet.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbVec); + + if (orbVec.size() > 0) { + DoubleMatrix S_temp(orbVec.size(), orbVec.size()); + coeffBlock.conservativeResize(Eigen::NoChange, orbVec.size()); + S_temp.noalias() = coeffBlock.transpose() * coeffBlock; + for (int i = 0; i < orbVec.size(); i++) { + for (int j = 0; j < orbVec.size(); j++) { + if (BraKet[orbVec[i]].func_ptr->data.n1[0] != BraKet[orbVec[j]].func_ptr->data.n1[0] and BraKet[orbVec[i]].func_ptr->data.n1[0] != 0 and + BraKet[orbVec[j]].func_ptr->data.n1[0] != 0) + continue; + S(orbVec[i], orbVec[j]) += S_temp(i, j); + } + } + } + } + } + if (serial) { +#pragma omp critical + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { S(i, j) += S_omp(i, j); } + } + } + } + + for (int i = 0; i < N; i++) { + for (int j = 0; j <= i; j++) { + if (i != j) S(j, i) = std::conj(S(i, j)); // ensure exact symmetri + } + } + + // Assumes linearity: result is sum of all nodes contributions + mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); + + // multiply by CompFunction multiplicative factor + ComplexVector Fac = ComplexVector::Zero(N); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(BraKet[i])) continue; + Fac[i] = BraKet[i].func_ptr->data.c1[0]; + } + mrcpp::mpi::allreduce_vector(Fac, mrcpp::mpi::comm_wrk); + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { S(i, j) *= std::conj(Fac[i]) * Fac[j]; } + } + + return S; +} + +/** @brief Compute the overlap matrix S_ij = + * + * Will take the conjugate of bra before integrating + */ +ComplexMatrix calc_overlap_matrix_cplx(CompFunctionVector &Bra, CompFunctionVector &Ket) { + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // for consistent timings + bool braisreal = !Bra[0].iscomplex(); + bool ketisreal = !Ket[0].iscomplex(); + if (braisreal or ketisreal) { + // temporary solution: copy as complex trees + if (braisreal) { + for (int i = 0; i < Bra.size(); i++) { + Bra[i].CompD[0]->CopyTreeToComplex(Bra[i].CompC[0]); + Bra[i].func_ptr->iscomplex = 1; + } + } + if (ketisreal) { + for (int i = 0; i < Ket.size(); i++) { + Ket[i].CompD[0]->CopyTreeToComplex(Ket[i].CompC[0]); + Ket[i].func_ptr->iscomplex = 1; + } + } + } + MultiResolutionAnalysis<3> *mra = Bra.vecMRA; + + int N = Bra.size(); + int M = Ket.size(); + ComplexMatrix S = ComplexMatrix::Zero(N, M); + + IntVector conjMatBra = IntVector::Zero(N); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(Bra[i])) continue; + conjMatBra[i] = (Bra[i].conjugate()) ? 1 : 0; + } + mrcpp::mpi::allreduce_vector(conjMatBra, mrcpp::mpi::comm_wrk); + IntVector conjMatKet = IntVector::Zero(M); + for (int i = 0; i < M; i++) { + if (!mrcpp::mpi::my_func(Ket[i])) continue; + conjMatKet[i] = (Ket[i].conjugate()) ? 1 : 0; + } + mrcpp::mpi::allreduce_vector(conjMatKet, mrcpp::mpi::comm_wrk); + + // 1) make union tree without coefficients for Bra (supposed smallest) + mrcpp::FunctionTree<3> refTree(*mra); + mrcpp::mpi::allreduce_Tree_noCoeff(refTree, Bra, mpi::comm_wrk); + // note that Ket is not part of union grid: if a node is in ket but not in Bra, the dot product is zero. + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + + // get a list of all nodes in union grid, as defined by their indices + std::vector coeffVec_ref; + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + std::vector scalefac; + int max_ix; + + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); + int max_n = indexVec_ref.size(); + max_ix++; + + bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch + + // only used for serial case: + std::vector> coeffVecBra(N); + std::map> node2orbVecBra; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2nodeBra(N); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + std::vector> coeffVecKet(M); + std::map> node2orbVecKet; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2nodeKet(M); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + mrcpp::BankAccount nodesBra; + mrcpp::BankAccount nodesKet; + + // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + if (serial) { + // 2) make list of all coefficients, and their reference indices + // for different orbitals, indexVec will give the same index for the same node in space + // TODO? : do not copy coefficients, but use directly the pointers + // could OMP parallelize, but is fast anyway + std::vector parindexVec; // serialIx of the parent nodes + std::vector indexVec; // serialIx of the nodes + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + Bra[j].complex().makeCoeffVector(coeffVecBra[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2nodeBra[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVecBra[ix].push_back(j); + } + } + for (int j = 0; j < M; j++) { + Ket[j].complex().makeCoeffVector(coeffVecKet[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2nodeKet[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVecKet[ix].push_back(j); + } + } + + } else { // MPI case + // 2) send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(Bra, refTree, nodesBra); + save_nodes(Ket, refTree, nodesKet); + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! + } + + // 3) make dot product for all the nodes and accumulate into S + int totsiz = 0; + int totget = 0; + int mxtotsiz = 0; + int ibank = 0; + // the omp crashes sometime for unknown reasons? +#pragma omp parallel if (serial) + { + ComplexMatrix S_omp = ComplexMatrix::Zero(N, M); // copy for each thread + +#pragma omp for schedule(dynamic) + for (int n = 0; n < max_n; n++) { + if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; + int csize; + std::vector orbVecBra; // identifies which Bra orbitals use this node + std::vector orbVecKet; // identifies which Ket orbitals use this node + if (parindexVec_ref[n] < 0) + csize = sizecoeff; + else + csize = sizecoeffW; + if (serial) { + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + ComplexMatrix coeffBlockBra(csize, node2orbVecBra[node_ix].size()); + ComplexMatrix coeffBlockKet(csize, node2orbVecKet[node_ix].size()); + if (parindexVec_ref[n] < 0) shift = 0; + + for (int j : node2orbVecBra[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2nodeBra[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlockBra(k, orbVecBra.size()) = coeffVecBra[j][orb_node_ix][k + shift]; + orbVecBra.push_back(j); + } + for (int j : node2orbVecKet[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2nodeKet[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlockKet(k, orbVecKet.size()) = coeffVecKet[j][orb_node_ix][k + shift]; + orbVecKet.push_back(j); + } + + if (orbVecBra.size() > 0 and orbVecKet.size() > 0) { + ComplexMatrix S_temp(orbVecBra.size(), orbVecKet.size()); + if (not conjMatBra[0] and not conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra.transpose().conjugate() * coeffBlockKet; + } else if (conjMatBra[0] and not conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet; + } else if (not conjMatBra[0] and conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet.transpose(); + } else if (conjMatBra[0] and conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra * coeffBlockKet.transpose(); + } else + MSG_ABORT("Unexpected case"); + for (int i = 0; i < orbVecBra.size(); i++) { + for (int j = 0; j < orbVecKet.size(); j++) { + if (Bra[orbVecBra[i]].func_ptr->data.n1[0] != Ket[orbVecKet[j]].func_ptr->data.n1[0] and Bra[orbVecBra[i]].func_ptr->data.n1[0] != 0 and + Ket[orbVecKet[j]].func_ptr->data.n1[0] != 0) + continue; + S_omp(orbVecBra[i], orbVecKet[j]) += S_temp(i, j); + } + } + } + } else { // MPI case + + ComplexMatrix coeffBlockBra(csize, N); + ComplexMatrix coeffBlockKet(csize, M); + nodesBra.get_nodeblock(indexVec_ref[n], coeffBlockBra.data(), orbVecBra); // get Bra parts + nodesKet.get_nodeblock(indexVec_ref[n], coeffBlockKet.data(), orbVecKet); // get Ket parts + totsiz += orbVecBra.size() * orbVecKet.size(); + mxtotsiz += N * M; + totget += orbVecBra.size() + orbVecKet.size(); + if (orbVecBra.size() > 0 and orbVecKet.size() > 0) { + ComplexMatrix S_temp(orbVecBra.size(), orbVecKet.size()); + coeffBlockBra.conservativeResize(Eigen::NoChange, orbVecBra.size()); + coeffBlockKet.conservativeResize(Eigen::NoChange, orbVecKet.size()); + if (not conjMatBra[0] and not conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra.transpose().conjugate() * coeffBlockKet; + } else if (conjMatBra[0] and not conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet; + } else if (not conjMatBra[0] and conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet.transpose(); + } else if (conjMatBra[0] and conjMatBra[0]) { + S_temp.noalias() = coeffBlockBra * coeffBlockKet.transpose(); + } else + MSG_ABORT("Unexpected case"); + + for (int i = 0; i < orbVecBra.size(); i++) { + for (int j = 0; j < orbVecKet.size(); j++) { + if (Bra[orbVecBra[i]].func_ptr->data.n1[0] != Ket[orbVecKet[j]].func_ptr->data.n1[0] and Bra[orbVecBra[i]].func_ptr->data.n1[0] != 0 and + Ket[orbVecKet[j]].func_ptr->data.n1[0] != 0) + continue; + S(orbVecBra[i], orbVecKet[j]) += S_temp(i, j); + } + } + } + } + } + if (serial) { +#pragma omp critical + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { S(i, j) += S_omp(i, j); } + } + } + } + + // 4) collect results from all MPI. Linearity: result is sum of all node contributions + + mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); + + // multiply by CompFunction multiplicative factor + ComplexVector FacBra = ComplexVector::Zero(N); + ComplexVector FacKet = ComplexVector::Zero(M); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(Bra[i])) continue; + FacBra[i] = Bra[i].func_ptr->data.c1[0]; + } + for (int i = 0; i < M; i++) { + if (!mrcpp::mpi::my_func(Ket[i])) continue; + FacKet[i] = Ket[i].func_ptr->data.c1[0]; + } + mrcpp::mpi::allreduce_vector(FacBra, mrcpp::mpi::comm_wrk); + mrcpp::mpi::allreduce_vector(FacKet, mrcpp::mpi::comm_wrk); + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { S(i, j) *= std::conj(FacBra[i]) * FacKet[j]; } + } + + // restore input + if (braisreal) { + for (int i = 0; i < Bra.size(); i++) { + delete Bra[i].CompC[0]; + Bra[i].CompC[0] = nullptr; + Bra[i].func_ptr->iscomplex = 0; + Bra[i].func_ptr->isreal = 1; + } + } + if (ketisreal) { + for (int i = 0; i < Ket.size(); i++) { + delete Ket[i].CompC[0]; + Ket[i].CompC[0] = nullptr; + Ket[i].func_ptr->iscomplex = 0; + Ket[i].func_ptr->isreal = 1; + } + } + return S; +} + +/** @brief Compute the overlap matrix S_ij = + * + */ +ComplexMatrix calc_overlap_matrix(CompFunctionVector &Bra, CompFunctionVector &Ket) { + + if (Bra[0].iscomplex() or Ket[0].iscomplex()) { return calc_overlap_matrix_cplx(Bra, Ket); } + + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // for consistent timings + + MultiResolutionAnalysis<3> *mra = Bra.vecMRA; + + int N = Bra.size(); + int M = Ket.size(); + ComplexMatrix S = ComplexMatrix::Zero(N, M); + + // 1) make union tree without coefficients for Bra (supposed smallest) + mrcpp::FunctionTree<3> refTree(*mra); + mrcpp::mpi::allreduce_Tree_noCoeff(refTree, Bra, mpi::comm_wrk); + // note that Ket is not part of union grid: if a node is in ket but not in Bra, the dot product is zero. + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + + // get a list of all nodes in union grid, as defined by their indices + std::vector coeffVec_ref; + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + std::vector scalefac; + int max_ix; + + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); + int max_n = indexVec_ref.size(); + max_ix++; + + bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch + + // only used for serial case: + std::vector> coeffVecBra(N); + std::map> node2orbVecBra; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2nodeBra(N); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + std::vector> coeffVecKet(M); + std::map> node2orbVecKet; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2nodeKet(M); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + mrcpp::BankAccount nodesBra; + mrcpp::BankAccount nodesKet; + // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + if (serial) { + // 2) make list of all coefficients, and their reference indices + // for different orbitals, indexVec will give the same index for the same node in space + // TODO? : do not copy coefficients, but use directly the pointers + // could OMP parallelize, but is fast anyway + std::vector parindexVec; // serialIx of the parent nodes + std::vector indexVec; // serialIx of the nodes + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + Bra[j].real().makeCoeffVector(coeffVecBra[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2nodeBra[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVecBra[ix].push_back(j); + } + } + for (int j = 0; j < M; j++) { + Ket[j].real().makeCoeffVector(coeffVecKet[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2nodeKet[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVecKet[ix].push_back(j); + } + } + + } else { // MPI case + // 2) send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(Bra, refTree, nodesBra); + save_nodes(Ket, refTree, nodesKet); + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! + } + + // 3) make dot product for all the nodes and accumulate into S + int totsiz = 0; + int totget = 0; + int mxtotsiz = 0; + int ibank = 0; +#pragma omp parallel if (serial) + { + DoubleMatrix S_omp = DoubleMatrix::Zero(N, M); // copy for each thread + // NB: dynamic does give strange errors? +#pragma omp for schedule(static) + for (int n = 0; n < max_n; n++) { + if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; + int csize; + std::vector orbVecBra; // identifies which Bra orbitals use this node + std::vector orbVecKet; // identifies which Ket orbitals use this node + if (parindexVec_ref[n] < 0) + csize = sizecoeff; + else + csize = sizecoeffW; + if (serial) { + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + DoubleMatrix coeffBlockBra(csize, node2orbVecBra[node_ix].size()); + DoubleMatrix coeffBlockKet(csize, node2orbVecKet[node_ix].size()); + if (parindexVec_ref[n] < 0) shift = 0; + + for (int j : node2orbVecBra[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2nodeBra[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlockBra(k, orbVecBra.size()) = coeffVecBra[j][orb_node_ix][k + shift]; + orbVecBra.push_back(j); + } + for (int j : node2orbVecKet[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2nodeKet[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlockKet(k, orbVecKet.size()) = coeffVecKet[j][orb_node_ix][k + shift]; + orbVecKet.push_back(j); + } + + if (orbVecBra.size() > 0 and orbVecKet.size() > 0) { + DoubleMatrix S_temp(orbVecBra.size(), orbVecKet.size()); + S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet; + + for (int i = 0; i < orbVecBra.size(); i++) { + for (int j = 0; j < orbVecKet.size(); j++) { + if (Bra[orbVecBra[i]].func_ptr->data.n1[0] != Ket[orbVecKet[j]].func_ptr->data.n1[0] and Bra[orbVecBra[i]].func_ptr->data.n1[0] != 0 and + Ket[orbVecKet[j]].func_ptr->data.n1[0] != 0) + continue; + S_omp(orbVecBra[i], orbVecKet[j]) += S_temp(i, j); + } + } + } + } else { // MPI case + + DoubleMatrix coeffBlockBra(csize, N); + DoubleMatrix coeffBlockKet(csize, M); + nodesBra.get_nodeblock(indexVec_ref[n], coeffBlockBra.data(), orbVecBra); // get Bra parts + nodesKet.get_nodeblock(indexVec_ref[n], coeffBlockKet.data(), orbVecKet); // get Ket parts + totsiz += orbVecBra.size() * orbVecKet.size(); + mxtotsiz += N * M; + totget += orbVecBra.size() + orbVecKet.size(); + if (orbVecBra.size() > 0 and orbVecKet.size() > 0) { + DoubleMatrix S_temp(orbVecBra.size(), orbVecKet.size()); + coeffBlockBra.conservativeResize(Eigen::NoChange, orbVecBra.size()); + coeffBlockKet.conservativeResize(Eigen::NoChange, orbVecKet.size()); + S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet; + for (int i = 0; i < orbVecBra.size(); i++) { + for (int j = 0; j < orbVecKet.size(); j++) { + if (Bra[orbVecBra[i]].func_ptr->data.n1[0] != Ket[orbVecKet[j]].func_ptr->data.n1[0] and Bra[orbVecBra[i]].func_ptr->data.n1[0] != 0 and + Ket[orbVecKet[j]].func_ptr->data.n1[0] != 0) + continue; + S(orbVecBra[i], orbVecKet[j]) += S_temp(i, j); + } + } + } + } + } + if (serial) { +#pragma omp critical + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { S(i, j) += S_omp(i, j); } + } + } + } + + // 4) collect results from all MPI. Linearity: result is sum of all node contributions + + mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); + + // multiply by CompFunction multiplicative factor + ComplexVector FacBra = ComplexVector::Zero(N); + ComplexVector FacKet = ComplexVector::Zero(M); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(Bra[i])) continue; + FacBra[i] = Bra[i].func_ptr->data.c1[0]; + } + for (int i = 0; i < M; i++) { + if (!mrcpp::mpi::my_func(Ket[i])) continue; + FacKet[i] = Ket[i].func_ptr->data.c1[0]; + } + mrcpp::mpi::allreduce_vector(FacBra, mrcpp::mpi::comm_wrk); + mrcpp::mpi::allreduce_vector(FacKet, mrcpp::mpi::comm_wrk); + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { S(i, j) *= std::conj(FacBra[i]) * FacKet[j]; } + } + + return S; +} + +/** @brief Compute the overlap matrix of the absolute value of the functions S_ij = <|bra_i|||ket_j|> + * + */ +DoubleMatrix calc_norm_overlap_matrix(CompFunctionVector &BraKet) { + int N = BraKet.size(); + DoubleMatrix S = DoubleMatrix::Zero(N, N); + DoubleMatrix Sreal = DoubleMatrix::Zero(N, N); // same as S, but stored as 4 blocks, rr,ri,ir,ii + MultiResolutionAnalysis<3> *mra = BraKet.vecMRA; + + // 1) make union tree without coefficients + mrcpp::FunctionTree<3> refTree(*mra); + mrcpp::mpi::allreduce_Tree_noCoeff(refTree, BraKet, mpi::comm_wrk); + + int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); + int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); + + // get a list of all nodes in union grid, as defined by their indices + std::vector scalefac; + std::vector coeffVec_ref; + std::vector indexVec_ref; // serialIx of the nodes + std::vector parindexVec_ref; // serialIx of the parent nodes + int max_ix; // largest index value (not used here) + + refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); + int max_n = indexVec_ref.size(); + + // only used for serial case: + std::vector> coeffVec(N); + std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node + std::vector> orb2node(N); // for a given orbital and a given node, gives the node index in + // the orbital given the node index in the reference tree + + bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch + mrcpp::BankAccount nodesBraKet; + + // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank + if (serial) { + // 2) make list of all coefficients, and their reference indices + // for different orbitals, indexVec will give the same index for the same node in space + std::vector parindexVec; // serialIx of the parent nodes + std::vector indexVec; // serialIx of the nodes + for (int j = 0; j < N; j++) { + // make vector with all coef pointers and their indices in the union grid + if (BraKet[j].hasReal()) { + BraKet[j].real().makeCoeffVector(coeffVec[j], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2node[j][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j); + } + } + if (BraKet[j].hasImag()) { + BraKet[j].imag().makeCoeffVector(coeffVec[j + N], indexVec, parindexVec, scalefac, max_ix, refTree); + // make a map that gives j from indexVec + int orb_node_ix = 0; + for (int ix : indexVec) { + orb2node[j + N][ix] = orb_node_ix++; + if (ix < 0) continue; + node2orbVec[ix].push_back(j + N); + } + } + } + } else { // MPI case + // 2) send own nodes to bank, identifying them through the serialIx of refTree + save_nodes(BraKet, refTree, nodesBraKet); + mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! + } + + // 3) make dot product for all the nodes and accumulate into S + + int ibank = 0; +#pragma omp parallel for schedule(dynamic) if (serial) + for (int n = 0; n < max_n; n++) { + if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; + int csize; + int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree + std::vector orbVec; // identifies which orbitals use this node + if (serial and node2orbVec[node_ix].size() <= 0) continue; + if (parindexVec_ref[n] < 0) + csize = sizecoeff; + else + csize = sizecoeffW; + // In the serial case we copy the coeff coeffBlock. In the mpi case coeffBlock is provided by the bank + if (serial) { + int shift = sizecoeff - sizecoeffW; // to copy only wavelet part + if (parindexVec_ref[n] < 0) shift = 0; + DoubleMatrix coeffBlock(csize, node2orbVec[node_ix].size()); + for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node + int orb_node_ix = orb2node[j][node_ix]; + for (int k = 0; k < csize; k++) coeffBlock(k, orbVec.size()) = coeffVec[j][orb_node_ix][k + shift]; + orbVec.push_back(j); + } + if (orbVec.size() > 0) { + DoubleMatrix S_temp(orbVec.size(), orbVec.size()); + coeffBlock = coeffBlock.cwiseAbs(); + S_temp.noalias() = coeffBlock.transpose() * coeffBlock; + for (int i = 0; i < orbVec.size(); i++) { + for (int j = 0; j < orbVec.size(); j++) { + if (BraKet[orbVec[i]].func_ptr->data.n1[0] != BraKet[orbVec[j]].func_ptr->data.n1[0] and BraKet[orbVec[i]].func_ptr->data.n1[0] != 0 and + BraKet[orbVec[j]].func_ptr->data.n1[0] != 0) + continue; + double &Srealij = Sreal(orbVec[i], orbVec[j]); + double &Stempij = S_temp(i, j); +#pragma omp atomic + Srealij += Stempij; + } + } + } + } else { // MPI case + DoubleMatrix coeffBlock(csize, N); + nodesBraKet.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbVec); + + if (orbVec.size() > 0) { + DoubleMatrix S_temp(orbVec.size(), orbVec.size()); + coeffBlock.conservativeResize(Eigen::NoChange, orbVec.size()); + coeffBlock = coeffBlock.cwiseAbs(); + S_temp.noalias() = coeffBlock.transpose() * coeffBlock; + for (int i = 0; i < orbVec.size(); i++) { + for (int j = 0; j < orbVec.size(); j++) { + if (BraKet[orbVec[i]].func_ptr->data.n1[0] != BraKet[orbVec[j]].func_ptr->data.n1[0] and BraKet[orbVec[i]].func_ptr->data.n1[0] != 0 and + BraKet[orbVec[j]].func_ptr->data.n1[0] != 0) + continue; + Sreal(orbVec[i], orbVec[j]) += S_temp(i, j); + } + } + } + } + } + + IntVector conjMat = IntVector::Zero(N); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(i)) continue; + conjMat[i] = (BraKet[i].conjugate()) ? -1 : 1; + } + mrcpp::mpi::allreduce_vector(conjMat, mrcpp::mpi::comm_wrk); + + for (int i = 0; i < N; i++) { + for (int j = 0; j <= i; j++) { + S(i, j) = Sreal(i, j) + conjMat[i] * conjMat[j] * Sreal(i + N, j + N) + conjMat[j] * Sreal(i, j + N) - conjMat[i] * Sreal(i + N, j); + S(j, i) = S(i, j); + } + } + + // Assumes linearity: result is sum of all nodes contributions + mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); + // multiply by CompFunction multiplicative factor + ComplexVector Fac = ComplexVector::Zero(N); + for (int i = 0; i < N; i++) { + if (!mrcpp::mpi::my_func(BraKet[i])) continue; + Fac[i] = BraKet[i].func_ptr->data.c1[0]; + } + mrcpp::mpi::allreduce_vector(Fac, mrcpp::mpi::comm_wrk); + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { S(i, j) *= std::norm(std::conj(Fac[i])) * std::norm(Fac[j]); } + } + return S; +} + +/** @brief Orthogonalize the functions in Bra against all orbitals in Ket + * + */ +void orthogonalize(double prec, CompFunctionVector &Bra, CompFunctionVector &Ket) { + // TODO: generalize for cases where Ket functions are not orthogonal to each other? + ComplexMatrix S = calc_overlap_matrix(Bra, Ket); + int N = Bra.size(); + int M = Ket.size(); + DoubleVector Ketnorms = DoubleVector::Zero(M); + for (int i = 0; i < M; i++) { + if (mpi::my_func(Ket[i])) Ketnorms(i) = Ket[i].getSquareNorm(); + } + mrcpp::mpi::allreduce_vector(Ketnorms, mrcpp::mpi::comm_wrk); + ComplexMatrix rmat = ComplexMatrix::Zero(M, N); + for (int j = 0; j < N; j++) { + for (int i = 0; i < M; i++) { rmat(i, j) = 0.0 - S.conjugate()(j, i) / Ketnorms(i); } + } + CompFunctionVector rotatedKet(N); + rotate(Ket, rmat, rotatedKet, prec / M); + for (int j = 0; j < N; j++) { + if (my_func(Bra[j])) Bra[j].add(1.0, rotatedKet[j]); + } +} + +/** @brief Orthogonalize the Bra against Ket + * + */ +template void orthogonalize(double prec, CompFunction &Bra, CompFunction &Ket) { + ComplexDouble overlap = dot(Bra, Ket); + double sq_norm = Ket.getSquareNorm(); + for (int i = 0; i < Bra.Ncomp(); i++) { + if (Bra.isreal()) { + if (abs(overlap.imag()) > MachineZero) MSG_ABORT("NOT IMPLEMENTED"); + Bra.CompD[i]->add_inplace(-overlap.real() / sq_norm, *Ket.CompD[i]); + } else { + if (Ket.isreal()) MSG_ABORT("NOT IMPLEMENTED"); + Bra.CompC[i]->add_inplace(-std::conj(overlap / sq_norm), *Ket.CompC[i]); + overlap = dot(Bra, Ket); + } + } +} + +template ComplexDouble dot(CompFunction<3> bra, CompFunction<3> ket); +template void project(CompFunction<3> &out, RepresentableFunction<3, double> &f, double prec); +template void project(CompFunction<3> &out, RepresentableFunction<3, ComplexDouble> &f, double prec); +template void multiply(CompFunction<3> &out, CompFunction<3> inp_a, CompFunction<3> inp_b, double prec, bool absPrec, bool useMaxNorms, bool conjugate); +template void multiply(CompFunction<3> &out, FunctionTree<3, double> &inp_a, RepresentableFunction<3, double> &f, double prec, int nrefine = 0, bool conjugate); +template void multiply(CompFunction<3> &out, FunctionTree<3, ComplexDouble> &inp_a, RepresentableFunction<3, ComplexDouble> &f, double prec, int nrefine = 0, bool conjugate); +template void multiply(CompFunction<3> &out, CompFunction<3> &inp_a, RepresentableFunction<3, double> &f, double prec, int nrefine = 0, bool conjugate); +template void multiply(CompFunction<3> &out, CompFunction<3> &inp_a, RepresentableFunction<3, ComplexDouble> &f, double prec, int nrefine = 0, bool conjugate); +template void deep_copy(CompFunction<3> *out, const CompFunction<3> &inp); +template void deep_copy(CompFunction<3> &out, const CompFunction<3> &inp); +template void add(CompFunction<3> &out, ComplexDouble a, CompFunction<3> inp_a, ComplexDouble b, CompFunction<3> inp_b, double prec, bool conjugate); +template void linear_combination(CompFunction<3> &out, const std::vector &c, std::vector> &inp, double prec, bool conjugate); +template double node_norm_dot(CompFunction<3> bra, CompFunction<3> ket); +template void orthogonalize(double prec, CompFunction<3> &Bra, CompFunction<3> &Ket); + +} // namespace mrcpp diff --git a/src/utils/CompFunction.h b/src/utils/CompFunction.h new file mode 100644 index 000000000..96ac057ca --- /dev/null +++ b/src/utils/CompFunction.h @@ -0,0 +1,202 @@ +#pragma once + +#include "mpi_utils.h" +#include "trees/FunctionTreeVector.h" + +using namespace Eigen; + +namespace mrcpp { + +template struct CompFunctionData { + // additional data that describe the overall multicomponent function (defined by user): + // occupancy, quantum number, norm, etc. + int Ncomp{0}; // number of components defined + int rank{-1}; // rank (index) if part of a vector + int conj{0}; // soft conjugate (all components) + int CompFn1{0}; + int CompFn2{0}; + int isreal{0}; // trees are defined for T=double + int iscomplex{0}; // trees are defined for T=DoubleComplex + double CompFd1{0.0}; + double CompFd2{0.0}; + double CompFd3{0.0}; + // additional data that describe each component (defined by user): + // occupancy, quantum number, norm, etc. + // Note: defined with fixed size to ease copying and MPI send + int n1[4]{0, 0, 0, 0}; // 0: neutral. otherwise different values are orthogonal to each other (product = 0) + int n2[4]{0, 0, 0, 0}; + int n3[4]{0, 0, 0, 0}; + int n4[4]{0, 0, 0, 0}; + // multiplicative scalar for the function. So far only actively used to take care of imag factor in momentum operator. + ComplexDouble c1[4]{{1.0, 0.0}, {1.0, 0.0}, {1.0, 0.0}, {1.0, 0.0}}; + double d1[4]{0.0, 0.0, 0.0, 0.0}; + double d2[4]{0.0, 0.0, 0.0, 0.0}; + double d3[4]{0.0, 0.0, 0.0, 0.0}; + // used for storage on disk + int type{0}; + int order{1}; + int scale{0}; + int depth{0}; + int boxes[3] = {0, 0, 0}; + int corner[3] = {0, 0, 0}; + + // used internally + int shared{0}; + int Nchunks[4]{0, 0, 0, 0}; // number of chunks of each component tree +}; + +template class TreePtr final { +public: + explicit TreePtr(bool share) + : shared_mem_real(nullptr) + , shared_mem_cplx(nullptr) { + for (int i = 0; i < 4; i++) real[i] = nullptr; + for (int i = 0; i < 4; i++) cplx[i] = nullptr; + is_shared = share; + if (is_shared and mpi::share_size > 1) { + // Memory size in MB defined in input. Virtual memory, does not cost anything if not used. +#ifdef MRCPP_HAS_MPI + this->shared_mem_real = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); + this->shared_mem_cplx = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); +#endif + } + } + + ~TreePtr() { + if (this->shared_mem_real != nullptr) delete this->shared_mem_real; + if (this->shared_mem_cplx != nullptr) delete this->shared_mem_cplx; + for (int i = 0; i < 4; i++) { + if (this->real[i] != nullptr) delete this->real[i]; + if (this->cplx[i] != nullptr) delete this->cplx[i]; + this->real[i] = nullptr; + this->cplx[i] = nullptr; + } + } + CompFunctionData data; + int &Ncomp = data.Ncomp; // number of components defined + int &rank = data.rank; // rank (index) if part of a vector + int &conj = data.conj; // soft conjugate + int &isreal = data.isreal; // T=double + int &iscomplex = data.iscomplex; // T=DoubleComplex + int &share = data.shared; + int *Nchunks = data.Nchunks; + + bool is_shared = false; + friend class CompFunction; + +protected: + FunctionTree *real[4]; // Real function + FunctionTree *cplx[4]; // Complex function + SharedMemory *shared_mem_real; + SharedMemory *shared_mem_cplx; +}; + +template class CompFunction { +public: + CompFunction(MultiResolutionAnalysis &mra); + CompFunction(); + CompFunction(int n1); + CompFunction(int n1, bool share); + CompFunction(const CompFunctionData &indata, bool alloc = false); + CompFunction(const CompFunction &compfunc); + CompFunction(CompFunction &&compfunc); + CompFunction &operator=(const CompFunction &compfunc); + virtual ~CompFunction() = default; + + FunctionTree **CompD; // = func_ptr->real so that we can use name CompD instead of func_ptr.real + FunctionTree **CompC; // = func_ptr->cplx + + std::string name; + + // additional data that describe each component (defined by user): + CompFunctionData data() const { return func_ptr->data; } + int Ncomp() const { return func_ptr->data.Ncomp; } // number of components defined + int rank() const { return func_ptr->data.rank; } // rank (index) if part of a vector + int conj() const { return func_ptr->data.conj; } // soft conjugate + int isreal() const { return func_ptr->data.isreal; } // T=double + int iscomplex() const { return func_ptr->data.iscomplex; } // T=DoubleComplex + void defreal() { func_ptr->data.isreal = 1; } // define as real + void defcomplex() { func_ptr->data.iscomplex = 1; } // define as complex + int share() const { return func_ptr->data.shared; } + int *Nchunks() const { return func_ptr->data.Nchunks; } // number of chunks of each component tree + + CompFunction paramCopy(bool alloc = false) const; + ComplexDouble integrate() const; + double norm() const; + double getSquareNorm() const; + void alloc(int nalloc = 1, bool zero = true); + void alloc_comp(int i = 0); // allocate one specific component + void setReal(FunctionTree *tree, int i = 0); + void setCplx(FunctionTree *tree, int i = 0); + void setRank(int i) { func_ptr->rank = i; }; + const int getRank() const { return func_ptr->rank; }; + void add(ComplexDouble c, CompFunction inp); + + int crop(double prec); + void rescale(ComplexDouble c); + void free(); + int getSizeNodes() const; + int getNNodes() const; + void flushMRAData(); + void flushFuncData(); + CompFunctionData getFuncData() const; + FunctionTree &real(int i = 0); + FunctionTree &complex(int i = 0); + const FunctionTree &real(int i = 0) const; + const FunctionTree &complex(int i = 0) const; + + // NB: All below should be revised. Now only for backwards compatibility to ComplexFunction class + + void free(int type) { free(); } + bool hasReal() const { return isreal(); } + bool hasImag() const { return iscomplex(); } + bool isShared() const { return share(); } + bool conjugate() const { return conj(); } + void dagger(); + FunctionTree &imag(int i = 0); // does not make sense now + const FunctionTree &imag(int i = 0) const; // does not make sense now + std::shared_ptr> func_ptr; +}; + +template void deep_copy(CompFunction *out, const CompFunction &inp); +template void deep_copy(CompFunction &out, const CompFunction &inp); +template void add(CompFunction &out, ComplexDouble a, CompFunction inp_a, ComplexDouble b, CompFunction inp_b, double prec, bool conjugate = false); +template void linear_combination(CompFunction &out, const std::vector &c, std::vector> &inp, double prec, bool conjugate = false); +template void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); +template +void multiply(double prec, CompFunction &out, double coef, CompFunction inp_a, CompFunction inp_b, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); +template void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); +template void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); +template void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); +template void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); +template void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); +template ComplexDouble dot(CompFunction bra, CompFunction ket); +template double node_norm_dot(CompFunction bra, CompFunction ket); +void project(CompFunction<3> &out, std::function &r)> f, double prec); +void project(CompFunction<3> &out, std::function &r)> f, double prec); +template void project(CompFunction &out, RepresentableFunction &f, double prec); +template void project(CompFunction &out, RepresentableFunction &f, double prec); +template void orthogonalize(double prec, CompFunction &Bra, CompFunction &Ket); + +class CompFunctionVector : public std::vector> { +public: + CompFunctionVector(int N = 0); + MultiResolutionAnalysis<3> *vecMRA; + void distribute(); +}; + +void rotate(CompFunctionVector &Phi, const ComplexMatrix &U, double prec = -1.0); +void rotate(CompFunctionVector &Phi, const ComplexMatrix &U, CompFunctionVector &Psi, double prec = -1.0); +// void rotate_cplx(CompFunctionVector &Phi, const ComplexMatrix &U, CompFunctionVector &Psi, double prec = -1.0); +void save_nodes(CompFunctionVector &Phi, mrcpp::FunctionTree<3, double> &refTree, BankAccount &account, int sizes = -1); +CompFunctionVector multiply(CompFunctionVector &Phi, RepresentableFunction<3> &f, double prec = -1.0, CompFunction<3> *Func = nullptr, int nrefine = 1, bool all = false); +void SetdefaultMRA(MultiResolutionAnalysis<3> *MRA); +ComplexVector dot(CompFunctionVector &Bra, CompFunctionVector &Ket); +ComplexMatrix calc_lowdin_matrix(CompFunctionVector &Phi); +ComplexMatrix calc_overlap_matrix(CompFunctionVector &BraKet); +ComplexMatrix calc_overlap_matrix(CompFunctionVector &Bra, CompFunctionVector &Ket); +// ComplexMatrix calc_overlap_matrix_cplx(CompFunctionVector &Bra, CompFunctionVector &Ket); +DoubleMatrix calc_norm_overlap_matrix(CompFunctionVector &BraKet); +void orthogonalize(double prec, CompFunctionVector &Bra, CompFunctionVector &Ket); + +} // namespace mrcpp diff --git a/src/utils/ComplexFunction.cpp b/src/utils/ComplexFunction.cpp deleted file mode 100644 index 11613780e..000000000 --- a/src/utils/ComplexFunction.cpp +++ /dev/null @@ -1,2016 +0,0 @@ -#include "ComplexFunction.h" -#include "Bank.h" -#include "Printer.h" -#include "Timer.h" -#include "parallel.h" -#include "treebuilders/grid.h" -#include "treebuilders/multiply.h" -#include "treebuilders/project.h" -#include "trees/FunctionNode.h" -#include "treebuilders/add.h" - -using mrcpp::Timer; - -namespace mrcpp { - -MultiResolutionAnalysis<3> *defaultMRA; // Global MRA - -ComplexFunction::ComplexFunction(std::shared_ptr funcptr) - : funcMRA(defaultMRA) - , func_ptr(funcptr) {} - -ComplexFunction::ComplexFunction(const ComplexFunction &func) - : funcMRA(func.funcMRA) - , conj(func.conj) - , func_ptr(func.func_ptr) - , rank(func.rank) {} - -ComplexFunction &ComplexFunction::operator=(const ComplexFunction &func) { - if (this != &func) { - this->conj = func.conj; - this->func_ptr = func.func_ptr; - this->funcMRA = func.funcMRA; - this->rank = func.rank; - } - return *this; -} - -/** @brief Constructor - * - * @param spin: electron spin (SPIN::Alpha/Beta/Paired) - * @param occ: occupation - * @param rank: MPI ownership (-1 means all MPI ranks) - * - * Initializes the mrcpp::ComplexFunction with NULL pointers for both real and imaginary part. - */ -ComplexFunction::ComplexFunction(int spin, double occ, int rank, bool share) - : funcMRA(defaultMRA) - , func_ptr(std::make_shared(share)) - , rank(rank) { - this->getFunctionData().spin = spin; - this->getFunctionData().occ = occ; - if (this->spin() < 0) INVALID_ARG_ABORT; - if (this->occ() < 0) { - if (this->spin() == SPIN::Paired) this->getFunctionData().occ = 2; - if (this->spin() == SPIN::Alpha) this->getFunctionData().occ = 1; - if (this->spin() == SPIN::Beta) this->getFunctionData().occ = 1; - } -} - -/** @brief Parameter copy - * - * Returns a new ComplexFunction with the same spin, occupation and rank_id as *this. - */ -ComplexFunction ComplexFunction::paramCopy() const { - return ComplexFunction(this->spin(), this->occ(), this->getRank()); -} - -MPI_FuncVector::MPI_FuncVector(int N) - : std::vector(N) { - for (int i = 0; i < N; i++) (*this)[i].setRank(i); - vecMRA = defaultMRA; -} -void MPI_FuncVector::distribute() { - for (int i = 0; i < this->size(); i++) (*this)[i].setRank(i); -} - -/** @brief Returns the orbital meta data - * - * Tree sizes (nChunks) are flushed before return. - */ -FunctionData &ComplexFunction::getFunctionData() { - this->func_ptr->flushFuncData(); - return this->func_ptr->func_data; -} - -ComplexFunction ComplexFunction::dagger() { - ComplexFunction out(*this); - out.conj = not(this->conj); - return out; // Return shallow copy -} - -void ComplexFunction::setReal(FunctionTree<3> *tree) { - if (isShared()) MSG_ABORT("Cannot set in shared function"); - this->func_ptr->re = tree; -} - -void ComplexFunction::setImag(FunctionTree<3> *tree) { - if (isShared()) MSG_ABORT("Cannot set in shared function"); - this->func_ptr->im = tree; -} - -void ComplexFunction::alloc(int type, MultiResolutionAnalysis<3> *mra) { - if (mra == nullptr) mra = funcMRA; - if (mra == nullptr) MSG_ABORT("Invalid argument"); - if (type == NUMBER::Real or type == NUMBER::Total) { - if (hasReal()) MSG_ABORT("Real part already allocated"); - this->func_ptr->re = new FunctionTree<3>(*mra, this->func_ptr->shared_mem_re); - } - if (type == NUMBER::Imag or type == NUMBER::Total) { - if (hasImag()) MSG_ABORT("Imaginary part already allocated"); - this->func_ptr->im = new FunctionTree<3>(*mra, this->func_ptr->shared_mem_im); - } -} - -void ComplexFunction::free(int type) { - if (type == NUMBER::Real or type == NUMBER::Total) { - if (hasReal()) delete this->func_ptr->re; - this->func_ptr->re = nullptr; - if (this->func_ptr->shared_mem_re) this->func_ptr->shared_mem_re->clear(); - } - if (type == NUMBER::Imag or type == NUMBER::Total) { - if (hasImag()) delete this->func_ptr->im; - this->func_ptr->im = nullptr; - if (this->func_ptr->shared_mem_im) this->func_ptr->shared_mem_im->clear(); - } -} - -int ComplexFunction::getSizeNodes(int type) const { - int size_mb = 0; // Memory size in kB - if (type == NUMBER::Real or type == NUMBER::Total) { - if (hasReal()) size_mb += real().getSizeNodes(); - } - if (type == NUMBER::Imag or type == NUMBER::Total) { - if (hasImag()) size_mb += imag().getSizeNodes(); - } - return size_mb; -} - -int ComplexFunction::getNNodes(int type) const { - int nNodes = 0; - if (type == NUMBER::Real or type == NUMBER::Total) { - if (hasReal()) nNodes += real().getNNodes(); - } - if (type == NUMBER::Imag or type == NUMBER::Total) { - if (hasImag()) nNodes += imag().getNNodes(); - } - return nNodes; -} - -int ComplexFunction::crop(double prec) { - if (prec < 0.0) return 0; - bool need_to_crop = not(isShared()) or mpi::share_master(); - int nChunksremoved = 0; - if (need_to_crop) { - if (hasReal()) nChunksremoved = real().crop(prec, 1.0, false); - if (hasImag()) nChunksremoved += imag().crop(prec, 1.0, false); - } - mpi::share_function(*this, 0, 7744, mpi::comm_share); - return nChunksremoved; -} - -ComplexDouble ComplexFunction::integrate() const { - double int_r = 0.0; - double int_i = 0.0; - if (hasReal()) int_r = real().integrate(); - if (hasImag()) int_i = imag().integrate(); - return ComplexDouble(int_r, int_i); -} - -/** @brief Returns the norm of the orbital */ -double ComplexFunction::norm() const { - double norm = squaredNorm(); - if (norm > 0.0) norm = std::sqrt(norm); - return norm; -} - -/** @brief Returns the squared norm of the orbital */ -double ComplexFunction::squaredNorm() const { - double sq_r = -1.0; - double sq_i = -1.0; - if (hasReal()) sq_r = real().getSquareNorm(); - if (hasImag()) sq_i = imag().getSquareNorm(); - - double sq_norm = 0.0; - if (sq_r < 0.0 and sq_i < 0.0) { - sq_norm = -1.0; - } else { - if (sq_r >= 0.0) sq_norm += sq_r; - if (sq_i >= 0.0) sq_norm += sq_i; - } - return sq_norm; -} - -/** @brief In place addition. - * - * Output is extended to union grid. - * - */ -void ComplexFunction::add(ComplexDouble c, ComplexFunction inp) { - double thrs = MachineZero; - bool cHasReal = (std::abs(c.real()) > thrs); - bool cHasImag = (std::abs(c.imag()) > thrs); - bool outNeedsReal = (cHasReal and inp.hasReal()) or (cHasImag and inp.hasImag()); - bool outNeedsImag = (cHasReal and inp.hasImag()) or (cHasImag and inp.hasReal()); - - ComplexFunction &out = *this; - bool clearReal(false), clearImag(false); - if (outNeedsReal and not(out.hasReal())) { - out.alloc(NUMBER::Real); - clearReal = true; - } - - if (outNeedsImag and not(out.hasImag())) { - out.alloc(NUMBER::Imag); - clearImag = true; - } - - bool need_to_add = not(out.isShared()) or mpi::share_master(); - if (need_to_add) { - if (clearReal) out.real().setZero(); - if (clearImag) out.imag().setZero(); - if (cHasReal and inp.hasReal()) { - while (refine_grid(out.real(), inp.real())) {} - out.real().add(c.real(), inp.real()); - } - if (cHasReal and inp.hasImag()) { - double conj = (inp.conjugate()) ? -1.0 : 1.0; - while (refine_grid(out.imag(), inp.imag())) {} - out.imag().add(conj * c.real(), inp.imag()); - } - if (cHasImag and inp.hasReal()) { - while (refine_grid(out.imag(), inp.real())) {} - out.imag().add(c.imag(), inp.real()); - } - if (cHasImag and inp.hasImag()) { - double conj = (inp.conjugate()) ? -1.0 : 1.0; - while (refine_grid(out.real(), inp.imag())) {} - out.real().add(-1.0 * conj * c.imag(), inp.imag()); - } - } - mpi::share_function(out, 0, 9911, mpi::comm_share); -} - -/** @brief In place addition of absolute values. - * - * Output is extended to union grid. - * - */ -void ComplexFunction::absadd(ComplexDouble c, ComplexFunction inp) { - double thrs = MachineZero; - bool cHasReal = (std::abs(c.real()) > thrs); - bool cHasImag = (std::abs(c.imag()) > thrs); - bool outNeedsReal = (cHasReal and inp.hasReal()) or (cHasImag and inp.hasImag()); - bool outNeedsImag = (cHasReal and inp.hasImag()) or (cHasImag and inp.hasReal()); - - ComplexFunction &out = *this; - bool clearReal(false), clearImag(false); - if (outNeedsReal and not(out.hasReal())) { - out.alloc(NUMBER::Real); - clearReal = true; - } - - if (outNeedsImag and not(out.hasImag())) { - out.alloc(NUMBER::Imag); - clearImag = true; - } - - bool need_to_add = not(out.isShared()) or mpi::share_master(); - if (need_to_add) { - if (clearReal) out.real().setZero(); - if (clearImag) out.imag().setZero(); - if (cHasReal and inp.hasReal()) { - while (refine_grid(out.real(), inp.real())) {} - out.real().absadd(c.real(), inp.real()); - } - if (cHasReal and inp.hasImag()) { - double conj = (inp.conjugate()) ? -1.0 : 1.0; - while (refine_grid(out.imag(), inp.imag())) {} - out.imag().absadd(conj * c.real(), inp.imag()); - } - if (cHasImag and inp.hasReal()) { - while (refine_grid(out.imag(), inp.real())) {} - out.imag().absadd(c.imag(), inp.real()); - } - if (cHasImag and inp.hasImag()) { - double conj = (inp.conjugate()) ? -1.0 : 1.0; - while (refine_grid(out.real(), inp.imag())) {} - out.real().absadd(-1.0 * conj * c.imag(), inp.imag()); - } - } - mpi::share_function(out, 0, 9912, mpi::comm_share); -} - -/** @brief In place multiply with real scalar. Fully in-place.*/ -void ComplexFunction::rescale(double c) { - bool need_to_rescale = not(isShared()) or mpi::share_master(); - if (need_to_rescale) { - if (hasReal()) real().rescale(c); - if (hasImag()) imag().rescale(c); - } - mpi::share_function(*this, 0, 5543, mpi::comm_share); -} - -/** @brief In place multiply with complex scalar. Involves a deep copy.*/ -void ComplexFunction::rescale(ComplexDouble c) { - ComplexFunction &out = *this; - ComplexFunction tmp(spin(), occ(), rank, isShared()); - cplxfunc::deep_copy(tmp, out); - out.free(NUMBER::Total); - out.add(c, tmp); -} - -/** @brief Returns a character representing the spin (a/b/p) */ -char ComplexFunction::printSpin() const { - char sp = 'u'; - if (this->spin() == SPIN::Paired) sp = 'p'; - if (this->spin() == SPIN::Alpha) sp = 'a'; - if (this->spin() == SPIN::Beta) sp = 'b'; - return sp; -} - -void cplxfunc::SetdefaultMRA(MultiResolutionAnalysis<3> *MRA) { - defaultMRA = MRA; -} - -/** @brief Compute = int bra^\dag(r) * ket(r) dr. - * - * Notice that the = int |bra^\dag(r)| * |ket(r)| dr. - * - */ -ComplexDouble cplxfunc::node_norm_dot(ComplexFunction bra, ComplexFunction ket, bool exact) { - double rr(0.0), ri(0.0), ir(0.0), ii(0.0); - if (bra.hasReal() and ket.hasReal()) rr = mrcpp::node_norm_dot(bra.real(), ket.real(), exact); - if (bra.hasReal() and ket.hasImag()) ri = mrcpp::node_norm_dot(bra.real(), ket.imag(), exact); - if (bra.hasImag() and ket.hasReal()) ir = mrcpp::node_norm_dot(bra.imag(), ket.real(), exact); - if (bra.hasImag() and ket.hasImag()) ii = mrcpp::node_norm_dot(bra.imag(), ket.imag(), exact); - - double bra_conj = (bra.conjugate()) ? -1.0 : 1.0; - double ket_conj = (ket.conjugate()) ? -1.0 : 1.0; - - double real_part = rr + bra_conj * ket_conj * ii; - double imag_part = ket_conj * ri - bra_conj * ir; - return ComplexDouble(real_part, imag_part); -} - -/** @brief Deep copy - * - * Returns a new function which is a full blueprint copy of the input function. - * This is achieved by building a new grid for the real and imaginary parts and - * copying. - */ -void cplxfunc::deep_copy(ComplexFunction &out, ComplexFunction &inp) { - bool need_to_copy = not(out.isShared()) or mpi::share_master(); - out.funcMRA = inp.funcMRA; - out.setRank(inp.getRank()); - if (inp.hasReal()) { - if (not out.hasReal()) out.alloc(NUMBER::Real); - if (need_to_copy) { - copy_grid(out.real(), inp.real()); - copy_func(out.real(), inp.real()); - } - } - if (inp.hasImag()) { - if (not out.hasImag()) out.alloc(NUMBER::Imag); - if (need_to_copy) { - copy_grid(out.imag(), inp.imag()); - copy_func(out.imag(), inp.imag()); - if (out.conjugate()) out.imag().rescale(-1.0); - } - } - mpi::share_function(out, 0, 1324, mpi::comm_share); -} - -void cplxfunc::project(ComplexFunction &out, std::function &r)> f, int type, double prec) { - bool need_to_project = not(out.isShared()) or mpi::share_master(); - if (type == NUMBER::Real or type == NUMBER::Total) { - if (not out.hasReal()) out.alloc(NUMBER::Real); - if (need_to_project) mrcpp::project<3>(prec, out.real(), f); - } - if (type == NUMBER::Imag or type == NUMBER::Total) { - if (not out.hasImag()) out.alloc(NUMBER::Imag); - if (need_to_project) mrcpp::project<3>(prec, out.imag(), f); - } - mpi::share_function(out, 0, 123123, mpi::comm_share); -} - -void cplxfunc::project(ComplexFunction &out, RepresentableFunction<3> &f, int type, double prec) { - bool need_to_project = not(out.isShared()) or mpi::share_master(); - if (type == NUMBER::Real or type == NUMBER::Total) { - if (not out.hasReal()) out.alloc(NUMBER::Real); - if (need_to_project) build_grid(out.real(), f); - if (need_to_project) mrcpp::project<3>(prec, out.real(), f); - } - if (type == NUMBER::Imag or type == NUMBER::Total) { - if (not out.hasImag()) out.alloc(NUMBER::Imag); - if (need_to_project) build_grid(out.imag(), f); - if (need_to_project) mrcpp::project<3>(prec, out.imag(), f); - } - mpi::share_function(out, 0, 132231, mpi::comm_share); -} - -/** @brief out = a*inp_a + b*inp_b - * - * Recast into linear_combination. - * - */ -void cplxfunc::add(ComplexFunction &out, ComplexDouble a, ComplexFunction inp_a, ComplexDouble b, ComplexFunction inp_b, double prec) { - ComplexVector coefs(2); - coefs(0) = a; - coefs(1) = b; - - std::vector funcs; // NB: not a ComplexFunctionVector, because not run in parallel! - funcs.push_back(inp_a); - funcs.push_back(inp_b); - - cplxfunc::linear_combination(out, coefs, funcs, prec); -} - -/** @brief out = inp_a * inp_b - * - */ -void cplxfunc::multiply(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec, bool useMaxNorms) { - multiply_real(out, inp_a, inp_b, prec, absPrec, useMaxNorms); - multiply_imag(out, inp_a, inp_b, prec, absPrec, useMaxNorms); -} - -/** @brief out = inp_a * f - * - */ -void cplxfunc::multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3> &f, double prec, int nrefine) { - // uses the mpifuncvec multiply - MPI_FuncVector mpi_funcvec_a; - mpi_funcvec_a.push_back(inp_a); - MPI_FuncVector mpi_funcvec_out; - mpi_funcvec_out = mpifuncvec::multiply(mpi_funcvec_a, f, prec, nullptr, nrefine, true); - out = mpi_funcvec_out[0]; -} - -/** @brief out = inp_a * f - * - */ -void cplxfunc::multiply(ComplexFunction &out, FunctionTree<3> &inp_a, RepresentableFunction<3> &f, double prec, int nrefine) { - ComplexFunction cplxfunc_a; - cplxfunc_a.setReal(&inp_a); - cplxfunc::multiply(out, cplxfunc_a, f, prec, nrefine); - cplxfunc_a.setReal(nullptr); // otherwise inp_a is deleted by cplxfunc_a destructor -} - -/** @brief out = c_0*inp_0 + c_1*inp_1 + ... + c_N*inp_N - * - */ -void cplxfunc::linear_combination(ComplexFunction &out, const ComplexVector &c, std::vector &inp, double prec) { - FunctionTreeVector<3> rvec; - FunctionTreeVector<3> ivec; - - double thrs = MachineZero; - for (int i = 0; i < inp.size(); i++) { - double sign = (inp[i].conjugate()) ? -1.0 : 1.0; - - bool cHasReal = (std::abs(c[i].real()) > thrs); - bool cHasImag = (std::abs(c[i].imag()) > thrs); - - if (cHasReal and inp[i].hasReal()) rvec.push_back(std::make_tuple(c[i].real(), &inp[i].real())); - if (cHasImag and inp[i].hasImag()) rvec.push_back(std::make_tuple(-sign * c[i].imag(), &inp[i].imag())); - - if (cHasImag and inp[i].hasReal()) ivec.push_back(std::make_tuple(c[i].imag(), &inp[i].real())); - if (cHasReal and inp[i].hasImag()) ivec.push_back(std::make_tuple(sign * c[i].real(), &inp[i].imag())); - } - - if (rvec.size() > 0 and not out.hasReal()) out.alloc(NUMBER::Real); - if (ivec.size() > 0 and not out.hasImag()) out.alloc(NUMBER::Imag); - - bool need_to_add = not(out.isShared()) or mpi::share_master(); - if (need_to_add) { - if (rvec.size() > 0) { - if (prec < 0.0) { - build_grid(out.real(), rvec); - mrcpp::add(prec, out.real(), rvec, 0); - } else { - mrcpp::add(prec, out.real(), rvec); - } - } else if (out.hasReal()) { - out.real().setZero(); - } - if (ivec.size() > 0) { - if (prec < 0.0) { - build_grid(out.imag(), ivec); - mrcpp::add(prec, out.imag(), ivec, 0); - } else { - mrcpp::add(prec, out.imag(), ivec); - } - } else if (out.hasImag()) { - out.imag().setZero(); - } - } - mpi::share_function(out, 0, 9911, mpi::comm_share); -} - -/** @brief out = Re(inp_a * inp_b) - * - */ -void cplxfunc::multiply_real(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec, bool useMaxNorms) { - double conj_a = (inp_a.conjugate()) ? -1.0 : 1.0; - double conj_b = (inp_b.conjugate()) ? -1.0 : 1.0; - - bool need_to_multiply = not(out.isShared()) or mpi::share_master(); - - FunctionTreeVector<3> vec; - if (inp_a.hasReal() and inp_b.hasReal()) { - auto *tree = new FunctionTree<3>(inp_a.real().getMRA()); - if (need_to_multiply) { - double coef = 1.0; - if (prec < 0.0) { - // Union grid - build_grid(*tree, inp_a.real()); - build_grid(*tree, inp_b.real()); - mrcpp::multiply(prec, *tree, coef, inp_a.real(), inp_b.real(), 0); - } else { - // Adaptive grid - mrcpp::multiply(prec, *tree, coef, inp_a.real(), inp_b.real(), -1, absPrec, useMaxNorms); - } - } - vec.push_back(std::make_tuple(1.0, tree)); - } - if (inp_a.hasImag() and inp_b.hasImag()) { - auto *tree = new FunctionTree<3>(inp_a.imag().getMRA()); - if (need_to_multiply) { - double coef = -1.0 * conj_a * conj_b; - if (prec < 0.0) { - // Union grid - build_grid(*tree, inp_a.imag()); - build_grid(*tree, inp_b.imag()); - mrcpp::multiply(prec, *tree, coef, inp_a.imag(), inp_b.imag(), 0); - } else { - // Adaptive grid - mrcpp::multiply(prec, *tree, coef, inp_a.imag(), inp_b.imag(), -1, absPrec, useMaxNorms); - } - } - vec.push_back(std::make_tuple(1.0, tree)); - } - - if (vec.size() > 0) { - if (out.hasReal()) { - if (need_to_multiply) out.real().clear(); - } else { - // All sharing procs must allocate - out.alloc(NUMBER::Real); - } - } - - if (need_to_multiply) { - if (vec.size() == 1) { - FunctionTree<3> &func_0 = get_func(vec, 0); - copy_grid(out.real(), func_0); - copy_func(out.real(), func_0); - clear(vec, true); - } else if (vec.size() == 2) { - build_grid(out.real(), vec); - mrcpp::add(prec, out.real(), vec, 0); - clear(vec, true); - } else if (out.hasReal()) { - out.real().setZero(); - } - } - mpi::share_function(out, 0, 9191, mpi::comm_share); -} - -/** @brief out = Im(inp_a * inp_b) - * - */ -void cplxfunc::multiply_imag(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec, bool useMaxNorms) { - double conj_a = (inp_a.conjugate()) ? -1.0 : 1.0; - double conj_b = (inp_b.conjugate()) ? -1.0 : 1.0; - bool need_to_multiply = not(out.isShared()) or mpi::share_master(); - - FunctionTreeVector<3> vec; - if (inp_a.hasReal() and inp_b.hasImag()) { - auto *tree = new FunctionTree<3>(inp_a.real().getMRA()); - if (need_to_multiply) { - double coef = conj_b; - if (prec < 0.0) { - // Union grid - build_grid(*tree, inp_a.real()); - build_grid(*tree, inp_b.imag()); - mrcpp::multiply(prec, *tree, coef, inp_a.real(), inp_b.imag(), 0); - } else { - // Adaptive grid - mrcpp::multiply(prec, *tree, coef, inp_a.real(), inp_b.imag(), -1, absPrec, useMaxNorms); - } - } - vec.push_back(std::make_tuple(1.0, tree)); - } - if (inp_a.hasImag() and inp_b.hasReal()) { - auto *tree = new FunctionTree<3>(inp_a.imag().getMRA()); - if (need_to_multiply) { - double coef = conj_a; - if (prec < 0.0) { - // Union grid - build_grid(*tree, inp_a.imag()); - build_grid(*tree, inp_b.real()); - mrcpp::multiply(prec, *tree, coef, inp_a.imag(), inp_b.real(), 0); - } else { - // Adaptive grid - mrcpp::multiply(prec, *tree, coef, inp_a.imag(), inp_b.real(), -1, absPrec, useMaxNorms); - } - } - vec.push_back(std::make_tuple(1.0, tree)); - } - - if (vec.size() > 0) { - if (out.hasImag()) { - if (need_to_multiply) out.imag().clear(); - } else { - // All sharing procs must allocate - out.alloc(NUMBER::Imag); - } - } - - if (need_to_multiply) { - if (vec.size() == 1) { - FunctionTree<3> &func_0 = get_func(vec, 0); - copy_grid(out.imag(), func_0); - copy_func(out.imag(), func_0); - clear(vec, true); - } else if (vec.size() == 2) { - build_grid(out.imag(), vec); - mrcpp::add(prec, out.imag(), vec, 0); - clear(vec, true); - } else if (out.hasImag()) { - out.imag().setZero(); - } - } - mpi::share_function(out, 0, 9292, mpi::comm_share); -} - -namespace mpifuncvec { - - -/** @brief Make a linear combination of functions - * - * Uses "local" representation: treats one node at a time. - * For each node, all functions are transformed simultaneously - * by a dense matrix multiplication. - * Phi input functions, Psi output functions - * - */ -void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, MPI_FuncVector &Psi, double prec) { - - // The principle of this routine is that nodes are rotated one by one using matrix multiplication. - // The routine does avoid when possible to move data, but uses pointers and indices manipulation. - // MPI version does not use OMP yet, Serial version uses OMP - // size of input is N, size of output is M - int N = Phi.size(); - int M = Psi.size(); - if (U.rows() < N) MSG_ABORT("Incompatible number of rows for U matrix"); - if (U.cols() < M) MSG_ABORT("Incompatible number of columns for U matrix"); - - // 1) make union tree without coefficients - FunctionTree<3> refTree(*Phi.vecMRA); - mpi::allreduce_Tree_noCoeff(refTree, Phi, mpi::comm_wrk); - - int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); - int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - std::vector scalefac_ref; - std::vector coeffVec_ref; // not used! - std::vector indexVec_ref; // serialIx of the nodes - std::vector parindexVec_ref; // serialIx of the parent nodes - int max_ix; - // get a list of all nodes in union tree, identified by their serialIx indices - refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac_ref, max_ix, refTree); - int max_n = indexVec_ref.size(); - - // 2) We work with real numbers only. Make real blocks for U matrix - bool UhasReal = false; - bool UhasImag = false; - for (int i = 0; i < N; i++) { - for (int j = 0; j < M; j++) { - if (std::abs(U(i, j).real()) > 10*MachineZero) UhasReal = true; - if (std::abs(U(i, j).imag()) > 10*MachineZero) UhasImag = true; - } - } - - IntVector PsihasReIm = IntVector::Zero(2); - for (int j = 0; j < N; j++) { - if (!mpi::my_orb(j)) continue; - PsihasReIm[0] = (Phi[j].hasReal()) ? 1 : 0; - PsihasReIm[1] = (Phi[j].hasImag()) ? 1 : 0; - } - mpi::allreduce_vector(PsihasReIm, mpi::comm_wrk); - if (not PsihasReIm[0] and not PsihasReIm[1]) { - return; // do nothing - } - - bool makeReal = (UhasReal and PsihasReIm[0]) or (UhasImag and PsihasReIm[1]); - bool makeImag = (UhasReal and PsihasReIm[1]) or (UhasImag and PsihasReIm[0]); - - for (int j = 0; j < M; j++) { - if (!mpi::my_orb(j)) continue; - if (not makeReal and Psi[j].hasReal()) Psi[j].free(NUMBER::Real); - if (not makeImag and Psi[j].hasImag()) Psi[j].free(NUMBER::Imag); - } - - if (not makeReal and not makeImag) { return; } - - int Neff = N; // effective number of input orbitals - int Meff = M; // effective number of output orbitals - if (makeImag) Neff = 2 * N; // Imag and Real treated independently. We always use real part of U - if (makeImag) Meff = 2 * M; // Imag and Real treated independently. We always use real part of U - - IntVector conjMat = IntVector::Zero(Neff); - for (int j = 0; j < Neff; j++) { - if (!mpi::my_orb(j % N)) continue; - conjMat[j] = (Phi[j % N].conjugate()) ? -1 : 1; - } - mpi::allreduce_vector(conjMat, mpi::comm_wrk); - - // we make a real matrix = U, but organized as one or four real blocks - // out_r = U_rr*in_r - U_ir*in_i*conjMat - // out_i = U_ri*in_r - U_ii*in_i*conjMat - // the first index of U is the one used on input Phi - DoubleMatrix Ureal(Neff, Meff); // four blocks, for rr ri ir ii - for (int j = 0; j < Neff; j++) { - for (int i = 0; i < Meff; i++) { - double sign = 1.0; - if (j < N and i < M) { - // real U applied on real Phi - Ureal(j, i) = U.real()(j % N, i % M); - } else if (j >= N and i >= M) { - // real U applied on imag Phi - Ureal(j, i) = conjMat[j] * U.real()(j % N, i % M); - } else if (j < N and i >= M) { - // imag U applied on real Phi - Ureal(j, i) = U.imag()(j % N, i % M); - } else { - // imag U applied on imag Phi - Ureal(j, i) = -1.0 * conjMat[j] * U.imag()(j % N, i % M); - } - } - } - - // 3) In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank - - bool serial = mpi::wrk_size == 1; // flag for serial/MPI switch - BankAccount nodesPhi; // to put the original nodes - BankAccount nodesRotated; // to put the rotated nodes - - // used for serial only: - std::vector> coeffVec(Neff); - std::vector> indexVec(Neff); // serialIx of the nodes - std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node - std::vector> orb2node(Neff); // for a given orbital and a given node, gives the node index in the - // orbital given the node index in the reference tree - if (serial) { - - // make list of all coefficients (coeffVec), and their reference indices (indexVec) - std::vector parindexVec; // serialIx of the parent nodes - std::vector scalefac; - for (int j = 0; j < N; j++) { - // make vector with all coef pointers and their indices in the union grid - if (Phi[j].hasReal()) { - Phi[j].real().makeCoeffVector(coeffVec[j], indexVec[j], parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec[j]) { - orb2node[j][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j); - } - } - if (Phi[j].hasImag()) { - Phi[j].imag().makeCoeffVector(coeffVec[j + N], indexVec[j + N], parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec[j + N]) { - orb2node[j + N][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j + N); - } - } - } - } else { // MPI case - - // send own nodes to bank, identifying them through the serialIx of refTree - mpifuncvec::save_nodes(Phi, refTree, nodesPhi); - mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. - } - - // 4) rotate all the nodes - IntMatrix split_serial; // in the serial case all split are stored in one array - std::vector> coeffpVec(Meff); // to put pointers to the rotated coefficient for each orbital in serial case - std::vector> ix2coef(Meff); // to find the index in for example rotCoeffVec[] corresponding to a serialIx - int csize; // size of the current coefficients (different for roots and branches) - std::vector rotatedCoeffVec; // just to ensure that the data from rotatedCoeff is not deleted, since we point to it. - // j indices are for unrotated orbitals, i indices are for rotated orbitals - if (serial) { - std::map ix2coef_ref; // to find the index n corresponding to a serialIx - split_serial.resize(Meff, max_n); // not use in the MPI case - for (int n = 0; n < max_n; n++) { - int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree - ix2coef_ref[node_ix] = n; - for (int i = 0; i < Meff; i++) split_serial(i, n) = 1; - } - - std::vector nodeReady(max_n, 0); // To indicate to OMP threads that the parent is ready (for splits) - - // assumes the nodes are ordered such that parent are treated before children. BFS or DFS ok. - // NB: the n must be traversed approximately in right order: Thread n may have to wait until som other preceding - // n is finished. -#pragma omp parallel for schedule(dynamic) - for (int n = 0; n < max_n; n++) { - int csize; - int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree - // 4a) make a dense contiguous matrix with the coefficient from all the orbitals using node n - std::vector orbjVec; // to remember which orbital correspond to each orbVec.size(); - if (node2orbVec[node_ix].size() <= 0) continue; - csize = sizecoeffW; - if (parindexVec_ref[n] < 0) csize = sizecoeff; // for root nodes we include scaling coeff - - int shift = sizecoeff - sizecoeffW; // to copy only wavelet part - if (parindexVec_ref[n] < 0) shift = 0; - DoubleMatrix coeffBlock(csize, node2orbVec[node_ix].size()); - for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node - int orb_node_ix = orb2node[j][node_ix]; - for (int k = 0; k < csize; k++) coeffBlock(k, orbjVec.size()) = coeffVec[j][orb_node_ix][k + shift]; - orbjVec.push_back(j); - } - - // 4b) make a list of rotated orbitals needed for this node - // OMP must wait until parent is ready - while (parindexVec_ref[n] >= 0 and nodeReady[ix2coef_ref[parindexVec_ref[n]]] == 0) { -#pragma omp flush - }; - - std::vector orbiVec; - for (int i = 0; i < Meff; i++) { // loop over all rotated orbitals - if (not makeReal and i < M) continue; - if (not makeImag and i >= M) continue; - if (parindexVec_ref[n] >= 0 and split_serial(i, ix2coef_ref[parindexVec_ref[n]]) == 0) continue; // parent node has too small wavelets - orbiVec.push_back(i); - } - - // 4c) rotate this node - DoubleMatrix Un(orbjVec.size(), orbiVec.size()); // chunk of U, with reorganized indices - for (int i = 0; i < orbiVec.size(); i++) { // loop over rotated orbitals - for (int j = 0; j < orbjVec.size(); j++) { Un(j, i) = Ureal(orbjVec[j], orbiVec[i]); } - } - DoubleMatrix rotatedCoeff(csize, orbiVec.size()); - // HERE IT HAPPENS! - rotatedCoeff.noalias() = coeffBlock * Un; // Matrix mutiplication - - // 4d) store and make rotated node pointers - // for now we allocate in buffer, in future could be directly allocated in the final trees - double thres = prec * prec * scalefac_ref[n] * scalefac_ref[n]; - // make all norms: - for (int i = 0; i < orbiVec.size(); i++) { - // check if parent must be split - if (parindexVec_ref[n] == -1 or split_serial(orbiVec[i], ix2coef_ref[parindexVec_ref[n]])) { - // mark this node for this orbital for later split -#pragma omp critical - { - ix2coef[orbiVec[i]][node_ix] = coeffpVec[orbiVec[i]].size(); - coeffpVec[orbiVec[i]].push_back(&(rotatedCoeff(0, i))); // list of coefficient pointers - } - // check norms for split - double wnorm = 0.0; // rotatedCoeff(k, i) is already in cache here - int kstart = 0; - if (parindexVec_ref[n] < 0) kstart = sizecoeff - sizecoeffW; // do not include scaling, even for roots - for (int k = kstart; k < csize; k++) wnorm += rotatedCoeff(k, i) * rotatedCoeff(k, i); - if (thres < wnorm or prec < 0) - split_serial(orbiVec[i], n) = 1; - else - split_serial(orbiVec[i], n) = 0; - } else { - ix2coef[orbiVec[i]][node_ix] = max_n + 1; // should not be used - split_serial(orbiVec[i], n) = 0; // do not split if parent does not need to be split - } - } - nodeReady[n] = 1; -#pragma omp critical - { - // this ensures that rotatedCoeff is not deleted, when getting out of scope - rotatedCoeffVec.push_back(std::move(rotatedCoeff)); - } - } - } else { // MPI case - - // TODO? rotate in bank, so that we do not get and put. Requires clever handling of splits. - std::vector split(Meff, -1.0); // which orbitals need splitting (at a given node). For now double for compatibilty with bank - std::vector needsplit(Meff, 1.0); // which orbitals need splitting - BankAccount nodeSplits; - mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. - - DoubleMatrix coeffBlock(sizecoeff, Neff); - max_ix++; // largest node index + 1. to store rotated orbitals with different id - TaskManager tasks(max_n); - for (int nn = 0; nn < max_n; nn++) { - int n = tasks.next_task(); - if (n < 0) break; - double thres = prec * prec * scalefac_ref[n] * scalefac_ref[n]; - // 4a) make list of orbitals that should split the parent node, i.e. include this node - int parentid = parindexVec_ref[n]; - if (parentid == -1) { - // root node, split if output needed - for (int i = 0; i < M; i++) { - if (makeReal) - split[i] = 1.0; - else - split[i] = -1.0; - } - for (int i = N; i < Meff; i++) { - if (makeImag) - split[i] = 1.0; - else - split[i] = -1.0; - } - csize = sizecoeff; - } else { - // note that it will wait until data is available - nodeSplits.get_data(parentid, Meff, split.data()); - csize = sizecoeffW; - } - std::vector orbiVec; - std::vector orbjVec; - for (int i = 0; i < Meff; i++) { // loop over rotated orbitals - if (split[i] < 0.0) continue; // parent node has too small wavelets - orbiVec.push_back(i); - } - - // 4b) rotate this node - DoubleMatrix coeffBlock(csize, Neff); // largest possible used size - nodesPhi.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbjVec); - coeffBlock.conservativeResize(Eigen::NoChange, orbjVec.size()); // keep only used part - - // chunk of U, with reorganized indices and separate blocks for real and imag: - DoubleMatrix Un(orbjVec.size(), orbiVec.size()); - DoubleMatrix rotatedCoeff(csize, orbiVec.size()); - - for (int i = 0; i < orbiVec.size(); i++) { // loop over included rotated real and imag part of orbitals - for (int j = 0; j < orbjVec.size(); j++) { // loop over input orbital, possibly imaginary parts - Un(j, i) = Ureal(orbjVec[j], orbiVec[i]); - } - } - - // HERE IT HAPPENS - rotatedCoeff.noalias() = coeffBlock * Un; // Matrix mutiplication - - // 3c) find which orbitals need to further refine this node, and store rotated node (after each other while - // in cache). - for (int i = 0; i < orbiVec.size(); i++) { // loop over rotated orbitals - needsplit[orbiVec[i]] = -1.0; // default, do not split - // check if this node/orbital needs further refinement - double wnorm = 0.0; - int kwstart = csize - sizecoeffW; // do not include scaling - for (int k = kwstart; k < csize; k++) wnorm += rotatedCoeff.col(i)[k] * rotatedCoeff.col(i)[k]; - if (thres < wnorm or prec < 0) needsplit[orbiVec[i]] = 1.0; - nodesRotated.put_nodedata(orbiVec[i], indexVec_ref[n] + max_ix, csize, rotatedCoeff.col(i).data()); - } - nodeSplits.put_data(indexVec_ref[n], Meff, needsplit.data()); - } - mpi::barrier(mpi::comm_wrk); // wait until all rotated nodes are ready - } - - // 5) reconstruct trees using rotated nodes. - - // only serial case can use OMP, because MPI cannot be used by threads - if (serial) { - // OMP parallelized, but does not scale well, because the total memory bandwidth is a bottleneck. (the main - // operation is writing the coefficient into the tree) - -#pragma omp parallel for schedule(static) - for (int j = 0; j < Meff; j++) { - if (coeffpVec[j].size()==0) continue; - if (j < M) { - if (!Psi[j].hasReal()) Psi[j].alloc(NUMBER::Real); - Psi[j].real().clear(); - Psi[j].real().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], prec); - } else { - if (!Psi[j % M].hasImag()) Psi[j % M].alloc(NUMBER::Imag); - Psi[j % M].imag().clear(); - Psi[j % M].imag().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], prec); - } - } - - } else { // MPI case - - for (int j = 0; j < Meff; j++) { - if (not mpi::my_orb(j % M)) continue; - // traverse possible nodes, and stop descending when norm is zero (leaf in out[j]) - std::vector coeffpVec; // - std::map ix2coef; // to find the index in coeffVec[] corresponding to a serialIx - int ix = 0; - std::vector pointerstodelete; // list of temporary arrays to clean up - for (int ibank = 0; ibank < mpi::bank_size; ibank++) { - std::vector nodeidVec; - double *dataVec; // will be allocated by bank - nodesRotated.get_orbblock(j, dataVec, nodeidVec, ibank); - if (nodeidVec.size() > 0) pointerstodelete.push_back(dataVec); - int shift = 0; - for (int n = 0; n < nodeidVec.size(); n++) { - assert(nodeidVec[n] - max_ix >= 0); // unrotated nodes have been deleted - assert(ix2coef.count(nodeidVec[n] - max_ix) == 0); // each nodeid treated once - ix2coef[nodeidVec[n] - max_ix] = ix++; - csize = sizecoeffW; - if (parindexVec_ref[nodeidVec[n] - max_ix] < 0) csize = sizecoeff; - coeffpVec.push_back(&dataVec[shift]); // list of coeff pointers - shift += csize; - } - } - if (j < M) { - // Real part - if (!Psi[j].hasReal()) Psi[j].alloc(NUMBER::Real); - Psi[j].real().clear(); - Psi[j].real().makeTreefromCoeff(refTree, coeffpVec, ix2coef, prec); - } else { - // Imag part - if (!Psi[j % M].hasImag()) Psi[j % M].alloc(NUMBER::Imag); - Psi[j % M].imag().clear(); - Psi[j % M].imag().makeTreefromCoeff(refTree, coeffpVec, ix2coef, prec); - } - for (double *p : pointerstodelete) delete[] p; - pointerstodelete.clear(); - } - } -} - - -void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, double prec) { - rotate(Phi, U, Phi, prec); - return; -} - -/** @brief Save all nodes in bank; identify them using serialIx from refTree - * shift is a shift applied in the id - */ -void save_nodes(MPI_FuncVector &Phi, FunctionTree<3> &refTree, BankAccount &account, int sizes) { - int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); - int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - int max_nNodes = refTree.getNNodes(); - std::vector coeffVec; - std::vector scalefac; - std::vector indexVec; // SerialIx of the node in refOrb - std::vector parindexVec; // SerialIx of the parent node - int N = Phi.size(); - int max_ix; - for (int j = 0; j < N; j++) { - if (not mpi::my_orb(j)) continue; - // make vector with all coef address and their index in the union grid - if (Phi[j].hasReal()) { - Phi[j].real().makeCoeffVector(coeffVec, indexVec, parindexVec, scalefac, max_ix, refTree); - int max_n = indexVec.size(); - // send node coefs from Phi[j] to bank - // except for the root nodes, only wavelets are sent - for (int i = 0; i < max_n; i++) { - if (indexVec[i] < 0) continue; // nodes that are not in refOrb - int csize = sizecoeffW; - if (parindexVec[i] < 0) csize = sizecoeff; - if (sizes > 0) { // fixed size - account.put_nodedata(j, indexVec[i], sizes, coeffVec[i]); - } else { - account.put_nodedata(j, indexVec[i], csize, &(coeffVec[i][sizecoeff - csize])); - } - } - } - // Imaginary parts are considered as orbitals with an orbid shifted by N - if (Phi[j].hasImag()) { - Phi[j].imag().makeCoeffVector(coeffVec, indexVec, parindexVec, scalefac, max_ix, refTree); - int max_n = indexVec.size(); - // send node coefs from Phi[j] to bank - for (int i = 0; i < max_n; i++) { - if (indexVec[i] < 0) continue; // nodes that are not in refOrb - // NB: the identifier (indexVec[i]) must be shifted for not colliding with the nodes from the real part - int csize = sizecoeffW; - if (parindexVec[i] < 0) csize = sizecoeff; - if (sizes > 0) { // fixed size - account.put_nodedata(j + N, indexVec[i], sizes, coeffVec[i]); - } else { - account.put_nodedata(j + N, indexVec[i], csize, &(coeffVec[i][sizecoeff - csize])); - } - } - } - } -} - -/** @brief Multiply all orbitals with a function - * - * @param Phi: orbitals to multiply - * @param f : function to multiply - * - * Computes the product of each orbital with a function - * in parallel using a local representation. - * Input trees are extended by one scale at most. - */ -MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double prec, ComplexFunction *Func, int nrefine, bool all) { - - int N = Phi.size(); - const int D = 3; - bool serial = mpi::wrk_size == 1; // flag for serial/MPI switch - - // 1a) extend grid where f is large (around nuclei) - // TODO: do it in save_nodes + refTree, only saving the extra nodes, without keeping them permanently. Or refine refTree? - - for (int i = 0; i < N; i++) { - if (!mpi::my_orb(i)) continue; - int irefine = 0; - while (Phi[i].hasReal() and irefine < nrefine and refine_grid(Phi[i].real(), f) > 0) irefine++; - irefine = 0; - while (Phi[i].hasImag() and irefine < nrefine and refine_grid(Phi[i].imag(), f) > 0) irefine++; - } - - // 1b) make union tree without coefficients - FunctionTree refTree(*Phi.vecMRA); - // refine_grid(refTree, f); //to test - mpi::allreduce_Tree_noCoeff(refTree, Phi, mpi::comm_wrk); - - int kp1 = refTree.getKp1(); - int kp1_d = refTree.getKp1_d(); - int nCoefs = refTree.getTDim() * kp1_d; - - IntVector PsihasReIm = IntVector::Zero(2); - for (int i = 0; i < N; i++) { - if (!mpi::my_orb(i)) continue; - PsihasReIm[0] = (Phi[i].hasReal()) ? 1 : 0; - PsihasReIm[1] = (Phi[i].hasImag()) ? 1 : 0; - } - mpi::allreduce_vector(PsihasReIm, mpi::comm_wrk); - MPI_FuncVector out(N); - MPI_FuncVector outtest(N); - if (not PsihasReIm[0] and not PsihasReIm[1]) { - return out; // do nothing - } - - int Neff = N; - if (PsihasReIm[1]) Neff = 2 * N; // Imag and Real treated independently. We always treat real part of Psi - - std::vector scalefac_ref; - std::vector coeffVec_ref; // not used! - std::vector indexVec_ref; // serialIx of the nodes - std::vector parindexVec_ref; // serialIx of the parent nodes - std::vector *> refNodes; // pointers to nodes - int max_ix; - // get a list of all nodes in union tree, identified by their serialIx indices - refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac_ref, max_ix, refTree, &refNodes); - int max_n = indexVec_ref.size(); - std::map ix2n; // for a given serialIx, give index in vectors - for (int nn = 0; nn < max_n; nn++) ix2n[indexVec_ref[nn]] = nn; - - // 2a) send own nodes to bank, identifying them through the serialIx of refTree - BankAccount nodesPhi; // to put the original nodes - BankAccount nodesMultiplied; // to put the multiplied nodes - - // used for serial only: - std::vector> coeffVec(Neff); - std::vector> indexVec(Neff); // serialIx of the nodes - std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node - std::vector> orb2node(Neff); // for a given orbital and a given node, gives the node index in the - // orbital given the node index in the reference tree - if (serial) { - // make list of all coefficients (coeffVec), and their reference indices (indexVec) - std::vector parindexVec; // serialIx of the parent nodes - std::vector scalefac; - for (int j = 0; j < N; j++) { - // make vector with all coef pointers and their indices in the union grid - if (Phi[j].hasReal()) { - Phi[j].real().makeCoeffVector(coeffVec[j], indexVec[j], parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec[j]) { - orb2node[j][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j); - } - } - if (Phi[j].hasImag()) { - Phi[j].imag().makeCoeffVector(coeffVec[j + N], indexVec[j + N], parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec[j + N]) { - orb2node[j + N][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j + N); - } - } - } - } else { - mpifuncvec::save_nodes(Phi, refTree, nodesPhi, nCoefs); - mpi::barrier(mpi::comm_wrk); // required for now, as the blockdata functionality has no queue yet. - } - - // 2b) save Func in bank and remove its coefficients - if (Func != nullptr and !serial) { - // put Func in local representation if not already done - if (!Func->real().isLocal) { Func->real().saveNodesAndRmCoeff(); } - } - - // 3) mutiply for each node - std::vector> coeffpVec(Neff); // to put pointers to the multiplied coefficient for each orbital in serial case - std::vector multipliedCoeffVec; // just to ensure that the data from multipliedCoeff is not deleted, since we point to it. - std::vector> ix2coef(Neff); // to find the index in for example rotCoeffVec[] corresponding to a serialIx - DoubleVector NODEP = DoubleVector::Zero(nCoefs); - DoubleVector NODEF = DoubleVector::Zero(nCoefs); - - if (serial) { -#pragma omp parallel for schedule(dynamic) - for (int n = 0; n < max_n; n++) { - MWNode node(*(refNodes[n]), false); - int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree - - // 3a) make values for f at this node - // 3a1) get coordinates of quadrature points for this node - Eigen::MatrixXd pts; // Eigen::Zero(D, nCoefs); - double fval[nCoefs]; - Coord r; - double *originalCoef = nullptr; - MWNode<3> *Fnode = nullptr; - if (Func == nullptr) { - node.getExpandedChildPts(pts); // TODO: use getPrimitiveChildPts (less cache). - for (int j = 0; j < nCoefs; j++) { - for (int d = 0; d < D; d++) r[d] = pts(d, j); //*scaling_factor[d]? - fval[j] = f.evalf(r); - } - } else { - Fnode = Func->real().findNode(node.getNodeIndex()); - if (Fnode == nullptr) { - node.getExpandedChildPts(pts); // TODO: use getPrimitiveChildPts (less cache). - for (int j = 0; j < nCoefs; j++) { - for (int d = 0; d < D; d++) r[d] = pts(d, j); //*scaling_factor[d]? - fval[j] = f.evalf(r); - } - } else { - originalCoef = Fnode->getCoefs(); - for (int j = 0; j < nCoefs; j++) fval[j] = originalCoef[j]; - Fnode->attachCoefs(fval); // note that each thread has its own copy - Fnode->mwTransform(Reconstruction); - Fnode->cvTransform(Forward); - } - } - DoubleMatrix multipliedCoeff(nCoefs, node2orbVec[node_ix].size()); - int i = 0; - // 3b) fetch all orbitals at this node - std::vector orbjVec; // to remember which orbital correspond to each orbVec.size(); - for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node - int orb_node_ix = orb2node[j][node_ix]; - orbjVec.push_back(j); - for (int k = 0; k < nCoefs; k++) multipliedCoeff(k, i) = coeffVec[j][orb_node_ix][k]; - // 3c) transform to grid - node.attachCoefs(&(multipliedCoeff(0, i))); - node.mwTransform(Reconstruction); - node.cvTransform(Forward); - // 3d) multiply - for (int k = 0; k < nCoefs; k++) multipliedCoeff(k, i) *= fval[k]; // replace by Matrix vector multiplication? - // 3e) transform back to mw - node.cvTransform(Backward); - node.mwTransform(Compression); - i++; - } - if (Func != nullptr and originalCoef != nullptr) { - // restablish original values - Fnode->attachCoefs(originalCoef); - } - - // 3f) save multiplied nodes - for (int i = 0; i < orbjVec.size(); i++) { -#pragma omp critical - { - ix2coef[orbjVec[i]][node_ix] = coeffpVec[orbjVec[i]].size(); - coeffpVec[orbjVec[i]].push_back(&(multipliedCoeff(0, i))); // list of coefficient pointers - } - } -#pragma omp critical - { - // this ensures that multipliedCoeff is not deleted, when getting out of scope - multipliedCoeffVec.push_back(std::move(multipliedCoeff)); - } - node.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor - } - } else { - // MPI - int count1 = 0; - int count2 = 0; - TaskManager tasks(max_n); - for (int nn = 0; nn < max_n; nn++) { - int n = tasks.next_task(); - if (n < 0) break; - MWNode node(*(refNodes[n]), false); - // 3a) make values for f - // 3a1) get coordinates of quadrature points for this node - Eigen::MatrixXd pts; // Eigen::Zero(D, nCoefs); - node.getExpandedChildPts(pts); // TODO: use getPrimitiveChildPts (less cache). - double fval[nCoefs]; - Coord r; - MWNode Fnode(*(refNodes[n]), false); - if (Func == nullptr) { - for (int j = 0; j < nCoefs; j++) { - for (int d = 0; d < D; d++) r[d] = pts(d, j); //*scaling_factor[d]? - fval[j] = f.evalf(r); - } - } else { - int nIdx = Func->real().getIx(node.getNodeIndex()); - count1++; - if (nIdx < 0) { - // use the function f instead of Func - count2++; - for (int j = 0; j < nCoefs; j++) { - for (int d = 0; d < D; d++) r[d] = pts(d, j); - fval[j] = f.evalf(r); - } - } else { - Func->real().getNodeCoeff(nIdx, fval); // fetch coef from Bank - Fnode.attachCoefs(fval); - Fnode.mwTransform(Reconstruction); - Fnode.cvTransform(Forward); - } - } - - // 3b) fetch all orbitals at this node - DoubleMatrix coeffBlock(nCoefs, Neff); // largest possible used size - std::vector orbjVec; - nodesPhi.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbjVec); - coeffBlock.conservativeResize(Eigen::NoChange, orbjVec.size()); // keep only used part - DoubleMatrix MultipliedCoeff(nCoefs, orbjVec.size()); - // 3c) transform to grid - for (int j = 0; j < orbjVec.size(); j++) { // TODO: transform all j at once ? - // TODO: select only nodes that are end nodes? - node.attachCoefs(coeffBlock.col(j).data()); - node.mwTransform(Reconstruction); - node.cvTransform(Forward); - // 3d) multiply - double *coefs = node.getCoefs(); - for (int i = 0; i < nCoefs; i++) coefs[i] *= fval[i]; - // 3e) transform back to mw - node.cvTransform(Backward); - node.mwTransform(Compression); - // 3f) save multiplied nodes - nodesMultiplied.put_nodedata(orbjVec[j], indexVec_ref[n] + max_ix, nCoefs, coefs); - } - node.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor - Fnode.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor - } - mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! - } - - // 5) reconstruct trees using multiplied nodes. - - // only serial case can use OMP, because MPI cannot be used by threads - if (serial) { - // OMP parallelized, but does not scale well, because the total memory bandwidth is a bottleneck. (the main - // operation is writing the coefficient into the tree) - -#pragma omp parallel for schedule(static) - for (int j = 0; j < Neff; j++) { - if (j < N) { - if (Phi[j].hasReal()) { - out[j].alloc(NUMBER::Real); - out[j].real().clear(); - out[j].real().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], -1.0, "copy"); - // 6) reconstruct trees from end nodes - out[j].real().mwTransform(BottomUp); - out[j].real().calcSquareNorm(); - } - } else { - if (Phi[j % N].hasImag()) { - out[j % N].alloc(NUMBER::Imag); - out[j % N].imag().clear(); - out[j % N].imag().makeTreefromCoeff(refTree, coeffpVec[j], ix2coef[j], -1.0, "copy"); - out[j].imag().mwTransform(BottomUp); - out[j].imag().calcSquareNorm(); - } - } - } - } else { - for (int j = 0; j < Neff; j++) { - if (not mpi::my_orb(j % N) and not all) continue; - // traverse possible nodes, and stop descending when norm is zero (leaf in out[j]) - std::vector coeffpVec; // - std::map ix2coef; // to find the index in coeffVec[] corresponding to a serialIx in refTree - int ix = 0; - std::vector pointerstodelete; // list of temporary arrays to clean up - - for (int ibank = 0; ibank < mpi::bank_size; ibank++) { - std::vector nodeidVec; - double *dataVec; // will be allocated by bank - nodesMultiplied.get_orbblock(j, dataVec, nodeidVec, ibank); - if (nodeidVec.size() > 0) pointerstodelete.push_back(dataVec); - int shift = 0; - for (int n = 0; n < nodeidVec.size(); n++) { - assert(nodeidVec[n] - max_ix >= 0); // unmultiplied nodes have been deleted - assert(ix2coef.count(nodeidVec[n] - max_ix) == 0); // each nodeid treated once - ix2coef[nodeidVec[n] - max_ix] = ix++; - coeffpVec.push_back(&dataVec[shift]); // list of coeff pointers - shift += nCoefs; - } - } - if (j < N) { - if (Phi[j].hasReal()) { - out[j].alloc(NUMBER::Real); - out[j].real().clear(); - out[j].real().makeTreefromCoeff(refTree, coeffpVec, ix2coef, -1.0, "copy"); - // 6) reconstruct trees from end nodes - out[j].real().mwTransform(BottomUp); - out[j].real().calcSquareNorm(); - out[j].real().resetEndNodeTable(); - // out[j].real().crop(prec, 1.0, false); //bad convergence if out is cropped - if (nrefine > 0) Phi[j].real().crop(prec, 1.0, false); // restablishes original Phi - } - } else { - if (Phi[j % N].hasImag()) { - out[j % N].alloc(NUMBER::Imag); - out[j % N].imag().clear(); - out[j % N].imag().makeTreefromCoeff(refTree, coeffpVec, ix2coef, -1.0, "copy"); - out[j % N].imag().mwTransform(BottomUp); - out[j % N].imag().calcSquareNorm(); - // out[j % N].imag().crop(prec, 1.0, false); - if (nrefine > 0) Phi[j % N].imag().crop(prec, 1.0, false); - } - } - - for (double *p : pointerstodelete) delete[] p; - pointerstodelete.clear(); - } - } - return out; -} - -ComplexVector dot(MPI_FuncVector &Bra, MPI_FuncVector &Ket) { - int N = Bra.size(); - ComplexVector result = ComplexVector::Zero(N); - for (int i = 0; i < N; i++) { - // The bra is sent to the owner of the ket - if (my_orb(Bra[i]) != my_orb(Ket[i])) { MSG_ABORT("same indices should have same ownership"); } - result[i] = cplxfunc::dot(Bra[i], Ket[i]); - if (not mrcpp::mpi::my_orb(i)) Bra[i].free(NUMBER::Total); - } - mrcpp::mpi::allreduce_vector(result, mrcpp::mpi::comm_wrk); - return result; -} - -/** @brief Compute Löwdin orthonormalization matrix - * - * @param Phi: orbitals to orthonomalize - * - * Computes the inverse square root of the orbital overlap matrix S^(-1/2) - */ -ComplexMatrix calc_lowdin_matrix(MPI_FuncVector &Phi) { - ComplexMatrix S_tilde = mpifuncvec::calc_overlap_matrix(Phi); - ComplexMatrix S_m12 = math_utils::hermitian_matrix_pow(S_tilde, -1.0 / 2.0); - return S_m12; -} - -/** @brief Orbital transformation out_j = sum_i inp_i*U_ij - * - * NOTE: OrbitalVector is considered a ROW vector, so rotation - * means matrix multiplication from the right - * - * MPI: Rank distribution of output vector is the same as input vector - * - */ -ComplexMatrix calc_overlap_matrix(MPI_FuncVector &BraKet) { - // NB: must be spinseparated at this point! - - int N = BraKet.size(); - ComplexMatrix S = ComplexMatrix::Zero(N, N); - DoubleMatrix Sreal = DoubleMatrix::Zero(2 * N, 2 * N); // same as S, but stored as 4 blocks, rr,ri,ir,ii - MultiResolutionAnalysis<3> *mra = BraKet.vecMRA; - - // 1) make union tree without coefficients - mrcpp::FunctionTree<3> refTree(*mra); - mpi::allreduce_Tree_noCoeff(refTree, BraKet, mpi::comm_wrk); - - int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); - int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - - // get a list of all nodes in union grid, as defined by their indices - std::vector scalefac; - std::vector coeffVec_ref; - std::vector indexVec_ref; // serialIx of the nodes - std::vector parindexVec_ref; // serialIx of the parent nodes - int max_ix; // largest index value (not used here) - - refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); - int max_n = indexVec_ref.size(); - - // only used for serial case: - std::vector> coeffVec(2 * N); - std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node - std::vector> orb2node(2 * N); // for a given orbital and a given node, gives the node index in - // the orbital given the node index in the reference tree - - bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch - mrcpp::BankAccount nodesBraKet; - - // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank - if (serial) { - // 2) make list of all coefficients, and their reference indices - // for different orbitals, indexVec will give the same index for the same node in space - std::vector parindexVec; // serialIx of the parent nodes - std::vector indexVec; // serialIx of the nodes - for (int j = 0; j < N; j++) { - // make vector with all coef pointers and their indices in the union grid - if (BraKet[j].hasReal()) { - BraKet[j].real().makeCoeffVector(coeffVec[j], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2node[j][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j); - } - } - if (BraKet[j].hasImag()) { - BraKet[j].imag().makeCoeffVector(coeffVec[j + N], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2node[j + N][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j + N); - } - } - } - } else { // MPI case - // 2) send own nodes to bank, identifying them through the serialIx of refTree - save_nodes(BraKet, refTree, nodesBraKet); - mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! - } - - // 3) make dot product for all the nodes and accumulate into S - - int ibank = 0; -#pragma omp parallel for schedule(dynamic) if (serial) - for (int n = 0; n < max_n; n++) { - if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; - int csize; - int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree - std::vector orbVec; // identifies which orbitals use this node - if (serial and node2orbVec[node_ix].size() <= 0) continue; - if (parindexVec_ref[n] < 0) - csize = sizecoeff; - else - csize = sizecoeffW; - - // In the serial case we copy the coeff coeffBlock. In the mpi case coeffBlock is provided by the bank - if (serial) { - int shift = sizecoeff - sizecoeffW; // to copy only wavelet part - if (parindexVec_ref[n] < 0) shift = 0; - DoubleMatrix coeffBlock(csize, node2orbVec[node_ix].size()); - for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node - int orb_node_ix = orb2node[j][node_ix]; - for (int k = 0; k < csize; k++) coeffBlock(k, orbVec.size()) = coeffVec[j][orb_node_ix][k + shift]; - orbVec.push_back(j); - } - if (orbVec.size() > 0) { - DoubleMatrix S_temp(orbVec.size(), orbVec.size()); - S_temp.noalias() = coeffBlock.transpose() * coeffBlock; - for (int i = 0; i < orbVec.size(); i++) { - for (int j = 0; j < orbVec.size(); j++) { - if (BraKet[orbVec[i] % N].spin() == SPIN::Alpha and BraKet[orbVec[j] % N].spin() == SPIN::Beta) - continue; - if (BraKet[orbVec[i] % N].spin() == SPIN::Beta and BraKet[orbVec[j] % N].spin() == SPIN::Alpha) - continue; - double &Srealij = Sreal(orbVec[i], orbVec[j]); - double &Stempij = S_temp(i, j); -#pragma omp atomic - Srealij += Stempij; - } - } - } - } else { // MPI case - DoubleMatrix coeffBlock(csize, 2 * N); - nodesBraKet.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbVec); - - if (orbVec.size() > 0) { - DoubleMatrix S_temp(orbVec.size(), orbVec.size()); - coeffBlock.conservativeResize(Eigen::NoChange, orbVec.size()); - S_temp.noalias() = coeffBlock.transpose() * coeffBlock; - for (int i = 0; i < orbVec.size(); i++) { - for (int j = 0; j < orbVec.size(); j++) { - if (BraKet[orbVec[i] % N].spin() == SPIN::Alpha and BraKet[orbVec[j] % N].spin() == SPIN::Beta) - continue; - if (BraKet[orbVec[i] % N].spin() == SPIN::Beta and BraKet[orbVec[j] % N].spin() == SPIN::Alpha) - continue; - Sreal(orbVec[i], orbVec[j]) += S_temp(i, j); - } - } - } - } - } - IntVector conjMat = IntVector::Zero(N); - for (int i = 0; i < N; i++) { - if (!mrcpp::mpi::my_orb(BraKet[i])) continue; - conjMat[i] = (BraKet[i].conjugate()) ? -1 : 1; - } - mrcpp::mpi::allreduce_vector(conjMat, mrcpp::mpi::comm_wrk); - - for (int i = 0; i < N; i++) { - for (int j = 0; j <= i; j++) { - S.real()(i, j) = Sreal(i, j) + conjMat[i] * conjMat[j] * Sreal(i + N, j + N); - S.imag()(i, j) = conjMat[j] * Sreal(i, j + N) - conjMat[i] * Sreal(i + N, j); - if (i != j) S(j, i) = std::conj(S(i, j)); // ensure exact symmetri - } - } - - // Assumes linearity: result is sum of all nodes contributions - mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); - - return S; -} - -/** @brief Compute the overlap matrix S_ij = - * - */ -ComplexMatrix calc_overlap_matrix(MPI_FuncVector &Bra, MPI_FuncVector &Ket) { - mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // for consistent timings - - MultiResolutionAnalysis<3> *mra = Bra.vecMRA; - - int N = Bra.size(); - int M = Ket.size(); - ComplexMatrix S = ComplexMatrix::Zero(N, M); - DoubleMatrix Sreal = DoubleMatrix::Zero(2 * N, 2 * M); // same as S, but stored as 4 blocks, rr,ri,ir,ii - - // 1) make union tree without coefficients for Bra (supposed smallest) - mrcpp::FunctionTree<3> refTree(*mra); - mrcpp::mpi::allreduce_Tree_noCoeff(refTree, Bra, mpi::comm_wrk); - // note that Ket is not part of union grid: if a node is in ket but not in Bra, the dot product is zero. - - int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); - int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - - // get a list of all nodes in union grid, as defined by their indices - std::vector coeffVec_ref; - std::vector indexVec_ref; // serialIx of the nodes - std::vector parindexVec_ref; // serialIx of the parent nodes - std::vector scalefac; - int max_ix; - - refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); - int max_n = indexVec_ref.size(); - max_ix++; - - bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch - - // only used for serial case: - std::vector> coeffVecBra(2 * N); - std::map> node2orbVecBra; // for each node index, gives a vector with the indices of the orbitals using this node - std::vector> orb2nodeBra(2 * N); // for a given orbital and a given node, gives the node index in - // the orbital given the node index in the reference tree - std::vector> coeffVecKet(2 * M); - std::map> node2orbVecKet; // for each node index, gives a vector with the indices of the orbitals using this node - std::vector> orb2nodeKet(2 * M); // for a given orbital and a given node, gives the node index in - // the orbital given the node index in the reference tree - mrcpp::BankAccount nodesBra; - mrcpp::BankAccount nodesKet; - - // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank - if (serial) { - // 2) make list of all coefficients, and their reference indices - // for different orbitals, indexVec will give the same index for the same node in space - // TODO? : do not copy coefficients, but use directly the pointers - // could OMP parallelize, but is fast anyway - std::vector parindexVec; // serialIx of the parent nodes - std::vector indexVec; // serialIx of the nodes - for (int j = 0; j < N; j++) { - // make vector with all coef pointers and their indices in the union grid - if (Bra[j].hasReal()) { - Bra[j].real().makeCoeffVector(coeffVecBra[j], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2nodeBra[j][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVecBra[ix].push_back(j); - } - } - if (Bra[j].hasImag()) { - Bra[j].imag().makeCoeffVector(coeffVecBra[j + N], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2nodeBra[j + N][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVecBra[ix].push_back(j + N); - } - } - } - for (int j = 0; j < M; j++) { - if (Ket[j].hasReal()) { - Ket[j].real().makeCoeffVector(coeffVecKet[j], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2nodeKet[j][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVecKet[ix].push_back(j); - } - } - if (Ket[j].hasImag()) { - Ket[j].imag().makeCoeffVector(coeffVecKet[j + M], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2nodeKet[j + M][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVecKet[ix].push_back(j + M); - } - } - } - - } else { // MPI case - // 2) send own nodes to bank, identifying them through the serialIx of refTree - save_nodes(Bra, refTree, nodesBra); - save_nodes(Ket, refTree, nodesKet); - mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! - } - - // 3) make dot product for all the nodes and accumulate into S - int totsiz = 0; - int totget = 0; - int mxtotsiz = 0; - int ibank = 0; - //For some unknown reason the h2_mag_lda test sometimes fails when schedule(dynamic) is chosen -#pragma omp parallel for schedule(static) if (serial) - for (int n = 0; n < max_n; n++) { - if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; - int csize; - std::vector orbVecBra; // identifies which Bra orbitals use this node - std::vector orbVecKet; // identifies which Ket orbitals use this node - if (parindexVec_ref[n] < 0) - csize = sizecoeff; - else - csize = sizecoeffW; - if (serial) { - int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree - int shift = sizecoeff - sizecoeffW; // to copy only wavelet part - DoubleMatrix coeffBlockBra(csize, node2orbVecBra[node_ix].size()); - DoubleMatrix coeffBlockKet(csize, node2orbVecKet[node_ix].size()); - if (parindexVec_ref[n] < 0) shift = 0; - - for (int j : node2orbVecBra[node_ix]) { // loop over indices of the orbitals using this node - int orb_node_ix = orb2nodeBra[j][node_ix]; - for (int k = 0; k < csize; k++) coeffBlockBra(k, orbVecBra.size()) = coeffVecBra[j][orb_node_ix][k + shift]; - orbVecBra.push_back(j); - } - for (int j : node2orbVecKet[node_ix]) { // loop over indices of the orbitals using this node - int orb_node_ix = orb2nodeKet[j][node_ix]; - for (int k = 0; k < csize; k++) coeffBlockKet(k, orbVecKet.size()) = coeffVecKet[j][orb_node_ix][k + shift]; - orbVecKet.push_back(j); - } - - if (orbVecBra.size() > 0 and orbVecKet.size() > 0) { - DoubleMatrix S_temp(orbVecBra.size(), orbVecKet.size()); - S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet; - for (int i = 0; i < orbVecBra.size(); i++) { - for (int j = 0; j < orbVecKet.size(); j++) { - if (Bra[orbVecBra[i] % N].spin() == SPIN::Alpha and Ket[orbVecKet[j] % M].spin() == SPIN::Beta) - continue; - if (Bra[orbVecBra[i] % N].spin() == SPIN::Beta and Ket[orbVecKet[j] % M].spin() == SPIN::Alpha) - continue; - // must ensure that threads are not competing - double &Srealij = Sreal(orbVecBra[i], orbVecKet[j]); - double &Stempij = S_temp(i, j); -#pragma omp atomic - Srealij += Stempij; - } - } - } - } else { - - DoubleMatrix coeffBlockBra(csize, 2 * N); - DoubleMatrix coeffBlockKet(csize, 2 * M); - nodesBra.get_nodeblock(indexVec_ref[n], coeffBlockBra.data(), orbVecBra); // get Bra parts - nodesKet.get_nodeblock(indexVec_ref[n], coeffBlockKet.data(), orbVecKet); // get Ket parts - totsiz += orbVecBra.size() * orbVecKet.size(); - mxtotsiz += N * M; - totget += orbVecBra.size() + orbVecKet.size(); - if (orbVecBra.size() > 0 and orbVecKet.size() > 0) { - DoubleMatrix S_temp(orbVecBra.size(), orbVecKet.size()); - coeffBlockBra.conservativeResize(Eigen::NoChange, orbVecBra.size()); - coeffBlockKet.conservativeResize(Eigen::NoChange, orbVecKet.size()); - S_temp.noalias() = coeffBlockBra.transpose() * coeffBlockKet; - for (int i = 0; i < orbVecBra.size(); i++) { - for (int j = 0; j < orbVecKet.size(); j++) { - if (Bra[orbVecBra[i] % N].spin() == SPIN::Alpha and Ket[orbVecKet[j] % M].spin() == SPIN::Beta) - continue; - if (Bra[orbVecBra[i] % N].spin() == SPIN::Beta and Ket[orbVecKet[j] % M].spin() == SPIN::Alpha) - continue; - Sreal(orbVecBra[i], orbVecKet[j]) += S_temp(i, j); - } - } - } - } - } - - IntVector conjMatBra = IntVector::Zero(N); - for (int i = 0; i < N; i++) { - if (!mrcpp::mpi::my_orb(Bra[i])) continue; - conjMatBra[i] = (Bra[i].conjugate()) ? -1 : 1; - } - mrcpp::mpi::allreduce_vector(conjMatBra, mrcpp::mpi::comm_wrk); - IntVector conjMatKet = IntVector::Zero(M); - for (int i = 0; i < M; i++) { - if (!mrcpp::mpi::my_orb(Ket[i])) continue; - conjMatKet[i] = (Ket[i].conjugate()) ? -1 : 1; - } - mrcpp::mpi::allreduce_vector(conjMatKet, mrcpp::mpi::comm_wrk); - - for (int i = 0; i < N; i++) { - for (int j = 0; j < M; j++) { - S.real()(i, j) = Sreal(i, j) + conjMatBra[i] * conjMatKet[j] * Sreal(i + N, j + M); - S.imag()(i, j) = conjMatKet[j] * Sreal(i, j + M) - conjMatBra[i] * Sreal(i + N, j); - } - } - - // 4) collect results from all MPI. Linearity: result is sum of all node contributions - - mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); - - return S; -} - -/** @brief Compute the overlap matrix of the absolute value of the functions S_ij = <|bra_i|||ket_j|> - * - */ -DoubleMatrix calc_norm_overlap_matrix(MPI_FuncVector &BraKet) { - int N = BraKet.size(); - DoubleMatrix S = DoubleMatrix::Zero(N, N); - DoubleMatrix Sreal = DoubleMatrix::Zero(2 * N, 2 * N); // same as S, but stored as 4 blocks, rr,ri,ir,ii - MultiResolutionAnalysis<3> *mra = BraKet.vecMRA; - - // 1) make union tree without coefficients - mrcpp::FunctionTree<3> refTree(*mra); - mrcpp::mpi::allreduce_Tree_noCoeff(refTree, BraKet, mpi::comm_wrk); - - int sizecoeff = (1 << refTree.getDim()) * refTree.getKp1_d(); - int sizecoeffW = ((1 << refTree.getDim()) - 1) * refTree.getKp1_d(); - - // get a list of all nodes in union grid, as defined by their indices - std::vector scalefac; - std::vector coeffVec_ref; - std::vector indexVec_ref; // serialIx of the nodes - std::vector parindexVec_ref; // serialIx of the parent nodes - int max_ix; // largest index value (not used here) - - refTree.makeCoeffVector(coeffVec_ref, indexVec_ref, parindexVec_ref, scalefac, max_ix, refTree); - int max_n = indexVec_ref.size(); - - // only used for serial case: - std::vector> coeffVec(2 * N); - std::map> node2orbVec; // for each node index, gives a vector with the indices of the orbitals using this node - std::vector> orb2node(2 * N); // for a given orbital and a given node, gives the node index in - // the orbital given the node index in the reference tree - - bool serial = mrcpp::mpi::wrk_size == 1; // flag for serial/MPI switch - mrcpp::BankAccount nodesBraKet; - - // In the serial case we store the coeff pointers in coeffVec. In the mpi case the coeff are stored in the bank - if (serial) { - // 2) make list of all coefficients, and their reference indices - // for different orbitals, indexVec will give the same index for the same node in space - std::vector parindexVec; // serialIx of the parent nodes - std::vector indexVec; // serialIx of the nodes - for (int j = 0; j < N; j++) { - // make vector with all coef pointers and their indices in the union grid - if (BraKet[j].hasReal()) { - BraKet[j].real().makeCoeffVector(coeffVec[j], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2node[j][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j); - } - } - if (BraKet[j].hasImag()) { - BraKet[j].imag().makeCoeffVector(coeffVec[j + N], indexVec, parindexVec, scalefac, max_ix, refTree); - // make a map that gives j from indexVec - int orb_node_ix = 0; - for (int ix : indexVec) { - orb2node[j + N][ix] = orb_node_ix++; - if (ix < 0) continue; - node2orbVec[ix].push_back(j + N); - } - } - } - } else { // MPI case - // 2) send own nodes to bank, identifying them through the serialIx of refTree - save_nodes(BraKet, refTree, nodesBraKet); - mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching! - } - - // 3) make dot product for all the nodes and accumulate into S - - int ibank = 0; -#pragma omp parallel for schedule(dynamic) if (serial) - for (int n = 0; n < max_n; n++) { - if (n % mrcpp::mpi::wrk_size != mrcpp::mpi::wrk_rank) continue; - int csize; - int node_ix = indexVec_ref[n]; // SerialIx for this node in the reference tree - std::vector orbVec; // identifies which orbitals use this node - if (serial and node2orbVec[node_ix].size() <= 0) continue; - if (parindexVec_ref[n] < 0) - csize = sizecoeff; - else - csize = sizecoeffW; - // In the serial case we copy the coeff coeffBlock. In the mpi case coeffBlock is provided by the bank - if (serial) { - int shift = sizecoeff - sizecoeffW; // to copy only wavelet part - if (parindexVec_ref[n] < 0) shift = 0; - DoubleMatrix coeffBlock(csize, node2orbVec[node_ix].size()); - for (int j : node2orbVec[node_ix]) { // loop over indices of the orbitals using this node - int orb_node_ix = orb2node[j][node_ix]; - for (int k = 0; k < csize; k++) coeffBlock(k, orbVec.size()) = coeffVec[j][orb_node_ix][k + shift]; - orbVec.push_back(j); - } - if (orbVec.size() > 0) { - DoubleMatrix S_temp(orbVec.size(), orbVec.size()); - coeffBlock = coeffBlock.cwiseAbs(); - S_temp.noalias() = coeffBlock.transpose() * coeffBlock; - for (int i = 0; i < orbVec.size(); i++) { - for (int j = 0; j < orbVec.size(); j++) { - if (BraKet[orbVec[i] % N].spin() == SPIN::Alpha and BraKet[orbVec[j] % N].spin() == SPIN::Beta) - continue; - if (BraKet[orbVec[i] % N].spin() == SPIN::Beta and BraKet[orbVec[j] % N].spin() == SPIN::Alpha) - continue; - double &Srealij = Sreal(orbVec[i], orbVec[j]); - double &Stempij = S_temp(i, j); -#pragma omp atomic - Srealij += Stempij; - } - } - } - } else { // MPI case - DoubleMatrix coeffBlock(csize, 2 * N); - nodesBraKet.get_nodeblock(indexVec_ref[n], coeffBlock.data(), orbVec); - - if (orbVec.size() > 0) { - DoubleMatrix S_temp(orbVec.size(), orbVec.size()); - coeffBlock.conservativeResize(Eigen::NoChange, orbVec.size()); - coeffBlock = coeffBlock.cwiseAbs(); - S_temp.noalias() = coeffBlock.transpose() * coeffBlock; - for (int i = 0; i < orbVec.size(); i++) { - for (int j = 0; j < orbVec.size(); j++) { - if (BraKet[orbVec[i] % N].spin() == SPIN::Alpha and BraKet[orbVec[j] % N].spin() == SPIN::Beta) - continue; - if (BraKet[orbVec[i] % N].spin() == SPIN::Beta and BraKet[orbVec[j] % N].spin() == SPIN::Alpha) - continue; - Sreal(orbVec[i], orbVec[j]) += S_temp(i, j); - } - } - } - } - } - - IntVector conjMat = IntVector::Zero(N); - for (int i = 0; i < N; i++) { - if (!mrcpp::mpi::my_orb(i)) continue; - conjMat[i] = (BraKet[i].conjugate()) ? -1 : 1; - } - mrcpp::mpi::allreduce_vector(conjMat, mrcpp::mpi::comm_wrk); - - for (int i = 0; i < N; i++) { - for (int j = 0; j <= i; j++) { - S(i, j) = Sreal(i, j) + conjMat[i] * conjMat[j] * Sreal(i + N, j + N) + conjMat[j] * Sreal(i, j + N) - conjMat[i] * Sreal(i + N, j); - S(j, i) = S(i, j); - } - } - - // Assumes linearity: result is sum of all nodes contributions - mrcpp::mpi::allreduce_matrix(S, mrcpp::mpi::comm_wrk); - return S; -} - -/** @brief Orthogonalize the functions in Bra against all orbitals in Ket - * - */ -void orthogonalize(double prec, MPI_FuncVector &Bra, MPI_FuncVector &Ket) { - // TODO: generalize for cases where Ket functions are not orthogonal to each other? - ComplexMatrix S = mpifuncvec::calc_overlap_matrix(Bra, Ket); - int N = Bra.size(); - int M = Ket.size(); - DoubleVector Ketnorms = DoubleVector::Zero(M); - for (int i = 0; i < M; i++) { - if (mpi::my_orb(Ket[i])) Ketnorms(i) = Ket[i].squaredNorm(); - } - mrcpp::mpi::allreduce_vector(Ketnorms, mrcpp::mpi::comm_wrk); - ComplexMatrix rmat = ComplexMatrix::Zero(M, N); - for (int j = 0; j < N; j++) { - for (int i = 0; i < M; i++) { - rmat(i,j) = 0.0 - S.conjugate()(j,i)/Ketnorms(i); - } - } - MPI_FuncVector rotatedKet(N); - mpifuncvec::rotate(Ket, rmat, rotatedKet, prec / M); - for (int j = 0; j < N; j++) { - if(my_orb(Bra[j]))Bra[j].add(1.0,rotatedKet[j]); - } -} -} // namespace mpifuncvec -} // namespace mrcpp diff --git a/src/utils/ComplexFunction.h b/src/utils/ComplexFunction.h deleted file mode 100644 index 699bbfcfb..000000000 --- a/src/utils/ComplexFunction.h +++ /dev/null @@ -1,199 +0,0 @@ -#pragma once - -#include "functions/RepresentableFunction.h" -#include "math_utils.h" -#include "mpi_utils.h" -#include "trees/FunctionTree.h" -#include "trees/MultiResolutionAnalysis.h" -#include - -using namespace Eigen; - -using IntVector = Eigen::VectorXi; -using DoubleVector = Eigen::VectorXd; -using ComplexVector = Eigen::VectorXcd; - -using IntMatrix = Eigen::MatrixXi; -using DoubleMatrix = Eigen::MatrixXd; -using ComplexMatrix = Eigen::MatrixXcd; - -class MPI_FuncVector; - -namespace mrcpp { - -class BankAccount; -template class FunctionTree; -template class MultiResolutionAnalysis; - -using ComplexDouble = std::complex; -namespace NUMBER { -enum type { Total, Real, Imag }; -} -namespace SPIN { -enum type { Paired, Alpha, Beta }; -} - -struct FunctionData { - int type{0}; - int order{1}; - int scale{0}; - int depth{0}; - int boxes[3] = {0, 0, 0}; - int corner[3] = {0, 0, 0}; - int real_size{0}; - int imag_size{0}; - bool is_shared{false}; - int spin{0}; - double occ{0}; -}; - -class TreePtr final { -public: - explicit TreePtr(bool share) - : shared_mem_re(nullptr) - , shared_mem_im(nullptr) - , re(nullptr) - , im(nullptr) { - this->func_data.is_shared = share; - if (this->func_data.is_shared and mpi::share_size > 1) { - // Memory size in MB defined in input. Virtual memory, does not cost anything if not used. -#ifdef MRCPP_HAS_MPI - this->shared_mem_re = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); - this->shared_mem_im = new mrcpp::SharedMemory(mpi::comm_share, mpi::shared_memory_size); -#endif - } - } - - ~TreePtr() { - if (this->shared_mem_re != nullptr) delete this->shared_mem_re; - if (this->shared_mem_im != nullptr) delete this->shared_mem_im; - if (this->re != nullptr) delete this->re; - if (this->im != nullptr) delete this->im; - } - - friend class ComplexFunction; - -private: - FunctionData func_data; - mrcpp::SharedMemory *shared_mem_re; - mrcpp::SharedMemory *shared_mem_im; - mrcpp::FunctionTree<3> *re; ///< Real part of function - mrcpp::FunctionTree<3> *im; ///< Imaginary part of function - - void flushFuncData() { - this->func_data.real_size = 0; - this->func_data.imag_size = 0; - if (this->re != nullptr) { - this->func_data.real_size = this->re->getNChunksUsed(); - flushMRAData(this->re->getMRA()); - } - if (this->im != nullptr) { - this->func_data.imag_size = this->im->getNChunksUsed(); - flushMRAData(this->im->getMRA()); - } - } - - void flushMRAData(const mrcpp::MultiResolutionAnalysis<3> &mra) { - const auto &box = mra.getWorldBox(); - this->func_data.type = mra.getScalingBasis().getScalingType(); - this->func_data.order = mra.getOrder(); - this->func_data.depth = mra.getMaxDepth(); - this->func_data.scale = box.getScale(); - this->func_data.boxes[0] = box.size(0); - this->func_data.boxes[1] = box.size(1); - this->func_data.boxes[2] = box.size(2); - this->func_data.corner[0] = box.getCornerIndex().getTranslation(0); - this->func_data.corner[1] = box.getCornerIndex().getTranslation(1); - this->func_data.corner[2] = box.getCornerIndex().getTranslation(2); - } -}; - -class ComplexFunction { -public: - ComplexFunction(std::shared_ptr funcptr); - ComplexFunction(const ComplexFunction &func); - ComplexFunction(int spin = 0, double occ = -1, int rank = -1, bool share = false); - ComplexFunction &operator=(const ComplexFunction &func); - ComplexFunction paramCopy() const; - bool isShared() const { return this->func_ptr->func_data.is_shared; } - bool hasReal() const { return (this->func_ptr->re == nullptr) ? false : true; } - bool hasImag() const { return (this->func_ptr->im == nullptr) ? false : true; } - FunctionData &getFunctionData(); - double occ() const { return this->func_ptr->func_data.occ; } - int spin() const { return this->func_ptr->func_data.spin; } - FunctionTree<3> &real() { return *this->func_ptr->re; } - FunctionTree<3> &imag() { return *this->func_ptr->im; } - const FunctionTree<3> &real() const { return *this->func_ptr->re; } - const FunctionTree<3> &imag() const { return *this->func_ptr->im; } - void release() { this->func_ptr.reset(); } - bool conjugate() const { return this->conj; } - MultiResolutionAnalysis<3> *funcMRA = nullptr; - int getRank() const { return rank; } - void setRank(int rank) { (*this).rank = rank; } - void setOcc(double occ) { this->getFunctionData().occ = occ; } - void setSpin(int spin) { this->getFunctionData().spin = spin; } - ComplexFunction dagger(); - virtual ~ComplexFunction() = default; - - void alloc(int type, mrcpp::MultiResolutionAnalysis<3> *mra = nullptr); - void free(int type); - - int getSizeNodes(int type) const; - int getNNodes(int type) const; - - void setReal(mrcpp::FunctionTree<3> *tree); - void setImag(mrcpp::FunctionTree<3> *tree); - - double norm() const; - double squaredNorm() const; - ComplexDouble integrate() const; - - int crop(double prec); - void rescale(double c); - void rescale(ComplexDouble c); - void add(ComplexDouble c, ComplexFunction inp); - void absadd(ComplexDouble c, ComplexFunction inp); - char printSpin() const; - -protected: - bool conj{false}; - std::shared_ptr func_ptr; - int rank = -1; // index in vector -}; - -namespace cplxfunc { -void SetdefaultMRA(MultiResolutionAnalysis<3> *MRA); -ComplexDouble dot(ComplexFunction bra, ComplexFunction ket); -ComplexDouble node_norm_dot(ComplexFunction bra, ComplexFunction ket, bool exact); -void deep_copy(ComplexFunction &out, ComplexFunction &inp); -void add(ComplexFunction &out, ComplexDouble a, ComplexFunction inp_a, ComplexDouble b, ComplexFunction inp_b, double prec); -void project(ComplexFunction &out, std::function &r)> f, int type, double prec); -void project(ComplexFunction &out, RepresentableFunction<3> &f, int type, double prec); -void multiply(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false); -void multiply_real(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false); -void multiply_imag(ComplexFunction &out, ComplexFunction inp_a, ComplexFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false); -void multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3> &f, double prec, int nrefine = 0); -void multiply(ComplexFunction &out, FunctionTree<3> &inp_a, RepresentableFunction<3> &f, double prec, int nrefine = 0); -void linear_combination(ComplexFunction &out, const ComplexVector &c, std::vector &inp, double prec); -} // namespace cplxfunc - -class MPI_FuncVector : public std::vector { -public: - MPI_FuncVector(int N = 0); - MultiResolutionAnalysis<3> *vecMRA; - void distribute(); -}; - -namespace mpifuncvec { -void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, double prec = -1.0); -void rotate(MPI_FuncVector &Phi, const ComplexMatrix &U, MPI_FuncVector &Psi, double prec = -1.0); -void save_nodes(MPI_FuncVector &Phi, mrcpp::FunctionTree<3> &refTree, BankAccount &account, int sizes = -1); -MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double prec = -1.0, ComplexFunction *Func = nullptr, int nrefine = 1, bool all = false); -ComplexVector dot(MPI_FuncVector &Bra, MPI_FuncVector &Ket); -ComplexMatrix calc_lowdin_matrix(MPI_FuncVector &Phi); -ComplexMatrix calc_overlap_matrix(MPI_FuncVector &BraKet); -ComplexMatrix calc_overlap_matrix(MPI_FuncVector &Bra, MPI_FuncVector &Ket); -DoubleMatrix calc_norm_overlap_matrix(MPI_FuncVector &BraKet); -void orthogonalize(double prec, MPI_FuncVector &Bra, MPI_FuncVector &Ket); -} // namespace mpifuncvec -} // namespace mrcpp diff --git a/src/utils/Plotter.cpp b/src/utils/Plotter.cpp index 455bb57e6..c29b3ee2e 100644 --- a/src/utils/Plotter.cpp +++ b/src/utils/Plotter.cpp @@ -37,24 +37,24 @@ namespace mrcpp { * * @param[in] o: Plot origin, default `(0, 0, ... , 0)` */ -template -Plotter::Plotter(const Coord &o) +template +Plotter::Plotter(const Coord &o) : O(o) { - setSuffix(Plotter::Line, ".line"); - setSuffix(Plotter::Surface, ".surf"); - setSuffix(Plotter::Cube, ".cube"); - setSuffix(Plotter::Grid, ".grid"); + setSuffix(Plotter::Line, ".line"); + setSuffix(Plotter::Surface, ".surf"); + setSuffix(Plotter::Cube, ".cube"); + setSuffix(Plotter::Grid, ".grid"); } /** @brief Set file extension for output file * - * @param[in] t: Plot type (`Plotter::Line`, `::Surface`, `::Cube`, `::Grid`) + * @param[in] t: Plot type (`Plotter::Line`, `::Surface`, `::Cube`, `::Grid`) * @param[in] s: Extension string, default `.line`, `.surf`, `.cube`, `.grid` * * @details The file name you decide for the output will get a predefined * suffix that differentiates between different types of plot. */ -template void Plotter::setSuffix(int t, const std::string &s) { +template void Plotter::setSuffix(int t, const std::string &s) { this->suffix.insert(std::pair(t, s)); } @@ -62,7 +62,7 @@ template void Plotter::setSuffix(int t, const std::string &s) { * * @param[in] o: Plot origin, default `(0, 0, ... , 0)` */ -template void Plotter::setOrigin(const Coord &o) { +template void Plotter::setOrigin(const Coord &o) { this->O = o; } @@ -72,7 +72,7 @@ template void Plotter::setOrigin(const Coord &o) { * @param[in] b: B vector * @param[in] c: C vector */ -template void Plotter::setRange(const Coord &a, const Coord &b, const Coord &c) { +template void Plotter::setRange(const Coord &a, const Coord &b, const Coord &c) { this->A = a; this->B = b; this->C = c; @@ -89,10 +89,10 @@ template void Plotter::setRange(const Coord &a, const Coord &b, * separate file, and will print only nodes owned by itself (pluss the * rootNodes). */ -template void Plotter::gridPlot(const MWTree &tree, const std::string &fname) { +template void Plotter::gridPlot(const MWTree &tree, const std::string &fname) { println(20, "----------Grid Plot-----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Grid]; + file << fname << this->suffix[Plotter::Grid]; openPlot(file.str()); writeGrid(tree); closePlot(); @@ -109,16 +109,13 @@ template void Plotter::gridPlot(const MWTree &tree, const std::str * vector A starting from the origin O to a file named fname + file extension * (".line" as default). */ -template -void Plotter::linePlot(const std::array &npts, - const RepresentableFunction &func, - const std::string &fname) { +template void Plotter::linePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname) { println(20, "----------Line Plot-----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Line]; + file << fname << this->suffix[Plotter::Line]; if (verifyRange(1)) { // Verifies only A vector Eigen::MatrixXd coords = calcLineCoordinates(npts[0]); - Eigen::VectorXd values = evaluateFunction(func, coords); + Eigen::Matrix values = evaluateFunction(func, coords); openPlot(file.str()); writeData(coords, values); closePlot(); @@ -138,16 +135,13 @@ void Plotter::linePlot(const std::array &npts, * vectors A (npts[0] points) and B (npts[1] points), starting from the * origin O, to a file named fname + file extension (".surf" as default). */ -template -void Plotter::surfPlot(const std::array &npts, - const RepresentableFunction &func, - const std::string &fname) { +template void Plotter::surfPlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname) { println(20, "--------Surface Plot----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Surface]; + file << fname << this->suffix[Plotter::Surface]; if (verifyRange(2)) { // Verifies A and B vectors Eigen::MatrixXd coords = calcSurfCoordinates(npts[0], npts[1]); - Eigen::VectorXd values = evaluateFunction(func, coords); + Eigen::Matrix values = evaluateFunction(func, coords); openPlot(file.str()); writeData(coords, values); closePlot(); @@ -168,16 +162,13 @@ void Plotter::surfPlot(const std::array &npts, * starting from the origin O, to a file named fname + file extension * (".cube" as default). */ -template -void Plotter::cubePlot(const std::array &npts, - const RepresentableFunction &func, - const std::string &fname) { +template void Plotter::cubePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname) { println(20, "----------Cube Plot-----------"); std::stringstream file; - file << fname << this->suffix[Plotter::Cube]; + file << fname << this->suffix[Plotter::Cube]; if (verifyRange(3)) { // Verifies A, B and C vectors Eigen::MatrixXd coords = calcCubeCoordinates(npts[0], npts[1], npts[2]); - Eigen::VectorXd values = evaluateFunction(func, coords); + Eigen::Matrix values = evaluateFunction(func, coords); openPlot(file.str()); writeCube(npts, values); closePlot(); @@ -192,7 +183,7 @@ void Plotter::cubePlot(const std::array &npts, * @details Generating a vector of pts_a equidistant coordinates that makes * up the vector A in D dimensions, starting from the origin O. */ -template Eigen::MatrixXd Plotter::calcLineCoordinates(int pts_a) const { +template Eigen::MatrixXd Plotter::calcLineCoordinates(int pts_a) const { MatrixXd coords; if (pts_a > 0) { Coord a = calcStep(this->A, pts_a); @@ -211,7 +202,7 @@ template Eigen::MatrixXd Plotter::calcLineCoordinates(int pts_a) cons * @details Generating a vector of equidistant coordinates that makes up the * area spanned by vectors A and B in D dimensions, starting from the origin O. */ -template Eigen::MatrixXd Plotter::calcSurfCoordinates(int pts_a, int pts_b) const { +template Eigen::MatrixXd Plotter::calcSurfCoordinates(int pts_a, int pts_b) const { if (D < 2) MSG_ERROR("Cannot surfPlot less than 2D"); MatrixXd coords; @@ -240,7 +231,7 @@ template Eigen::MatrixXd Plotter::calcSurfCoordinates(int pts_a, int * volume spanned by vectors A, B and C in D dimensions, starting from * the origin O. */ -template Eigen::MatrixXd Plotter::calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const { +template Eigen::MatrixXd Plotter::calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const { if (D < 3) MSG_ERROR("Cannot cubePlot less than 3D function"); MatrixXd coords; @@ -272,12 +263,10 @@ template Eigen::MatrixXd Plotter::calcCubeCoordinates(int pts_a, int * this routine evaluates the function in these points and stores the results * in the vector "values". */ -template -Eigen::VectorXd Plotter::evaluateFunction(const RepresentableFunction &func, - const Eigen::MatrixXd &coords) const { +template Eigen::Matrix Plotter::evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const { auto npts = coords.rows(); if (npts == 0) MSG_ERROR("Empty coordinates"); - Eigen::VectorXd values = VectorXd::Zero(npts); + Eigen::Matrix values = Eigen::Matrix::Zero(npts); #pragma omp parallel for schedule(static) num_threads(mrcpp_get_num_threads()) for (auto i = 0; i < npts; i++) { Coord r{}; @@ -294,7 +283,7 @@ Eigen::VectorXd Plotter::evaluateFunction(const RepresentableFunction &fun * point number (between 0 and nPoints), coordinates 1 through D and the * function value. */ -template void Plotter::writeData(const Eigen::MatrixXd &coords, const Eigen::VectorXd &values) { +template void Plotter::writeData(const Eigen::MatrixXd &coords, const Eigen::Matrix &values) { if (coords.rows() != values.size()) INVALID_ARG_ABORT; std::ofstream &o = *this->fout; for (auto i = 0; i < values.size(); i++) { @@ -308,17 +297,17 @@ template void Plotter::writeData(const Eigen::MatrixXd &coords, const } // Specialized for D=3 below -template void Plotter::writeCube(const std::array &npts, const Eigen::VectorXd &values) { +template void Plotter::writeCube(const std::array &npts, const Eigen::Matrix &values) { NOT_IMPLEMENTED_ABORT } // Specialized for D=3 below -template void Plotter::writeNodeGrid(const MWNode &node, const std::string &color) { +template void Plotter::writeNodeGrid(const MWNode &node, const std::string &color) { NOT_IMPLEMENTED_ABORT } // Specialized for D=3 below -template void Plotter::writeGrid(const MWTree &tree) { +template void Plotter::writeGrid(const MWTree &tree) { NOT_IMPLEMENTED_ABORT } @@ -326,7 +315,7 @@ template void Plotter::writeGrid(const MWTree &tree) { * * @details Opens a file output stream fout for file named fname. */ -template void Plotter::openPlot(const std::string &fname) { +template void Plotter::openPlot(const std::string &fname) { if (fname.empty()) { if (this->fout == nullptr) { MSG_ERROR("Plot file not set!"); @@ -350,7 +339,7 @@ template void Plotter::openPlot(const std::string &fname) { * * @details Closes the file output stream fout. */ -template void Plotter::closePlot() { +template void Plotter::closePlot() { if (this->fout != nullptr) this->fout->close(); this->fout = nullptr; } @@ -412,31 +401,22 @@ template <> void Plotter<3>::writeNodeGrid(const MWNode<3> &node, const std::str for (int d = 0; d < 3; d++) origin[d] = node.getNodeIndex()[d] * length; - o << origin[0] << " " << origin[1] << " " << origin[2] << " " << color << origin[0] << " " << origin[1] << " " - << origin[2] + length << " " << color << origin[0] << " " << origin[1] + length << " " << origin[2] + length - << " " << color << origin[0] << " " << origin[1] + length << " " << origin[2] << color << std::endl; - - o << origin[0] << " " << origin[1] << " " << origin[2] << " " << color << origin[0] << " " << origin[1] << " " - << origin[2] + length << " " << color << origin[0] + length << " " << origin[1] << " " << origin[2] + length - << " " << color << origin[0] + length << " " << origin[1] << " " << origin[2] << color << std::endl; - o << origin[0] << " " << origin[1] << " " << origin[2] << " " << color << origin[0] << " " << origin[1] + length - << " " << origin[2] << " " << color << origin[0] + length << " " << origin[1] + length << " " << origin[2] << " " - << color << origin[0] + length << " " << origin[1] << " " << origin[2] << color << std::endl; - - o << origin[0] + length << " " << origin[1] + length << " " << origin[2] + length << " " << color - << origin[0] + length << " " << origin[1] + length << " " << origin[2] << " " << color << origin[0] + length - << " " << origin[1] << " " << origin[2] << " " << color << origin[0] + length << " " << origin[1] << " " - << origin[2] + length << color << std::endl; - - o << origin[0] + length << " " << origin[1] + length << " " << origin[2] + length << " " << color - << origin[0] + length << " " << origin[1] + length << " " << origin[2] << " " << color << origin[0] << " " - << origin[1] + length << " " << origin[2] << " " << color << origin[0] << " " << origin[1] + length << " " - << origin[2] + length << color << std::endl; - - o << origin[0] + length << " " << origin[1] + length << " " << origin[2] + length << " " << color - << origin[0] + length << " " << origin[1] << " " << origin[2] + length << " " << color << origin[0] << " " - << origin[1] << " " << origin[2] + length << " " << color << origin[0] << " " << origin[1] + length << " " - << origin[2] + length << color << std::endl; + o << origin[0] << " " << origin[1] << " " << origin[2] << " " << color << origin[0] << " " << origin[1] << " " << origin[2] + length << " " << color << origin[0] << " " << origin[1] + length + << " " << origin[2] + length << " " << color << origin[0] << " " << origin[1] + length << " " << origin[2] << color << std::endl; + + o << origin[0] << " " << origin[1] << " " << origin[2] << " " << color << origin[0] << " " << origin[1] << " " << origin[2] + length << " " << color << origin[0] + length << " " << origin[1] + << " " << origin[2] + length << " " << color << origin[0] + length << " " << origin[1] << " " << origin[2] << color << std::endl; + o << origin[0] << " " << origin[1] << " " << origin[2] << " " << color << origin[0] << " " << origin[1] + length << " " << origin[2] << " " << color << origin[0] + length << " " + << origin[1] + length << " " << origin[2] << " " << color << origin[0] + length << " " << origin[1] << " " << origin[2] << color << std::endl; + + o << origin[0] + length << " " << origin[1] + length << " " << origin[2] + length << " " << color << origin[0] + length << " " << origin[1] + length << " " << origin[2] << " " << color + << origin[0] + length << " " << origin[1] << " " << origin[2] << " " << color << origin[0] + length << " " << origin[1] << " " << origin[2] + length << color << std::endl; + + o << origin[0] + length << " " << origin[1] + length << " " << origin[2] + length << " " << color << origin[0] + length << " " << origin[1] + length << " " << origin[2] << " " << color + << origin[0] << " " << origin[1] + length << " " << origin[2] << " " << color << origin[0] << " " << origin[1] + length << " " << origin[2] + length << color << std::endl; + + o << origin[0] + length << " " << origin[1] + length << " " << origin[2] + length << " " << color << origin[0] + length << " " << origin[1] << " " << origin[2] + length << " " << color + << origin[0] << " " << origin[1] << " " << origin[2] + length << " " << color << origin[0] << " " << origin[1] + length << " " << origin[2] + length << color << std::endl; } /** @brief Writing grid data to file @@ -462,7 +442,7 @@ template <> void Plotter<3>::writeGrid(const MWTree<3> &tree) { } /** @brief Checks the validity of the plotting range */ -template bool Plotter::verifyRange(int dim) const { +template bool Plotter::verifyRange(int dim) const { auto is_len_zero = [](Coord vec) { double vec_sq = 0.0; @@ -483,14 +463,18 @@ template bool Plotter::verifyRange(int dim) const { } /** @brief Compute step length to cover vector with `pts` points, including edges */ -template Coord Plotter::calcStep(const Coord &vec, int pts) const { +template Coord Plotter::calcStep(const Coord &vec, int pts) const { Coord step; for (auto d = 0; d < D; d++) step[d] = vec[d] / (pts - 1.0); return step; } -template class Plotter<1>; -template class Plotter<2>; -template class Plotter<3>; +template class Plotter<1, double>; +template class Plotter<2, double>; +template class Plotter<3, double>; + +template class Plotter<1, ComplexDouble>; +template class Plotter<2, ComplexDouble>; +template class Plotter<3, ComplexDouble>; } // namespace mrcpp diff --git a/src/utils/Plotter.h b/src/utils/Plotter.h index d38941b27..9612dedec 100644 --- a/src/utils/Plotter.h +++ b/src/utils/Plotter.h @@ -56,7 +56,7 @@ namespace mrcpp { * */ -template class Plotter { +template class Plotter { public: explicit Plotter(const Coord &o = {}); virtual ~Plotter() = default; @@ -65,10 +65,10 @@ template class Plotter { void setOrigin(const Coord &o); void setRange(const Coord &a, const Coord &b = {}, const Coord &c = {}); - void gridPlot(const MWTree &tree, const std::string &fname); - void linePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); - void surfPlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); - void cubePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + void gridPlot(const MWTree &tree, const std::string &fname); + void linePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + void surfPlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + void cubePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); enum type { Line, Surface, Cube, Grid }; @@ -86,13 +86,13 @@ template class Plotter { Eigen::MatrixXd calcSurfCoordinates(int pts_a, int pts_b) const; Eigen::MatrixXd calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const; - Eigen::VectorXd evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const; + Eigen::Matrix evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const; - void writeData(const Eigen::MatrixXd &coords, const Eigen::VectorXd &values); - virtual void writeCube(const std::array &npts, const Eigen::VectorXd &values); + void writeData(const Eigen::MatrixXd &coords, const Eigen::Matrix &values); + virtual void writeCube(const std::array &npts, const Eigen::Matrix &values); - void writeGrid(const MWTree &tree); - void writeNodeGrid(const MWNode &node, const std::string &color); + void writeGrid(const MWTree &tree); + void writeNodeGrid(const MWNode &node, const std::string &color); private: bool verifyRange(int dim) const; diff --git a/src/utils/Printer.cpp b/src/utils/Printer.cpp index d9d04f4bd..24585feb3 100644 --- a/src/utils/Printer.cpp +++ b/src/utils/Printer.cpp @@ -265,7 +265,7 @@ void print::tree(int level, const std::string &txt, int n, int m, double t) { * @param[in] tree: Tree to be printed * @param[in] timer: Timer to be evaluated */ -template void print::tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer) { +template void print::tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer) { if (level > Printer::getPrintLevel()) return; auto n = tree.getNNodes(); diff --git a/src/utils/Printer.h b/src/utils/Printer.h index dc4935aa8..c021155e8 100644 --- a/src/utils/Printer.h +++ b/src/utils/Printer.h @@ -39,7 +39,7 @@ namespace mrcpp { class Timer; -template class MWTree; +template class MWTree; /** @class Printer * @@ -128,7 +128,7 @@ void memory(int level, const std::string &txt); void value(int level, const std::string &txt, double v, const std::string &unit = "", int p = -1, bool sci = true); void time(int level, const std::string &txt, const Timer &timer); void tree(int level, const std::string &txt, int n, int m, double t); -template void tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer); +template void tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer); } // namespace print // clang-format off diff --git a/src/utils/math_utils.cpp b/src/utils/math_utils.cpp index 8506be298..69a13f300 100644 --- a/src/utils/math_utils.cpp +++ b/src/utils/math_utils.cpp @@ -171,18 +171,26 @@ void math_utils::tensor_self_product(const VectorXd &A, MatrixXd &tprod) { for (int i = 0; i < Ar; i++) { tprod.block(i, 0, 1, Ar) = A(i) * A; } } -void math_utils::apply_filter(double *out, double *in, const MatrixXd &filter, int kp1, int kp1_dm1, double fac) { -#ifdef HAVE_BLAS - cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, kp1_dm1, kp1, kp1, 1.0, in, kp1, filter.data(), kp1, fac, out, kp1_dm1); -#else - Map f(in, kp1, kp1_dm1); - Map g(out, kp1_dm1, kp1); - if (fac < MachineZero) { - g.noalias() = f.transpose() * filter; - } else { - g.noalias() += f.transpose() * filter; - } -#endif +/** Matrix multiplication of the filter with the input coefficients */ +template void math_utils::apply_filter(T *out, T *in, const MatrixXd &filter, int kp1, int kp1_dm1, double fac) { + if constexpr (std::is_same::value) { + Map f(in, kp1, kp1_dm1); + Map g(out, kp1_dm1, kp1); + if (fac < MachineZero) { + g.noalias() = f.transpose() * filter; + } else { + g.noalias() += f.transpose() * filter; + } + } else if constexpr (std::is_same::value) { + Map f(in, kp1, kp1_dm1); + Map g(out, kp1_dm1, kp1); + if (fac < MachineZero) { + g.noalias() = f.transpose() * filter; + } else { + g.noalias() += f.transpose() * filter; + } + } else + NOT_IMPLEMENTED_ABORT; } /** Make a nD-representation from 1D-representations of separable functions. @@ -226,7 +234,6 @@ void math_utils::tensor_expand_coords_3D(int kp1, const MatrixXd &primitive, Mat } } - /** @brief Compute the eigenvalues and eigenvectors of a Hermitian matrix * * @param A: matrix to diagonalize (not modified) @@ -327,6 +334,9 @@ template std::vector> math_utils::cartesian_product(std return output; } +template void math_utils::apply_filter(double *out, double *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); +template void math_utils::apply_filter(ComplexDouble *out, ComplexDouble *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); + template double math_utils::calc_distance<1>(const Coord<1> &a, const Coord<1> &b); template double math_utils::calc_distance<2>(const Coord<2> &a, const Coord<2> &b); template double math_utils::calc_distance<3>(const Coord<3> &a, const Coord<3> &b); diff --git a/src/utils/math_utils.h b/src/utils/math_utils.h index 9c371aa51..3eacfa10b 100644 --- a/src/utils/math_utils.h +++ b/src/utils/math_utils.h @@ -66,7 +66,7 @@ double matrix_norm_inf(const Eigen::MatrixXd &M); double matrix_norm_1(const Eigen::MatrixXd &M); double matrix_norm_2(const Eigen::MatrixXd &M); -void apply_filter(double *out, double *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); +template void apply_filter(T *out, T *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); void tensor_expand_coefs(int dim, int dir, int kp1, int kp1_d, const Eigen::MatrixXd &primitive, Eigen::VectorXd &expanded); diff --git a/src/utils/mpi_utils.cpp b/src/utils/mpi_utils.cpp index d61f2bd23..77526375b 100644 --- a/src/utils/mpi_utils.cpp +++ b/src/utils/mpi_utils.cpp @@ -36,7 +36,8 @@ namespace mrcpp { * @param[in] comm: Communicator sharing resources * @param[in] sh_size: Memory size, in MB */ -SharedMemory::SharedMemory(mrcpp::mpi_comm comm, int sh_size) +template +SharedMemory::SharedMemory(mrcpp::mpi_comm comm, int sh_size) : sh_start_ptr(nullptr) , sh_end_ptr(nullptr) , sh_max_ptr(nullptr) @@ -57,18 +58,18 @@ SharedMemory::SharedMemory(mrcpp::mpi_comm comm, int sh_size) int qdisp = 0; MPI_Win_shared_query(this->sh_win, 0, &qsize, &qdisp, &this->sh_start_ptr); MPI_Win_fence(0, this->sh_win); - this->sh_max_ptr = this->sh_start_ptr + qsize / sizeof(double); + this->sh_max_ptr = this->sh_start_ptr + qsize / sizeof(T); this->sh_end_ptr = this->sh_start_ptr; #endif } -void SharedMemory::clear() { +template void SharedMemory::clear() { #ifdef MRCPP_HAS_MPI this->sh_end_ptr = this->sh_start_ptr; #endif } -SharedMemory::~SharedMemory() { +template SharedMemory::~SharedMemory() { #ifdef MRCPP_HAS_MPI // deallocates the memory block MPI_Win_free(&this->sh_win); @@ -88,7 +89,7 @@ SharedMemory::~SharedMemory() { * to speed up communication, otherwise it will be communicated in a separate * step before the main communication. */ -template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { +template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { #ifdef MRCPP_HAS_MPI auto &allocator = tree.getNodeAllocator(); @@ -101,8 +102,7 @@ template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp:: Timer t1; for (int iChunk = 0; iChunk < nChunks; iChunk++) { MPI_Send(allocator.getNodeChunk(iChunk), allocator.getNodeChunkSize(), MPI_BYTE, dst, tag + iChunk + 1, comm); - if (coeff) - MPI_Send(allocator.getCoefChunk(iChunk), allocator.getCoefChunkSize(), MPI_BYTE, dst, tag + iChunk + 1001, comm); + if (coeff) MPI_Send(allocator.getCoefChunk(iChunk), allocator.getCoefChunkSize(), MPI_BYTE, dst, tag + iChunk + 1001, comm); } println(10, " Time send " << std::setw(30) << t1.elapsed()); #endif @@ -121,7 +121,7 @@ template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp:: * to speed up communication, otherwise it will be communicated in a separate * step before the main communication. */ -template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { +template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff) { #ifdef MRCPP_HAS_MPI MPI_Status status; auto &allocator = tree.getNodeAllocator(); @@ -136,8 +136,7 @@ template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp:: allocator.init(nChunks, coeff); for (int iChunk = 0; iChunk < nChunks; iChunk++) { MPI_Recv(allocator.getNodeChunk(iChunk), allocator.getNodeChunkSize(), MPI_BYTE, src, tag + iChunk + 1, comm, &status); - if (coeff) - MPI_Recv(allocator.getCoefChunk(iChunk), allocator.getCoefChunkSize(), MPI_BYTE, src, tag + iChunk + 1001, comm, &status); + if (coeff) MPI_Recv(allocator.getCoefChunk(iChunk), allocator.getCoefChunkSize(), MPI_BYTE, src, tag + iChunk + 1001, comm, &status); } println(10, " Time receive " << std::setw(30) << t1.elapsed()); @@ -157,7 +156,7 @@ template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp:: * @details This function should be called every time a shared function is * updated, in order to update the local memory of each MPI process. */ -template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm) { +template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm) { #ifdef MRCPP_HAS_MPI Timer t1; auto &allocator = tree.getNodeAllocator(); @@ -197,15 +196,27 @@ template void share_tree(FunctionTree &tree, int src, int tag, mrcpp: println(10, " Time share " << std::setw(30) << t1.elapsed()); #endif } - -template void send_tree<1>(FunctionTree<1> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void send_tree<2>(FunctionTree<2> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void send_tree<3>(FunctionTree<3> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void recv_tree<1>(FunctionTree<1> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void recv_tree<2>(FunctionTree<2> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void recv_tree<3>(FunctionTree<3> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); -template void share_tree<1>(FunctionTree<1> &tree, int src, int tag, mrcpp::mpi_comm comm); -template void share_tree<2>(FunctionTree<2> &tree, int src, int tag, mrcpp::mpi_comm comm); -template void share_tree<3>(FunctionTree<3> &tree, int src, int tag, mrcpp::mpi_comm comm); +template class SharedMemory; +template class SharedMemory; + +template void send_tree<1>(FunctionTree<1, double> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<2>(FunctionTree<2, double> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<3>(FunctionTree<3, double> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<1>(FunctionTree<1, double> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<2>(FunctionTree<2, double> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<3>(FunctionTree<3, double> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void share_tree<1>(FunctionTree<1, double> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<2>(FunctionTree<2, double> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<3>(FunctionTree<3, double> &tree, int src, int tag, mrcpp::mpi_comm comm); + +template void send_tree<1>(FunctionTree<1, ComplexDouble> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<2>(FunctionTree<2, ComplexDouble> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void send_tree<3>(FunctionTree<3, ComplexDouble> &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<1>(FunctionTree<1, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<2>(FunctionTree<2, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void recv_tree<3>(FunctionTree<3, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks, bool coeff); +template void share_tree<1>(FunctionTree<1, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<2>(FunctionTree<2, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm); +template void share_tree<3>(FunctionTree<3, ComplexDouble> &tree, int src, int tag, mrcpp::mpi_comm comm); } // namespace mrcpp diff --git a/src/utils/mpi_utils.h b/src/utils/mpi_utils.h index 1b94b0dc9..062d1affa 100644 --- a/src/utils/mpi_utils.h +++ b/src/utils/mpi_utils.h @@ -51,7 +51,9 @@ extern int sh_group_rank; extern int is_bank; extern int is_bankclient; extern int bank_size; +extern int bank_per_node; extern int omp_threads; +extern int use_omp_num_threads; extern int tot_bank_size; extern int max_tag; extern int task_bank; @@ -61,7 +63,7 @@ extern MPI_Comm comm_share; extern MPI_Comm comm_sh_group; extern MPI_Comm comm_bank; -}// namespace mpi +} // namespace mpi } // namespace mrcpp namespace mrcpp { @@ -74,26 +76,26 @@ namespace mrcpp { * communicator. In order to allocate a FunctionTree in shared memory, * simply pass a SharedMemory object to the FunctionTree constructor. */ -class SharedMemory { +template class SharedMemory { public: SharedMemory(mrcpp::mpi_comm comm, int sh_size); SharedMemory(const SharedMemory &mem) = delete; - SharedMemory &operator=(const SharedMemory &mem) = delete; + SharedMemory &operator=(const SharedMemory &mem) = delete; ~SharedMemory(); void clear(); // show shared memory as entirely available - double *sh_start_ptr; // start of shared block - double *sh_end_ptr; // end of used part - double *sh_max_ptr; // end of shared block + T *sh_start_ptr; // start of shared block + T *sh_end_ptr; // end of used part + T *sh_max_ptr; // end of shared block mrcpp::mpi_win sh_win; // MPI window object int rank; // rank among shared group }; -template class FunctionTree; +template class FunctionTree; -template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); -template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); -template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm); +template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); +template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); +template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm); } // namespace mrcpp diff --git a/src/utils/parallel.cpp b/src/utils/parallel.cpp index 2877d12e9..332d0fb5d 100644 --- a/src/utils/parallel.cpp +++ b/src/utils/parallel.cpp @@ -1,18 +1,17 @@ #include #include #include -#include #include +#include #include "Bank.h" -#include "ComplexFunction.h" #include "omp_utils.h" #include "parallel.h" #include "trees/FunctionTree.h" #ifdef MRCPP_HAS_OMP #define mrcpp_get_max_threads() omp_get_max_threads() -#define mrcpp_get_num_procs() omp_get_num_procs()/2 +#define mrcpp_get_num_procs() omp_get_num_procs() #define mrcpp_set_dynamic(n) omp_set_dynamic(n) #else #define mrcpp_get_max_threads() 1 @@ -55,9 +54,11 @@ int is_centralbank = 0; int is_bankclient = 1; int is_bankmaster = 0; // only one bankmaster is_bankmaster int bank_size = 0; -int omp_threads = -1; // can be set to force number of threads -int tot_bank_size = 0; // size of bank, including the task manager -int max_tag = 0; // max value allowed by MPI +int bank_per_node = 0; +int omp_threads = -1; // can be set to force number of threads +int use_omp_num_threads = -1; // can be set to use number of threads from env +int tot_bank_size = 0; // size of bank, including the task manager +int max_tag = 0; // max value allowed by MPI vector bankmaster; int task_bank = -1; // world rank of the task manager @@ -66,89 +67,100 @@ MPI_Comm comm_share; MPI_Comm comm_sh_group; MPI_Comm comm_bank; -} // namespace mpi - int id_shift; // to ensure that nodes, orbitals and functions do not collide extern int metadata_block[3]; // can add more metadata in future extern int const size_metadata = 3; -void mpi::initialize() { +void initialize() { Eigen::setNbThreads(1); mrcpp_set_dynamic(0); #ifdef MRCPP_HAS_MPI MPI_Init(nullptr, nullptr); - MPI_Comm_size(MPI_COMM_WORLD, &mpi::world_size); - MPI_Comm_rank(MPI_COMM_WORLD, &mpi::world_rank); + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); // divide the world into groups // each group has its own group communicator definition + // count the number of process per node + MPI_Comm node_comm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &node_comm); + int node_rank, node_size; + MPI_Comm_rank(node_comm, &node_rank); + MPI_Comm_size(node_comm, &node_size); + // define independent group of MPI processes, that are not part of comm_wrk // for now the new group does not include comm_share - mpi::comm_bank = MPI_COMM_WORLD; // clients and master - MPI_Comm comm_remainder; // clients only + comm_bank = MPI_COMM_WORLD; // clients and master + MPI_Comm comm_remainder; // clients only // set bank_size automatically if not defined by user - if (mpi::world_size < 2) { - mpi::bank_size = 0; - } else if (mpi::bank_size < 0) { - mpi::bank_size = max(mpi::world_size / 3, 1); + if (world_size < 2) { + bank_size = 0; + } else if (bank_size < 0) { + if (bank_per_node >= 0) { + bank_size = node_size * bank_per_node; + } else { + bank_size = max(world_size / 3, 1); + } + } else if (bank_size >= 0 and bank_per_node >= 0) { + if (bank_size != node_size * bank_per_node and world_rank == 0) std::cout << "WARNING: bank_size and bank_per_node are incompatible " << bank_size << " " << bank_per_node << std::endl; } - if (mpi::world_size - mpi::bank_size < 1) MSG_ABORT("No MPI ranks left for working!"); - if (mpi::bank_size < 1 and mpi::world_size > 1) MSG_ABORT("Bank size must be at least one when using MPI!"); + if (world_size - bank_size < 1) MSG_ABORT("No MPI ranks left for working!"); + if (bank_size < 1 and world_size > 1) MSG_ABORT("Bank size must be at least one when using MPI!"); - mpi::bankmaster.resize(mpi::bank_size); - for (int i = 0; i < mpi::bank_size; i++) { - mpi::bankmaster[i] = mpi::world_size - i - 1; // rank of the bankmasters + bankmaster.resize(bank_size); + for (int i = 0; i < bank_size; i++) { + bankmaster[i] = world_size - i - 1; // rank of the bankmasters } - if (mpi::world_rank < mpi::world_size - mpi::bank_size) { + if (world_rank < world_size - bank_size) { // everything which is left - mpi::is_bank = 0; - mpi::is_centralbank = 0; - mpi::is_bankclient = 1; + is_bank = 0; + is_centralbank = 0; + is_bankclient = 1; } else { // special group of centralbankmasters - mpi::is_bank = 1; - mpi::is_centralbank = 1; - mpi::is_bankclient = 0; - if (mpi::world_rank == mpi::world_size - mpi::bank_size) mpi::is_bankmaster = 1; + is_bank = 1; + is_centralbank = 1; + is_bankclient = 0; + if (world_rank == world_size - bank_size) is_bankmaster = 1; } - MPI_Comm_split(MPI_COMM_WORLD, mpi::is_bankclient, mpi::world_rank, &comm_remainder); + MPI_Comm_split(MPI_COMM_WORLD, is_bankclient, world_rank, &comm_remainder); // split world into groups that can share memory - MPI_Comm_split_type(comm_remainder, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &mpi::comm_share); + MPI_Comm_split_type(comm_remainder, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &comm_share); - MPI_Comm_rank(mpi::comm_share, &mpi::share_rank); - MPI_Comm_size(mpi::comm_share, &mpi::share_size); + MPI_Comm_rank(comm_share, &share_rank); + MPI_Comm_size(comm_share, &share_size); // define a rank of the group - MPI_Comm_split(comm_remainder, mpi::share_rank, mpi::world_rank, &mpi::comm_sh_group); + MPI_Comm_split(comm_remainder, share_rank, world_rank, &comm_sh_group); // mpiShRank is color (same color->in same group) // MPI_worldrank is key (orders rank within the groups) // we define a new orbital rank, so that the orbitals within // a shared memory group, have consecutive ranks - MPI_Comm_rank(mpi::comm_sh_group, &mpi::sh_group_rank); + MPI_Comm_rank(comm_sh_group, &sh_group_rank); - mpi::wrk_rank = mpi::share_rank + mpi::sh_group_rank * mpi::world_size; - MPI_Comm_split(comm_remainder, 0, mpi::wrk_rank, &mpi::comm_wrk); + wrk_rank = share_rank + sh_group_rank * world_size; + MPI_Comm_split(comm_remainder, 0, wrk_rank, &comm_wrk); // 0 is color (same color->in same group) // mpiOrbRank is key (orders rank in the group) - MPI_Comm_rank(mpi::comm_wrk, &mpi::wrk_rank); - MPI_Comm_size(mpi::comm_wrk, &mpi::wrk_size); + MPI_Comm_rank(comm_wrk, &wrk_rank); + MPI_Comm_size(comm_wrk, &wrk_size); // if bank_size is large enough, we reserve one as "task manager" - mpi::tot_bank_size = mpi::bank_size; - if (mpi::bank_size <= 2 and mpi::bank_size > 0) { + tot_bank_size = bank_size; + if (bank_size <= 2 and bank_size > 0) { // use the first bank as task manager - mpi::task_bank = mpi::bankmaster[0]; - } else if (mpi::bank_size > 1) { + task_bank = bankmaster[0]; + } else if (bank_size > 1) { // reserve one bank for task management only - mpi::bank_size--; - mpi::task_bank = mpi::bankmaster[mpi::bank_size]; // the last rank is reserved as task manager + bank_size--; + task_bank = bankmaster[bank_size]; // the last rank is reserved as task manager } // determine the maximum value alowed for mpi tags @@ -158,77 +170,89 @@ void mpi::initialize() { max_tag = *(int *)val / 2; id_shift = max_tag / 2; // half is reserved for non orbital. - // determine the number of threads we can assign to each mpi worker. - // mrcpp_get_num_procs is total number of hardware logical threads accessible by this mpi - // We assume that half of them are physical cores. - // mrcpp_get_max_threads is OMP_NUM_THREADS (environment variable). - // omp_threads_available is the total number of logical threads available on this compute-node - // We assume that half of them are physical cores. - // - // six conditions should be satisfied: - // 1) no one use more than mrcpp_get_num_procs()/2 - // 2) NOT ENFORCED: no one use more than mrcpp_get_max_threads, as defined by rank 0 - // 3) the total number of threads used on the compute-node must not exceed omp_threads_available/2 - // 4) Bank needs only one thread - // 5) workers need as many threads as possible - // 6) at least one thread - - MPI_Comm comm_share_world;//all that share the memory + MPI_Comm comm_share_world; // all that share the memory MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &comm_share_world); - int n_bank_thisnode; //number of banks on this node + int n_bank_thisnode; // number of banks on this node MPI_Allreduce(&is_bank, &n_bank_thisnode, 1, MPI_INT, MPI_SUM, comm_share_world); - int n_wrk_thisnode; //number of workers on this node + int n_wrk_thisnode; // number of workers on this node MPI_Allreduce(&is_bankclient, &n_wrk_thisnode, 1, MPI_INT, MPI_SUM, comm_share_world); int omp_threads_available = thread::hardware_concurrency(); - int nthreads = 1; - if (is_bankclient) nthreads = (omp_threads_available/2-n_bank_thisnode)/n_wrk_thisnode; // 3) and 5) - - // do not exceed total number of cores accessible (assumed to be half the number of logical threads) - nthreads = min(nthreads, mrcpp_get_num_procs()); // 1) - // NB: we do not use OMP_NUM_THREADS. Use all cores accessible. Could change this in the future - // if OMP_NUM_THREADS is set, do not exceed - // we enforce that all compute nodes use the same OMP_NUM_THREADS. Rank 0 decides. - /* int my_OMP_NUM_THREADS = mrcpp_get_max_threads(); + int nthreads = 1; + int my_OMP_NUM_THREADS = mrcpp_get_max_threads(); MPI_Bcast(&my_OMP_NUM_THREADS, 1, MPI_INT, 0, MPI_COMM_WORLD); - - if (my_OMP_NUM_THREADS > 0) nthreads = min(nthreads, my_OMP_NUM_THREADS); // 2) - */ - - nthreads = max(1, nthreads); // 6) - - if (is_bank) nthreads = 1; // 4) - - //cout< 0) { - if (omp_threads != nthreads and world_rank == 0) { - cout<<"Warning: recommended number of threads is "< 0) { + if (omp_threads != nthreads and world_rank == 0) { + cout << "Warning: recommended number of threads is " << nthreads << endl; + cout << "setting number of threads to omp_threads, " << max(1, omp_threads) << endl; + } + nthreads = omp_threads; } - nthreads = omp_threads; + } + nthreads = max(1, nthreads); // 5) + + if (nthreads * n_wrk_thisnode + n_bank_thisnode < omp_threads_available / 3 and world_rank == 0) { + std::cout << "WARNING: only " << nthreads * n_wrk_thisnode + n_bank_thisnode << " threads used per node while " << omp_threads_available << " logical cpus are accessible " << std::endl; } + if (nthreads > mrcpp_get_num_procs() / 2) { std::cout << "WARNING: MPI rank " << world_rank << " will use " << nthreads << " but only " << mrcpp_get_num_procs() / 2 << " procs are accessible" << std::endl; } + omp::n_threads = nthreads; mrcpp::set_max_threads(nthreads); - if (mpi::is_bank) { + if (is_bank) { // bank is open until end of program - if (mpi::is_centralbank) { dataBank.open(); } - mpi::finalize(); + if (is_centralbank) { dataBank.open(); } + finalize(); exit(EXIT_SUCCESS); } #else - mpi::bank_size = 0; + bank_size = 0; mrcpp::set_max_threads(omp::n_threads); #endif } -void mpi::finalize() { +void finalize() { #ifdef MRCPP_HAS_MPI - if (mpi::bank_size > 0 and mpi::grand_master()) { + if (bank_size > 0 and grand_master()) { println(4, " max data in bank " << dataBank.get_maxtotalsize() << " MB "); dataBank.close(); } @@ -237,7 +261,7 @@ void mpi::finalize() { #endif } -void mpi::barrier(MPI_Comm comm) { +void barrier(MPI_Comm comm) { #ifdef MRCPP_HAS_MPI MPI_Barrier(comm); #endif @@ -247,33 +271,38 @@ void mpi::barrier(MPI_Comm comm) { * Orbital related MPI functions * *********************************/ -bool mpi::grand_master() { - return (mpi::world_rank == 0 and is_bankclient) ? true : false; +bool grand_master() { + return (world_rank == 0 and is_bankclient) ? true : false; +} + +bool share_master() { + return (share_rank == 0) ? true : false; } -bool mpi::share_master() { - return (mpi::share_rank == 0) ? true : false; +/** @brief Test if function belongs to this MPI rank */ +bool my_func(int j) { + return ((j) % wrk_size == wrk_rank) ? true : false; } -/** @brief Test if orbital belongs to this MPI rank (or is common)*/ -bool mpi::my_orb(int j) { - return ((j) % mpi::wrk_size == mpi::wrk_rank) ? true : false; +/** @brief Test if function belongs to this MPI rank */ +bool my_func(const CompFunction<3> &func) { + return my_func(func.rank()); } -/** @brief Test if orbital belongs to this MPI rank (or is common)*/ -bool mpi::my_orb(ComplexFunction orbj) { - return my_orb(orbj.getRank()); +/** @brief Test if function belongs to this MPI rank */ +bool my_func(CompFunction<3> *func) { + return my_func(func->rank()); } /** @brief Free all function pointers not belonging to this MPI rank */ -void mpi::free_foreign(MPI_FuncVector &Phi) { - for (ComplexFunction &i : Phi) { - if (not mpi::my_orb(i)) i.free(NUMBER::Total); +void free_foreign(CompFunctionVector &Phi) { + for (CompFunction<3> &i : Phi) { + if (not my_func(i)) i.free(); } } /** @brief Add up each entry of the vector with contributions from all MPI ranks */ -void mpi::allreduce_vector(IntVector &vec, MPI_Comm comm) { +void allreduce_vector(IntVector &vec, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI int N = vec.size(); MPI_Allreduce(MPI_IN_PLACE, vec.data(), N, MPI_INT, MPI_SUM, comm); @@ -281,7 +310,7 @@ void mpi::allreduce_vector(IntVector &vec, MPI_Comm comm) { } /** @brief Add up each entry of the vector with contributions from all MPI ranks */ -void mpi::allreduce_vector(DoubleVector &vec, MPI_Comm comm) { +void allreduce_vector(DoubleVector &vec, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI int N = vec.size(); MPI_Allreduce(MPI_IN_PLACE, vec.data(), N, MPI_DOUBLE, MPI_SUM, comm); @@ -289,7 +318,7 @@ void mpi::allreduce_vector(DoubleVector &vec, MPI_Comm comm) { } /** @brief Add up each entry of the vector with contributions from all MPI ranks */ -void mpi::allreduce_vector(ComplexVector &vec, MPI_Comm comm) { +void allreduce_vector(ComplexVector &vec, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI int N = vec.size(); MPI_Allreduce(MPI_IN_PLACE, vec.data(), N, MPI_C_DOUBLE_COMPLEX, MPI_SUM, comm); @@ -297,7 +326,7 @@ void mpi::allreduce_vector(ComplexVector &vec, MPI_Comm comm) { } /** @brief Add up each entry of the matrix with contributions from all MPI ranks */ -void mpi::allreduce_matrix(IntMatrix &mat, MPI_Comm comm) { +void allreduce_matrix(IntMatrix &mat, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI int N = mat.size(); MPI_Allreduce(MPI_IN_PLACE, mat.data(), N, MPI_INT, MPI_SUM, comm); @@ -305,7 +334,7 @@ void mpi::allreduce_matrix(IntMatrix &mat, MPI_Comm comm) { } /** @brief Add up each entry of the matrix with contributions from all MPI ranks */ -void mpi::allreduce_matrix(DoubleMatrix &mat, MPI_Comm comm) { +void allreduce_matrix(DoubleMatrix &mat, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI int N = mat.size(); MPI_Allreduce(MPI_IN_PLACE, mat.data(), N, MPI_DOUBLE, MPI_SUM, comm); @@ -313,58 +342,65 @@ void mpi::allreduce_matrix(DoubleMatrix &mat, MPI_Comm comm) { } /** @brief Add up each entry of the matrix with contributions from all MPI ranks */ -void mpi::allreduce_matrix(ComplexMatrix &mat, MPI_Comm comm) { +void allreduce_matrix(ComplexMatrix &mat, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI int N = mat.size(); MPI_Allreduce(MPI_IN_PLACE, mat.data(), N, MPI_C_DOUBLE_COMPLEX, MPI_SUM, comm); #endif } -// send a function with MPI -void mpi::send_function(ComplexFunction &func, int dst, int tag, MPI_Comm comm) { +// send a component function with MPI +void send_function(const CompFunction<3> &func, int dst, int tag, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI - if (func.isShared()) MSG_WARN("Sending a shared function is not recommended"); - FunctionData &funcinfo = func.getFunctionData(); - MPI_Send(&funcinfo, sizeof(FunctionData), MPI_BYTE, dst, 0, comm); - if (func.hasReal()) mrcpp::send_tree(func.real(), dst, tag, comm, funcinfo.real_size); - if (func.hasImag()) mrcpp::send_tree(func.imag(), dst, tag + 10000, comm, funcinfo.imag_size); + for (int i = 0; i < func.Ncomp(); i++) { + // make sure that Nchunks is up to date + if (func.isreal()) + func.Nchunks()[i] = func.CompD[i]->getNChunks(); + else + func.Nchunks()[i] = func.CompC[i]->getNChunks(); + } + MPI_Send(&func.func_ptr->data, sizeof(CompFunctionData<3>), MPI_BYTE, dst, 0, comm); + for (int i = 0; i < func.Ncomp(); i++) { + if (func.isreal()) + mrcpp::send_tree(*func.CompD[i], dst, tag, comm, func.Nchunks()[i]); + else + mrcpp::send_tree(*func.CompC[i], dst, tag, comm, func.Nchunks()[i]); + } #endif } -// receive a function with MPI -void mpi::recv_function(ComplexFunction &func, int src, int tag, MPI_Comm comm) { +// receive a component function with MPI +void recv_function(CompFunction<3> &func, int src, int tag, MPI_Comm comm) { #ifdef MRCPP_HAS_MPI - if (func.isShared()) MSG_WARN("Receiving a shared function is not recommended"); MPI_Status status; - - FunctionData &funcinfo = func.getFunctionData(); - MPI_Recv(&funcinfo, sizeof(FunctionData), MPI_BYTE, src, 0, comm, &status); - if (funcinfo.real_size > 0) { - // We must have a tree defined for receiving nodes. Define one: - if (not func.hasReal()) func.alloc(NUMBER::Real); - mrcpp::recv_tree(func.real(), src, tag, comm, funcinfo.real_size); - } - - if (funcinfo.imag_size > 0) { - // We must have a tree defined for receiving nodes. Define one: - if (not func.hasImag()) func.alloc(NUMBER::Imag); - mrcpp::recv_tree(func.imag(), src, tag + 10000, comm, funcinfo.imag_size); + int func_ncomp_in = func.Ncomp(); + MPI_Recv(&func.func_ptr->data, sizeof(CompFunctionData<3>), MPI_BYTE, src, 0, comm, &status); + for (int i = 0; i < func.Ncomp(); i++) { + if (func_ncomp_in <= i) func.alloc(i + 1); + if (func.isreal()) + mrcpp::recv_tree(*func.CompD[i], src, tag, comm, func.Nchunks()[i]); + else + mrcpp::recv_tree(*func.CompC[i], src, tag, comm, func.Nchunks()[i]); } #endif } /** Update a shared function after it has been changed by one of the MPI ranks. */ -void mpi::share_function(ComplexFunction &func, int src, int tag, MPI_Comm comm) { +void share_function(CompFunction<3> &func, int src, int tag, MPI_Comm comm) { if (func.isShared()) { #ifdef MRCPP_HAS_MPI - if (func.hasReal()) mrcpp::share_tree(func.real(), src, tag, comm); - if (func.hasImag()) mrcpp::share_tree(func.imag(), src, 2 * tag, comm); + for (int comp = 0; comp < func.Ncomp(); comp++) { + if (func.isreal()) + mrcpp::share_tree(*func.CompD[comp], src, tag, comm); + else + mrcpp::share_tree(*func.CompC[comp], src, tag, comm); + } #endif } } /** @brief Add all mpi function into rank zero */ -void mpi::reduce_function(double prec, ComplexFunction &func, MPI_Comm comm) { +void reduce_function(double prec, CompFunction<3> &func, MPI_Comm comm) { /* 1) Each odd rank send to the left rank 2) All odd ranks are "deleted" (can exit routine) 3) new "effective" ranks are defined within the non-deleted ranks @@ -383,9 +419,9 @@ void mpi::reduce_function(double prec, ComplexFunction &func, MPI_Comm comm) { // receive int src = comm_rank + fac; if (src < comm_size) { - ComplexFunction func_i(false); + CompFunction<3> func_i; int tag = 3333 + src; - mpi::recv_function(func_i, src, tag, comm); + recv_function(func_i, src, tag, comm); func.add(1.0, func_i); // add in place using union grid func.crop(prec); } @@ -395,7 +431,7 @@ void mpi::reduce_function(double prec, ComplexFunction &func, MPI_Comm comm) { int dest = comm_rank - fac; if (dest >= 0) { int tag = 3333 + comm_rank; - mpi::send_function(func, dest, tag, comm); + send_function(func, dest, tag, comm); break; // once data is sent we are done } } @@ -406,7 +442,7 @@ void mpi::reduce_function(double prec, ComplexFunction &func, MPI_Comm comm) { } /** @brief make union tree and send into rank zero */ -void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { +template void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm) { /* 1) Each odd rank send to the left rank 2) All odd ranks are "deleted" (can exit routine) 3) new "effective" ranks are defined within the non-deleted ranks @@ -426,7 +462,7 @@ void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { int src = comm_rank + fac; if (src < comm_size) { int tag = 3333 + src; - mrcpp::FunctionTree<3> tree_i(tree.getMRA()); + mrcpp::FunctionTree<3, T> tree_i(tree.getMRA()); mrcpp::recv_tree(tree_i, src, tag, comm, -1, false); tree.appendTreeNoCoeff(tree_i); // make union grid } @@ -447,9 +483,9 @@ void mpi::reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { } /** @brief make union tree without coeff and send to all - * Include both real and imaginary parts */ -void mpi::allreduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, vector &Phi, MPI_Comm comm) { +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, vector> &Phi, MPI_Comm comm) { +#ifdef MRCPP_HAS_MPI /* 1) make union grid of own orbitals 2) make union grid with others orbitals (sent to rank zero) 3) rank zero broadcast func to everybody @@ -457,16 +493,36 @@ void mpi::allreduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, vector void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, vector> &Phi, MPI_Comm comm) { +#ifdef MRCPP_HAS_MPI + /* 1) make union grid of own orbitals + 2) make union grid with others orbitals (sent to rank zero) + 3) rank zero broadcast func to everybody + */ + + int N = Phi.size(); + for (int j = 0; j < N; j++) { + if (not my_func(j)) continue; + if (Phi[j].isreal()) tree.appendTreeNoCoeff(*Phi[j].CompD[0]); + if (Phi[j].iscomplex()) tree.appendTreeNoCoeff(*Phi[j].CompC[0]); + } + mrcpp::mpi::reduce_Tree_noCoeff(tree, comm_wrk); + mrcpp::mpi::broadcast_Tree_noCoeff(tree, comm_wrk); +#endif } /** @brief Distribute rank zero function to all ranks */ -void mpi::broadcast_function(ComplexFunction &func, MPI_Comm comm) { +void broadcast_function(CompFunction<3> &func, MPI_Comm comm) { /* use same strategy as a reduce, but in reverse order */ #ifdef MRCPP_HAS_MPI int comm_size, comm_rank; @@ -483,13 +539,13 @@ void mpi::broadcast_function(ComplexFunction &func, MPI_Comm comm) { // receive int src = comm_rank - fac; int tag = 4334 + comm_rank; - mpi::recv_function(func, src, tag, comm); + recv_function(func, src, tag, comm); } if (comm_rank % fac == 0 and (comm_rank / fac) % 2 == 0) { // send int dst = comm_rank + fac; int tag = 4334 + dst; - if (dst < comm_size) mpi::send_function(func, dst, tag, comm); + if (dst < comm_size) send_function(func, dst, tag, comm); } fac /= 2; } @@ -498,7 +554,7 @@ void mpi::broadcast_function(ComplexFunction &func, MPI_Comm comm) { } /** @brief Distribute rank zero function to all ranks */ -void mpi::broadcast_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { +template void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm) { /* use same strategy as a reduce, but in reverse order */ #ifdef MRCPP_HAS_MPI int comm_size, comm_rank; @@ -529,4 +585,15 @@ void mpi::broadcast_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm) { #endif } +template void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, MPI_Comm comm); +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, std::vector> &Phi, MPI_Comm comm); +template void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, MPI_Comm comm); +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, double> &tree, std::vector> &Phi, MPI_Comm comm); + +template void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, MPI_Comm comm); +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, std::vector> &Phi, MPI_Comm comm); +template void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, MPI_Comm comm); +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, ComplexDouble> &tree, std::vector> &Phi, MPI_Comm comm); + +} // namespace mpi } // namespace mrcpp diff --git a/src/utils/parallel.h b/src/utils/parallel.h index 78a3e2fd9..395cc1174 100644 --- a/src/utils/parallel.h +++ b/src/utils/parallel.h @@ -2,7 +2,7 @@ #include -#include "ComplexFunction.h" +#include "CompFunction.h" #include "mpi_utils.h" #include "trees/MultiResolutionAnalysis.h" #include @@ -10,8 +10,6 @@ // define a class for things that can be sent with MPI -template class MultiResolutionAnalysis; - using namespace Eigen; using IntVector = Eigen::VectorXi; @@ -41,22 +39,26 @@ void barrier(MPI_Comm comm); bool grand_master(); bool share_master(); -bool my_orb(int j); -bool my_orb(ComplexFunction orbj); + +bool my_func(int j); +bool my_func(const CompFunction<3> &func); +bool my_func(CompFunction<3> *func); // bool my_unique_orb(const Orbital &orb); -void free_foreign(MPI_FuncVector &Phi); +void free_foreign(CompFunctionVector &Phi); + +void send_function(const CompFunction<3> &func, int dst, int tag, MPI_Comm comm = mpi::comm_wrk); +void recv_function(CompFunction<3> &func, int src, int tag, MPI_Comm comm = mpi::comm_wrk); +void share_function(CompFunction<3> &func, int src, int tag, MPI_Comm comm); -void send_function(ComplexFunction &func, int dst, int tag, MPI_Comm comm = mpi::comm_wrk); -void recv_function(ComplexFunction &func, int src, int tag, MPI_Comm comm = mpi::comm_wrk); -void share_function(ComplexFunction &func, int src, int tag, MPI_Comm comm); +void reduce_function(double prec, CompFunction<3> &func, MPI_Comm comm); +void broadcast_function(CompFunction<3> &func, MPI_Comm comm); -void reduce_function(double prec, ComplexFunction &func, MPI_Comm comm); -void broadcast_function(ComplexFunction &func, MPI_Comm comm); +template void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm); +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, std::vector> &Phi, MPI_Comm comm); +template void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm); -void reduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm); -void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, std::vector &Phi, MPI_Comm comm); -void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3> &tree, MPI_Comm comm); +template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, std::vector> &Phi, MPI_Comm comm); void allreduce_vector(IntVector &vec, MPI_Comm comm); void allreduce_vector(DoubleVector &vec, MPI_Comm comm); diff --git a/src/utils/tree_utils.cpp b/src/utils/tree_utils.cpp index 523d3e263..333544f6e 100644 --- a/src/utils/tree_utils.cpp +++ b/src/utils/tree_utils.cpp @@ -44,7 +44,7 @@ namespace mrcpp { * Calculates the threshold that has to be met in the wavelet norm in order to * guarantee the precision in the function representation. Depends on the * square norm of the function and the requested relative accuracy. */ -template bool tree_utils::split_check(const MWNode &node, double prec, double split_fac, bool abs_prec) { +template bool tree_utils::split_check(const MWNode &node, double prec, double split_fac, bool abs_prec) { bool split = false; if (prec > 0.0) { double t_norm = 1.0; @@ -66,40 +66,40 @@ template bool tree_utils::split_check(const MWNode &node, double prec /** Traverse tree along the Hilbert path and find nodes of any rankId. * Returns one nodeVector for the whole tree. GenNodes disregarded. */ -template void tree_utils::make_node_table(MWTree &tree, MWNodeVector &table) { - TreeIterator it(tree, TopDown, Hilbert); +template void tree_utils::make_node_table(MWTree &tree, MWNodeVector &table) { + TreeIterator it(tree, TopDown, Hilbert); it.setReturnGenNodes(false); while (it.nextParent()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); if (node.getDepth() == 0) continue; table.push_back(&node); } it.init(tree); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); table.push_back(&node); } } /** Traverse tree along the Hilbert path and find nodes of any rankId. * Returns one nodeVector per scale. GenNodes disregarded. */ -template void tree_utils::make_node_table(MWTree &tree, std::vector> &table) { - TreeIterator it(tree, TopDown, Hilbert); +template void tree_utils::make_node_table(MWTree &tree, std::vector> &table) { + TreeIterator it(tree, TopDown, Hilbert); it.setReturnGenNodes(false); while (it.nextParent()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); if (node.getDepth() == 0) continue; int depth = node.getDepth() + tree.getNNegScales(); // Add one more element - if (depth + 1 > table.size()) table.push_back(MWNodeVector()); + if (depth + 1 > table.size()) table.push_back(MWNodeVector()); table[depth].push_back(&node); } it.init(tree); while (it.next()) { - MWNode &node = it.getNode(); + MWNode &node = it.getNode(); int depth = node.getDepth() + tree.getNNegScales(); // Add one more element - if (depth + 1 > table.size()) table.push_back(MWNodeVector()); + if (depth + 1 > table.size()) table.push_back(MWNodeVector()); table[depth].push_back(&node); } } @@ -110,7 +110,7 @@ template void tree_utils::make_node_table(MWTree &tree, std::vector void tree_utils::mw_transform(const MWTree &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite) { +template void tree_utils::mw_transform(const MWTree &tree, T *coeff_in, T *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite) { int operation = Reconstruction; int kp1 = tree.getKp1(); int kp1_d = tree.getKp1_d(); @@ -118,8 +118,8 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co int kp1_dm1 = math_utils::ipow(kp1, D - 1); const MWFilter &filter = tree.getMRA().getFilter(); double overwrite = 0.0; - double tmpcoeff[kp1_d * tDim]; - double tmpcoeff2[kp1_d * tDim]; + T tmpcoeff[kp1_d * tDim]; + T tmpcoeff2[kp1_d * tDim]; int ftlim = tDim; int ftlim2 = tDim; int ftlim3 = tDim; @@ -135,13 +135,13 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co int i = 0; int mask = 1; for (int gt = 0; gt < tDim; gt++) { - double *out = tmpcoeff + gt * kp1_d; + T *out = tmpcoeff + gt * kp1_d; for (int ft = 0; ft < ftlim; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = coeff_in + ft * kp1_d; + T *in = coeff_in + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -155,13 +155,13 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co i++; mask = 2; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = tmpcoeff2 + gt * kp1_d; + T *out = tmpcoeff2 + gt * kp1_d; for (int ft = 0; ft < ftlim2; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = tmpcoeff + ft * kp1_d; + T *in = tmpcoeff + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -178,13 +178,13 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co i++; mask = 4; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = coeff_out + gt * stride; // write right into children + T *out = coeff_out + gt * stride; // write right into children for (int ft = 0; ft < ftlim3; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = tmpcoeff2 + ft * kp1_d; + T *in = tmpcoeff2 + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -200,7 +200,7 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co if (D > 3) MSG_ABORT("D>3 NOT IMPLEMENTED for S_mwtransform"); if (D < 3) { - double *out; + T *out; if (D == 1) out = tmpcoeff; if (D == 2) out = tmpcoeff2; if (b_overwrite) { @@ -216,9 +216,9 @@ template void tree_utils::mw_transform(const MWTree &tree, double *co } // Specialized for D=3 below. -template void tree_utils::mw_transform_back(MWTree &tree, double *coeff_in, double *coeff_out, int stride) { - NOT_IMPLEMENTED_ABORT; -} +// template void tree_utils::mw_transform_back(MWTree &tree, double *coeff_in, double *coeff_out, int stride) { +// NOT_IMPLEMENTED_ABORT; +//} /** Make parent from children scaling coefficients * Other node info are not used/set @@ -226,7 +226,7 @@ template void tree_utils::mw_transform_back(MWTree &tree, double *coe * The output is read directly from the 8 children scaling coefficients. * NB: ASSUMES that the children coefficients are separated by Children_Stride! */ -template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff_in, double *coeff_out, int stride) { +template void tree_utils::mw_transform_back(MWTree<3, T> &tree, T *coeff_in, T *coeff_out, int stride) { int operation = Compression; int kp1 = tree.getKp1(); int kp1_d = tree.getKp1_d(); @@ -234,7 +234,7 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff int kp1_dm1 = math_utils::ipow(kp1, 2); const MWFilter &filter = tree.getMRA().getFilter(); double overwrite = 0.0; - double tmpcoeff[kp1_d * tDim]; + T tmpcoeff[kp1_d * tDim]; int ftlim = tDim; int ftlim2 = tDim; @@ -243,13 +243,13 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff int i = 0; int mask = 1; for (int gt = 0; gt < tDim; gt++) { - double *out = coeff_out + gt * kp1_d; + T *out = coeff_out + gt * kp1_d; for (int ft = 0; ft < ftlim; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = coeff_in + ft * stride; + T *in = coeff_in + ft * stride; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -262,13 +262,13 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff i++; mask = 2; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = tmpcoeff + gt * kp1_d; + T *out = tmpcoeff + gt * kp1_d; for (int ft = 0; ft < ftlim2; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = coeff_out + ft * kp1_d; + T *in = coeff_out + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -281,14 +281,14 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff i++; mask = 4; // 1 << i; for (int gt = 0; gt < tDim; gt++) { - double *out = coeff_out + gt * kp1_d; - // double *out = coeff_out + gt * N_coeff; + T *out = coeff_out + gt * kp1_d; + // T *out = coeff_out + gt * N_coeff; for (int ft = 0; ft < ftlim3; ft++) { // Operate in direction i only if the bits along other // directions are identical. The bit of the direction we // operate on determines the appropriate filter/operator if ((gt | mask) == (ft | mask)) { - double *in = tmpcoeff + ft * kp1_d; + T *in = tmpcoeff + ft * kp1_d; int filter_index = 2 * ((gt >> i) & 1) + ((ft >> i) & 1); const Eigen::MatrixXd &oper = filter.getSubFilter(filter_index, operation); @@ -300,24 +300,44 @@ template <> void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff } } -template bool tree_utils::split_check<1>(const MWNode<1> &node, double prec, double split_fac, bool abs_prec); -template bool tree_utils::split_check<2>(const MWNode<2> &node, double prec, double split_fac, bool abs_prec); -template bool tree_utils::split_check<3>(const MWNode<3> &node, double prec, double split_fac, bool abs_prec); +template void tree_utils::make_node_table<1, double>(MWTree<1, double> &tree, MWNodeVector<1, double> &table); +template void tree_utils::make_node_table<2, double>(MWTree<2, double> &tree, MWNodeVector<2, double> &table); +template void tree_utils::make_node_table<3, double>(MWTree<3, double> &tree, MWNodeVector<3, double> &table); + +template void tree_utils::make_node_table<1, double>(MWTree<1, double> &tree, std::vector> &table); +template void tree_utils::make_node_table<2, double>(MWTree<2, double> &tree, std::vector> &table); +template void tree_utils::make_node_table<3, double>(MWTree<3, double> &tree, std::vector> &table); + +template bool tree_utils::split_check<1, double>(const MWNode<1, double> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<2, double>(const MWNode<2, double> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<3, double>(const MWNode<3, double> &node, double prec, double split_fac, bool abs_prec); + +template void tree_utils::mw_transform<1, double>(const MWTree<1, double> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<2, double>(const MWTree<2, double> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<3, double>(const MWTree<3, double> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); + +// template void tree_utils::mw_transform_back<1, double>(MWTree<1, double> &tree, double *coeff_in, double *coeff_out, int stride); +// template void tree_utils::mw_transform_back<2, double>(MWTree<2, double> &tree, double *coeff_in, double *coeff_out, int stride); +template void tree_utils::mw_transform_back(MWTree<3, double> &tree, double *coeff_in, double *coeff_out, int stride); + +template void tree_utils::make_node_table<1, ComplexDouble>(MWTree<1, ComplexDouble> &tree, MWNodeVector<1, ComplexDouble> &table); +template void tree_utils::make_node_table<2, ComplexDouble>(MWTree<2, ComplexDouble> &tree, MWNodeVector<2, ComplexDouble> &table); +template void tree_utils::make_node_table<3, ComplexDouble>(MWTree<3, ComplexDouble> &tree, MWNodeVector<3, ComplexDouble> &table); -template void tree_utils::make_node_table<1>(MWTree<1> &tree, MWNodeVector<1> &table); -template void tree_utils::make_node_table<2>(MWTree<2> &tree, MWNodeVector<2> &table); -template void tree_utils::make_node_table<3>(MWTree<3> &tree, MWNodeVector<3> &table); +template void tree_utils::make_node_table<1, ComplexDouble>(MWTree<1, ComplexDouble> &tree, std::vector> &table); +template void tree_utils::make_node_table<2, ComplexDouble>(MWTree<2, ComplexDouble> &tree, std::vector> &table); +template void tree_utils::make_node_table<3, ComplexDouble>(MWTree<3, ComplexDouble> &tree, std::vector> &table); -template void tree_utils::make_node_table<1>(MWTree<1> &tree, std::vector> &table); -template void tree_utils::make_node_table<2>(MWTree<2> &tree, std::vector> &table); -template void tree_utils::make_node_table<3>(MWTree<3> &tree, std::vector> &table); +template bool tree_utils::split_check<1, ComplexDouble>(const MWNode<1, ComplexDouble> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<2, ComplexDouble>(const MWNode<2, ComplexDouble> &node, double prec, double split_fac, bool abs_prec); +template bool tree_utils::split_check<3, ComplexDouble>(const MWNode<3, ComplexDouble> &node, double prec, double split_fac, bool abs_prec); -template void tree_utils::mw_transform<1>(const MWTree<1> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); -template void tree_utils::mw_transform<2>(const MWTree<2> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); -template void tree_utils::mw_transform<3>(const MWTree<3> &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<1, ComplexDouble>(const MWTree<1, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<2, ComplexDouble>(const MWTree<2, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); +template void tree_utils::mw_transform<3, ComplexDouble>(const MWTree<3, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, bool readOnlyScaling, int stride, bool b_overwrite); -template void tree_utils::mw_transform_back<1>(MWTree<1> &tree, double *coeff_in, double *coeff_out, int stride); -template void tree_utils::mw_transform_back<2>(MWTree<2> &tree, double *coeff_in, double *coeff_out, int stride); -template void tree_utils::mw_transform_back<3>(MWTree<3> &tree, double *coeff_in, double *coeff_out, int stride); +// template void tree_utils::mw_transform_back<1, ComplexDouble>(MWTree<1, ComplexDouble &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, int stride); +// template void tree_utils::mw_transform_back<2, ComplexDouble>(MWTree<2, ComplexDouble &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, int stride); +template void tree_utils::mw_transform_back(MWTree<3, ComplexDouble> &tree, ComplexDouble *coeff_in, ComplexDouble *coeff_out, int stride); } // namespace mrcpp diff --git a/src/utils/tree_utils.h b/src/utils/tree_utils.h index 8f2c4220a..56c8c7d79 100644 --- a/src/utils/tree_utils.h +++ b/src/utils/tree_utils.h @@ -26,17 +26,19 @@ #pragma once #include "MRCPP/mrcpp_declarations.h" +#include "utils/math_utils.h" namespace mrcpp { namespace tree_utils { -template bool split_check(const MWNode &node, double prec, double split_fac, bool abs_prec); +template bool split_check(const MWNode &node, double prec, double split_fac, bool abs_prec); -template void make_node_table(MWTree &tree, MWNodeVector &table); -template void make_node_table(MWTree &tree, std::vector> &table); +template void make_node_table(MWTree &tree, MWNodeVector &table); +template void make_node_table(MWTree &tree, std::vector> &table); -template void mw_transform(const MWTree &tree, double *coeff_in, double *coeff_out, bool readOnlyScaling, int stride, bool overwrite = true); -template void mw_transform_back(MWTree &tree, double *coeff_in, double *coeff_out, int stride); +template void mw_transform(const MWTree &tree, T *coeff_in, T *coeff_out, bool readOnlyScaling, int stride, bool overwrite = true); +// template void mw_transform_back(MWTree &tree, T *coeff_in, T *coeff_out, int stride); +template void mw_transform_back(MWTree<3, T> &tree, T *coeff_in, T *coeff_out, int stride); } // namespace tree_utils } // namespace mrcpp diff --git a/tests/operators/derivative_operator.cpp b/tests/operators/derivative_operator.cpp index 5f4400d1b..b6d73fd8c 100644 --- a/tests/operators/derivative_operator.cpp +++ b/tests/operators/derivative_operator.cpp @@ -102,10 +102,10 @@ template void testDifferentiationABGV(double a, double b) { }; FunctionTree f_tree(*mra); - project(prec / 10, f_tree, f); + project(prec / 10, f_tree, f); FunctionTree df_tree(*mra); - project(prec / 10, df_tree, df); + project(prec / 10, df_tree, df); FunctionTree dg_tree(*mra); apply(dg_tree, diff, f_tree, 0); @@ -122,6 +122,48 @@ template void testDifferentiationABGV(double a, double b) { delete mra; } +/* trees are defined as complex trees */ +template void testDifferentiationCplxABGV(double a, double b) { + MultiResolutionAnalysis *mra = initializeMRA(); + + double prec = 1.0e-3; + ABGVOperator diff(*mra, a, b); + ComplexDouble s = {1.1, 1.3}; // NB: Complex + + Coord r_0; + for (auto &x : r_0) x = pi; + + auto f = [r_0, s](const Coord &r) { + double R = math_utils::calc_distance(r, r_0); + return std::exp(-R * R * s); + }; + + auto df = [r_0, s](const Coord &r) { // analytical derivative of f + double R = math_utils::calc_distance(r, r_0); + return -2.0 * s * std::exp(-R * R * s) * (r[0] - r_0[0]); + }; + + FunctionTree f_tree(*mra); + project(prec / 10, f_tree, f); + + FunctionTree df_tree(*mra); + project(prec / 10, df_tree, df); + + FunctionTree dg_tree(*mra); // MW derivative of f + apply(dg_tree, diff, f_tree, 0); + + FunctionTree err_tree(*mra); + add(-1.0, err_tree, {1.0, 0.0}, df_tree, {-1.0, 0.0}, dg_tree); // difference between analytical and MW derivative of f. + + double df_norm = std::sqrt(df_tree.getSquareNorm()); + double abs_err = std::sqrt(err_tree.getSquareNorm()); + double rel_err = abs_err / df_norm; + + REQUIRE(rel_err == Catch::Approx(0.0).margin(prec)); + + delete mra; +} + template void testDifferentiationPH(int order) { MultiResolutionAnalysis *mra = initializeMRA(); @@ -143,10 +185,10 @@ template void testDifferentiationPH(int order) { }; FunctionTree f_tree(*mra); - project(prec / 10, f_tree, f); + project(prec / 10, f_tree, f); FunctionTree df_tree(*mra); - project(prec / 10, df_tree, df); + project(prec / 10, df_tree, df); FunctionTree dg_tree(*mra); apply(dg_tree, diff, f_tree, 0); @@ -174,7 +216,7 @@ template void testDifferentiationPeriodicABGV(double a, double b) { FunctionTree g_tree(*mra); FunctionTree dg_tree(*mra); - project(prec, g_tree, g_func); + project(prec, g_tree, g_func); apply(dg_tree, diff, g_tree, 0); refine_grid(dg_tree, 1); // for accurate evalf @@ -202,7 +244,7 @@ template void testDifferentiationPeriodicPH(int order) { FunctionTree g_tree(*mra); FunctionTree dg_tree(*mra); - project(prec, g_tree, g_func); + project(prec, g_tree, g_func); apply(dg_tree, diff, g_tree, 0); refine_grid(dg_tree, 1); // for accurate evalf @@ -237,10 +279,10 @@ template void testDifferentiationBS(int order) { }; FunctionTree f_tree(*mra); - project(prec / 10, f_tree, f); + project(prec / 10, f_tree, f); FunctionTree df_tree(*mra); - project(prec / 10, df_tree, df); + project(prec / 10, df_tree, df); FunctionTree dg_tree(*mra); apply(dg_tree, diff, f_tree, 0); @@ -259,56 +301,117 @@ template void testDifferentiationBS(int order) { TEST_CASE("ABGV differentiantion central difference", "[derivative_operator], [central_difference]") { // 0.5,0.5 specifies central difference - SECTION("1D derivative test") { testDifferentiationABGV<1>(0.5, 0.5); } - SECTION("2D derivative test") { testDifferentiationABGV<2>(0.5, 0.5); } - SECTION("3D derivative test") { testDifferentiationABGV<3>(0.5, 0.5); } + SECTION("1D derivative test") { + testDifferentiationABGV<1>(0.5, 0.5); + } + SECTION("2D derivative test") { + testDifferentiationABGV<2>(0.5, 0.5); + } + SECTION("3D derivative test") { + testDifferentiationABGV<3>(0.5, 0.5); + } } TEST_CASE("ABGV differentiantion center difference", "[derivative_operator], [center_difference]") { // 0,0 specifies center difference - SECTION("1D derivative test") { testDifferentiationABGV<1>(0, 0); } - SECTION("2D derivative test") { testDifferentiationABGV<2>(0, 0); } - SECTION("3D derivative test") { testDifferentiationABGV<3>(0, 0); } + SECTION("1D derivative test") { + testDifferentiationABGV<1>(0, 0); + } + SECTION("2D derivative test") { + testDifferentiationABGV<2>(0, 0); + } + SECTION("3D derivative test") { + testDifferentiationABGV<3>(0, 0); + } +} + +TEST_CASE("ABGV differentiantion of Complex function", "[derivative_operator], [Complex]") { + // 0.5,0.5 specifies central difference + SECTION("1D derivative test") { + testDifferentiationCplxABGV<1>(0.5, 0.5); + } + SECTION("2D derivative test") { + testDifferentiationCplxABGV<2>(0.5, 0.5); + } + SECTION("3D derivative test") { + testDifferentiationCplxABGV<3>(0.5, 0.5); + } } TEST_CASE("PH differentiantion first order", "[derivative_operator], [PH_first_order]") { - SECTION("1D derivative test") { testDifferentiationPH<1>(1); } - SECTION("2D derivative test") { testDifferentiationPH<2>(1); } - SECTION("3D derivative test") { testDifferentiationPH<3>(1); } + SECTION("1D derivative test") { + testDifferentiationPH<1>(1); + } + SECTION("2D derivative test") { + testDifferentiationPH<2>(1); + } + SECTION("3D derivative test") { + testDifferentiationPH<3>(1); + } } TEST_CASE("PH differentiantion second order", "[derivative_operator], [PH_second_order]") { - SECTION("1D second order derivative test") { testDifferentiationPH<1>(2); } - SECTION("2D second order derivative test") { testDifferentiationPH<2>(2); } - SECTION("3D second order derivative test") { testDifferentiationPH<3>(2); } + SECTION("1D second order derivative test") { + testDifferentiationPH<1>(2); + } + SECTION("2D second order derivative test") { + testDifferentiationPH<2>(2); + } + SECTION("3D second order derivative test") { + testDifferentiationPH<3>(2); + } } TEST_CASE("Periodic ABGV differentiantion central difference", "[periodic_derivative],[derivative_operator], [central_difference], [ABGV_periodic]") { // 0.5,0.5 specifies central difference - SECTION("3D periodic derivative test") { testDifferentiationPeriodicABGV<3>(0.5, 0.5); } + SECTION("3D periodic derivative test") { + testDifferentiationPeriodicABGV<3>(0.5, 0.5); + } } TEST_CASE("Periodic PH differentiantion", "[periodic_derivative], [derivative_operator], [PH_periodic]") { - SECTION("3D first order periodic derivative test") { testDifferentiationPeriodicPH<3>(1); } - SECTION("3D first order periodic derivative test") { testDifferentiationPeriodicPH<3>(2); } + SECTION("3D first order periodic derivative test") { + testDifferentiationPeriodicPH<3>(1); + } + SECTION("3D first order periodic derivative test") { + testDifferentiationPeriodicPH<3>(2); + } } TEST_CASE("BS differentiantion first order", "[derivative_operator], [BS_first_order]") { - SECTION("1D derivative test") { testDifferentiationBS<1>(1); } - SECTION("2D derivative test") { testDifferentiationBS<2>(1); } - SECTION("3D derivative test") { testDifferentiationBS<3>(1); } + SECTION("1D derivative test") { + testDifferentiationBS<1>(1); + } + SECTION("2D derivative test") { + testDifferentiationBS<2>(1); + } + SECTION("3D derivative test") { + testDifferentiationBS<3>(1); + } } TEST_CASE("BS differentiantion second order", "[derivative_operator], [BS_second_order]") { - SECTION("1D derivative test") { testDifferentiationBS<1>(2); } - SECTION("2D derivative test") { testDifferentiationBS<2>(2); } - SECTION("3D derivative test") { testDifferentiationBS<3>(2); } + SECTION("1D derivative test") { + testDifferentiationBS<1>(2); + } + SECTION("2D derivative test") { + testDifferentiationBS<2>(2); + } + SECTION("3D derivative test") { + testDifferentiationBS<3>(2); + } } TEST_CASE("BS differentiantion third order", "[derivative_operator], [BS_third_order]") { - SECTION("1D derivative test") { testDifferentiationBS<1>(3); } - SECTION("2D derivative test") { testDifferentiationBS<2>(3); } - SECTION("3D derivative test") { testDifferentiationBS<3>(3); } + SECTION("1D derivative test") { + testDifferentiationBS<1>(3); + } + SECTION("2D derivative test") { + testDifferentiationBS<2>(3); + } + SECTION("3D derivative test") { + testDifferentiationBS<3>(3); + } } TEST_CASE("Gradient operator", "[derivative_operator], [gradient_operator]") { @@ -335,7 +438,7 @@ TEST_CASE("Gradient operator", "[derivative_operator], [gradient_operator]") { }; FunctionTree<3> f_tree(*mra); - project<3>(prec, f_tree, f); + project<3, double>(prec, f_tree, f); auto grad_f = gradient(diff, f_tree); REQUIRE(grad_f.size() == 3); @@ -373,7 +476,7 @@ TEST_CASE("Divergence operator", "[derivative_operator], [divergence_operator]") }; FunctionTree<3> f_tree(*mra); - project<3>(prec, f_tree, f); + project<3, double>(prec, f_tree, f); FunctionTreeVector<3> f_vec; f_vec.push_back(std::make_tuple(1.0, &f_tree)); f_vec.push_back(std::make_tuple(2.0, &f_tree)); diff --git a/tests/operators/heat_evolution_operator.cpp b/tests/operators/heat_evolution_operator.cpp index fb6fd7649..09ddcdcbb 100644 --- a/tests/operators/heat_evolution_operator.cpp +++ b/tests/operators/heat_evolution_operator.cpp @@ -41,7 +41,6 @@ #include "trees/BandWidth.h" #include "operators/HeatOperator.h" #include "functions/special_functions.h" -#include "treebuilders/complex_apply.h" #include "treebuilders/add.h" //using namespace mrcpp; @@ -56,7 +55,7 @@ TEST_CASE("Apply heat evolution operator", "[apply_heat_evolution], [heat_evolut const auto order = 5; const auto prec = 1.0e-8; - + // Time moment: double delta_t = 0.0005; @@ -67,7 +66,7 @@ TEST_CASE("Apply heat evolution operator", "[apply_heat_evolution], [heat_evolut // Time evolution operatror Exp(delta_t) mrcpp::HeatOperator<1> H(MRA, delta_t, prec); - + // Analytical solution parameters for psi(x, t) double sigma = 0.001; double x0 = 0.5; @@ -81,7 +80,7 @@ TEST_CASE("Apply heat evolution operator", "[apply_heat_evolution], [heat_evolut mrcpp::project<1>(prec, f_tree, f); mrcpp::FunctionTree<1> g_tree(MRA); mrcpp::project<1>(prec, g_tree, g); - + // Apply operator H = Exp(delta_t) f(x) mrcpp::FunctionTree<1> output(MRA); mrcpp::apply(prec, output, H, f_tree); @@ -97,4 +96,4 @@ TEST_CASE("Apply heat evolution operator", "[apply_heat_evolution], [heat_evolut } -} // namespace schrodinger_evolution_operator \ No newline at end of file +} // namespace schrodinger_evolution_operator diff --git a/tests/operators/helmholtz_operator.cpp b/tests/operators/helmholtz_operator.cpp index 7a0dc0243..8f570691d 100644 --- a/tests/operators/helmholtz_operator.cpp +++ b/tests/operators/helmholtz_operator.cpp @@ -169,14 +169,14 @@ TEST_CASE("Apply Helmholtz' operator", "[apply_helmholtz], [helmholtz_operator], return R_0 * Y_00; }; FunctionTree<3> psi_n(MRA); - project<3>(proj_prec, psi_n, hFunc); + project<3, double>(proj_prec, psi_n, hFunc); auto f = [Z](const Coord<3> &r) -> double { double x = std::sqrt(r[0] * r[0] + r[1] * r[1] + r[2] * r[2]); return -Z / x; }; FunctionTree<3> V(MRA); - project<3>(proj_prec, V, f); + project<3, double>(proj_prec, V, f); FunctionTree<3> Vpsi(MRA); copy_grid(Vpsi, psi_n); @@ -222,7 +222,7 @@ TEST_CASE("Apply Periodic Helmholtz' operator", "[apply_periodic_helmholtz], [he auto source = [mu](const mrcpp::Coord<3> &r) { return 3.0 * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi) + mu * mu * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi); }; FunctionTree<3> source_tree(MRA); - project<3>(proj_prec, source_tree, source); + project<3, double>(proj_prec, source_tree, source); FunctionTree<3> sol_tree(MRA); FunctionTree<3> in_tree(MRA); @@ -265,7 +265,7 @@ TEST_CASE("Apply negative scale Helmholtz' operator", "[apply_periodic_helmholtz auto source = [mu](const mrcpp::Coord<3> &r) { return 3.0 * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi) + mu * mu * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi); }; FunctionTree<3> source_tree(MRA); - project<3>(proj_prec, source_tree, source); + project<3, double>(proj_prec, source_tree, source); FunctionTree<3> sol_tree(MRA); @@ -274,4 +274,5 @@ TEST_CASE("Apply negative scale Helmholtz' operator", "[apply_periodic_helmholtz REQUIRE(sol_tree.evalf({0.0, 0.0, 0.0}) == Catch::Approx(1.0).epsilon(apply_prec)); REQUIRE(sol_tree.evalf({pi, 0.0, 0.0}) == Catch::Approx(-1.0).epsilon(apply_prec)); } + } // namespace helmholtz_operator diff --git a/tests/operators/poisson_operator.cpp b/tests/operators/poisson_operator.cpp index df841a625..23bb22a06 100644 --- a/tests/operators/poisson_operator.cpp +++ b/tests/operators/poisson_operator.cpp @@ -187,7 +187,7 @@ TEST_CASE("Apply Periodic Poisson' operator", "[apply_periodic_Poisson], [poisso auto source = [](const mrcpp::Coord<3> &r) { return 3.0 * cos(r[0]) * cos(r[1]) * cos(r[2]) / (4.0 * pi); }; FunctionTree<3> source_tree(MRA); - project<3>(proj_prec, source_tree, source); + project<3, double>(proj_prec, source_tree, source); FunctionTree<3> sol_tree(MRA); diff --git a/tests/operators/schrodinger_evolution_operator.cpp b/tests/operators/schrodinger_evolution_operator.cpp index c986ec756..e6e416f09 100644 --- a/tests/operators/schrodinger_evolution_operator.cpp +++ b/tests/operators/schrodinger_evolution_operator.cpp @@ -28,17 +28,15 @@ #include "factory_functions.h" #include "functions/GaussFunc.h" +#include "functions/special_functions.h" #include "operators/MWOperator.h" -#include "treebuilders/project.h" #include "operators/TimeEvolutionOperator.h" -#include "functions/special_functions.h" -#include "treebuilders/complex_apply.h" #include "treebuilders/add.h" - +#include "treebuilders/complex_apply.h" +#include "treebuilders/project.h" namespace schrodinger_evolution_operator { - TEST_CASE("Apply Schrodinger's evolution operator", "[apply_schrodinger_evolution], [schrodinger_evolution_operator], [mw_operator]") { const auto min_scale = 0; const auto max_depth = 25; @@ -46,13 +44,13 @@ TEST_CASE("Apply Schrodinger's evolution operator", "[apply_schrodinger_evolutio const auto order = 4; const auto prec = 1.0e-7; - int finest_scale = 7; //for time evolution operator construction (not recommended to use more than 10) - //int max_Jpower = 20; //the amount of J integrals to be used in construction (20 should be enough) + int finest_scale = 7; // for time evolution operator construction (not recommended to use more than 10) + // int max_Jpower = 20; //the amount of J integrals to be used in construction (20 should be enough) // Time moments: - double t1 = 0.001; //initial time moment (not recommended to use more than 0.001) - double delta_t = 0.03; //time step (not recommended to use less than 0.001) - double t2 = delta_t + t1; //final time moment + double t1 = 0.001; // initial time moment (not recommended to use more than 0.001) + double delta_t = 0.03; // time step (not recommended to use less than 0.001) + double t2 = delta_t + t1; // final time moment // Initialize world in the unit cube [0,1] auto basis = mrcpp::LegendreBasis(order); @@ -62,72 +60,59 @@ TEST_CASE("Apply Schrodinger's evolution operator", "[apply_schrodinger_evolutio // Time evolution operatror Exp(delta_t) mrcpp::TimeEvolutionOperator<1> ReExp(MRA, prec, delta_t, finest_scale, false); mrcpp::TimeEvolutionOperator<1> ImExp(MRA, prec, delta_t, finest_scale, true); - + // Analytical solution parameters for psi(x, t) double sigma = 0.001; double x0 = 0.5; // Functions f(x) = psi(x, t1) and g(x) = psi(x, t2) - auto Re_f = [sigma, x0, t=t1](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); - }; - auto Im_f = [sigma, x0, t=t1](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); - }; - auto Re_g = [sigma, x0, t=t2](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); - }; - auto Im_g = [sigma, x0, t=t2](const mrcpp::Coord<1> &r) -> double - { - return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); - }; + auto Re_f = [sigma, x0, t = t1](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); }; + auto Im_f = [sigma, x0, t = t1](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); }; + auto Re_g = [sigma, x0, t = t2](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).real(); }; + auto Im_g = [sigma, x0, t = t2](const mrcpp::Coord<1> &r) -> double { return mrcpp::free_particle_analytical_solution(r[0], x0, t, sigma).imag(); }; // Projecting functions mrcpp::FunctionTree<1> Re_f_tree(MRA); - mrcpp::project<1>(prec, Re_f_tree, Re_f); + mrcpp::project<1, double>(prec, Re_f_tree, Re_f); mrcpp::FunctionTree<1> Im_f_tree(MRA); - mrcpp::project<1>(prec, Im_f_tree, Im_f); + mrcpp::project<1, double>(prec, Im_f_tree, Im_f); mrcpp::FunctionTree<1> Re_g_tree(MRA); - mrcpp::project<1>(prec, Re_g_tree, Re_g); + mrcpp::project<1, double>(prec, Re_g_tree, Re_g); mrcpp::FunctionTree<1> Im_g_tree(MRA); - mrcpp::project<1>(prec, Im_g_tree, Im_g); + mrcpp::project<1, double>(prec, Im_g_tree, Im_g); // Output function trees mrcpp::FunctionTree<1> Re_fout_tree(MRA); mrcpp::FunctionTree<1> Im_fout_tree(MRA); - + // Complex objects for use in apply() - mrcpp::ComplexObject< mrcpp::ConvolutionOperator<1> > E(ReExp, ImExp); - mrcpp::ComplexObject< mrcpp::FunctionTree<1> > input(Re_f_tree, Im_f_tree); - mrcpp::ComplexObject< mrcpp::FunctionTree<1> > output(Re_fout_tree, Im_fout_tree); + mrcpp::ComplexObject> E(ReExp, ImExp); + mrcpp::ComplexObject> input(Re_f_tree, Im_f_tree); + mrcpp::ComplexObject> output(Re_fout_tree, Im_fout_tree); // Apply operator Exp(delta_t) f(x) mrcpp::apply(prec, output, E, input); - + // Check g(x) = Exp(delta_t) f(x) - mrcpp::FunctionTree<1> Re_error(MRA); // = Re_fout_tree - Re_g_tree - mrcpp::FunctionTree<1> Im_error(MRA); // = Im_fout_tree - Im_g_tree - + mrcpp::FunctionTree<1> Re_error(MRA); // = Re_fout_tree - Re_g_tree + mrcpp::FunctionTree<1> Im_error(MRA); // = Im_fout_tree - Im_g_tree + // Re_error = Re_fout_tree - Re_g_tree mrcpp::add(prec, Re_error, 1.0, Re_fout_tree, -1.0, Re_g_tree); - auto Re_sq_norm = Re_error.getSquareNorm(); //1.7e-16 - + auto Re_sq_norm = Re_error.getSquareNorm(); // 1.7e-16 + // Im_error = Im_fout_tree - Im_g_tree mrcpp::add(prec, Im_error, 1.0, Im_fout_tree, -1.0, Im_g_tree); - auto Im_sq_norm = Im_error.getSquareNorm(); //1.7e-17 - - double tolerance = prec * prec / 50.0; //2.0e-16 - - //std::cout << "Re_sq_norm = " << Re_sq_norm << std::endl; - //std::cout << "Im_sq_norm = " << Im_sq_norm << std::endl; - //std::cout << "tolerance = " << tolerance << std::endl; - + auto Im_sq_norm = Im_error.getSquareNorm(); // 1.7e-17 + + double tolerance = prec * prec / 50.0; // 2.0e-16 + + // std::cout << "Re_sq_norm = " << Re_sq_norm << std::endl; + // std::cout << "Im_sq_norm = " << Im_sq_norm << std::endl; + // std::cout << "tolerance = " << tolerance << std::endl; + REQUIRE(Re_sq_norm == Catch::Approx(0.0).margin(tolerance)); REQUIRE(Im_sq_norm == Catch::Approx(0.0).margin(tolerance)); } - -} // namespace schrodinger_evolution_operator \ No newline at end of file +} // namespace schrodinger_evolution_operator diff --git a/tests/treebuilders/map.cpp b/tests/treebuilders/map.cpp index 0745db2e6..6be2153f0 100644 --- a/tests/treebuilders/map.cpp +++ b/tests/treebuilders/map.cpp @@ -40,9 +40,15 @@ namespace mapping { template void testMapping(); SCENARIO("Map a MW tree", "[map], [tree_builder]") { - GIVEN("One MW functions in 1D") { testMapping<1>(); } - GIVEN("One MW functions in 2D") { testMapping<2>(); } - GIVEN("One MW functions in 3D") { testMapping<3>(); } + GIVEN("One MW functions in 1D") { + testMapping<1>(); + } + GIVEN("One MW functions in 2D") { + testMapping<2>(); + } + GIVEN("One MW functions in 3D") { + testMapping<3>(); + } } template void testMapping() { @@ -77,7 +83,7 @@ template void testMapping() { const double inp_int = inp_tree.integrate(); const double inp_norm = inp_tree.getSquareNorm(); - auto fmap = [](double val) { return val * val; }; + FMap fmap = [](double val) { return val * val; }; WHEN("the function is mapped") { FunctionTree out_tree(*mra); diff --git a/tests/treebuilders/multiplication.cpp b/tests/treebuilders/multiplication.cpp index 5fa2178e1..c8d35e723 100644 --- a/tests/treebuilders/multiplication.cpp +++ b/tests/treebuilders/multiplication.cpp @@ -41,9 +41,15 @@ template void testMultiplication(); template void testSquare(); SCENARIO("Multiplying MW trees", "[multiplication], [tree_builder]") { - GIVEN("Two MW functions in 1D") { testMultiplication<1>(); } - GIVEN("Two MW functions in 2D") { testMultiplication<2>(); } - GIVEN("Two MW functions in 3D") { testMultiplication<3>(); } + GIVEN("Two MW functions in 1D") { + testMultiplication<1>(); + } + GIVEN("Two MW functions in 2D") { + testMultiplication<2>(); + } + GIVEN("Two MW functions in 3D") { + testMultiplication<3>(); + } } template void testMultiplication() { @@ -116,9 +122,15 @@ template void testMultiplication() { } SCENARIO("Squaring MW trees", "[square], [tree_builder]") { - GIVEN("A MW function in 1D") { testSquare<1>(); } - GIVEN("A MW function in 2D") { testSquare<2>(); } - GIVEN("A MW function in 3D") { testSquare<3>(); } + GIVEN("A MW function in 1D") { + testSquare<1>(); + } + GIVEN("A MW function in 2D") { + testSquare<2>(); + } + GIVEN("A MW function in 3D") { + testSquare<3>(); + } } template void testSquare() { @@ -226,9 +238,9 @@ TEST_CASE("Dot product FunctionTreeVectors", "[multiplication], [tree_vector_dot FunctionTree<3> fy_tree(*mra); FunctionTree<3> fz_tree(*mra); - project<3>(prec, fx_tree, fx); - project<3>(prec, fy_tree, fy); - project<3>(prec, fz_tree, fz); + project<3, double>(prec, fx_tree, fx); + project<3, double>(prec, fy_tree, fy); + project<3, double>(prec, fz_tree, fz); FunctionTreeVector<3> vec_a; vec_a.push_back(std::make_tuple(1.0, &fx_tree));