diff --git a/src/core/CrossCorrelation.h b/src/core/CrossCorrelation.h index aa674b4ca..b12fa8b58 100644 --- a/src/core/CrossCorrelation.h +++ b/src/core/CrossCorrelation.h @@ -34,29 +34,134 @@ namespace mrcpp { +/** + * @class CrossCorrelation + * @brief Container/loader for multiwavelet cross-correlation coefficient tables. + * + * This class encapsulates the left/right cross-correlation matrices associated + * with a chosen multiwavelet filter family and polynomial order. + * + * • The filter "family" is identified by an integer @c type + * (e.g., Interpolatory or Legendre; concrete codes are defined elsewhere + * and validated in the implementation). + * + * • The polynomial @c order is k ≥ 1. We use K = k + 1 for dimensions. + * + * • Two dense matrices are held: + * Left ∈ ℝ^{(K·K) × (2K)}, + * Right ∈ ℝ^{(K·K) × (2K)}. + * Each row corresponds to a flattened (i,j) pair with i,j∈{0..K−1}; + * each row stores a 2K-wide correlation stencil. + * + * Objects can be constructed by loading the binary coefficient files from disk + * (constructor #1) or by adopting matrices already residing in memory + * (constructor #2). Accessors expose the type/order and const references to + * the matrices; there are no mutating public methods by design. + * + * Invariants (enforced in the implementation): + * - 1 ≤ order ≤ MaxOrder + * - Left.cols() == Right.cols() == 2K where K = order + 1 + * - Left.rows() == Right.rows() == K*K + * + * Thread-safety: the class is a simple value holder once constructed. + * Concurrent reads are safe; concurrent writes are not supported. + */ class CrossCorrelation final { public: + /** + * @brief Construct by loading coefficient tables from the filter library. + * + * The library path is discovered internally (see details::find_filters()). + * Files are chosen based on @p type and @p k and read into #Left/#Right. + * + * @param k Polynomial order (k ≥ 1). Sets K = k + 1 for dimensions. + * @param t Filter family/type code (e.g., Interpol, Legendre). + * + * @throws abort/error (via MRCPP messaging) on invalid @p k/@p t or if the + * required binary files cannot be opened. + */ CrossCorrelation(int k, int t); + + /** + * @brief Construct from in-memory matrices (no file I/O). + * + * The order is inferred from the column count: 2K columns ⇒ order = K−1. + * The two matrices must be shape-compatible (same size). + * + * @param t Filter family/type code. + * @param ldata Left matrix, size (K*K) × (2K). + * @param rdata Right matrix, size (K*K) × (2K). + * + * @throws abort/error if dimensions are inconsistent or the type is invalid. + */ CrossCorrelation(int t, const Eigen::MatrixXd &ldata, const Eigen::MatrixXd &rdata); + /** @return The filter family/type code associated with this object. */ int getType() const { return this->type; } + + /** @return The polynomial order k (so K = k + 1). */ int getOrder() const { return this->order; } + + /** @return Const reference to the left cross-correlation matrix. */ const Eigen::MatrixXd &getLMatrix() const { return this->Left; } + + /** @return Const reference to the right cross-correlation matrix. */ const Eigen::MatrixXd &getRMatrix() const { return this->Right; } protected: + /** + * @brief Filter family/type code. + * + * The meaning of this integer is validated against known families + * (e.g., Interpolatory / Legendre) in the implementation. Kept as @c int + * here to avoid header dependencies on the specific enum. + */ int type; + + /** + * @brief Polynomial order k (k ≥ 1; K = k + 1). + * + * Controls the matrix dimensions: + * rows = K*K, cols = 2K. + */ int order; + /** + * @brief Left cross-correlation coefficient matrix. + * Size: (K*K) × (2K), where K = order + 1. + */ Eigen::MatrixXd Left; + + /** + * @brief Right cross-correlation coefficient matrix. + * Size: (K*K) × (2K), where K = order + 1. + */ Eigen::MatrixXd Right; private: + /** + * @brief Compose on-disk file paths for the left/right tables. + * + * Uses the discovered filter library root @p lib and the current + * @c type / @c order to set #L_path and #R_path to the expected filenames. + * (Naming convention is family-specific; see implementation.) + */ void setCCCPaths(const std::string &lib); + + /** + * @brief Read the binary coefficient tables into #Left/#Right. + * + * Expects two files (left/right). Populates matrices with dimensions + * (K*K) × (2K). Very small magnitudes may be zeroed for numerical + * cleanliness (implementation detail). + */ void readCCCBin(); + /** @brief Full path to the left coefficient file (resolved at runtime). */ std::string L_path; + + /** @brief Full path to the right coefficient file (resolved at runtime). */ std::string R_path; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/CrossCorrelationCache.h b/src/core/CrossCorrelationCache.h index 198d830b2..4deafb629 100644 --- a/src/core/CrossCorrelationCache.h +++ b/src/core/CrossCorrelationCache.h @@ -33,29 +33,110 @@ namespace mrcpp { +/** + * @def getCrossCorrelationCache(T, X) + * @brief Convenience macro to obtain a named reference to the singleton cache. + * + * Expands to: + * CrossCorrelationCache &X = CrossCorrelationCache::getInstance() + * + * Example: + * getCrossCorrelationCache(Interpol, ccc); + * const auto& L = ccc.getLMatrix(order); + */ #define getCrossCorrelationCache(T, X) CrossCorrelationCache &X = CrossCorrelationCache::getInstance() +/** + * @class CrossCorrelationCache + * @brief Thread-safe cache for @ref CrossCorrelation objects, keyed by order. + * + * This cache avoids repeatedly loading the (potentially large) left/right + * cross-correlation matrices from disk. One cache instance exists per filter + * family, realized as a template parameter @p T (e.g., Interpol or Legendre). + * + * Design notes: + * - Singleton pattern (Meyers singleton) per @p T via getInstance(). + * - Inherits from @ref ObjectCache, which provides the + * generic cache interface (load/get/hasId etc.). + * - Actual loading and synchronization details are implemented in the + * corresponding .cpp; OpenMP locks guard first-time insertions. + * + * @tparam T Filter family tag (int constant), e.g. Interpol or Legendre. + */ template class CrossCorrelationCache final : public ObjectCache { public: + /** + * @brief Access the unique cache instance for the template family @p T. + * + * Uses a function-local static (Meyers singleton). Thread-safe in C++11+. + */ static CrossCorrelationCache &getInstance() { static CrossCorrelationCache theCrossCorrelationCache; return theCrossCorrelationCache; } + + /** + * @brief Ensure that the entry for @p order is present in the cache. + * + * If absent, constructs a new @ref CrossCorrelation(order, type) and + * inserts it. See .cpp for locking and memory accounting. + */ void load(int order) override; + + /** + * @brief Retrieve the cached @ref CrossCorrelation for @p order. + * + * Loads on demand if missing. Returns a reference owned by the cache. + */ CrossCorrelation &get(int order) override; + /** + * @brief Convenience accessor for the Left matrix of a given order. + * + * Triggers lazy load if needed, then returns a const reference. + */ const Eigen::MatrixXd &getLMatrix(int order); + + /** + * @brief Convenience accessor for the Right matrix of a given order. + * + * Triggers lazy load if needed, then returns a const reference. + */ const Eigen::MatrixXd &getRMatrix(int order); + /** + * @brief Filter family/type code associated with this cache. + * + * Set in the private constructor based on the template parameter @p T. + * (E.g., Interpol or Legendre.) + */ int getType() const { return this->type; } protected: + /** + * @brief Filter family/type code (matches template parameter @p T). + */ int type; + + /** + * @brief Base path to filter/correlation library on disk. + * + * Reserved for potential use by loaders. Actual path resolution is + * currently handled inside CrossCorrelation (see details::find_filters()). + */ std::string libPath; ///< Base path to filter library + private: + /** + * @brief Private constructor enforces the singleton pattern. + * + * Initializes @ref type based on T; see .cpp for validation. + */ CrossCorrelationCache(); + + // Non-copyable / non-assignable — keeps the singleton unique. CrossCorrelationCache(CrossCorrelationCache const &ccc) = delete; CrossCorrelationCache &operator=(CrossCorrelationCache const &ccc) = delete; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/FilterCache.h b/src/core/FilterCache.h index 89127ef06..80c0220b3 100644 --- a/src/core/FilterCache.h +++ b/src/core/FilterCache.h @@ -24,15 +24,30 @@ */ /* + * Overview + * -------- + * FilterCache provides a process-wide cache for multiwavelet filter banks + * (MWFilter) so that the same filter for a given polynomial order is created + * and loaded exactly once and then reused. This avoids repeated I/O and setup. * - * \breif FilterCache is a static class taking care of loading and - * unloading MultiWavelet filters, and their tensor counter parts. + * Design highlights: + * - There are different *families* of filters (e.g. Legendre vs Interpolating). + * We want caches for both, alive simultaneously. To achieve this, the + * concrete cache is a class template FilterCache, where T encodes the + * family. Each T gets its own singleton instance. * - * All data in FilterCache is static, and thus shared amongst all - * instance objects. The type of filter, Legendre or Interpolating is - * determined by a template variable so that both types of filters can - * co-exist. + * - The cache is keyed by the *order* (polynomial order k). Loading an entry + * constructs MWFilter(order, type) and stores it internally for reuse. * + * - Thread-safety and the actual load/get logic are implemented in the .cpp + * using OpenMP locks (MRCPP_SET_OMP_LOCK / MRCPP_UNSET_OMP_LOCK). + * + * About this header: + * - Declares a tiny abstract façade (BaseFilterCache) to allow use via a + * non-templated base pointer/reference when the family is not known at + * compile time. + * - Declares the templated FilterCache singleton with the minimal API: + * load(order), get(order), and getFilterMatrix(order). */ #pragma once @@ -45,38 +60,106 @@ namespace mrcpp { +/** + * @def getFilterCache(T, X) + * @brief Create a named reference @p X bound to the singleton FilterCache. + * + * Usage: + * getFilterCache(Interpol, cache); + * const auto& H = cache.getFilterMatrix(order); + * + * @def getLegendreFilterCache(X) + * @brief Convenience macro for FilterCache. + * + * @def getInterpolatingFilterCache(X) + * @brief Convenience macro for FilterCache. + */ #define getFilterCache(T, X) FilterCache &X = FilterCache::getInstance() #define getLegendreFilterCache(X) FilterCache &X = FilterCache::getInstance() #define getInterpolatingFilterCache(X) FilterCache &X = FilterCache::getInstance() -/** This class is an abstract base class for the various filter caches. - * It's needed in order to be able to use the actual filter caches - * without reference to the actual filter types */ +/** + * @class BaseFilterCache + * @brief Abstract façade over the templated filter cache. + * + * Rationale: + * Callers that do not know the filter family T at compile time can still + * interact with a cache through this non-templated interface. Concrete + * implementations are provided by FilterCache. + * + * Notes: + * - Inherits from ObjectCache to reuse generic cache plumbing. + * - Pure virtual methods delegate to the concrete implementation in + * FilterCache. + */ class BaseFilterCache : public ObjectCache { public: + /// Ensure the filter for @p order exists in the cache (lazy load if needed). void load(int order) override = 0; + + /// Retrieve the cached MWFilter for @p order (loads it on demand). MWFilter &get(int order) override = 0; + + /// Convenience accessor: return the filter matrix (const) for @p order. virtual const Eigen::MatrixXd &getFilterMatrix(int order) = 0; }; +/** + * @class FilterCache + * @tparam T Integer tag selecting the filter family (e.g., Interpol, Legendre). + * @brief Singleton cache of MWFilter objects for a specific filter family. + * + * Key properties: + * - One singleton instance per family T (Meyers singleton via getInstance()). + * - API mirrors BaseFilterCache and ObjectCache. + * - The constructor is private to enforce the singleton pattern. + * - Copy/assignment are deleted to prevent accidental duplication. + * + * Thread-safety: + * - The .cpp guards first-time loads with OpenMP locks. + * - Reads after an entry exists are lock-free through the base cache API. + */ template class FilterCache final : public BaseFilterCache { public: + /** + * @brief Access the singleton cache for the template family T. + * + * The instance is created on first use and lives until program exit. + */ static FilterCache &getInstance() { static FilterCache theFilterCache; return theFilterCache; } + /// Ensure entry for @p order exists; loads it if missing (see .cpp). void load(int order) override; + + /// Retrieve the MWFilter for @p order; loads it if missing. MWFilter &get(int order) override; + + /// Convenience accessor returning a const reference to the filter matrix. const Eigen::MatrixXd &getFilterMatrix(int order) override; protected: + /** + * @brief Runtime family/type code corresponding to template parameter T. + * + * Initialized in the private constructor; used to construct MWFilter(order, type). + */ int type; private: + /** + * @brief Private constructor enforces the singleton pattern. + * + * Sets #type based on T and performs any minimal family-specific setup. + * (Validation happens in the .cpp.) + */ FilterCache(); + + // Non-copyable and non-assignable to maintain single instance semantics. FilterCache(FilterCache const &fc) = delete; FilterCache &operator=(FilterCache const &fc) = delete; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/GaussQuadrature.h b/src/core/GaussQuadrature.h index 147b4faf4..638111722 100644 --- a/src/core/GaussQuadrature.h +++ b/src/core/GaussQuadrature.h @@ -31,49 +31,177 @@ namespace mrcpp { +/** + * @brief Maximum supported Gauss-Legendre order (per sub-interval). + * + * Implementation uses Newton iterations on Legendre polynomials in double + * precision and is tuned for numerical stability up to this limit. + */ const int MaxGaussOrder = 42; + +/** + * @brief Convergence tolerance for Newton's method when locating roots. + */ static const double EPS = 3.0e-12; + +/** + * @brief Safety cap on Newton iterations per root. + */ static const int NewtonMaxIter = 10; + +/** + * @brief Hard cap for a not-yet-implemented generic N-D integrator scaffold. + * (Kept for legacy/planning; current code provides explicit 1D/2D/3D.) + */ static const int MaxQuadratureDim = 7; +/** + * @class GaussQuadrature + * @brief Composite Gauss–Legendre quadrature on [A,B] with equal sub-intervals. + * + * What it represents + * ------------------ + * A parameterized Gauss–Legendre rule over a (possibly partitioned) interval: + * - order : number of Gauss nodes per sub-interval, + * - intervals : number of equal pieces tiling [A,B], + * - roots : all nodes over [A,B] for the composite rule (size npts), + * - weights : corresponding weights (size npts). + * + * In addition, it stores the canonical (unscaled) Gauss nodes/weights on [-1,1] + * so the rule can be remapped quickly if [A,B] or 'intervals' changes. + * + * Typical usage + * ------------- + * GaussQuadrature g(q, a, b, m); // q = order, [a,b] bounds, m sub-intervals + * auto val1 = g.integrate(f1D); + * auto val2 = g.integrate(f2D); // tensor-product rule (q*m in each axis) + * + * Notes + * ----- + * - “Composite” means we replicate the same order-q rule on each of the + * 'intervals' equal sub-intervals, then sum the contributions. + * - setBounds() / setIntervals() preserve the canonical [-1,1] rule and + * rebuild the scaled (roots,weights) for the new configuration. + */ class GaussQuadrature final { public: + /** + * @brief Construct a Gauss–Legendre quadrature rule. + * @param k Order (nodes per sub-interval), 0 ≤ k ≤ MaxGaussOrder. + * @param a Lower bound A (default -1). + * @param b Upper bound B (default 1). + * @param inter Number of equal sub-intervals (default 1, must be ≥ 1). + * + * Effects (see .cpp): + * - Builds canonical nodes/weights on [-1,1] via Newton’s method. + * - Scales/replicates them to fill (roots,weights) over [A,B]. + */ GaussQuadrature(int k, double a = -1.0, double b = 1.0, int inter = 1); + /** + * @name Tensor-product integration helpers + * @{ + * @brief Integrate a RepresentableFunction using the prepared rule. + * + * 1D: ∑_i w_i f(x_i) + * 2D: ∑_i ∑_j w_i w_j f(x_i, x_j) + * 3D: ∑_i ∑_j ∑_k w_i w_j w_k f(x_i, x_j, x_k) + */ double integrate(RepresentableFunction<1> &func) const; double integrate(RepresentableFunction<2> &func) const; double integrate(RepresentableFunction<3> &func) const; + /** @} */ + /** + * @brief Set the number of equal sub-intervals and rebuild (roots,weights). + * @param i New number of sub-intervals (≥ 1). + * + * Reallocates global arrays to size npts = order * intervals and remaps + * the canonical [-1,1] rule accordingly. + */ void setIntervals(int i); + + /** + * @brief Set integration bounds [a,b] and rebuild (roots,weights). + * @param a Lower bound + * @param b Upper bound (must satisfy a < b) + */ void setBounds(double a, double b); + /** @return Number of sub-intervals tiling [A,B]. */ int getIntervals() const { return this->intervals; } + /** @return Upper bound B. */ double getUpperBound() const { return this->B; } + /** @return Lower bound A. */ double getLowerBound() const { return this->A; } + /** @return Composite-rule nodes over [A,B] (size npts). */ const Eigen::VectorXd &getRoots() const { return this->roots; } + /** @return Composite-rule weights over [A,B] (size npts). */ const Eigen::VectorXd &getWeights() const { return this->weights; } + + /** @return Canonical Gauss nodes on [-1,1] (size order). */ const Eigen::VectorXd &getUnscaledRoots() const { return this->unscaledRoots; } + /** @return Canonical Gauss weights on [-1,1] (size order). */ const Eigen::VectorXd &getUnscaledWeights() const { return this->unscaledWeights; } protected: - int order; - double A; - double B; - int intervals; - int npts; - Eigen::VectorXd roots; - Eigen::VectorXd weights; - Eigen::VectorXd unscaledRoots; - Eigen::VectorXd unscaledWeights; + // ---- Parameters describing the rule ---- + int order; ///< Nodes per sub-interval (q) + double A; ///< Lower integration bound + double B; ///< Upper integration bound + int intervals; ///< Number of equal sub-intervals tiling [A,B] + int npts; ///< Total nodes = order * intervals + // ---- Scaled (composite) rule on [A,B] ---- + Eigen::VectorXd roots; ///< All nodes over [A,B] (size npts) + Eigen::VectorXd weights; ///< All weights over [A,B] (size npts) + + // ---- Canonical rule on [-1,1] ---- + Eigen::VectorXd unscaledRoots; ///< Nodes on [-1,1] (size 'order') + Eigen::VectorXd unscaledWeights; ///< Weights on [-1,1] (size 'order') + + /** + * @brief Map canonical nodes onto [a,b] replicated over @p inter blocks. + * @param rts Output vector of length inter*order. + * @param a,b Interval bounds. + * @param inter Number of sub-intervals (default 1). + * + * Each block is an affine image of [-1,1] with width (b-a)/inter. + */ void rescaleRoots(Eigen::VectorXd &rts, double a, double b, int inter = 1) const; + + /** + * @brief Map canonical weights onto [a,b] replicated over @p inter blocks. + * @param wgts Output vector of length inter*order. + * @param a,b Interval bounds. + * @param inter Number of sub-intervals (default 1). + * + * Weights scale by the Jacobian of the affine mapping (factor 0.5*(b-a)/inter). + */ void rescaleWeights(Eigen::VectorXd &wgts, double a, double b, int inter = 1) const; + /** + * @brief Rebuild (roots,weights) on [A,B] for the current 'intervals'. + * + * Uses the stored canonical (unscaled) rule and performs replication. + */ void calcScaledPtsWgts(); + + /** + * @brief Compute canonical Gauss–Legendre nodes/weights on [-1,1]. + * @return 1 on success; 0 if Newton iteration failed to converge. + * + * Uses Newton’s method on Legendre polynomials with symmetric pairing: + * computes half the roots in (0,1) and reflects them about 0. + */ int calcGaussPtsWgts(); + /** + * @brief Planned recursive N-D integration (not implemented). + * @return No return; aborts at runtime if called. + */ double integrate_nd(RepresentableFunction<3> &func, int axis = 0) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/InterpolatingBasis.h b/src/core/InterpolatingBasis.h index 12251f54f..926946179 100644 --- a/src/core/InterpolatingBasis.h +++ b/src/core/InterpolatingBasis.h @@ -33,12 +33,49 @@ namespace mrcpp { * * @brief Interpolating scaling functions as defined by Alpert etal, * J Comp Phys 182, 149-190 (2002). + * + * High-level overview + * ------------------- + * InterpolatingBasis represents the *interpolatory scaling functions* used in + * the multiwavelet framework. These functions are constructed so that: + * • they interpolate at Gaussian quadrature nodes (cardinal property), + * • the quadrature-induced inner product is simple/diagonal, + * • they form the scaling space for the chosen polynomial order. + * + * Relationship to the hierarchy: + * - Inherits from ScalingBasis, which provides common functionality for + * scaling-function families (orders, quadrature data, storage for basis + * polynomials, value/coefficient maps, etc.). + * - The constructor finalizes initialization by calling three private + * helpers: + * 1) initScalingBasis() — build the interpolating polynomials, + * 2) calcQuadratureValues() — fill values at quadrature nodes, + * 3) calcCVMaps() — build coefficient↔value diagonal maps. + * + * Mathematical context (very short): + * - Follows the construction in Alpert (2002) for interpolatory multiwavelets, + * where basis functions {I_k} satisfy I_k(x_j) = δ_{k,j} at quadrature nodes + * {x_j}. This makes projection/evaluation particularly efficient. */ class InterpolatingBasis final : public ScalingBasis { public: /** @returns New InterpolatingBasis object * @param[in] k: Polynomial order of basis, `1 < k < 40` + * + * What happens in the constructor: + * - Calls the ScalingBasis base constructor with (k, Interpol), which + * sets the family/type to “Interpolating”. + * - initScalingBasis(): constructs the set of interpolating polynomials + * (stored in the base's internal container, typically `funcs`). + * - calcQuadratureValues(): sets the basis evaluation matrix at nodes to + * the identity (cardinality property). + * - calcCVMaps(): builds diagonal conversion maps between coefficient + * vectors and values at quadrature nodes using the quadrature weights. + * + * Precondition: + * - k must be within the supported range of the library (checked by the + * base class). Typical limits are 1 < k < 40 as noted here. */ InterpolatingBasis(int k) : ScalingBasis(k, Interpol) { @@ -48,9 +85,34 @@ class InterpolatingBasis final : public ScalingBasis { } private: + /** + * @brief Construct the interpolatory scaling polynomials {I_k}. + * + * Implementation details (in .cpp): + * - Uses Gaussian quadrature roots/weights of order q. + * - Expands I_k in a Legendre polynomial basis and enforces I_k(x_j)=δ_{kj}. + * - Applies sqrt(weight) normalization so that the induced inner product + * is diagonal and the cv/vc maps become simple scalings. + */ void initScalingBasis(); + + /** + * @brief Fill the basis-at-nodes matrix. + * + * For an interpolating basis, evaluating the k-th basis at node j yields + * δ_{kj}. The implementation sets the diagonal entries to 1 (identity). + */ void calcQuadratureValues(); + + /** + * @brief Build coefficient↔value diagonal maps using quadrature weights. + * + * - cvMap(k,k) = sqrt(1 / w_k) (coefficients → values at nodes) + * - vcMap(k,k) = sqrt(w_k) (values at nodes → coefficients) + * + * These maps are exact inverses due to the chosen normalization. + */ void calcCVMaps(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/LegendreBasis.h b/src/core/LegendreBasis.h index c26aa4fed..75554ce21 100644 --- a/src/core/LegendreBasis.h +++ b/src/core/LegendreBasis.h @@ -35,12 +35,54 @@ namespace mrcpp { * * @brief Legendre scaling functions as defined by Alpert, * SIAM J Math Anal 24 (1), 246 (1993). + * + * High-level overview + * ------------------- + * LegendreBasis represents the *Legendre scaling functions* used as a scaling + * space in the multiwavelet framework. In contrast to an *interpolating* basis, + * here the basis functions are (shifted/scaled) Legendre polynomials with + * exact L² normalization. This choice leads to dense coefficient↔value maps + * (built from evaluations at quadrature nodes), but offers orthogonality and + * well-understood approximation properties. + * + * Relationship to the class hierarchy + * ----------------------------------- + * - Inherits from @ref ScalingBasis, which provides: + * • storage for basis polynomials (e.g. `funcs`), + * • quadrature order and data, + * • matrices for basis evaluated at quadrature nodes (`quadVals`), + * • conversion maps between coefficient and nodal value spaces + * (`cvMap` and `vcMap`). + * + * What the constructor does + * ------------------------- + * The constructor takes the polynomial order `k` (with typical bounds 1 < k < 40) + * and: + * 1) calls the base `ScalingBasis(k, Legendre)` to set the family/tag, + * 2) `initScalingBasis()` to build the list of normalized Legendre polynomials + * up to degree `k`, + * 3) `calcQuadratureValues()` to evaluate the basis at quadrature nodes, + * 4) `calcCVMaps()` to assemble value→coefficient (`vcMap`) using quadrature + * weights and then compute coefficient→value (`cvMap`) as its inverse. + * + * Notes + * ----- + * - The actual construction details are implemented in the corresponding .cpp: + * • `initScalingBasis()` multiplies P_k by √(2k+1) for exact normalization. + * • `calcQuadratureValues()` fills `quadVals(i,k) = P_k(x_i)`. + * • `calcCVMaps()` sets `vcMap(i,k) = P_k(x_i) * w_i` and inverts it. */ class LegendreBasis final : public ScalingBasis { public: /** @returns New LegendreBasis object * @param[in] k: Polynomial order of basis, `1 < k < 40` + * + * Construction sequence: + * - `ScalingBasis(k, Legendre)` tags this as a Legendre-family scaling basis. + * - `initScalingBasis()` builds normalized Legendre polynomials {P_0..P_k}. + * - `calcQuadratureValues()` evaluates the basis at Gaussian nodes. + * - `calcCVMaps()` creates value↔coefficient maps using quadrature weights. */ LegendreBasis(int k) : ScalingBasis(k, Legendre) { @@ -50,9 +92,12 @@ class LegendreBasis final : public ScalingBasis { } private: + /** @brief Build and store the normalized Legendre polynomials up to degree k. */ void initScalingBasis(); + /** @brief Fill the matrix of basis values at quadrature nodes. */ void calcQuadratureValues(); + /** @brief Assemble value→coefficient map and its inverse (coeff→value). */ void calcCVMaps(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/MWFilter.h b/src/core/MWFilter.h index 8daf43c54..099fceb4f 100644 --- a/src/core/MWFilter.h +++ b/src/core/MWFilter.h @@ -34,29 +34,144 @@ namespace mrcpp { +/** + * @class MWFilter + * @brief Container for a 2K×2K multiwavelet filter bank and its block views. + * + * High-level + * ---------- + * An MWFilter represents the matrix of a 1D multiwavelet transform for a given + * polynomial order and family (type). With K = order + 1, the full transform + * matrix has size 2K × 2K and is organized into four K × K blocks: + * + * filter = [ G0 G1 ] (top row: scaling channel) + * [ H0 H1 ] (bottom row: wavelet channel) + * + * In the implementation (.cpp), G0/H0 are loaded from binary tables and + * G1/H1 are derived by family-specific symmetry relations. Transposes of the + * four blocks are also precomputed for the compression direction. + * + * Usage model + * ----------- + * - Construct from (order, type) → loads data from disk and builds blocks. + * - Construct from a given 2K×2K matrix → slices into blocks (no I/O). + * - Multiply vectors/matrices with the transform or its transpose using + * apply()/applyInverse(). + * - Query individual K×K subfilters for compression or reconstruction. + * + * Notes on 'type' + * --------------- + * 'type' identifies the filter family (e.g., Interpol or Legendre). The exact + * integer codes are defined elsewhere in MRCPP and validated in the .cpp. + * + * Dimension conventions + * --------------------- + * - order = k, K = k + 1 + * - Full transform: 2K × 2K (acts on 2K-length vectors / 2K-row matrices). + */ class MWFilter final { public: + /** + * @brief Construct from order and family type; loads blocks from disk. + * @param k Polynomial order (k ≥ 0; with library-defined upper bound). + * @param t Filter family/type tag (e.g., Interpol or Legendre). + * + * Side effects (see .cpp): + * - Locates binary tables on disk (family+order dependent). + * - Reads G0 and H0, synthesizes G1 and H1 by symmetry. + * - Assembles the full 2K×2K matrix 'filter'. + */ MWFilter(int k, int t); + + /** + * @brief Construct directly from a full 2K×2K matrix (no I/O). + * @param t Filter family/type tag. + * @param data Full transform matrix of size 2K×2K. + * + * The order is inferred as K = data.cols()/2, order = K - 1. + * The four K×K blocks (and their transposes) are sliced from @p data. + */ MWFilter(int t, const Eigen::MatrixXd &data); + /** + * @name Apply the transform / its transpose + * @{ + * + * @brief Apply the forward/reconstruction transform: data ← filter * data. + * Overloads exist for Eigen::MatrixXd and Eigen::VectorXd. + * + * @brief Apply the inverse/compression transform: data ← filter^T * data. + * Overloads exist for Eigen::MatrixXd and Eigen::VectorXd. + * + * Precondition: + * - data.rows() must equal filter.cols() (i.e., 2K). + */ void apply(Eigen::MatrixXd &data) const; void apply(Eigen::VectorXd &data) const; void applyInverse(Eigen::MatrixXd &data) const; void applyInverse(Eigen::VectorXd &data) const; + /** @} */ + /** @return Polynomial order k (so K = k + 1). */ int getOrder() const { return this->order; } + + /** @return Filter family/type code. */ int getType() const { return this->type; } + /** @return Const reference to the full 2K×2K transform matrix. */ const Eigen::MatrixXd &getFilter() const { return this->filter; } + + /** + * @brief Return one of the four K×K subfilters. + * @param i Block index in the chosen operation's order (0..3). + * @param oper Operation selector (direction), defaults to 0. + * + * Semantics (see .cpp): + * - For Reconstruction: blocks returned in order (H0, G0, H1, G1). + * - For Compression: transposed blocks (H0^T, H1^T, G0^T, G1^T). + * + * The actual enum/integer values for 'oper' (e.g., Reconstruction/Compression) + * are defined elsewhere (constants header). This method aborts on invalid + * @p i or @p oper. + */ const Eigen::MatrixXd &getSubFilter(int i, int oper = 0) const; + + /** + * @brief Shorthand: return the i-th compression subfilter (transposed form). + * Order: i=0→H0^T, 1→H1^T, 2→G0^T, 3→G1^T. + */ const Eigen::MatrixXd &getCompressionSubFilter(int i) const; + + /** + * @brief Shorthand: return the i-th reconstruction subfilter (direct form). + * Order: i=0→H0, 1→G0, 2→H1, 3→G1. + */ const Eigen::MatrixXd &getReconstructionSubFilter(int i) const; protected: + /** + * @brief Filter family/type tag (e.g., Interpol, Legendre). + */ int type; + + /** + * @brief Polynomial order k (K = k + 1). + */ int order; + + /** + * @brief Auxiliary dimension (reserved; may be unused in current code). + */ int dim; + /** + * @name Stored matrices + * @{ + * @brief Full transform and its K×K sub-blocks (+ transposes). + * + * Layout: filter = [ G0 G1 ] + * [ H0 H1 ] + */ Eigen::MatrixXd filter; ///< Full MW-transformation matrix Eigen::MatrixXd G0; Eigen::MatrixXd G1; @@ -67,14 +182,34 @@ class MWFilter final { Eigen::MatrixXd G1t; Eigen::MatrixXd H0t; Eigen::MatrixXd H1t; + /** @} */ private: + /** + * @brief Compose on-disk paths to H0/G0 tables (family- and order-specific). + * + * Implemented in the .cpp; uses a discovered filter library root. + * Sets #H_path and #G_path accordingly. + */ void setFilterPaths(const std::string &lib); + + /** + * @brief Slice #filter into sub-blocks and compute their transposes. + * + * Used in the constructor that takes a full matrix. + */ void fillFilterBlocks(); + + /** + * @brief Load H0/G0 from disk, synthesize H1/G1 by symmetry, and + * precompute transposes. Populates #G0,#G1,#H0,#H1 and their ^T. + */ void generateBlocks(); + /** @brief Absolute file path to the H0 table. */ std::string H_path; + /** @brief Absolute file path to the G0 table. */ std::string G_path; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/ObjectCache.h b/src/core/ObjectCache.h index 2323ae681..bdac0fe50 100644 --- a/src/core/ObjectCache.h +++ b/src/core/ObjectCache.h @@ -32,46 +32,152 @@ namespace mrcpp { +/** + * @def getObjectCache(T, X) + * @brief Convenience macro to bind a local reference @p X to the singleton + * instance of ObjectCache. + * + * Expands to: + * ObjectCache &X = ObjectCache::getInstance(); + * + * Example: + * getObjectCache(MyType, cache); + * if (!cache.hasId(id)) cache.load(id); + * MyType& obj = cache.get(id); + */ #define getObjectCache(T, X) ObjectCache &X = ObjectCache::getInstance(); +/** + * @class ObjectCache + * @tparam T The object type to be cached (owned via raw pointer). + * + * @brief A lightweight, index-addressed cache with singleton access, + * optional OpenMP locking, and simple memory accounting. + * + * High-level + * ---------- + * - Stores pointers to objects of type T in a sparse, integer-indexed array. + * - One global instance per T (Meyers singleton via getInstance()). + * - Provides virtual hooks `load(id)`, `unload(id)`, `get(id)` for derived + * caches to specialize on-demand construction and retrieval. + * - Tracks approximate memory usage per entry and in total. + * + * Thread-safety + * ------------- + * - The base class initializes an OpenMP lock (if MRCPP_HAS_OMP), and its + * destructor clears under that lock. However, *load/get/unload* here are not + * automatically locked; derived classes are expected to guard first-time + * construction (see FilterCache, ScalingCache, etc.). + * + * Ownership & lifetime + * -------------------- + * - The cache owns stored objects: `unload(id)` deletes them. + * - `clear()` unloads all present entries. + * - Copy/assignment are deleted to enforce singleton semantics. + * + * Indexing model + * -------------- + * - `highWaterMark` tracks the largest index ever seen. + * - Vectors `objs` and `mem` grow to accommodate new ids; gaps are filled with + * `nullptr` and `0`. Presence is tested with `hasId(id)`. + * + * Memory accounting + * ----------------- + * - `mem[id]` holds an approximate byte size for entry `id` (provided by the + * caller when inserting via `load(id, T*, memory)`). + * - `memLoaded` sums the sizes of currently loaded entries. + */ template class ObjectCache { public: + /** @brief Singleton accessor (one cache per T). */ static ObjectCache &getInstance(); + /** @brief Unload and delete all loaded objects. */ virtual void clear(); + /** + * @brief On-demand loader hook. Default impl is a no-op; derived caches + * should override to construct and insert the object for @p id. + */ virtual void load(int id); + + /** + * @brief Insert an already-constructed object pointer at index @p id. + * @param id Integer key. + * @param new_o Ownership-transferred pointer to T. + * @param memory Approximate size in bytes (for accounting). + * + * Expands internal storage if needed. If an object is already present + * at @p id, this is a no-op. + */ void load(int id, T *new_o, int memory); + + /** + * @brief Remove and delete the object at @p id (if present). + * Updates memory accounting. Virtual to allow specialization. + */ virtual void unload(int id); + /** + * @brief Retrieve a reference to the loaded object at @p id. + * Emits errors if @p id is invalid or if no object is loaded. + */ virtual T &get(int id); + + /** + * @brief Check whether an object is present at @p id. + * @return true if id ≤ highWaterMark and objs[id] != nullptr. + */ bool hasId(int id); + /** @return Number of slots allocated (including empty/null slots). */ int getNObjs() { return this->objs.size(); } + /** @return Total accounted memory over loaded entries. */ int getMem() { return this->memLoaded; } + /** @return Accounted memory for a specific @p id (0 if empty). */ int getMem(int id) { return this->mem[id]; } protected: + /** + * @brief Protected ctor initializes slot 0, memory 0, and OMP lock. + * + * Slot 0 is reserved/initialized so that valid ids can start at 1 if + * desired, but the cache also happily accepts id=0. + */ ObjectCache() { this->objs.push_back(nullptr); this->mem.push_back(0); MRCPP_INIT_OMP_LOCK(); } + /** + * @brief Destructor clears the cache under lock and destroys the lock. + * + * Ensures that concurrent threads do not race during teardown. + */ virtual ~ObjectCache() { MRCPP_SET_OMP_LOCK(); clear(); MRCPP_UNSET_OMP_LOCK(); MRCPP_DESTROY_OMP_LOCK(); } + + // Non-copyable singleton. ObjectCache(ObjectCache const &oc) = delete; ObjectCache &operator=(ObjectCache const &oc) = delete; + #ifdef MRCPP_HAS_OMP + /** @brief OpenMP lock for derived-class synchronized sections. */ omp_lock_t omp_lock; #endif + private: + /** @brief Largest index ever used (inclusive). */ int highWaterMark{0}; + /** @brief Sum of accounted memory over loaded entries. */ int memLoaded{0}; ///< memory occupied by loaded objects + /** @brief Sparse vector of owned pointers; nullptr denotes empty slot. */ std::vector objs; ///< objects store + /** @brief Per-slot memory accounting (0 if empty). */ std::vector mem; ///< mem per object }; diff --git a/src/core/QuadratureCache.h b/src/core/QuadratureCache.h index 5e7f11feb..f0a6abaf8 100644 --- a/src/core/QuadratureCache.h +++ b/src/core/QuadratureCache.h @@ -32,38 +32,157 @@ namespace mrcpp { +/** + * @def getQuadratureCache(X) + * @brief Convenience macro to bind a local reference @p X to the global + * (singleton) QuadratureCache instance. + * + * Expands to: + * QuadratureCache &X = QuadratureCache::getInstance() + * + * Example: + * getQuadratureCache(qc); + * const auto& w = qc.getWeights(order); + */ #define getQuadratureCache(X) QuadratureCache &X = QuadratureCache::getInstance() +/** + * @class QuadratureCache + * @brief Process-wide cache for Gaussian quadrature rules (roots & weights). + * + * High-level + * ---------- + * Gaussian quadrature (Gauss-Legendre in MRCPP) is parameterized by: + * • order (number of nodes/weights), + * • integration domain [A, B], + * • optional replication over multiple equal sub-intervals ("intervals"). + * + * Constructing GaussQuadrature objects repeatedly can be costly; this cache + * stores one instance per order (and current domain/interval settings) and + * hands out references on demand. + * + * Design + * ------ + * - Singleton per process (Meyers' singleton via getInstance()). + * - Inherits from ObjectCache which provides basic + * load/unload/get plumbing indexed by an integer id (here: order). + * - Domain control: + * setBounds(a,b) → set global integration bounds [A,B] + * setIntervals(i) → split [A,B] into @p i equal sub-intervals (if used) + * These settings influence how GaussQuadrature is created in load(order). + * + * Thread-safety + * ------------- + * The base ObjectCache does not synchronize by itself; specialized caches + * typically guard first-time loads with OpenMP locks in the .cpp. Users should + * assume the cache is safe to read concurrently after an entry is present. + * + * Typical usage + * ------------- + * auto& qc = QuadratureCache::getInstance(); + * qc.setBounds(-1.0, 1.0); + * const Eigen::VectorXd& x = qc.getRoots(quad_order); + * const Eigen::VectorXd& w = qc.getWeights(quad_order); + */ class QuadratureCache final : public ObjectCache { public: + /** + * @brief Access the singleton instance. + */ static QuadratureCache &getInstance() { static QuadratureCache theQuadratureCache; return theQuadratureCache; } + /** + * @brief Ensure the quadrature of a given @p order is loaded. + * + * Implemented in the .cpp: constructs/initializes a GaussQuadrature that + * reflects the current @ref A, @ref B, and @ref intervals settings and + * inserts it into the underlying ObjectCache if absent. + */ void load(int order); + + /** + * @brief Retrieve the cached quadrature for @p order (lazy-loads if needed). + * @return Reference to the GaussQuadrature object owned by the cache. + */ GaussQuadrature &get(int order); + /** + * @name Convenience accessors (fetch vectors directly) + * @{ + * @brief Get the vector of abscissas (roots) for a given order. + */ const Eigen::VectorXd &getRoots(int i) { return get(i).getRoots(); } + + /** + * @brief Get the vector of weights for a given order. + */ const Eigen::VectorXd &getWeights(int i) { return get(i).getWeights(); } + /** @} */ + /** + * @brief Set the number of equal sub-intervals for composite quadrature. + * + * Interpretation: + * - If intervals > 1, the base interval [A,B] can be partitioned into + * `intervals` equal pieces and the quadrature replicated/shifted. + * - Exact semantics depend on GaussQuadrature; this cache records the + * value so that new loads honor it. + */ void setIntervals(int i); + + /** + * @brief Set the integration bounds used by subsequently loaded rules. + * @param a Lower bound A + * @param b Upper bound B + * + * Newly created GaussQuadrature objects will target [A,B]. Existing + * cached entries are unaffected until explicitly unloaded/reloaded. + */ void setBounds(double a, double b); + /** @return Current number of sub-intervals recorded in the cache. */ int getIntervals() const { return this->intervals; } + + /** @return Current upper integration bound B. */ double getUpperBound() const { return this->B; } + + /** @return Current lower integration bound A. */ double getLowerBound() const { return this->A; } private: + /** + * @brief Lower and upper bounds of the integration domain. + * + * Defaults are set in the private constructor (see .cpp). Changing these + * affects only future loads; existing cached rules remain as created. + */ double A; double B; + + /** + * @brief Number of equal sub-intervals used to tile [A,B]. + * + * When >1, the cache can generate composite quadrature by replicating the + * base rule on each sub-interval (implementation in .cpp / GaussQuadrature). + */ int intervals; + /** + * @brief Private constructor initializes default bounds/intervals. + * + * Enforces the singleton pattern; use getInstance() to access the cache. + */ QuadratureCache(); + + /// Private destructor; cache cleans up its owned objects via ObjectCache. ~QuadratureCache(); + // Non-copyable / non-assignable to maintain singleton semantics. QuadratureCache(QuadratureCache const &qc) = delete; QuadratureCache &operator=(QuadratureCache const &qc) = delete; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/ScalingBasis.h b/src/core/ScalingBasis.h index 2ea48fba0..c44942877 100644 --- a/src/core/ScalingBasis.h +++ b/src/core/ScalingBasis.h @@ -33,37 +33,129 @@ namespace mrcpp { +/** + * @class ScalingBasis + * @brief Abstract base for scaling-function families (Legendre, Interpolating). + * + * What this class represents + * -------------------------- + * A *scaling basis* is a finite set of 1D polynomials {φ_k}_{k=0..order} + * that span the scaling space at level 0 for a given multiwavelet family. + * Concrete families (e.g., LegendreBasis, InterpolatingBasis) derive from + * this class and: + * • construct and store the polynomials in `funcs`, + * • populate the evaluation matrix at quadrature nodes `quadVals`, + * • build coefficient↔value conversion maps `cvMap` / `vcMap`. + * + * Dimensions and conventions + * -------------------------- + * - order := polynomial degree cutoff (≥ 0). + * - Quadrature order q = order + 1 (one node per basis function). + * - `quadVals` is q×q with layout: rows = nodes, cols = basis index. + * - `cvMap` maps coefficient vectors → nodal values (Forward). + * - `vcMap` maps nodal values → coefficient vectors (Backward). + * + * Responsibilities provided here + * ------------------------------ + * - Store family `type` (Legendre or Interpol, defined in constants.h) and `order`. + * - Provide access to basis polynomials and to the conversion matrices. + * - Offer a generic evaluator to sample the basis at arbitrary points. + * - Define equality operators (same family and order). + * + * Notes for implementers of derived classes + * ----------------------------------------- + * - Call the base ctor with (k, t). It sizes `quadVals`, `cvMap`, `vcMap` + * to q×q zeros; you must fill them in your implementation (.cpp). + * - Push back exactly q polynomials into `funcs` in the order k = 0..order. + */ class ScalingBasis { public: + /** + * @brief Construct a base scaling space descriptor. + * @param k Polynomial order (k ≥ 0). + * @param t Family tag (e.g., Legendre or Interpol). + * + * Effects (implemented in the .cpp): + * - Stores @p t, @p k. + * - Allocates q×q zero matrices for `quadVals`, `cvMap`, and `vcMap`, + * where q = k + 1. + * - Derived classes then fill these structures. + */ ScalingBasis(int k, int t); virtual ~ScalingBasis() = default; + /** + * @brief Evaluate all basis polynomials at D sample points. + * @param r Pointer to array of D abscissas. + * @param vals Output matrix of size (q × D) with + * vals(k, d) = φ_k( r[d] ), k = 0..q-1. + * + * Precondition: + * - vals.rows() == funcs.size() == q. + * + * Remarks: + * - Column-major Eigen storage is irrelevant here; we just fill entries. + * - Useful for projecting/evaluating on arbitrary nodes (not only quadrature). + */ void evalf(const double *r, Eigen::MatrixXd &vals) const; + /** @return Mutable reference to the k-th basis polynomial φ_k. */ Polynomial &getFunc(int k) { return this->funcs[k]; } + /** @return Const reference to the k-th basis polynomial φ_k. */ const Polynomial &getFunc(int k) const { return this->funcs[k]; } + /** @return The type of scaling basis (Legendre or Interpol; see MRCPP/constants.h) */ int getScalingType() const { return this->type; } + /** @return Polynomial order k. */ int getScalingOrder() const { return this->order; } + /** @return Quadrature order q = k + 1 (one node per basis function). */ int getQuadratureOrder() const { return this->order + 1; } + /** @return Matrix of basis values at quadrature nodes (q × q). */ const Eigen::MatrixXd &getQuadratureValues() const { return this->quadVals; } + + /** + * @brief Access the coefficient/value conversion map. + * @param operation Use `Forward` (from constants.h) for coeff→value, + * anything else selects value→coeff. + * @return const reference to `cvMap` (Forward) or `vcMap` (Backward). + */ const Eigen::MatrixXd &getCVMap(int operation) const; + /** @brief Equality iff same family type and polynomial order. */ bool operator==(const ScalingBasis &basis) const; + /** @brief Inequality iff family type or polynomial order differs. */ bool operator!=(const ScalingBasis &basis) const; + /** + * @brief Stream print helper (delegates to virtual print()). + * Prints order and a human-readable family name. + */ friend std::ostream &operator<<(std::ostream &o, const ScalingBasis &bas) { return bas.print(o); } protected: + /** @brief Family tag (Legendre or Interpol). */ const int type; + /** @brief Polynomial order k. */ const int order; + + /** @brief Basis values at quadrature points: quadVals(i,k) = φ_k(x_i). */ Eigen::MatrixXd quadVals; // function values at quadrature pts + + /** @brief Coefficient → value (at nodes) linear map (q × q). */ Eigen::MatrixXd cvMap; // coef-value transformation matrix + + /** @brief Value (at nodes) → coefficient linear map (q × q). */ Eigen::MatrixXd vcMap; // value-coef transformation matrix + + /** @brief List of basis polynomials φ_0..φ_k (size q). */ std::vector funcs; + /** + * @brief Polymorphic pretty-printer called by operator<<. + * Concrete bases may override to append family-specific info. + */ std::ostream &print(std::ostream &o) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/core/ScalingCache.h b/src/core/ScalingCache.h index 6ea795b6e..03efc734e 100644 --- a/src/core/ScalingCache.h +++ b/src/core/ScalingCache.h @@ -29,16 +29,76 @@ namespace mrcpp { +/** + * @def getLegendreScalingCache(X) + * @brief Convenience macro to bind a local reference @p X to the singleton + * ScalingCache specialized for LegendreBasis. + * + * Usage: + * getLegendreScalingCache(cache); + * auto& B = cache.get(order); + */ #define getLegendreScalingCache(X) ScalingCache &X = ScalingCache::getInstance() + +/** + * @def getInterpolatingScalingCache(X) + * @brief Convenience macro to bind a local reference @p X to the singleton + * ScalingCache specialized for InterpolatingBasis. + * + * Usage: + * getInterpolatingScalingCache(cache); + * auto& B = cache.get(order); + */ #define getInterpolatingScalingCache(X) \ ScalingCache &X = ScalingCache::getInstance() +/** + * @class ScalingCache + * @tparam P A concrete scaling-basis type (e.g., LegendreBasis, InterpolatingBasis). + * @brief Thread-safe singleton cache for scaling bases keyed by polynomial order. + * + * Motivation + * ---------- + * Constructing a scaling basis of order `k` (which internally prepares + * polynomials, quadrature-derived maps, etc.) can be relatively expensive. + * This cache guarantees that for a given template parameter P (basis family) + * and a given order, exactly one instance is created and then reused. + * + * Design + * ------ + * - Inherits from @ref ObjectCache

, which provides sparse indexed storage, + * memory accounting, and basic get/load/unload primitives. + * - Singleton per `P` (Meyers singleton via getInstance()) so that all parts + * of the program share the same cache for the same basis family. + * - Thread-safety: the first-time construction/insertion is protected by + * MRCPP_SET_OMP_LOCK / MRCPP_UNSET_OMP_LOCK. Reads after presence are fast. + * + * Memory accounting + * ----------------- + * The `memo` value passed to ObjectCache is a *rough* byte estimate: + * memo ≈ 2 * (order+1)^2 * sizeof(double) + * The constant factor “2” approximates two q×q matrices stored by a basis + * (e.g., cvMap and vcMap), where q = order + 1. It is not an exact footprint, + * but suffices for simple bookkeeping. + */ template class ScalingCache final : public ObjectCache

{ public: + /** + * @brief Access the singleton instance for the template parameter P. + * + * One instance per concrete basis family exists process-wide. + */ static ScalingCache &getInstance() { static ScalingCache theScalingCache; return theScalingCache; } + + /** + * @brief Ensure the basis of a given @p order is present in the cache. + * + * If absent, constructs a new P(order) under an OpenMP lock and inserts it + * into the underlying ObjectCache with a rough memory estimate. + */ void load(int order) { MRCPP_SET_OMP_LOCK(); if (not this->hasId(order)) { @@ -49,15 +109,21 @@ template class ScalingCache final : public ObjectCache

{ MRCPP_UNSET_OMP_LOCK(); } + /** + * @brief Retrieve the cached basis of a given @p order (lazy-loads if needed). + * @return Reference to the basis object owned by the cache. + */ P &get(int order) { if (not this->hasId(order)) { load(order); } return ObjectCache

::get(order); } private: + /// Private constructor enforces the singleton pattern. ScalingCache() {} + // Non-copyable / non-assignable. ScalingCache(const ScalingCache

&sc) = delete; ScalingCache

&operator=(const ScalingCache

&sc) = delete; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/AnalyticFunction.h b/src/functions/AnalyticFunction.h index aca20285b..938718d72 100644 --- a/src/functions/AnalyticFunction.h +++ b/src/functions/AnalyticFunction.h @@ -32,27 +32,69 @@ namespace mrcpp { -template class AnalyticFunction : public RepresentableFunction { +/** + * @class AnalyticFunction + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @brief Implementation of @ref RepresentableFunction for the datatype double + */ +template +class AnalyticFunction : public RepresentableFunction { public: + /** @brief Default constructor; leaves the callable empty */ AnalyticFunction() = default; + + /** @brief Virtual destructor to match the base class interface */ ~AnalyticFunction() override = default; - AnalyticFunction(std::function &r)> f, const double *a = nullptr, const double *b = nullptr) + /** + * @brief Constructor with raw pointers for the bounds + * + * @param f The analytic function which is evaluated in this class + * @param a Optional raw pointer to an array of D lower bounds (can be nullptr) + * @param b Optional raw pointer to an array of D upper bounds (can be nullptr) + */ + AnalyticFunction(std::function &r)> f, + const double *a = nullptr, + const double *b = nullptr) : RepresentableFunction(a, b) , func(f) {} - AnalyticFunction(std::function &r)> f, const std::vector &a, const std::vector &b) + + /** + * @brief Overload constructor with std::vector for the bounds + * + * @param f The analytic function which is evaluated in this class + * @param a Vector of D lower bounds. + * @param b Vector of D upper bounds. + */ + AnalyticFunction(std::function &r)> f, + const std::vector &a, + const std::vector &b) : AnalyticFunction(f, a.data(), b.data()) {} + /** + * @brief Set the analytic function to be evaluated + * @param f New analytic function + */ void set(std::function &r)> f) { this->func = f; } + /** + * @brief Evaluate the analytic function at coordinate @p r. + * @param r Coordinate where to evaluate the function + * + * @details Checks if the point is within bounds before evaluating + * + * @return The function value at point @p r + */ T evalf(const Coord &r) const override { - T val = 0.0; + T val = T(0); if (not this->outOfBounds(r)) val = this->func(r); return val; } protected: - std::function &r)> func; + std::function &r)> func; ///< User-provided analytic function }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/BoysFunction.h b/src/functions/BoysFunction.h index f8b8824d1..3bd29f89a 100644 --- a/src/functions/BoysFunction.h +++ b/src/functions/BoysFunction.h @@ -30,16 +30,38 @@ namespace mrcpp { +/** + * @class BoysFunction + * @brief Adaptive multiresolution evaluator for the Boys function + * \f$ F_n(x) = \int_{0}^{1} t^{2n}\mathrm{e}^{-xt^2}\mathrm{d}t \f$, + * where \f$ x\ge0 \f$ and \f$ n\ge0 \f$ + */ class BoysFunction final : public RepresentableFunction<1, double> { public: + /** + * @brief Construct an evaluator for \f$ F_n(x) \f$ + * @param n The order (\f$ \ge0 \f$) of the Boys function + * @param prec Projection precision for the adaptive MRA (default \f$ 10^{-10} \f$) + * + * @details The `MRA` member is initialised in the .cpp with a default 1D bounding + * box and a fixed scaling basis (currently an interpolating basis of order 13); + * this header does not constrain that choice. + */ BoysFunction(int n, double prec = 1.0e-10); + /** + * @brief Evaluate \f$ F_n(x) \f$ at the given abscissa + * @param r Coordinate container with a single component: \f$ x = r[0] \f$ + * @return The numerical value of \f$ F_n(x) \f$ + */ double evalf(const Coord<1> &r) const override; private: const int order; + const double prec; + MultiResolutionAnalysis<1> MRA; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/GaussExp.cpp b/src/functions/GaussExp.cpp index a57fe6708..521a94d41 100644 --- a/src/functions/GaussExp.cpp +++ b/src/functions/GaussExp.cpp @@ -333,11 +333,6 @@ template std::ostream &GaussExp::print(std::ostream &o) const { return o; } -/** @returns Coulomb repulsion energy between all pairs in GaussExp, including self-interaction - * - * @note Each Gaussian must be normalized to unit charge - * \f$ c = (\alpha/\pi)^{D/2} \f$ for this to be correct! - */ template double GaussExp::calcCoulombEnergy() const { NOT_IMPLEMENTED_ABORT } diff --git a/src/functions/GaussExp.h b/src/functions/GaussExp.h index a4315e381..79d87309c 100644 --- a/src/functions/GaussExp.h +++ b/src/functions/GaussExp.h @@ -37,10 +37,13 @@ namespace mrcpp { #define GAUSS_EXP_PREC 1.e-10 -/** @class GaussExp +/** + * @class GaussExp + * @tparam D Spatial dimension (1, 2, or 3) * * @brief Gaussian expansion in D dimensions * + * @details * - Monodimensional Gaussian expansion: * * \f$ g(x) = \sum_{m=1}^M g_m(x) = \sum_{m=1}^M \alpha_m e^{-\beta (x-x^0)^2} \f$ @@ -49,87 +52,302 @@ namespace mrcpp { * * \f$ G(x) = \sum_{m=1}^M G_m(x) = \sum_{m=1}^M \prod_{d=1}^D g_m^d(x^d) \f$ * + * Each Gaussian-type functions (GTFs) is represented by a @ref Gaussian + * (base class) and is concretely either a pure Gaussian @ref GaussFunc or a + * Gaussian times a Cartesian polynomial @ref GaussPoly. */ - template class GaussExp : public RepresentableFunction { public: + /** + * @brief Construct a Gaussian expansion and initialize each GTF to `nullptr` + * + * @param nTerms Number of GTFs (default 0) + * @param prec Unused here + * + * @note After construction, populate GTFs via @ref setFunc or @ref append. + */ GaussExp(int nTerms = 0, double prec = GAUSS_EXP_PREC); + + /// @brief Deep-copy constructor (clones every GTF via virtual copy()) GaussExp(const GaussExp &gExp); + + /// @brief Deep-copy assignment (existing GTFs are discarded then cloned) GaussExp &operator=(const GaussExp &gExp); - ~GaussExp() override; - auto begin() { return funcs.begin(); } - auto end() { return funcs.end(); } + /// @brief Destructor: deletes all owned GTFs and clears pointers + ~GaussExp() override; - const auto begin() const { return funcs.begin(); } - const auto end() const { return funcs.end(); } + auto begin() { return funcs.begin(); } ///< @return An iterator pointing to the first GTF + auto end() { return funcs.end(); } ///< @return An iterator pointing to the past-the-end GTF + const auto begin() const { return funcs.begin(); } ///< @return A const iterator pointing to the first GTF + const auto end() const { return funcs.end(); } ///< @return A const iterator pointing to the past-the-end GTF + /** + * @brief Coulomb repulsion energy between all pairs in the Gaussian expansion, including self-interaction + * @note Each GTF must be normalized to unit charge + * \f$ c = (\alpha/\pi)^{D/2} \f$ for this to be correct! + * Currently this function is only implemented for `D=3`. + */ double calcCoulombEnergy() const; + + /** + * @brief Compute the squared L2 norm of the expansion + * @details Use \f$ \| \sum_i f_i \|_2^2 = \sum_i \|f_i\|^2 + 2\sum_{i::calcScreening. + */ void calcScreening(double nStdDev = defaultScreening); + /** + * @brief Evaluate the Gaussian expansion at a D-dimensional coordinate + * @param r Point (Coord) in physical space in the MRA box + * @return Gaussian expansion value at the point + */ double evalf(const Coord &r) const override; + /** + * @brief Generates a Gaussian expansion that is semi-periodic around a unit-cell + * @param[in] period: The period of the unit cell + * @param[in] nStdDev: Number of standard diviations covered in each direction. Default 4.0 + * @return Semi-periodic version of a Gaussian expansion around a unit-cell + * + * @note See the implementation of each GTF in @ref Gaussian::periodify. + */ GaussExp periodify(const std::array &period, double nStdDev = 4.0) const; + + /** + * @brief Analytic derivative d/dx_dir (Cartesian direction) of the Gaussian expansion + * @param dir Axis index in [0, D-1] + * + * @return A GaussExp representing the derivative + */ GaussExp differentiate(int dir) const; + /** + * @brief Build a new Gaussian expansion that is the combination of this expansion and the other + * @param g The other Gaussian expansion + * @return The new Gaussian expansion with all GTFs from this expansion and the other + */ GaussExp add(GaussExp &g); + + /** + * @brief Build a new Gaussian expansion by appending a single GTF to this expansion + * @param g The single GTF + * @return The new Gaussian expansion with GTFs from this expansion and the single GTF + */ GaussExp add(Gaussian &g); + + /** + * @brief Build a new Gaussian expansion by multiplying this expansion and the other + * @param g The other Gaussian expansion + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i})(\sum_{j} g_{j}) = \sum_{ij} f_{i} g_{j} \f$ + */ GaussExp mult(GaussExp &g); + + /** + * @brief Build a new Gaussian expansion by multiplying this expansion and a single GTF + * @param g The single GTF + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i}) g = \sum_{i} f_{i} g \f$ + */ GaussExp mult(GaussFunc &g); + + /** + * @brief Build a new Gaussian expansion by multiplying this expansion and a single @ref GaussPoly + * @param g The single @ref GaussPoly + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i}) g = \sum_{i} g f_{i} \f$ + */ GaussExp mult(GaussPoly &g); + + /** + * @brief Build a new Gaussian expansion whose coefficient is scaled by a scalar + * @param d The scalar + * @return The new Gaussian expansion */ GaussExp mult(double d); + + /** + * @brief Scale coefficients of this expansion in place by a scalar + * @param d The scalar */ void multInPlace(double d); + /** + * @brief Overload the + operator to return a new Gaussian expansion formed by combining this expansion with the other + * @param g The other Gaussian expansion + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i})(\sum_{j} g_{j}) = \sum_{ij} f_{i} g_{j} \f$ + */ GaussExp operator+(GaussExp &g) { return this->add(g); } + /** + * @brief Overload the + operator to return a new Gaussian expansion formed by appending a single GTF to this expansion + * @param g The single GTF + * @return The new Gaussian expansion with GTFs from this expansion and the single GTF + */ GaussExp operator+(Gaussian &g) { return this->add(g); } + /** + * @brief Overload the * operator to return a new Gaussian expansion formed by multiplying this expansion and the other + * @param g The other Gaussian expansion + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i})(\sum_{j} g_{j}) = \sum_{ij} f_{i} g_{j} \f$ + */ GaussExp operator*(GaussExp &g) { return this->mult(g); } + /** + * @brief Overload the * operator to return a new Gaussian expansion formed by multiplying this expansion and a single GTF + * @param g The single GTF + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i}) g = \sum_{i} f_{i} g \f$ + */ GaussExp operator*(GaussFunc &g) { return this->mult(g); } + /** + * @brief Overload the * operator to return a new Gaussian expansion formed by multiplying this expansion and a single @ref GaussPoly + * @param g The single @ref GaussPoly + * @return The new Gaussian expansion \f$ (\sum_{i} f_{i}) g = \sum_{i} g f_{i} \f$ + */ GaussExp operator*(GaussPoly &g) { return this->mult(g); } + /** + * @brief Overload the * operator to return a new Gaussian expansion whose coefficient is scaled by a scalar + * @param d The scalar + * @return The new Gaussian expansion */ GaussExp operator*(double d) { return this->mult(d); } + /** + * @brief Overload the * operator to scale coefficients of this expansion in place by a scalar + * @param d The scalar */ void operator*=(double d) { this->multInPlace(d); } + /// @brief Get screening parameter double getScreening() const { return screening; } + + /// @brief Get monomial exponent on the axis for the i-th GTF std::array getExp(int i) const { return this->funcs[i]->getExp(); } + + /// @brief Get coefficient for the i-th GTF double getCoef(int i) const { return this->funcs[i]->getCoef(); } + + /// @brief Get powers for the i-th GTF const std::array &getPower(int i) const { return this->funcs[i]->getPower(); } + + /// @brief Get position for the i-th GTF const std::array &getPos(int i) const { return this->funcs[i]->getPos(); } + /// @brief Get number of GTFs in the expansion int size() const { return this->funcs.size(); } + + /// @brief Get mutable access to the i-th GTF Gaussian &getFunc(int i) { return *this->funcs[i]; } + + /// @brief Get const access to the i-th GTF const Gaussian &getFunc(int i) const { return *this->funcs[i]; } + /// @brief Get mutable pointer access to the i-th GTF Gaussian *operator[](int i) { return this->funcs[i]; } + + /// @brief Get const pointer access to the i-th GTF const Gaussian *operator[](int i) const { return this->funcs[i]; } + /** + * @brief Set a @ref GaussPoly for the i-th GTF in the expansion and scale its coefficient by a scalar + * @param i The i-th GTF + * @param g The @ref GaussPoly + * @param c The scalar + * @note Existing i-th GTF will be deleted + */ void setFunc(int i, const GaussPoly &g, double c = 1.0); + + /** + * @brief Set a single GTF for the i-th GTF in the expansion and scale its coefficient by a scalar + * @param i The i-th GTF + * @param g The @ref GaussPoly + * @param c The scalar + * @note Existing i-th GTF will be deleted + */ void setFunc(int i, const GaussFunc &g, double c = 1.0); + /// @brief Set global default screening for the Gaussian expansion void setDefaultScreening(double screen); + + /** + * @brief Enable/disable screening for this expansion and forward to all GTFs + * @details Conventionally, a positive @ref screening means "enabled" and + * a negative value means "disabled". + */ void setScreen(bool screen); + + /** + * @brief Set (isotropic) exponent the i-th GTF + * @param i The i-th GTF + * @param a The (isotropic) exponent + */ void setExp(int i, double a) { this->funcs[i]->setExp(a); } + + /** + * @brief Set coefficient for the i-th GTF + * @param i The i-th GTF + * @param b The coefficient + */ void setCoef(int i, double b) { this->funcs[i]->setCoef(b); } + + /** + * @brief Set powers for the i-th GTF + * @param i The i-th GTF + * @param power The powers + */ void setPow(int i, const std::array &power) { this->funcs[i]->setPow(power); } + + /** + * @brief Set center coordinates for the i-th GTF + * @param i The i-th GTF + * @param pos The center coordinates + */ void setPos(int i, const std::array &pos) { this->funcs[i]->setPos(pos); } - /** @brief Append Gaussian to expansion */ + /** + * @brief Append a single GTF to the end of the expansion + * @param g The single GTF + */ void append(const Gaussian &g); - /** @brief Append GaussExp to expansion */ + + /** + * @brief Append all GTFs from the other expansion + * @param g The other expansion + */ void append(const GaussExp &g); + /** @brief Stream pretty-printer (delegates to protected function @ref GaussExp::print) */ friend std::ostream &operator<<(std::ostream &o, const GaussExp &gExp) { return gExp.print(o); } + friend class Gaussian; protected: std::vector *> funcs; + static double defaultScreening; + double screening{0.0}; + /** + * @brief Implementation of stream printing (called by operator<<) + * @param o The output stream + * @return The output stream + */ std::ostream &print(std::ostream &o) const; + /** + * @brief Heuristic visibility vs. resolution scale and quadrature sampling + * @param scale Dyadic scale (tile size ~ 2^{-scale}) + * @param nPts Number of quadrature points per tile edge + * @return false if any GTF declares itself not visible, true otherwise + */ bool isVisibleAtScale(int scale, int nPts) const override; + + /** + * @brief Quick check whether the expansion is essentially zero on [la,lb] per axis + * @param la Lower bounds array of length D + * @param øb Upper bounds array of length D + * @return true only if each GTF is effectively zero on [la,lb] + */ bool isZeroOnInterval(const double *lb, const double *ub) const override; }; diff --git a/src/functions/GaussFunc.cpp b/src/functions/GaussFunc.cpp index 28736be58..06fa8904b 100644 --- a/src/functions/GaussFunc.cpp +++ b/src/functions/GaussFunc.cpp @@ -143,11 +143,6 @@ template void GaussFunc::multInPlace(const GaussFunc &rhs) { this->setPow(newPow); } -/** @brief Multiply two GaussFuncs - * @param[in] this: Left hand side of multiply - * @param[in] rhs: Right hand side of multiply - * @returns New GaussPoly - */ template GaussPoly GaussFunc::mult(const GaussFunc &rhs) { GaussFunc &lhs = *this; GaussPoly result; @@ -163,10 +158,6 @@ template GaussPoly GaussFunc::mult(const GaussFunc &rhs) { return result; } -/** @brief Multiply GaussFunc by scalar - * @param[in] c: Scalar to multiply - * @returns New GaussFunc - */ template GaussFunc GaussFunc::mult(double c) { GaussFunc g = *this; g.coef *= c; @@ -195,14 +186,6 @@ template std::ostream &GaussFunc::print(std::ostream &o) const { return o; } -/** @brief Compute Coulomb repulsion energy between two GaussFuncs - * @param[in] this: Left hand GaussFunc - * @param[in] rhs: Right hand GaussFunc - * @returns Coulomb energy - * - * @note Both Gaussians must be normalized to unit charge - * \f$ \alpha = (\beta/\pi)^{D/2} \f$ for this to be correct! - */ template double GaussFunc::calcCoulombEnergy(const GaussFunc &gf) const { NOT_IMPLEMENTED_ABORT; } diff --git a/src/functions/GaussFunc.h b/src/functions/GaussFunc.h index 874bb3850..42b1be70a 100644 --- a/src/functions/GaussFunc.h +++ b/src/functions/GaussFunc.h @@ -32,60 +32,159 @@ namespace mrcpp { -/** @class GaussFunc +/** + * @class GaussFunc + * @tparam D Spatial dimension (1, 2, or 3) * * @brief Gaussian function in D dimensions with a simple monomial in front * * - Monodimensional Gaussian (GaussFunc<1>): - * * \f$ g(x) = \alpha (x-x_0)^a e^{-\beta (x-x_0)^2} \f$ * * - Multidimensional Gaussian (GaussFunc): - * * \f$ G(x) = \prod_{d=1}^D g^d(x^d) \f$ */ - template class GaussFunc : public Gaussian { public: - /** @returns New GaussFunc object - * @param[in] beta: Exponent, \f$ e^{-\beta r^2} \f$ - * @param[in] alpha: Coefficient, \f$ \alpha e^{-r^2} \f$ - * @param[in] pos: Position \f$ (x - pos[0]), (y - pos[1]), ... \f$ - * @param[in] pow: Monomial power, \f$ x^{pow[0]}, y^{pow[1]}, ... \f$ + /** + * @brief Constructor which forwards to the Gaussian constructor + * @param beta Exponent, \f$ e^{-\beta r^2} \f$ + * @param alpha Coefficient, \f$ \alpha e^{-r^2} \f$ + * @param[in] pos Position \f$ (x - pos[0]), (y - pos[1]), ... \f$ + * @param[in] pow Monomial power, \f$ x^{pow[0]}, y^{pow[1]}, ... \f$ */ GaussFunc(double beta, double alpha, const Coord &pos = {}, const std::array &pow = {}) : Gaussian(beta, alpha, pos, pow) {} + + /** + * @brief Constructor which forwards to the Gaussian constructor + * @param[in] beta List of exponents, \f$ e^{-\beta r^2} \f$ + * @param alpha Coefficient, \f$ \alpha e^{-r^2} \f$ + * @param[in] pos Position \f$ (x - pos[0]), (y - pos[1]), ... \f$ + * @param[in] pow Monomial power, \f$ x^{pow[0]}, y^{pow[1]}, ... \f$ + */ GaussFunc(const std::array &beta, double alpha, const Coord &pos = {}, const std::array &pow = {}) : Gaussian(beta, alpha, pos, pow) {} + + /// @brief Copy constructor. GaussFunc(const GaussFunc &gf) : Gaussian(gf) {} + GaussFunc &operator=(const GaussFunc &rhs) = delete; - Gaussian *copy() const override; + /** + * @brief Performs a deep copy + * @return Pointer to a new GaussFunc copy of this instance + */ + Gaussian *copy() const override; + + /** + * @brief Compute Coulomb repulsion energy between this GaussFunc and another + * @param gf Other GaussFunc + * @return Coulomb energy + * + * @note Implemented only for D = 3 + * @note Both Gaussians must be normalized to unit charge + * \f$ \alpha = (\beta/\pi)^{D/2} \f$ for this to be correct! + */ double calcCoulombEnergy(const GaussFunc &rhs) const; + + /** + * @brief Calculates the squared norm of this GaussFunc + * @return The squared norm + */ double calcSquareNorm() const override; + /** + * @brief Evaluate the gaussian f(r) at a D-dimensional coordinate + * @param r Point (Coord) in physical space in the MRA box + * @return Function value f(r). + */ double evalf(const Coord &r) const override; + + /** + * @brief Evaluate the *1D* separable factor along axis @p dim + * @param r Coordinate along axis @p dim + * @param dim Axis index in [0, D-1]. + * + * @return The value of the 1D Gaussian factor g_dim(r), dim = {0, .., D-1} -> x,y,z... + */ double evalf1D(double r, int dir) const override; + /** + * @brief Convert this GaussFunc to a GaussExp object + * @return A GaussExp representing this GaussFunc + */ GaussExp asGaussExp() const override; + + /** + * @brief Analytic derivative d/dx_dir (Cartesian direction) of the GaussFunc + * @param dir Axis index in [0, D-1] + * + * @return A GaussPoly representing the derivative (polynomial×Gaussian) + */ GaussPoly differentiate(int dir) const override; + /** + * @brief Multiplies this GaussFunc in-place with another GaussFunc + * @param rhs The GaussFunc to multiply with + * @note The result is stored in this GaussFunc, thus overwriting its previous values + */ void multInPlace(const GaussFunc &rhs); + + /** + * @brief Operator overload forwarding to multInPlace + * @param rhs The GaussFunc to multiply with + */ void operator*=(const GaussFunc &rhs) { multInPlace(rhs); } + + /** + * @brief Multiply another GaussFunc with this GaussFunc + * @param rhs Other GaussFunc + * @return Resulting GaussPoly + */ GaussPoly mult(const GaussFunc &rhs); + + /** + * @brief Multiply this GaussFunc with a scalar + * @param c Scalar to multiply + * @returns Resulting GaussFunc + */ GaussFunc mult(double c); + + /** + * @brief Operator overload forwarding to mult + * @param rhs The GaussFunc to multiply with + * @return Resulting GaussPoly + */ GaussPoly operator*(const GaussFunc &rhs) { return this->mult(rhs); } + + /** + * @brief Operator overload forwarding to mult + * @param rhs Scalar to multiply with + * @return Resulting GaussFunc + */ GaussFunc operator*(double c) { return this->mult(c); } + /** + * @brief Set the power in dimension d + * @param d Dimension index + * @param power Power to set + */ void setPow(int d, int power) override { this->power[d] = power; } + + /** + * @brief Set the powers in all dimensions + * @param power Array of powers to set + */ void setPow(const std::array &power) override { this->power = power; } private: + /// @brief Print GaussFunc to output stream std::ostream &print(std::ostream &o) const override; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/GaussPoly.cpp b/src/functions/GaussPoly.cpp index 0dfeaf2cd..da37171ba 100644 --- a/src/functions/GaussPoly.cpp +++ b/src/functions/GaussPoly.cpp @@ -37,12 +37,6 @@ using namespace Eigen; namespace mrcpp { -/** @returns New GaussPoly object - * @param[in] beta: Exponent, \f$ e^{-\beta r^2} \f$ - * @param[in] alpha: Coefficient, \f$ \alpha e^{-r^2} \f$ - * @param[in] pos: Position \f$ (x - pos[0]), (y - pos[1]), ... \f$ - * @param[in] pow: Max polynomial degree, \f$ P_0(x), P_1(y), ... \f$ - */ template GaussPoly::GaussPoly(double beta, double alpha, const Coord &pos, const std::array &power) : Gaussian(beta, alpha, pos, power) { @@ -261,10 +255,6 @@ template GaussPoly GaussPoly::mult(const GaussPoly &rhs) { */ } -/** @brief Multiply GaussPoly by scalar - * @param[in] c: Scalar to multiply - * @returns New GaussPoly - */ template GaussPoly GaussPoly::mult(double c) { GaussPoly g = *this; g.coef *= c; @@ -283,11 +273,6 @@ template void GaussPoly::setPow(const std::array &pow) { } } -/** @brief Set polynomial in given dimension - * - * @param[in] d: Cartesian direction - * @param[in] poly: Polynomial to set - */ template void GaussPoly::setPoly(int d, Polynomial &poly) { if (this->poly[d] != nullptr) { delete this->poly[d]; } this->poly[d] = new Polynomial(poly); diff --git a/src/functions/GaussPoly.h b/src/functions/GaussPoly.h index 97ed6f47d..052d7e307 100644 --- a/src/functions/GaussPoly.h +++ b/src/functions/GaussPoly.h @@ -7,8 +7,8 @@ * This file is part of MRCPP. * * MRCPP is free software: you can redistribute it and/or modify - * it under the terms of the GNU Lesser General Public License as published by - * the Free Software Foundation, either version 3 of the License, or + * it under the terms of the GNU Lesser General Public License + * as published by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * MRCPP is distributed in the hope that it will be useful, @@ -35,8 +35,10 @@ namespace mrcpp { -/** @class GaussPoly - * +/** + * @class GaussPoly + * @tparam D Spatial dimension (1, 2, or 3) + * * @brief Gaussian function in D dimensions with a general polynomial in front * * - Monodimensional Gaussian (GaussPoly<1>): @@ -50,51 +52,184 @@ namespace mrcpp { template class GaussPoly : public Gaussian { public: + /** + * @brief Constructor + * + * @param beta Exponent, \f$ e^{-\beta r^2} \f$ + * @param alpha Coefficient, \f$ \alpha e^{-r^2} \f$ + * @param[in] pos Position \f$ (x - pos[0]), (y - pos[1]), ... \f$ + * @param[in] pow Max polynomial degree, \f$ P_0(x), P_1(y), ... \f$ + */ GaussPoly(double alpha = 0.0, double coef = 1.0, const Coord &pos = {}, const std::array &power = {}); + + /** + * @brief Constructor + * + * @param[in] beta List of exponents, \f$ e^{-\beta r^2} \f$ + * @param alpha Coefficient, \f$ \alpha e^{-r^2} \f$ + * @param[in] pos Position \f$ (x - pos[0]), (y - pos[1]), ... \f$ + * @param[in] pow Max polynomial degree, \f$ P_0(x), P_1(y), ... \f$ + */ GaussPoly(const std::array &alpha, double coef, const Coord &pos = {}, const std::array &power = {}); + + /// @brief Copy constructor. GaussPoly(const GaussPoly &gp); + + /** + * @brief Construct from a GaussFunc + * @param[in] gf: GaussFunc to convert + */ GaussPoly(const GaussFunc &gf); + GaussPoly &operator=(const GaussPoly &gp) = delete; + + /** + * @brief Performs a deep copy + * @return Pointer to a new GaussFunc copy of this instance + */ Gaussian *copy() const override; + ~GaussPoly(); + /** + * @brief Calculates the squared norm of this GaussFunc + * @return The squared norm + */ double calcSquareNorm() const override; + /** + * @brief Evaluate the gaussian f(r) at a D-dimensional coordinate + * @param r Point (Coord) in physical space in the MRA box + * @return Function value f(r). + */ double evalf(const Coord &r) const override; + + /** + * @brief Evaluate the *1D* separable factor along axis @p dim + * @param r Coordinate along axis @p dim + * @param dim Axis index in [0, D-1]. + * + * @return The value of the 1D Gaussian factor g_dim(r), dim = {0, .., D-1} -> x,y,z... + */ double evalf1D(double r, int dim) const override; + /** + * @brief Convert this GaussPoly to a GaussExp object + * @return A GaussExp representing this GaussPoly + */ GaussExp asGaussExp() const override; + + /// @warning This method is currently not implemented. GaussPoly differentiate(int dir) const override; + /// @warning This method is currently not implemented. void multInPlace(const GaussPoly &rhs); + + /** @brief In-place product operator (delegates to @ref multInPlace). */ void operator*=(const GaussPoly &rhs) { multInPlace(rhs); } + + /// @warning This method is currently not implemented. GaussPoly mult(const GaussPoly &rhs); + + /** + * @brief Multiply this GaussPoly with a scalar + * @param c Scalar to multiply + * @returns Resulting GaussPoly + */ GaussPoly mult(double c); + + /** + * @brief Operator overload forwarding to mult + * @param rhs The GaussPoly to multiply with + * @return Resulting GaussPoly + * @warning @ref mult is currently not implemented. + */ GaussPoly operator*(const GaussPoly &rhs) { return mult(rhs); } + + /** + * @brief Operator overload forwarding to mult + * @param rhs Scalar to multiply with + * @return Resulting GaussPoly + */ GaussPoly operator*(double c) { return mult(c); } + /** + * @brief Returns the polynomial coefficients in a specified dimension + * @param i Dimension index + * @return The Eigen vector of coefficients + */ const Eigen::VectorXd &getPolyCoefs(int i) const { return poly[i]->getCoefs(); } + + /** + * @brief Returns the polynomial coefficients in a specified dimension + * @param i Dimension index + * @return The Eigen vector of coefficients + */ Eigen::VectorXd &getPolyCoefs(int i) { return poly[i]->getCoefs(); } + + /** + * @brief Returns the Polynomial in a specified dimension + * @param i Dimension index + * @return The Polynomial reference + */ const Polynomial &getPoly(int i) const { return *poly[i]; } + + /** + * @brief Returns the Polynomial in a specified dimension + * @param i Dimension index + * @return The Polynomial reference + */ Polynomial &getPoly(int i) { return *poly[i]; } + /** + * @brief Set the power in dimension d + * @param d Dimension index + * @param power Power to set + */ void setPow(int d, int pow) override; + + /** + * @brief Set the powers in all dimensions + * @param power Array of powers to set + */ void setPow(const std::array &pow) override; - void setPoly(int d, Polynomial &poly); + /** + * @brief Set polynomial in given dimension + * @param d Cartesian direction + * @param[in] poly Polynomial to set + */ + void setPoly(int d, Polynomial &poly); private: - Polynomial *poly[D]; - + Polynomial *poly[D]; ///< Per-axis polynomial factors + + /** + * @brief Recursive helper function to fill coefficient and power vectors for all terms + * @param[out] coefs Vector to fill with coefficients + * @param[out] power Vector to fill with power arrays + * @param pow Current power array being built + * @param dir Current dimension being processed + */ void fillCoefPowVector(std::vector &coefs, std::vector &power, int pow[D], int dir) const; + + /** + * @brief Recursive helper function to fill coefficient and power vectors for all terms + * @param[out] coefs Vector to fill with coefficients + * @param[out] power Vector to fill with power arrays + * @param pow Current power array being built + * @param dir Current dimension being processed + */ void fillCoefPowVector(std::vector &coefs, std::vector &power, std::array &pow, int dir) const; + + /// @brief Print GaussFunc to output stream std::ostream &print(std::ostream &o) const override; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/Gaussian.cpp b/src/functions/Gaussian.cpp index 6dbfa7c5b..e6431d971 100644 --- a/src/functions/Gaussian.cpp +++ b/src/functions/Gaussian.cpp @@ -171,16 +171,7 @@ template double Gaussian::calcOverlap(const Gaussian &inp) const { return S; } -/** @brief Generates a GaussExp that is semi-periodic around a unit-cell - * - * @returns Semi-periodic version of a Gaussian around a unit-cell - * @param[in] period: The period of the unit cell - * @param[in] nStdDev: Number of standard diviations covered in each direction. Default 4.0 - * - * @details nStdDev = 1, 2, 3 and 4 ensures atleast 68.27%, 95.45%, 99.73% and 99.99% of the - * integral is conserved with respect to the integration limits. - * - */ + template GaussExp Gaussian::periodify(const std::array &period, double nStdDev) const { GaussExp gauss_exp; auto pos_vec = std::vector>(); diff --git a/src/functions/Gaussian.h b/src/functions/Gaussian.h index ddb039202..7683f4aca 100644 --- a/src/functions/Gaussian.h +++ b/src/functions/Gaussian.h @@ -23,11 +23,6 @@ * */ -/** - * - * Base class for Gaussian type functions - */ - #pragma once #include @@ -40,75 +35,241 @@ namespace mrcpp { +/** + * @class Gaussian + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @brief Represent and manipulate Gaussian-type functions (GTFs) + * + * @details The Gaussian class is an abstract base class for + * representing Gaussian-type functions (GTFs) in D dimensions. + * A GTF is defined as + * \f$ f(\mathbf{r}) = C \prod_{d=0}^{D-1} g_d(r_d) \f$, where + * \f$ g_d(r_d) = (r_d - p_d)^{\alpha_d} \exp[-\beta_d (r_d - p_d)^2] \f$. + * Here, C is a global coefficient, p_d is the center position along axis d, + * \alpha_d is the exponent for the monomial term, and \beta_d is the exponent for the Gaussian envelope. + * + * The class provides methods for evaluating the function at given points, + * computing overlap integrals with other Gaussian functions, differentiating + * the function, and converting to Gaussian expansions suitable for pairwise operations. + */ template class Gaussian : public RepresentableFunction { public: + /** + * @brief Isotropic constructor (same exponent on all axes) + * @param a Exponent value α to be replicated on all axes (α[d] = a) + * @param c Global scalar coefficient + * @param r Center position (Coord), defaults to origin + * @param p Per-axis monomial powers (non-negative), stored as array + * + * @warning This ctor does not check positivity of @p a; callers are expected + * to pass α>0 (required for square integrability and σ = 1/√(2α)). + */ Gaussian(double a, double c, const Coord &r, const std::array &p); + + /** + * @brief Anisotropic constructor (different set of coefficients and exponents per each axis) + * @param a Exponent ARRAY α[d] per axis. + * @param c Global scalar coefficient. + * @param r Center position (Coord). + * @param p Per-axis monomial powers (non-negative). + */ Gaussian(const std::array &a, double c, const Coord &r, const std::array &p); - Gaussian &operator=(const Gaussian &gp) = delete; - virtual Gaussian *copy() const = 0; + + Gaussian &operator=(const Gaussian &gp) = delete; ///< Non-assignable; use clones. + virtual Gaussian *copy() const = 0; ///< Virtual copy (clone). virtual ~Gaussian() = default; + /** @name Evaluation API (to be implemented by subclasses) */ + ///@{ + /** + * @brief Evaluate the gaussian f(r) at a D-dimensional coordinate + * @param[in] r Point (Coord) in physical space in the MRA box + * @return Function value f(r). + */ virtual double evalf(const Coord &r) const = 0; + + /** + * @brief Evaluate the *1D* separable factor along axis @p dim + * @param r Coordinate along axis @p dim + * @param dim Axis index in [0, D-1]. + * + * @return The value of the 1D Gaussian factor g_dim(r), dim = {0, .., D-1} -> x,y,z... + */ virtual double evalf1D(double r, int dim) const = 0; + + /** + * @brief Evaluate a set of points in D dimensions, arranged in the matrix form + * @param[in] points Matrix (N×D): column d holds all coordinates along axis d. + * @param[out] values Matrix (N×D): on return, values(i,d) = evalf1D(points(i,d), d). + * + * @note This does *not* multiply across dimensions; it only fills the + * per-axis factors column-wise for later tensor products. + */ void evalf(const Eigen::MatrixXd &points, Eigen::MatrixXd &values) const; + ///@} + /** @name Integral properties and expansions */ + ///@{ + /** + * @brief Overlap ⟨this|inp⟩ of two gaussians + * + * @param[in] inp The other Gaussian instance + * + * @return The value of the overlap integral ⟨this|inp⟩ as a double + */ double calcOverlap(const Gaussian &inp) const; + + /// @return Exact L2 norm squared ∥f∥² (implemented by subclass) virtual double calcSquareNorm() const = 0; + + + /// @brief Represent as a sum of Gaussians (pure or polynomial-times-Gaussian), suitable for pairwise operations; implemented by subclass virtual GaussExp asGaussExp() const = 0; + + /** @brief Generates a GaussExp that is semi-periodic around a unit-cell + * + * @returns Semi-periodic version of a Gaussian around a unit-cell + * @param[in] period: The period of the unit cell + * @param[in] nStdDev: Number of standard diviations covered in each direction. Default 4.0 + * + * @details nStdDev = 1, 2, 3 and 4 ensures atleast 68.27%, 95.45%, 99.73% and 99.99% of the + * integral is conserved with respect to the integration limits. + */ GaussExp periodify(const std::array &period, double nStdDev = 4.0) const; + ///@} - /** @brief Compute analytic derivative of Gaussian - * @param[in] dir: Cartesian direction of derivative - * @returns New GaussPoly + /** @name Differential operators */ + ///@{ + /** + * @brief Analytic derivative d/dx_dir (Cartesian direction) of the Gaussian + * @param dir Axis index in [0, D-1] + * + * @return A GaussPoly representing the derivative (polynomial×Gaussian) */ virtual GaussPoly differentiate(int dir) const = 0; + ///@} + /** @name Screening and normalization */ + ///@{ + /** + * @brief Build ±nσ bounds around the center on each axis and enable screening + * + * @param stdDeviations Number of standard deviations n used for the box + * + * @note Used to cheaply cull tiles/intervals that cannot contribute + */ void calcScreening(double stdDeviations); - /** @brief Rescale function by its norm \f$ ||f||^{-1} \f$ */ + /** + * @brief Rescale the Gaussian so that its L2 norm equals 1. + * @note Calls calcSquareNorm() from the derived class + */ void normalize() { double norm = std::sqrt(calcSquareNorm()); multConstInPlace(1.0 / norm); } + ///@} + + /** @name Algebra on the pure Gaussian core */ + ///@{ + /** + * @brief Complete-the-square product of two *pure* Gaussians into *this*. + * Polynomial factors are handled in derived types (GaussFunc/GaussPoly). + */ void multPureGauss(const Gaussian &lhs, const Gaussian &rhs); + + /// @brief Scale the global coefficient by a constant void multConstInPlace(double c) { this->coef *= c; } + + /// @brief Shorthand for multConstInPlace void operator*=(double c) { multConstInPlace(c); } + ///@} + /** @name Screening access */ + ///@{ bool getScreen() const { return screen; } + /** + * @brief Tile-level culling test for dyadic box at scale n and translation l + * @return True if the box is completely outside the screening bounds and can be skipped + */ bool checkScreen(int n, const int *l) const; + ///@} + - int getPower(int i) const { return power[i]; } - double getCoef() const { return coef; } - double getExp(int i) const { return alpha[i]; } - const std::array &getPower() const { return power; } - const std::array &getPos() const { return pos; } - std::array getExp() const { return alpha; } - - virtual void setPow(const std::array &power) = 0; - virtual void setPow(int d, int power) = 0; - void setScreen(bool _screen) { this->screen = _screen; } - void setCoef(double cf) { this->coef = cf; } - void setExp(double _alpha) { this->alpha.fill(_alpha); } - void setExp(const std::array &_alpha) { this->alpha = _alpha; } - void setPos(const std::array &r) { this->pos = r; } + // some getters and setters + /** @name Parameter accessors */ + ///@{ + int getPower(int i) const { return power[i]; } ///< Get monomial power on axis i + double getCoef() const { return coef; } ///< Get monomial coefficient + double getExp(int i) const { return alpha[i]; } ///< Get monomial exponent on axis i + const std::array &getPower() const { return power; } ///< Get monomial powers on the axis in an array + const std::array &getPos() const { return pos; } ///< Get monomial positions on the axis in an array + std::array getExp() const { return alpha; } ///< Get monomial exponent on the axis in an array + ///@} + /** @name Parameter mutators */ + ///@{ + virtual void setPow(const std::array &power) = 0; ///< Set all monomial powers. + virtual void setPow(int d, int power) = 0; ///< Set monomial power on axis d. + void setScreen(bool _screen) { this->screen = _screen; } ///< Enable/disable screening flag. + void setCoef(double cf) { this->coef = cf; } ///< Set global coefficient. + void setExp(double _alpha) { this->alpha.fill(_alpha); } ///< Set isotropic exponent α[d]=_alpha. + void setExp(const std::array &_alpha) { this->alpha = _alpha; } ///< Set per-axis exponents. + void setPos(const std::array &r) { this->pos = r; } ///< Set center coordinates. + ///@} + + /** @brief Stream pretty-printer (delegates to virtual print()). */ friend std::ostream &operator<<(std::ostream &o, const Gaussian &gauss) { return gauss.print(o); } - friend class GaussExp; + friend class GaussExp; ///< Allows GaussExp to access internals when assembling expansions. protected: - bool screen; - double coef; /**< constant factor */ - std::array power; /**< max power in each dim */ - std::array alpha; /**< exponent */ - Coord pos; /**< center */ + /** @name Core parameters (POD) */ + ///@{ + bool screen; ///< If true, use [A,B] screening in fast checks (set via calcScreening / setScreen) + double coef; ///< Global scale factor (α in the docs above) + std::array power; ///< Monomial powers per axis (non-negative integers) + std::array alpha; ///< Exponents per axis (>0). Controls width: σ_d = 1/√(2 α_d) + Coord pos; ///< Center coordinates + ///@} + /** @name Visibility / culling helpers used by trees and projection */ + ///@{ + /** + * @brief Heuristic visibility vs. resolution scale and quadrature sampling + * @param scale Dyadic scale (tile size ~ 2^{-scale}) + * @param nQuadPts Number of quadrature points per tile edge + * @return false if the Gaussian is “too narrow” to be represented at this scale + */ bool isVisibleAtScale(int scale, int nQuadPts) const; + + /** + * @brief Quick check whether the function is essentially zero on [a,b] per axis, + * using a ±5σ bounding rule (implementation in the .cpp) + * @param a Lower bounds array of length D + * @param b Upper bounds array of length D + * @return true if the function is effectively zero on [a,b] + */ bool isZeroOnInterval(const double *a, const double *b) const; + ///@} + /** + * @brief Maximum standard deviation across axes: max_d 1/√(2 α_d). + * @details Used by periodify() to decide how many neighboring images to include + * + * @return The maximum standard deviation among all axes + */ double getMaximumStandardDiviation() const; + /** + * @brief Subclass hook for stream output; should print parameters in a readable way + * @param o The output stream + * @return The output stream + */ virtual std::ostream &print(std::ostream &o) const = 0; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/JpowerIntegrals.h b/src/functions/JpowerIntegrals.h index dea01305f..eab08a50a 100644 --- a/src/functions/JpowerIntegrals.h +++ b/src/functions/JpowerIntegrals.h @@ -7,8 +7,8 @@ * This file is part of MRCPP. * * MRCPP is free software: you can redistribute it and/or modify - * it under the terms of the GNU Lesser General Public License as published by - * the Free Software Foundation, either version 3 of the License, or + * it under the terms of the GNU Lesser General Public License + * as published by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * MRCPP is distributed in the hope that it will be useful, @@ -34,116 +34,143 @@ namespace mrcpp { /** @class JpowerIntegrals * - * @brief A class needed for construction Schrodinger time evolution operator + * @brief Precompute and store families of power–type integrals \f$ \{\widetilde J_m(l,a)\}_{m\ge 0} \f$ + * for integer shifts \f$ l \f$, used in the Schrödinger time–evolution operator. * - * @details A two dimensional array consisting of integrals \f$ J_m \f$ as follows. - * Our main operator has the following expansion + * @details + * This helper class generates the sequences + * \f$ \big(\widetilde J_m(l,a)\big)_{m=0}^{M} \f$ for a finite set of integer + * translations \f$ l \in \{-(2^n-1),\ldots,-1,0,1,\ldots,2^n-1\} \f$, where + * \f$ n=\texttt{scaling} \f$ and \f$ a>0 \f$ is the time–scaled parameter + * (typically \f$ a = t\,\mathfrak N^2 = t\,4^{\mathfrak n} \f$). + * + * The integrals appear in the expansion of the (matrix–valued) operator + * \f[ + * \big[ \sigma_l^{\mathfrak n} \big]_{pj}(a) + * = + * \sum_{k=0}^{\infty} C_{jp}^{2k}\, + * \widetilde J_{\,2k + j + p}(l,a), + * \f] + * where the scalar building blocks are * \f[ - * \left[ \sigma_l^{\mathfrak n} \right]_{pj} - * (a) + * \widetilde J_m(l,a) * = - * \sum_{k = 0}^{\infty} - * C_{jp}^{2k} - * \widetilde J_{2k + j + p}(l, a) - * , + * \frac{e^{i\frac{\pi}{4}(m-1)}}{2\pi\,(m+2)!} + * \int_{\mathbb R} + * \exp\!\Big( + * \rho\,l\,e^{i\pi/4} - a\,\rho^2 + * \Big)\, + * \rho^m\, d\rho . * \f] - * where \f$ a = t \mathfrak N^2 = t 4^{\mathfrak n} \f$ - * and + * + * In the code, \f$ \widetilde J_m \f$ are produced by the three–term recurrence + * (valid for \f$ m=0,1,2,\ldots \f$) * \f[ - * \widetilde J_m - * = - * \frac - * { - * I_m - * e^{ i \frac {\pi}4 (m - 1) } - * } - * { - * 2 \pi ( m + 2 )! - * } - * = - * \frac - * { - * e^{ i \frac {\pi}4 (m - 1) } - * } - * { - * 2 \pi ( m + 2 )! - * } - * \int_{\mathbb R} - * \exp - * \left( - * \rho l \exp \left( i \frac \pi 4 \right) - a \rho^2 - * \right) - * \rho^m - * d \rho + * \widetilde J_{m+1} + * = + * \frac{i}{2a\,(m+3)}\left( + * l\,\widetilde J_m + \frac{m}{m+2}\,\widetilde J_{m-1} + * \right), + * \qquad \widetilde J_{-1}=0, * \f] - * satisfying the following relation + * with the closed–form seed * \f[ - * \widetilde J_{m+1} - * = - * \frac - * { - * il - * } - * { - * 2a (m + 3) - * } - * \widetilde J_m - * + - * \frac {im}{2a(m + 2)(m + 3)} - * \widetilde J_{m-1} - * = - * \frac - * { - * i - * } - * { - * 2a (m + 3) - * } - * \left( - * l - * \widetilde J_m - * + - * \frac {m}{(m + 2)} - * \widetilde J_{m-1} - * \right) - * , \quad - * m = 0, 1, 2, \ldots, + * \widetilde J_0(l,a) + * = + * \frac{e^{-i\pi/4}}{4\sqrt{\pi a}}\, + * \exp\!\left(\frac{i\,l^2}{4a}\right). * \f] - * with \f$ \widetilde J_{-1} = 0 \f$ and + * + * ### Storage layout + * For convenience of iteration, the container `integrals` is filled + * in the following order: * \f[ - * \label{power_integral_0} - * \widetilde J_0 - * = - * \frac{ e^{ -i \frac{\pi}4 } }{ 4 \sqrt{ \pi a } } - * \exp - * \left( - * \frac{il^2}{4a} - * \right) - * . + * l = 0, 1, \ldots, 2^n-1,\; 1-2^n, 2-2^n, \ldots, -2, -1. * \f] + * Each entry is a vector + * \code + * integrals[k] == { J_0(l), J_1(l), ..., J_M(l) } + * \endcode + * of complex values for a fixed shift \f$ l \f$. + * + * ### Intended use + * - Construct once for a given \f$ a \f$, \f$ n \f$ and \f$ M \f$. + * - Access the sequence for a particular shift via `operator[](l)`. + * - Combine with precomputed correlation coefficients \f$ C_{jp}^{2k} \f$ + * to assemble \f$ [\sigma_l^{\mathfrak n}]_{pj}(a) \f$. * - * + * @note The class offers an internal @ref crop routine to trim negligible + * tail entries of a sequence (based on a magnitude threshold). Whether and + * when cropping is used is an implementation detail; sequences are always + * returned in full length \f$ M\!+\!1 \f$ from the constructor path. */ class JpowerIntegrals { public: - /// @brief creates an array of power integrals - /// @param a : parameter a - /// @param scaling : scaling level - /// @param M : maximum amount of integrals for each l - /// @param threshold : lower limit for neglecting the integrals - /// @details The array is orginised as a vector ordered as \f$l = 0, 1, 2, \ldots, 2^n - 1, 1 - 2^n, 2 - 2^n, \ldots, -2, -1 \f$. + /// @brief Construct and precompute all \f$ \widetilde J_m(l,a) \f$ for + /// \f$ l\in[-(2^n-1),\ldots,2^n-1] \f$ and \f$ m=0,\ldots,M \f$. + /// + /// @param a Time–scaled parameter (typically \f$ a=t\,4^{\mathfrak n} \f$), must be positive. + /// @param scaling Level \f$ n \f$ that defines \f$ N=2^n \f$ distinct nonnegative shifts + /// (the negative ones are added symmetrically after them). + /// @param M Highest power index in the sequence (inclusive). Each stored vector + /// has length \f$ M+1 \f$ starting at \f$ m=0 \f$. + /// @param threshold Magnitude cutoff used by the private @ref crop routine to remove + /// negligible tail entries (if cropping is applied internally). + /// + /// @details + /// The internal ordering of the outer container is + /// \f$ l=0,1,\ldots,2^n-1, 1-2^n,\ldots,-1 \f$. This ordering is mirrored by + /// the @ref operator[] which will map negative indices to the appropriate + /// position of the storage. JpowerIntegrals(double a, int scaling, int M, double threshold = 1.0e-15); //JpowerIntegrals(const JpowerIntegrals& other); - - int scaling; //it is probably not used + /// @brief Scaling level \f$ n \f$ (kept for reference; not used directly in lookups). + int scaling; + + /// @brief Container of sequences \f$ \{\widetilde J_m(l,a)\}_{m=0}^M \f$ for all shifts \f$ l \f$. + /// Each element is a vector of length \f$ M+1 \f$: + /// \code + /// integrals[idx_for_l] = { J_0(l), J_1(l), ..., J_M(l) } + /// \endcode std::vector>> integrals; + /// @brief Mutable access to the precomputed sequence for a given shift \f$ l \f$. + /// + /// @param index Integer shift \f$ l \in [-(2^n-1), \ldots, 2^n-1] \f$. + /// @return Reference to the vector \f$ [J_0(l), \ldots, J_M(l)] \f$. + /// + /// @details + /// Negative indices are transparently remapped to the internal storage order + /// (see the constructor’s documentation). This allows natural use like `obj[-3]`. std::vector> & operator[](int index); + private: + /// @brief Build one full sequence \f$ \{\widetilde J_m(l,a)\}_{m=0}^M \f$ for a fixed shift @p l. + /// + /// @param l Shift index. + /// @param a Time–scaled parameter (positive). + /// @param M Highest power index. + /// @param threshold Magnitude cutoff passed to @ref crop (if enabled). + /// + /// @return A vector with entries \f$ [J_0(l), J_1(l), \ldots, J_M(l)] \f$. + /// + /// @details + /// The routine uses the closed–form seed \f$ \widetilde J_0(l,a) \f$ and the + /// recurrence relation to fill the sequence up to \f$ m=M \f$. std::vector> calculate_J_power_integrals(int l, double a, int M, double threshold); + + /// @brief Remove negligible tail entries from a sequence in place. + /// + /// @param J The sequence to be cropped (modified in place). + /// @param threshold Entries with both real and imaginary parts below @p threshold + /// in absolute value are considered negligible. + /// + /// @details + /// Cropping can be used to shrink \f$ [J_0,\ldots,J_M] \f$ to + /// \f$ [J_0,\ldots,J_{m^\*}] \f$ once the tail has decayed under the requested tolerance. void crop(std::vector> & J, double threshold); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/LegendrePoly.cpp b/src/functions/LegendrePoly.cpp index 4dce049b0..e785449f3 100644 --- a/src/functions/LegendrePoly.cpp +++ b/src/functions/LegendrePoly.cpp @@ -42,12 +42,10 @@ namespace mrcpp { using LegendreCache = ObjectCache; -/** Legendre polynomial constructed on [-1,1] and - * scaled by n and translated by l */ LegendrePoly::LegendrePoly(int k, double n, double l) : Polynomial(k) { // Since we create Legendre polynomials recursively on [-1,1] - // we cache all lower order polynomilas for future use. + // we cache all lower order polynomials for future use. LegendreCache &Cache = LegendreCache::getInstance(); if (k >= 1) { if (not Cache.hasId(k - 1)) { @@ -63,7 +61,6 @@ LegendrePoly::LegendrePoly(int k, double n, double l) dilate(n); } -/** Compute Legendre polynomial coefs on interval [-1,1] */ void LegendrePoly::computeLegendrePolynomial(int k) { assert(this->size() >= k); if (k == 0) { @@ -91,9 +88,6 @@ void LegendrePoly::computeLegendrePolynomial(int k) { } } -/** Calculate the value of an n:th order Legendre polynominal in x, including - * the first derivative. - */ Vector2d LegendrePoly::firstDerivative(double x) const { double c1, c2, c4, ym, yp, y; double dy, dyp, dym; @@ -139,9 +133,6 @@ Vector2d LegendrePoly::firstDerivative(double x) const { return val; } -/** Calculate the value of an n:th order Legendre polynominal in x, including - * first and second derivatives. - */ Vector3d LegendrePoly::secondDerivative(double x) const { NOT_IMPLEMENTED_ABORT; double c1, c2, c4, ym, yp, y, d2y; diff --git a/src/functions/LegendrePoly.h b/src/functions/LegendrePoly.h index 8e1bbd2a9..c719bbe8f 100644 --- a/src/functions/LegendrePoly.h +++ b/src/functions/LegendrePoly.h @@ -7,8 +7,8 @@ * This file is part of MRCPP. * * MRCPP is free software: you can redistribute it and/or modify - * it under the terms of the GNU Lesser General Public License as published by - * the Free Software Foundation, either version 3 of the License, or + * it under the terms of the GNU Lesser General Public License + * as published by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * MRCPP is distributed in the hope that it will be useful, @@ -29,15 +29,40 @@ namespace mrcpp { + /** + * @class LegendrePoly + * @brief Class defining a Legendre polynomial of degree k + */ class LegendrePoly final : public Polynomial { public: + /** + * @brief Construct and compute a Legendre polynomial of degree k + * @param k Degree (order) of the Legendre polynomial + * @param n Dilation factor (applied after translation) + * @param l Translation (applied before dilation) + */ LegendrePoly(int k, double n = 1.0, double l = 0.0); + /** + * @brief Evaluate value and first derivative of this Legendre polynomial in x + * @param x External evaluation point + * @return Value and first derivative as an Eigen::Vector2d + */ Eigen::Vector2d firstDerivative(double x) const; + + /** + * @brief Evaluate second derivative of this Legendre polynomial in x + * @param x External evaluation point + * @return Value, first and second derivative as an Eigen::Vector3d + */ Eigen::Vector3d secondDerivative(double x) const; private: + /** + * @brief Recursively compute the Legendre polynomial of order k on interval [-1,1] + * @param k Order of the Legendre polynomial + */ void computeLegendrePolynomial(int k); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/Polynomial.cpp b/src/functions/Polynomial.cpp index c54acc148..08d8cdfaf 100644 --- a/src/functions/Polynomial.cpp +++ b/src/functions/Polynomial.cpp @@ -24,12 +24,9 @@ */ /** - * * \date Jun 7, 2009 * \author Jonas Juselius \n * CTCC, University of Tromsø - * - * */ #include @@ -42,8 +39,6 @@ using namespace Eigen; namespace mrcpp { -/** Construct polynomial of order zero with given size and bounds. - * Includes default constructor. */ Polynomial::Polynomial(int k, const double *a, const double *b) : RepresentableFunction<1, double>(a, b) { assert(k >= 0); @@ -60,7 +55,6 @@ Polynomial::Polynomial(double c, int k, const double *a, const double *b) for (int i = 0; i <= k; i++) { this->coefs[i] *= std::pow(c, k - i); } } -/** Construct polynomial with given coefficient vector and bounds. */ Polynomial::Polynomial(const VectorXd &c, const double *a, const double *b) : RepresentableFunction<1>(a, b) { this->N = 1.0; @@ -68,7 +62,6 @@ Polynomial::Polynomial(const VectorXd &c, const double *a, const double *b) setCoefs(c); } -/** Makes a complete copy of the polynomial */ Polynomial::Polynomial(const Polynomial &poly) : RepresentableFunction<1>(poly) { this->N = poly.N; @@ -76,7 +69,6 @@ Polynomial::Polynomial(const Polynomial &poly) this->coefs = poly.coefs; } -/** Copies only the function, not its bounds */ Polynomial &Polynomial::operator=(const Polynomial &poly) { RepresentableFunction<1>::operator=(poly); this->N = poly.N; @@ -85,7 +77,6 @@ Polynomial &Polynomial::operator=(const Polynomial &poly) { return *this; } -/** Evaluate scaled and translated polynomial */ double Polynomial::evalf(double x) const { if (isBounded()) { if (x < this->getScaledLowerBound()) return 0.0; @@ -100,35 +91,28 @@ double Polynomial::evalf(double x) const { return y; } -/** This returns the actual scaled lower bound */ double Polynomial::getScaledLowerBound() const { if (not isBounded()) MSG_ERROR("Unbounded polynomial"); return (1.0 / this->N * (this->A[0] + this->L)); } -/** This returns the actual scaled upper bound */ double Polynomial::getScaledUpperBound() const { if (not isBounded()) MSG_ERROR("Unbounded polynomial"); return (1.0 / this->N * (this->B[0] + this->L)); } -/** Divide by norm of (bounded) polynomial. */ void Polynomial::normalize() { double sqNorm = calcSquareNorm(); if (sqNorm < 0.0) MSG_ABORT("Cannot normalize polynomial"); (*this) *= 1.0 / std::sqrt(sqNorm); } -/** Compute the squared L2-norm of the (bounded) polynomial. - * Unbounded polynomials return -1.0. */ double Polynomial::calcSquareNorm() { double sqNorm = -1.0; if (isBounded()) { sqNorm = this->innerProduct(*this); } return sqNorm; } -/** Returns the order of the highest non-zero coef. - * NB: Not the length of the coefs vector. */ int Polynomial::getOrder() const { int n = 0; for (int i = 0; i < this->coefs.size(); i++) { @@ -137,13 +121,11 @@ int Polynomial::getOrder() const { return n; } -/** Calculate P = c*P */ Polynomial &Polynomial::operator*=(double c) { this->coefs = c * this->coefs; return *this; } -/** Calculate P = P*Q */ Polynomial &Polynomial::operator*=(const Polynomial &Q) { Polynomial &P = *this; if (std::abs(P.getDilation() - Q.getDilation()) > MachineZero) { MSG_ERROR("Polynomials not defined on same scale."); } @@ -160,7 +142,6 @@ Polynomial &Polynomial::operator*=(const Polynomial &Q) { return P; } -/** Calculate Q = c*P */ Polynomial Polynomial::operator*(double c) const { const Polynomial &P = *this; Polynomial Q(P); @@ -168,8 +149,6 @@ Polynomial Polynomial::operator*(double c) const { return Q; } -/** Calculate R = P*Q. - * Returns unbounded polynomial. */ Polynomial Polynomial::operator*(const Polynomial &Q) const { const Polynomial &P = *this; Polynomial R; @@ -178,19 +157,16 @@ Polynomial Polynomial::operator*(const Polynomial &Q) const { return R; } -/** Calculate P = P + Q. */ Polynomial &Polynomial::operator+=(const Polynomial &Q) { this->addInPlace(1.0, Q); return *this; } -/** Calculate P = P - Q. */ Polynomial &Polynomial::operator-=(const Polynomial &Q) { this->addInPlace(-1.0, Q); return *this; } -/** Calculate P = P + c*Q. */ void Polynomial::addInPlace(double c, const Polynomial &Q) { Polynomial &P = *this; if (std::abs(P.getDilation() - Q.getDilation()) > MachineZero) { MSG_ERROR("Polynomials not defined on same scale."); } @@ -208,8 +184,6 @@ void Polynomial::addInPlace(double c, const Polynomial &Q) { P.setCoefs(newCoefs); } -/** Calculate R = P + c*Q, with a default c = 1.0. - * Returns unbounded polynomial. */ Polynomial Polynomial::add(double c, const Polynomial &Q) const { const Polynomial &P = *this; Polynomial R; @@ -218,7 +192,6 @@ Polynomial Polynomial::add(double c, const Polynomial &Q) const { return R; } -/** Calculate Q = dP/dx */ Polynomial Polynomial::calcDerivative() const { const Polynomial &P = *this; Polynomial Q(P); @@ -226,7 +199,6 @@ Polynomial Polynomial::calcDerivative() const { return Q; } -/** Calculate P = dP/dx */ void Polynomial::calcDerivativeInPlace() { Polynomial &P = *this; int P_order = P.getOrder(); @@ -236,7 +208,6 @@ void Polynomial::calcDerivativeInPlace() { P.setCoefs(newCoefs); } -/** Calculate indefinite integral Q = \int dP dx, integration constant set to zero */ Polynomial Polynomial::calcAntiDerivative() const { const Polynomial &P = *this; Polynomial Q(P); @@ -244,7 +215,6 @@ Polynomial Polynomial::calcAntiDerivative() const { return Q; } -/** Calculate indefinite integral P = \int dP dx, integration constant set to zero */ void Polynomial::calcAntiDerivativeInPlace() { Polynomial &P = *this; int P_order = P.getOrder(); @@ -256,7 +226,6 @@ void Polynomial::calcAntiDerivativeInPlace() { P.setCoefs(newCoefs); } -/** Integrate the polynomial P on [a,b] analytically */ double Polynomial::integrate(const double *a, const double *b) const { double lb = -DBL_MAX, ub = DBL_MAX; if (this->isBounded()) { @@ -275,7 +244,6 @@ double Polynomial::integrate(const double *a, const double *b) const { return sfac * (antidiff.evalf(ub) - antidiff.evalf(lb)); } -/** Compute analytically on interval defined by the calling polynomial. */ double Polynomial::innerProduct(const Polynomial &Q) const { const Polynomial &P = *this; if (not P.isBounded()) MSG_ERROR("Unbounded polynomial"); diff --git a/src/functions/Polynomial.h b/src/functions/Polynomial.h index 93e3ec77d..ecdcfdab3 100644 --- a/src/functions/Polynomial.h +++ b/src/functions/Polynomial.h @@ -23,17 +23,6 @@ * */ -/** - * - * Base class for general polynomials with reasonably advanced - * properties. The Polynomial class(es) are not implemented in the - * most efficient manner, because they are only evaluated a fixed - * number of times in a few predefined points, and all other - * evaluations are done by linear transformations. PolynomialCache - * implements the fast, and static const versions of the various - * 4Polynomials. - */ - #pragma once #include @@ -44,72 +33,237 @@ namespace mrcpp { +/** + * @class Polynomial + * + * @brief Base class for general polynomials + * + * @details The Polynomial class(es) are not implemented in the + * most efficient manner, because they are only evaluated a fixed + * number of times in a few predefined points, and all other + * evaluations are done by linear transformations. PolynomialCache + * implements the fast, and static const versions of the various + * 4Polynomials. + */ class Polynomial : public RepresentableFunction<1, double> { public: + /** + * @brief Construct polynomial of order zero with given bounds + * @param k Order of the polynomial + * @param a Lower bound in x as raw pointer + * @param b Upper bound in x as raw pointer + */ Polynomial(int k = 0, const double *a = nullptr, const double *b = nullptr); + + /** + * @brief Construct polynomial of order k with given bounds + * @param k Order of the polynomial + * @param a Lower bound in x as vector + * @param b Upper bound in x as vector + */ Polynomial(int k, const std::vector &a, const std::vector &b) : Polynomial(k, a.data(), b.data()) {} - Polynomial(const Eigen::VectorXd &c, const double *a = nullptr, const double *b = nullptr); - Polynomial(const Eigen::VectorXd &c, const std::vector &a, const std::vector &b) - : Polynomial(c, a.data(), b.data()) {} + + /** + * @brief Construct polynomial with given coefficient, order and bounds + * @param c Coefficient of the polynomial + * @param k Order of the polynomial + * @param a Lower bound in x as raw pointer + * @param b Upper bound in x as raw pointer + */ Polynomial(double c, int k = 0, const double *a = nullptr, const double *b = nullptr); + + /** + * @brief Construct polynomial with given coefficient, order and bounds + * @param c Coefficient of the polynomial + * @param k Order of the polynomial + * @param a Lower bound in x as vector + * @param b Upper bound in x as vector + */ Polynomial(double c, int k, const std::vector &a, const std::vector &b) : Polynomial(c, k, a.data(), b.data()) {} + + /** + * @brief Construct polynomial with given coefficient vector and bounds + * @param c Coefficient vector + * @param a Lower bound in x as raw pointer + * @param b Upper bound in x as raw pointer + */ + Polynomial(const Eigen::VectorXd &c, const double *a = nullptr, const double *b = nullptr); + + /** + * @brief Construct polynomial with given coefficient vector and bounds + * @param c Coefficient vector + * @param a Lower bound in x as vector + * @param b Upper bound in x as vector + */ + Polynomial(const Eigen::VectorXd &c, const std::vector &a, const std::vector &b) + : Polynomial(c, a.data(), b.data()) {} + + /** @brief Copy constructor */ Polynomial(const Polynomial &poly); + /** @brief Assignment operator, copies oly the function, not its bounds */ Polynomial &operator=(const Polynomial &poly); + /** @brief Virtual destructor */ virtual ~Polynomial() = default; + /** + * @brief Evaluate scaled and translated polynomial + * @param x External evaluation point + * @return The polynomial value at x + */ double evalf(double x) const; + + /** + * @brief Evaluate scaled and translated polynomial at a given point + * @param r 1D-Cartesian coordinate + * @return The polynomial value at r + */ double evalf(const Coord<1> &r) const { return evalf(r[0]); } - double getScaledLowerBound() const; - double getScaledUpperBound() const; + double getScaledLowerBound() const; ///< @return The actual scaled lower bound + double getScaledUpperBound() const; ///< @return The actual scaled upper bound + + void normalize(); ///< @brief Divide by norm of (bounded) polynomial - void normalize(); + /** + * @brief Calculated squared L2 norm of the (bounded) polynomial + * @return Squared L2 norm, -1 if unbounded + */ double calcSquareNorm(); - double getTranslation() const { return this->L; } - double getDilation() const { return this->N; } + double getTranslation() const { return this->L; } ///< @return Current translation + double getDilation() const { return this->N; } ///< @return Current dilation - void setDilation(double n) { this->N = n; } - void setTranslation(double l) { this->L = l; } - void dilate(double n) { this->N *= n; } - void translate(double l) { this->L += this->N * l; } + void setDilation(double n) { this->N = n; } ///< @brief Set dilation factor N + void setTranslation(double l) { this->L = l; } ///< @brief Set translation L + void dilate(double n) { this->N *= n; } ///< @brief Dilate by factor n + void translate(double l) { this->L += this->N * l; } ///< @brief Translate by l - int size() const { return this->coefs.size(); } ///< Length of coefs vector - int getOrder() const; - void clearCoefs() { this->coefs = Eigen::VectorXd::Zero(1); } - void setZero() { this->coefs = Eigen::VectorXd::Zero(this->coefs.size()); } - void setCoefs(const Eigen::VectorXd &c) { this->coefs = c; } + int size() const { return this->coefs.size(); } ///< @return The size of the coefficient vector + int getOrder() const; ///< @return The order of the highest non-zero coefficient + + void clearCoefs() { this->coefs = Eigen::VectorXd::Zero(1); } ///< @brief Clear all coefficients + void setZero() { this->coefs = Eigen::VectorXd::Zero(this->coefs.size()); } ///< @brief Set all coefficients to zero + void setCoefs(const Eigen::VectorXd &c) { this->coefs = c; } ///< @brief Replace the coefficient vector with a new one - Eigen::VectorXd &getCoefs() { return this->coefs; } - const Eigen::VectorXd &getCoefs() const { return this->coefs; } + Eigen::VectorXd &getCoefs() { return this->coefs; } ///< @return The coefficient vector + const Eigen::VectorXd &getCoefs() const { return this->coefs; } ///< @return The coefficient vector (const version) + /** + * @brief Calculates the derivative \f$ Q = dP/dx \f$ of this polynomial + * @return The derivative polynomial Q + */ Polynomial calcDerivative() const; + + /** + * @brief Calculates the indefinite integral \f$ Q = \int P\,dx \f$ of this polynomial, with constant = 0 + * @return The indefinite integral polynomial Q + */ Polynomial calcAntiDerivative() const; + /** + * @brief Calculates the derivative \f$ P \leftarrow dP/dx \f$ of this polynomial in-place + * @details Replaces the current polynomial with its derivative, i.e. \f$ P \leftarrow dP/dx \f$. + */ void calcDerivativeInPlace(); + + /** + * @brief Calculates the indefinite integral \f$ P \leftarrow \int P\,dx \f$ of this polynomial in-place, with constant = 0 + * @details Replaces the current polynomial with its indefinite integral, i.e. \f$ P \leftarrow \int P\,dx \f$, with integration constant set to zero. + */ void calcAntiDerivativeInPlace(); + /** + * @brief Calculates the analytical integral of P on [a, b] + * @param a Lower bound of the integration interval, defaults to the polynomial's lower bound + * @param b Upper bound of the integration interval, defaults to the polynomial's upper bound + * @return The integral of P on [a, b] + */ double integrate(const double *a = 0, const double *b = 0) const; - double innerProduct(const Polynomial &p) const; + + /** + * @brief Analytically calculates the inner product of this polynomial with another one + * @param Q The other polynomial + * @return The inner product of the two polynomials + */ + double innerProduct(const Polynomial &Q) const; + /** + * @brief In-place sum \f$ P \leftarrow P + c\,Q \f$. + * @param c Scalar multiplier for Q + * @param Q The polynomial to be added to P + */ void addInPlace(double c, const Polynomial &Q); + + /** + * @brief Sum \f$ R = P + c\,Q \f$. + * @param c Scalar multiplier for Q + * @param Q The polynomial to be added to P + * @return The resulting polynomial + */ Polynomial add(double c, const Polynomial &Q) const; + /** + * @brief Scalar product of Polynomial with c + * @param c The scalar multiplier + * @return The resulting polynomial + */ Polynomial operator*(double c) const; + + /** + * @brief Product of two Polynomials + * @param Q The other polynomial + * @return The resulting (unbounded) polynomial + */ Polynomial operator*(const Polynomial &Q) const; + + /** + * @brief Sum two Polynomials + * @param Q The other polynomial + * @return The resulting polynomial + */ Polynomial operator+(const Polynomial &Q) const { return add(1.0, Q); } + + /** + * @brief Difference of two Polynomials + * @param Q The other polynomial + * @return The resulting polynomial + */ Polynomial operator-(const Polynomial &Q) const { return add(-1.0, Q); } + + /** + * @brief In-place scalar product. + * @param c The scalar multiplier + * @return Reference to the modified polynomial + */ Polynomial &operator*=(double c); + + /** + * @brief In-place product of two Polynomials + * @param Q The other polynomial + * @return Reference to the modified polynomial + */ Polynomial &operator*=(const Polynomial &Q); + + /** + * @brief In-place sum of two Polynomials + * @param Q The other polynomial + * @return Reference to the modified polynomial + */ Polynomial &operator+=(const Polynomial &Q); + + /** + * @brief In-place difference of two Polynomials + * @param Q The other polynomial + * @return Reference to the modified polynomial + */ Polynomial &operator-=(const Polynomial &Q); protected: - double N; ///< Dilation coeff - double L; ///< Translation coeff + double N; ///< Dilation coefficient + double L; ///< Translation coefficient Eigen::VectorXd coefs; ///< Expansion coefficients }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/RepresentableFunction.h b/src/functions/RepresentableFunction.h index 6123e3051..dbadeb5a1 100644 --- a/src/functions/RepresentableFunction.h +++ b/src/functions/RepresentableFunction.h @@ -24,9 +24,28 @@ */ /* + * # RepresentableFunction (interface) * - * Base class of functions that is representable in the mw basis. - * This includes gaussians, expansions, polynomials and even function trees. + * Base interface for objects that can be **represented/evaluated** in the + * multiresolution (multiwavelet) framework. Typical implementations include + * analytic functors, Gaussian(-like) functions/expansions, polynomials and + * function trees. + * + * ## Bounding box semantics + * A function may be marked **bounded** on a Cartesian product of *half-open* + * intervals: + * + * Π_d [ A_d, B_d ) + * + * The half-open convention prevents double counting on shared cell faces and + * is used consistently by `outOfBounds()`. If a function is **unbounded**, its + * bounds pointers are `nullptr` and containment checks always succeed. + * + * ## Lifetime & copying + * - Bounds (arrays `A`, `B` of length `D`) are owned by the instance when set. + * - The copy constructor **deep-copies** the bounds (if any). + * - The assignment operator in the base class returns `*this` (does not copy + * bounds), leaving copying policy to derived classes if needed. */ #pragma once @@ -42,50 +61,158 @@ namespace mrcpp { +/** + * @tparam D Spatial dimension (1, 2, 3, …). + * @tparam T Value type returned by the function (e.g. `double`, + * complex types, etc.). + * + * @brief Abstract base class for functions evaluable in the multiwavelet basis. + * + * The class provides **optional bounding boxes** and related helpers, while + * deferring the actual evaluation to @ref evalf implemented by derived types. + */ template class RepresentableFunction { public: + /** + * @name Construction & assignment + * @{ + */ + + /** + * @brief Construct with optional bounds. + * + * If either `a` or `b` is `nullptr`, the function is created unbounded. + * Otherwise, `A[d]=a[d]` and `B[d]=b[d]` are deep-copied and the function + * becomes bounded. Each dimension is validated to satisfy `a[d] ≤ b[d]`. + * + * @param a Lower bounds array of length `D` or `nullptr`. + * @param b Upper bounds array of length `D` or `nullptr`. + */ RepresentableFunction(const double *a = nullptr, const double *b = nullptr); + + /// Convenience constructor from `std::vector` bounds. RepresentableFunction(const std::vector &a, const std::vector &b) : RepresentableFunction(a.data(), b.data()) {} + + /** + * @brief Copy-construct, including bounds if present. + * + * Deep-copies `A` and `B` when `func` is bounded; otherwise remains unbounded. + */ RepresentableFunction(const RepresentableFunction &func); + + /** + * @brief Assignment operator (base). + * + * The base implementation **does not** copy bounds and simply returns `*this`. + * Derived classes may extend this behavior to copy additional state. + */ RepresentableFunction &operator=(const RepresentableFunction &func); - virtual ~RepresentableFunction(); - /** @returns Function value in a point @param[in] r: Cartesian coordinate */ + /// Virtual destructor releases bound storage if allocated. + virtual ~RepresentableFunction(); + /** @} */ + + /** + * @brief Evaluate the function at a given point. + * @param r Cartesian coordinate (length-`D`). + * @returns The function value at `r`. + * + * Derived classes should usually check @ref outOfBounds before performing + * expensive work and return a zero value outside the active domain. + */ virtual T evalf(const Coord &r) const = 0; + /** + * @name Bounds management + * @{ + */ + + /** + * @brief Set (or overwrite) bounds. + * + * Allocates and stores deep copies of `a` and `b` (length `D`) if not already + * bounded. Validates that `a[d] ≤ b[d]` for all `d`. + */ void setBounds(const double *a, const double *b); + + /** + * @brief Clear bounds and mark the function unbounded. + * + * After this call, @ref isBounded returns `false` and @ref outOfBounds + * will always return `false`. + */ void clearBounds(); + /// @returns `true` if the function has active bounds, `false` otherwise. bool isBounded() const { return this->bounded; } + + /** + * @brief Test whether a point lies outside the active bounds. + * + * Implements the **half-open** check for each coordinate: + * `r[d] < A[d] || r[d] >= B[d]`. If the function is unbounded, + * this always returns `false`. + */ bool outOfBounds(const Coord &r) const; + /// @returns Lower bound in dimension `d` (requires @ref isBounded). double getLowerBound(int d) const { return this->A[d]; } + /// @returns Upper bound in dimension `d` (requires @ref isBounded). double getUpperBound(int d) const { return this->B[d]; } + /// @returns Pointer to the lower bounds array (length `D`) or `nullptr` if unbounded. const double *getLowerBounds() const { return this->A; } + /// @returns Pointer to the upper bounds array (length `D`) or `nullptr` if unbounded. const double *getUpperBounds() const { return this->B; } + /** @} */ + /// @note Bridge/adapter that may require direct access to bounds. friend class AnalyticAdaptor; protected: - bool bounded; - double *A; ///< Lower bound, NULL if unbounded - double *B; ///< Upper bound, Null if unbounded - + /** @name Internal state + * @{ + */ + bool bounded; ///< `true` if the function is currently bounded. + double *A; ///< Lower bounds (owned; `nullptr` if unbounded). + double *B; ///< Upper bounds (owned; `nullptr` if unbounded). + /** @} */ + + /** + * @brief Optional visibility hint used by some projection routines. + * @returns `true` when the function is expected to contribute at a given scale. + */ virtual bool isVisibleAtScale(int scale, int nQuadPts) const { return true; } + + /** + * @brief Optional fast zero-test on an interval (per dimension). + * @returns `true` if the function is provably zero on `[a,b]` (component-wise). + */ virtual bool isZeroOnInterval(const double *a, const double *b) const { return false; } }; -/* - * Same as RepresentableFunction, but output a matrix of values - * for all points in a node, given its NodeIndex. +/** + * @brief Matrix-valued evaluation interface. * + * A companion interface that asks an object to produce a **batch evaluation** + * over all quadrature points associated with a tree node, returning a matrix + * whose layout is decided by the concrete implementation. + * + * This is useful for high-throughput projection steps where per-point + * overhead must be minimized. */ class RepresentableFunction_M { public: RepresentableFunction_M() {} + + /** + * @brief Evaluate at all points described by a node index. + * @param nIdx Node index (scale and translation), typically defines the + * evaluation grid/points. + * @returns A matrix of values (shape and semantics are implementation-defined). + */ virtual Eigen::MatrixXd evalf(mrcpp::NodeIndex<3> nIdx) const = 0; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/function_utils.cpp b/src/functions/function_utils.cpp index 598c9b12a..23806285d 100644 --- a/src/functions/function_utils.cpp +++ b/src/functions/function_utils.cpp @@ -27,10 +27,6 @@ namespace mrcpp { -namespace function_utils { -double ObaraSaika_ab(int power_a, int power_b, double pos_a, double pos_b, double expo_a, double expo_b); -} // namespace function_utils - template double function_utils::calc_overlap(const GaussFunc &a, const GaussFunc &b) { double S = 1.0; for (int d = 0; d < D; d++) { S *= ObaraSaika_ab(a.getPower()[d], b.getPower()[d], a.getPos()[d], b.getPos()[d], a.getExp()[d], b.getExp()[d]); } @@ -38,29 +34,10 @@ template double function_utils::calc_overlap(const GaussFunc &a, cons return S; } -/** Compute the monodimensional overlap integral between two - gaussian distributions by means of the Obara-Saika recursiive - scheme - - \f[ S_{ij} = \int_{-\infty}^{+\infty} \,\mathrm{d} x - (x-x_a)^{p_a} - (x-x_b)^{p_b} - e^{-c_a (x-x_a)^2} - e^{-c_b (x-x_b)^2}\f] - - @param power_a \f$ p_a \f$ - @param power_b \f$ p_b \f$ - @param pos_a \f$ x_a \f$ - @param pos_b \f$ x_b \f$ - @param expo_a \f$ c_a \f$ - @param expo_b \f$ c_b \f$ - - */ double function_utils::ObaraSaika_ab(int power_a, int power_b, double pos_a, double pos_b, double expo_a, double expo_b) { int i, j; double expo_p, mu, pos_p, x_ab, x_pa, x_pb, s_00; - /* The highest angular momentum combination is l=20 for a and b - * simulatnelusly */ + // The highest angular momentum combination is l=20 for a and b simultaneously double s_coeff[64]; // if (out_of_bounds(power_a, 0, MAX_GAUSS_POWER) || @@ -70,36 +47,36 @@ double function_utils::ObaraSaika_ab(int power_a, int power_b, double pos_a, dou // INVALID_ARG_EXIT; // } - /* initialization of a hell of a lot of coefficients.... */ - expo_p = expo_a + expo_b; /* total exponent */ - mu = expo_a * expo_b / (expo_a + expo_b); /* reduced exponent */ - pos_p = (expo_a * pos_a + expo_b * pos_b) / expo_p; /* center of charge */ - x_ab = pos_a - pos_b; /* X_{AB} */ - x_pa = pos_p - pos_a; /* X_{PA} */ - x_pb = pos_p - pos_b; /* X_{PB} */ + // initialization of a hell of a lot of coefficients.... + expo_p = expo_a + expo_b; // total exponent + mu = expo_a * expo_b / (expo_a + expo_b); // reduced exponent + pos_p = (expo_a * pos_a + expo_b * pos_b) / expo_p; // center of charge + x_ab = pos_a - pos_b; // X_{AB} + x_pa = pos_p - pos_a; // X_{PA} + x_pb = pos_p - pos_b; // X_{PB} s_00 = pi / expo_p; - s_00 = std::sqrt(s_00) * std::exp(-mu * x_ab * x_ab); /* overlap of two spherical gaussians */ - // int n_0j_coeff = 1 + power_b; /* n. of 0j coefficients needed */ - // int n_ij_coeff = 2 * power_a; /* n. of ij coefficients needed (i > 0) */ + s_00 = std::sqrt(s_00) * std::exp(-mu * x_ab * x_ab); // overlap of two spherical gaussians + // int n_0j_coeff = 1 + power_b; // n. of 0j coefficients needed + // int n_ij_coeff = 2 * power_a; // n. of ij coefficients needed (i > 0) - /* we add 3 coeffs. to avoid a hell of a lot of if statements */ - /* n_tot_coeff = n_0j_coeff + n_ij_coeff + 3; */ - /* s_coeff = (double *) calloc(n_tot_coeff, sizeof(double));*/ + // we add 3 coeffs. to avoid a hell of a lot of if statements + // n_tot_coeff = n_0j_coeff + n_ij_coeff + 3; + // s_coeff = (double *) calloc(n_tot_coeff, sizeof(double)); - /* generate first two coefficients */ + // generate first two coefficients s_coeff[0] = s_00; s_coeff[1] = x_pb * s_00; j = 1; - /* generate the rest of the first row */ + // generate the rest of the first row while (j < power_b) { s_coeff[j + 1] = x_pb * s_coeff[j] + j * s_coeff[j - 1] / (2.0 * expo_p); j++; } - /* generate the first two coefficients with i > 0 */ + // generate the first two coefficients with i > 0 s_coeff[j + 1] = s_coeff[j] - x_ab * s_coeff[j - 1]; s_coeff[j + 2] = x_pa * s_coeff[j] + j * s_coeff[j - 1] / (2.0 * expo_p); i = 1; - /* generate the remaining coefficients with i > 0 */ + // generate the remaining coefficients with i > 0 while (i < power_a) { int i_l = j + 2 * i + 1; int i_r = j + 2 * i + 2; @@ -108,7 +85,7 @@ double function_utils::ObaraSaika_ab(int power_a, int power_b, double pos_a, dou i++; } - /* free(s_coeff);*/ + // free(s_coeff); return s_coeff[power_b + 2 * power_a]; } diff --git a/src/functions/function_utils.h b/src/functions/function_utils.h index 896c06257..c04baff22 100644 --- a/src/functions/function_utils.h +++ b/src/functions/function_utils.h @@ -27,7 +27,41 @@ #include "Gaussian.h" namespace mrcpp { + +// Forward declaration only: definition is provided in function_utils.cpp. +// Keeping this here avoids heavy includes and potential include cycles. namespace function_utils { +/** + * @brief Compute the monodimensional overlap integral between two + * gaussian distributions by means of the Obara-Saika recursive + * scheme + * + * \f$ [ S_{ij} = \int_{-\infty}^{+\infty} \,\mathrm{d} x + * (x-x_a)^{p_a} + * (x-x_b)^{p_b} + * e^{-c_a (x-x_a)^2} + * e^{-c_b (x-x_b)^2} \f$ + * + * @param power_a \f$ p_a \f$ + * @param power_b \f$ p_b \f$ + * @param pos_a \f$ x_a \f$ + * @param pos_b \f$ x_b \f$ + * @param expo_a \f$ c_a \f$ + * @param expo_b \f$ c_b \f$ + * + * @return The value of the overlap integral as a double + */ +double ObaraSaika_ab(int power_a, int power_b, double pos_a, double pos_b, double expo_a, double expo_b); + +/** + * @brief Compute the overlap integral between two Gaussian functions. + * + * @param[in] a The first Gaussian function + * @param[in] b The second Gaussian function + * + * @return The value of the overlap integral + */ template double calc_overlap(const GaussFunc &a, const GaussFunc &b); } // namespace function_utils -} // namespace mrcpp + +} // namespace mrcpp \ No newline at end of file diff --git a/src/functions/special_functions.cpp b/src/functions/special_functions.cpp index 555528a58..8c8c21972 100644 --- a/src/functions/special_functions.cpp +++ b/src/functions/special_functions.cpp @@ -29,31 +29,7 @@ namespace mrcpp { -/** @brief Free-particle time evolution on real line. - * - * @param[in] x: space coordinate in \f$ \mathbb R \f$. - * @param[in] x0: \f$ x_0 \f$ center of gaussian function at zero time moment. - * @param[in] t: time moment. - * @param[in] sigma: \f$ \sigma \f$ width of the initial gaussian wave. - * - * @details Analytical solution of a one dimensional free-particle - * movement - * \f[ - * \psi(x, t) - * = - * \sqrt{ - * \frac{ \sigma }{ 4it + \sigma } - * } - * e^{ - \frac { (x - x_0)^2 }{ 4it + \sigma } } - * \f] - * where \f$ t, \sigma > 0 \f$. - * - * @returns The complex-valued wave function - * \f$ \psi(x, t) \f$ - * at the specified space coordinate and time. - * - * - */ + std::complex free_particle_analytical_solution(double x, double x0, double t, double sigma) { std::complex i(0.0, 1.0); // Imaginary unit @@ -64,35 +40,12 @@ std::complex free_particle_analytical_solution(double x, double x0, doub return std::sqrt(sigma) / sqrt_denom * std::exp(exponent); } - - -/** @brief A smooth compactly supported non-negative function. - * - * @param[in] x: space coordinate in \f$ \mathbb R \f$. - * @param[in] a: the left support boundary. - * @param[in] b: the right support boundary. - * - * @details Smooth function on the real line \f$ \mathbb R \f$ - * defined by the formula - * \f[ - * g_{a,b} (x) = \exp \left( - \frac{b - a}{(x - a)(b - x)} \right) - * , \quad - * a < x < b - * \f] - * and \f$ g_{a,b} (x) = 0 \f$ elsewhere. - * - * @returns The non-negative value - * \f$ g_{a,b} (x) \f$ - * at the specified space coordinate \f$ x \in \mathbb R \f$. - * - * - */ double smooth_compact_function(double x, double a, double b) { double res = 0; if (a < x && x < b) { res = exp((a - b) / (x - a) / (b - x)); } return res; -} +} } // namespace mrcpp \ No newline at end of file diff --git a/src/functions/special_functions.h b/src/functions/special_functions.h index 4c2f68ac3..d2043cb20 100644 --- a/src/functions/special_functions.h +++ b/src/functions/special_functions.h @@ -26,13 +26,58 @@ #pragma once #include -#include - +#include namespace mrcpp { +/** @brief Free-particle time evolution on real line. + * + * @param[in] x: space coordinate in \f$ \mathbb R \f$. + * @param[in] x0: \f$ x_0 \f$ center of gaussian function at zero time moment. + * @param[in] t: time moment. + * @param[in] sigma: \f$ \sigma \f$ width of the initial gaussian wave. + * + * @details Analytical solution of a one dimensional free-particle + * movement + * \f[ + * \psi(x, t) + * = + * \sqrt{ + * \frac{ \sigma }{ 4it + \sigma } + * } + * e^{ - \frac { (x - x_0)^2 }{ 4it + \sigma } } + * \f] + * where \f$ t, \sigma > 0 \f$. + * + * @returns The complex-valued wave function + * \f$ \psi(x, t) \f$ + * at the specified space coordinate and time. + * + * + */ std::complex free_particle_analytical_solution(double x, double x0, double t, double sigma); +/** @brief A smooth compactly supported non-negative function. + * + * @param[in] x: space coordinate in \f$ \mathbb R \f$. + * @param[in] a: the left support boundary. + * @param[in] b: the right support boundary. + * + * @details Smooth function on the real line \f$ \mathbb R \f$ + * defined by the formula + * \f[ + * g_{a,b} (x) = \exp \left( - \frac{b - a}{(x - a)(b - x)} \right) + * , \quad + * a < x < b + * \f] + * and \f$ g_{a,b} (x) = 0 \f$ elsewhere. + * + * @returns The non-negative value + * \f$ g_{a,b} (x) \f$ + * at the specified space coordinate \f$ x \in \mathbb R \f$. + * + * + */ double smooth_compact_function(double x, double a = 0, double b = 1); } // namespace mrcpp \ No newline at end of file diff --git a/src/operators/ABGVOperator.h b/src/operators/ABGVOperator.h index 3cf85bbaa..ab76d882b 100644 --- a/src/operators/ABGVOperator.h +++ b/src/operators/ABGVOperator.h @@ -29,23 +29,80 @@ namespace mrcpp { -/** @class ABGVOperator +/** + * @class ABGVOperator + * @brief Multiresolution first-derivative operator of Alpert–Beylkin–Gines–Vozovoi. * - * @brief Derivative operator as defined by Alpert, Beylkin, Ginez and Vozovoi, - * J Comp Phys 182, 149-190 (2002). + * This class builds a **first-order differential operator** in the + * multiresolution (MR) basis defined by a given + * #mrcpp::MultiResolutionAnalysis. The discrete representation follows + * the construction in: * - * NOTE: This is the recommended derivative operator for "cuspy" or discontinuous - * functions. The BSOperator is recommended for smooth functions. + * - B. Alpert, G. Beylkin, D. Gines, and L. Vozovoi, + * *Adaptive Solution of Partial Differential Equations in Multiwavelet Bases*, + * J. Comput. Phys. **182** (2002) 149–190. + * + * ### When to use this operator + * - **Recommended** for functions with **cusps, kinks, or discontinuities**, + * where strictly smooth (BS) operators tend to produce Gibbs-type artifacts. + * - For **smooth** functions, prefer #mrcpp::BSOperator for slightly better + * accuracy/efficiency with smooth stencils. + * + * ### Boundary/stencil parameters \p a and \p b + * The parameters `(a, b)` control the local stencil asymmetry at element + * interfaces. Common choices: + * + * - `a = 0.0`, `b = 0.0` → strictly local “center” rule (bandwidth 0) + * - `a = 0.5`, `b = 0.5` → semi-local **central** difference (bandwidth 1) + * - `a = 1.0`, `b = 0.0` → semi-local **forward** difference (bandwidth 1) + * - `a = 0.0`, `b = 1.0` → semi-local **backward** difference (bandwidth 1) + * + * Any non-zero `a` or `b` widens the coupling to nearest neighbors (bandwidth = 1) + * across scales; this is enforced during assembly. + * + * ### Assembly and application + * Internally, the operator is assembled once into an #mrcpp::OperatorTree + * (stored in the base #mrcpp::DerivativeOperator). After construction, + * applying the operator to MR coefficient vectors is cheap and can be done + * repeatedly. + * + * @tparam D Spatial dimension (1, 2, or 3). */ - template class ABGVOperator final : public DerivativeOperator { public: + /** + * @brief Construct the ABGV derivative operator on a given MRA. + * + * The constructor triggers an internal `initialize(a, b)` routine that: + * 1. Decides the operator bandwidth from `(a, b)`. + * 2. Builds the operator matrix blocks using the MRA’s scaling basis. + * 3. Assembles an #mrcpp::OperatorTree with a bandwidth adaptor. + * 4. Finalizes and caches the representation for fast application. + * + * @param mra Multiresolution analysis defining the domain, basis and scales. + * @param a Left-side boundary/stencil parameter (see class docs). + * @param b Right-side boundary/stencil parameter (see class docs). + * + * @note The operator is built at the MRA’s **root scale** and is valid for + * coefficient vectors defined on the same MRA. + */ ABGVOperator(const MultiResolutionAnalysis &mra, double a, double b); + ABGVOperator(const ABGVOperator &oper) = delete; ABGVOperator &operator=(const ABGVOperator &oper) = delete; protected: + /** + * @brief Internal assembly routine (called by the constructor). + * + * Decides sparsity (bandwidth) from `(a, b)`, constructs the calculator + * implementing the ABGV derivative in the given scaling basis, and uses a + * `TreeBuilder` + `BandWidthAdaptor` to assemble and cache an operator tree. + * + * @param a Left boundary/stencil parameter. + * @param b Right boundary/stencil parameter. + */ void initialize(double a, double b); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/BSOperator.h b/src/operators/BSOperator.h index 873b5a3d8..d5aacb615 100644 --- a/src/operators/BSOperator.h +++ b/src/operators/BSOperator.h @@ -30,21 +30,84 @@ namespace mrcpp { /** @class BSOperator + * @ingroup operators * - * @brief B-spline derivative operator as defined by Anderson etal, J Comp Phys X 4, 100033 (2019). + * @brief Smooth multiresolution derivative operator (“BS” operator). * - * NOTE: This is the recommended derivative operator only for _smooth_ functions. - * Use the ABGVOperator if the function has known cusps or discontinuities. + * This class builds a derivative operator in the multiresolution scaling basis + * tailored for **smooth** functions. The discrete stencil is compact (nearest- + * neighbor bandwidth) and its local blocks are generated by the *BS* scheme + * (see Anderson *et al.*, J. Comp. Phys. X 4, 100033 (2019)). + * + * ### When to use + * - Prefer this operator when the target function is sufficiently smooth at + * the scales of interest (e.g. no strong cusps or jump discontinuities). + * - For functions with cusps/discontinuities, use #mrcpp::ABGVOperator instead, + * which is more robust in the non-smooth regime. + * + * ### What it builds internally + * The constructor triggers an assembly pipeline (via a hidden `initialize()`) + * that: + * 1. Creates a sparse, bandwidth-1 operator tree on the provided + * #mrcpp::MultiResolutionAnalysis (MRA). + * 2. Uses a calculator (BS formulation) to fill local operator blocks for the + * requested derivative order. + * 3. Finalizes and caches per-node data for fast application. + * + * ### Complexity & reuse + * - **Build**: one-time cost per (MRA, derivative order). + * - **Apply**: fast, cache-friendly application to MR coefficient vectors. + * + * @tparam D Spatial dimension (1, 2, or 3). + * + * @see mrcpp::ABGVOperator + * @see mrcpp::DerivativeOperator + * @see mrcpp::OperatorTree + * @see mrcpp::MultiResolutionAnalysis */ - template class BSOperator final : public DerivativeOperator { public: + /** + * @brief Construct a BS derivative operator on a given MRA. + * + * The operator is anchored to the MRA’s root scale (handled by the + * #mrcpp::DerivativeOperator base class) and immediately assembled. The + * derivative order typically supports 1, 2, or 3 (as provided by the BS + * calculator implementation). + * + * @param mra Multiresolution analysis defining basis, domain, and scales. + * @param order Derivative order (e.g., 1, 2, or 3). + * + * @note This operator assumes smoothness; if your target function has + * strong non-smooth features, consider #mrcpp::ABGVOperator. + * + * @code + * MultiResolutionAnalysis<1> mra(...); + * BSOperator<1> Dx(mra, 1); // first derivative in 1D + * // apply Dx to a function tree / coefficient vector later... + * @endcode + */ explicit BSOperator(const MultiResolutionAnalysis &mra, int order); + + /// Deleted copy constructor: operators are heavyweight and own caches. explicit BSOperator(const BSOperator &oper) = delete; + /// Deleted assignment. BSOperator &operator=(const BSOperator &oper) = delete; protected: + /** + * @brief Assemble and cache the operator (implementation detail). + * + * Internal steps (performed once at construction): + * - Choose a compact bandwidth (nearest-neighbor coupling). + * - Use a BS-based calculator to generate local blocks for the requested + * derivative order. + * - Build a sparse #mrcpp::OperatorTree on the provided MRA. + * - Precompute norms and per-node caches for fast application. + * + * @warning This is not intended to be called by users directly. + */ void initialize(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/CartesianConvolution.h b/src/operators/CartesianConvolution.h index 63d70a53a..626ccbf03 100644 --- a/src/operators/CartesianConvolution.h +++ b/src/operators/CartesianConvolution.h @@ -29,17 +29,106 @@ namespace mrcpp { +// Forward declaration to avoid pulling the full header into users of this file. +template class GaussExp; + +/** + * @class CartesianConvolution + * @brief 3D separable convolution operator assembled from a 1D Gaussian expansion. + * + * This operator represents a Cartesian, rank-R, separable convolution in 3D, + * where the separation rank R equals the number of terms in a provided 1D + * Gaussian expansion, `GaussExp<1>`. + * + * ### How it is constructed (see .cpp) + * The implementation builds three *blocks* of 1D operator trees from the + * same Gaussian expansion, corresponding to monomial prefactors of degree + * 0, 1 and 2 (i.e. powers `{0}`, `{1}`, `{2}`), and stores them + * contiguously. These blocks can then be assigned independently to the + * x/y/z axes, enabling vector/tensor kernels that differ only by the + * Cartesian polynomial factor. + * + * After construction, the total number of internally stored operator trees is + * `3 * sep_rank` (three monomial blocks, each of size `sep_rank`). + * + * ### Choosing the Cartesian components + * Use setCartesianComponents(x, y, z) to select which monomial block + * (0 → degree 0, 1 → degree 1, 2 → degree 2) is used along each axis. This + * *rewires* the already built 1D factors—no rebuilding occurs. + * + * ### Precision + * The constructor accepts a single build precision `prec`. The .cpp implementation + * employs a slightly stricter precision for fitting the 1D kernel terms so that + * the final composed 3D operator meets the requested tolerance. + * + * ### Ownership / lifetime + * The class does not take ownership of the input `GaussExp<1>`; it only reads it + * during construction. Internally created operator trees are owned by this object. + * + * ### Copy semantics + * Copying is disabled (non-copyable) because the underlying operator trees are + * heavy and managed resources. Move is not provided. + * + * ### Example + * @code + * MultiResolutionAnalysis<3> mra(...); + * GaussExp<1> kernel = ...; // ∑_{r=1}^R α_r e^{-β_r (x - x_r)^2} + * double prec = 1e-8; + * + * CartesianConvolution conv(mra, kernel, prec); + * // Use degree-1 along x, degree-0 along y, degree-2 along z: + * conv.setCartesianComponents(/* x = * / 1, /* y = * / 0, /* z = * / 2); + * + * // conv can now be applied as a separable 3D convolution operator. + * @endcode + */ class CartesianConvolution : public ConvolutionOperator<3> { public: + /** + * @brief Construct a 3D separable convolution operator from a 1D Gaussian expansion. + * + * @param mra Multiresolution analysis defining the 3D basis/domain. + * @param kernel 1D Gaussian expansion; its length sets the separation rank R. + * The implementation temporarily adjusts the monomial power of each + * Gaussian term to build three internal blocks (degrees 0, 1, 2), + * but does not take ownership of @p kernel. + * @param prec Target build precision for the assembled operator. + * + * @details + * Internally, three batches of operator trees are built (for polynomial degrees + * 0/1/2), each of size R, and stored contiguously. The final separable operator + * exposes rank R; per-axis assignment of the blocks is deferred to + * setCartesianComponents(). + */ CartesianConvolution(const MultiResolutionAnalysis<3> &mra, GaussExp<1> &kernel, double prec); + CartesianConvolution(const CartesianConvolution &oper) = delete; CartesianConvolution &operator=(const CartesianConvolution &oper) = delete; virtual ~CartesianConvolution() = default; + /** + * @brief Select which monomial block is used on each Cartesian axis. + * + * @param x Block index for x-axis (0 → degree 0, 1 → degree 1, 2 → degree 2). + * @param y Block index for y-axis (same convention). + * @param z Block index for z-axis (same convention). + * + * @details + * - This operation is O(R) for each axis and **does not rebuild** the operator; + * it remaps the already constructed 1D operator trees into the separable slots. + * - Valid indices are {0,1,2}. Using the same block on multiple axes is allowed. + */ void setCartesianComponents(int x, int y, int z); protected: + /** + * @brief Separation rank R of the operator (number of terms in the input 1D kernel). + * + * @details + * The internal storage contains 3·R operator trees (three monomial blocks), + * but the exposed separable rank for downstream composition is R. + */ int sep_rank; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/ConvolutionOperator.h b/src/operators/ConvolutionOperator.h index 33d254e9d..51f8a2736 100644 --- a/src/operators/ConvolutionOperator.h +++ b/src/operators/ConvolutionOperator.h @@ -30,56 +30,155 @@ namespace mrcpp { /** @class ConvolutionOperator + * @ingroup operators * - * @brief Convolution defined by a Gaussian expansion + * @brief D-dimensional separable convolution operator built from a 1D Gaussian expansion. + * + * @tparam D Spatial dimension of the target operator. + * + * @details + * This operator represents a separable convolution constructed from a sum of + * one–dimensional Gaussian factors: * - * @details Represents the operator * \f[ - * T = \sum_{m=1}^M - * \text{sign} (\alpha_m) \bigotimes \limits_{d = 1}^D T_d - * \left( \beta_m, \sqrt[D]{| \alpha_m |} \right) - * , + * T \;=\; \sum_{m=1}^{M} + * \operatorname{sign}(\alpha_m) + * \bigotimes_{d=1}^{D} + * T_d\!\left(\beta_m,\;\sqrt[D]{|\alpha_m|}\right), * \f] - * where each - * \f$ T_d \left( \beta, \alpha \right) \f$ - * is the convolution operator with one-dimensional Gaussian kernel - * \f$ k(x_d) = \alpha \exp \left( - \beta x_d^2 \right) \f$. - * Operator - * \f$ T \f$ - * is obtained from the Gaussian expansion + * + * where each \f$ T_d(\beta,\alpha) \f$ is the 1D convolution with kernel + * \f$ k(x_d) = \alpha \exp(-\beta x_d^2) \f$ along coordinate \f$ x_d \f$. + * The separable rank of the constructed operator equals the number of terms + * \f$ M \f$ in the 1D Gaussian expansion: + * * \f[ - * \sum_{m=1}^M \alpha_m \exp \left( - \beta_m |x|^2 \right) + * \sum_{m=1}^{M} \alpha_m \exp(-\beta_m |x|^2). * \f] - * which is passed as a parameter to the first two constructors. * - * @note Every \f$ T_d \left( \beta_m, \sqrt[D]{| \alpha_m |} \right) \f$ is the same - * operator associated with the one-dimensional variable \f$ x_d \f$ for \f$ d = 1, \ldots, D \f$. + * ### Construction strategy (high level) + * 1. For each Gaussian term \f$ \alpha_m e^{-\beta_m x^2} \f$ in the 1D expansion, + * we rescale its coefficient to \f$ \sqrt[D]{|\alpha_m|} \f$ and keep the sign, + * so that the D-fold separable composition recovers the desired amplitude. + * 2. Each 1D term is projected to a 1D multiresolution function tree using the + * same scaling family as the D-D operator (interpolating or Legendre). + * 3. Cross-correlation machinery lifts each 1D factor into a 2D operator block + * (per axis-pair) and the @ref MWOperator backbone assembles the full D-D, + * separable operator. + * + * ### Precision control + * The *build precision* (see @ref setBuildPrec and @ref getBuildPrec) governs: + * - the tolerance used when projecting 1D kernel terms to their function trees, and + * - the tolerance for assembling/thresholding operator trees. + * Implementations typically use a tighter internal precision for the kernel + * projection than for the operator assembly to keep the total error within the + * requested target. * - * \todo: One may want to change the logic so that \f$ D \f$-root is evaluated on the previous step, - * namely, when \f$ \alpha_m, \beta_m \f$ are calculated. + * @note All constructors are *non-owning* with respect to the input expansion; the + * implementation copies kernel terms as needed for projection, and keeps only the + * operator trees internally. * + * @see ConvolutionOperator::initialize + * @see ConvolutionOperator::getKernelMRA + * @see MWOperator */ template class ConvolutionOperator : public MWOperator { public: + /** + * @brief Build a separable convolution operator on the default operator root/extent. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis defining the domain and basis. + * @param kernel 1D Gaussian expansion providing the separable factors (rank = kernel.size()). + * @param prec Target build precision used to steer kernel projection and operator assembly. + * + * @details Uses the operator's default root scale (@c mra.getRootScale()) and a + * reach chosen by the implementation. For more control over root/reach, use the + * other constructor. + */ ConvolutionOperator(const MultiResolutionAnalysis &mra, GaussExp<1> &kernel, double prec); + + /** + * @brief Build a separable convolution operator with explicit root scale and reach. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis. + * @param kernel 1D Gaussian expansion (rank = kernel.size()). + * @param prec Target build precision. + * @param root Operator root scale (level) to anchor the construction. + * @param reach Operator reach (half-width in levels). Negative value means + * *auto*—deduced from the world box extents. + * + * @details This variant allows advanced users to control the scale window spanned + * by the operator representation, which may be useful when coupling to other + * operators or enforcing boundary extents. + */ ConvolutionOperator(const MultiResolutionAnalysis &mra, GaussExp<1> &kernel, double prec, int root, int reach); + ConvolutionOperator(const ConvolutionOperator &oper) = delete; ConvolutionOperator &operator=(const ConvolutionOperator &oper) = delete; virtual ~ConvolutionOperator() = default; + /// @brief Retrieve the user-requested build precision associated with this operator. double getBuildPrec() const { return this->build_prec; } protected: + /** + * @brief Protected convenience constructor for subclasses that defer initialization. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis. + * + * @details Initializes the @ref MWOperator base with default root and reach. + * Subclasses must call @ref initialize to populate the separable expansion. + */ ConvolutionOperator(const MultiResolutionAnalysis &mra) : MWOperator(mra, mra.getRootScale(), -10) {} + + /** + * @brief Protected convenience constructor with explicit root and reach. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis. + * @param root Root level. + * @param reach Reach (half-width in levels). Negative = auto. + */ ConvolutionOperator(const MultiResolutionAnalysis &mra, int root, int reach) : MWOperator(mra, root, reach) {} + /** + * @brief Core build routine that projects 1D kernel terms and assembles operator trees. + * + * @param kernel 1D Gaussian expansion (input rank M). + * @param k_prec Precision used when projecting each 1D Gaussian term into a + * 1D function tree (typically tighter than @p o_prec). + * @param o_prec Precision used when expanding to operator trees and performing + * wavelet transforms/thresholding. + * + * @details For each term in @p kernel: + * - Coefficient is rescaled to \f$ \sqrt[D]{|\alpha|} \f$ with the original sign, + * ensuring the D-fold separable product reproduces the intended amplitude. + * - The analytic 1D Gaussian is projected to a 1D @ref FunctionTree with tolerance + * @p k_prec. + * - A @ref CrossCorrelationCalculator lifts the 1D representation to a 2D operator + * block; bottom-up wavelet transforms and caching finalize each block. + * The set of blocks is stored in the @ref MWOperator base and exposed as a + * separable expansion of rank @c kernel.size(). + */ void initialize(GaussExp<1> &kernel, double k_prec, double o_prec); + + /// @brief Store the user-requested build precision (used for reporting/inspection). void setBuildPrec(double prec) { this->build_prec = prec; } + /** + * @brief Build a 1D @ref MultiResolutionAnalysis to discretize kernel factors. + * + * @return A 1D MRA whose scaling family matches the D-D operator MRA (Interpolating or Legendre), + * with an order chosen as \f$ 2s+1 \f$ where \f$ s \f$ is the operator scaling order. + * + * @details The 1D box uses the operator root. Its reach is the operator reach + 1 + * (or derived from the world box if reach is negative) to ensure kernel support + * covers the correlations used during lifting. + */ MultiResolutionAnalysis<1> getKernelMRA() const; + /// Target precision requested at construction time; used to steer sub-steps in the build. double build_prec{-1.0}; }; diff --git a/src/operators/DerivativeConvolution.h b/src/operators/DerivativeConvolution.h index 71ade7025..2d19e28ed 100644 --- a/src/operators/DerivativeConvolution.h +++ b/src/operators/DerivativeConvolution.h @@ -30,24 +30,96 @@ namespace mrcpp { /** @class DerivativeConvolution + * @ingroup operators * - * @brief Convolution with a derivative kernel + * @brief Separable convolution operator that approximates a spatial derivative + * using a differentiated Gaussian kernel. * - * @details Derivative operator written as a convolution. The derivative kernel (derivative of - * Dirac's delta function) is approximated by the derivative of a narrow Gaussian function: - * \f$ D^x(r-r') = \frac{d}{dx}\delta(r-r') \approx \frac{d}{dx} \alpha e^{-\beta (r-r')^2} \f$ + * @tparam D Spatial dimension of the target operator (1, 2, or 3). * - * NOTE: This is _not_ the recommended derivative operator for practial calculations, it's - * a proof-of-concept operator. Use the ABGVOperator for "cuspy" functions and the - * BSOperator for smooth functions. + * @details + * This class implements a *proof-of-concept* derivative as a convolution with the + * derivative of a narrow Gaussian. In distributional terms one would like to have + * \f$ \partial_x \delta \f$; numerically, we approximate it by + * a derivative-of-Gaussian (DoG) kernel that is narrow enough to capture the + * local derivative while remaining representable on the multiwavelet grid. + * + * Formally, for one Cartesian component: + * \f[ + * (D^x f)(\mathbf r) + * \;\approx\; + * \int_{\mathbb R^D} + * \frac{\partial}{\partial x} + * \left[\alpha\, e^{-\beta \lvert \mathbf r - \mathbf r' \rvert^2}\right] + * f(\mathbf r') \, d\mathbf r' + * \;=\; (k'_x * f)(\mathbf r). + * \f] + * In MRCPP this is realized as a @ref ConvolutionOperator with a 1D DoG kernel; + * the D-dimensional operator is assembled as a separable tensor product across + * coordinates and lifted to the multiwavelet basis. + * + * ### Responsibilities and division of labor + * - **This class**: chooses a derivative-like kernel and exposes convenient + * constructors. It does not change application logic. + * - **@ref ConvolutionOperator**: projects the 1D kernel to a function tree, + * lifts to operator trees via cross-correlation, transforms/caches in the MW basis, + * and manages rank/separability. + * + * ### Precision handling + * The constructors accept a *build precision* that governs the narrowness of the + * DoG kernel and the tolerances used during kernel projection and operator assembly: + * tighter precision ⇒ narrower kernel ⇒ closer to an ideal derivative, but with + * higher resolution demands and potentially larger operator bandwidth. + * + * @note This operator is primarily for validation/experiments. For production + * use consider: + * - @ref ABGVOperator for cuspy/discontinuous functions. + * - @ref BSOperator for sufficiently smooth functions. + * + * @see ConvolutionOperator, ABGVOperator, BSOperator */ - template class DerivativeConvolution final : public ConvolutionOperator { public: + /** + * @brief Build a derivative-convolution operator on the default root/reach. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis that defines the domain + * and the scaling basis used by the operator. + * @param prec Target build precision controlling kernel narrowness and + * assembly tolerances (tighter ⇒ narrower DoG). + * + * @details + * Internally constructs a single-term derivative kernel (derivative of a Gaussian) + * with parameters derived from @p prec, then delegates to + * @ref ConvolutionOperator::initialize to: + * - project the 1D kernel to a function tree, + * - lift it to separable operator blocks via cross-correlation, + * - transform and cache in the multiwavelet basis. + * + * @warning Very small @p prec values produce *very* narrow kernels that may + * require deeper trees and higher-order bases to avoid under-resolution. + */ DerivativeConvolution(const MultiResolutionAnalysis &mra, double prec); + + /** + * @brief Build a derivative-convolution operator with an explicit scale window. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis. + * @param prec Target build precision (as above). + * @param root Operator root level (coarsest active scale). + * @param reach Operator reach in levels; negative values trigger auto-detection + * from the domain extents. + * + * @details + * Use this overload to constrain the operator to a chosen scale window; useful for + * benchmarking, domain-decomposition experiments, or when composing multiple + * operators with controlled supports. Kernel choice and assembly otherwise mirror + * the simpler constructor. + */ DerivativeConvolution(const MultiResolutionAnalysis &mra, double prec, int root, int reach); - DerivativeConvolution(const DerivativeConvolution &oper) = delete; - DerivativeConvolution &operator=(const DerivativeConvolution &oper) = delete; + + DerivativeConvolution(const DerivativeConvolution &oper) = delete; ///< Non-copyable + DerivativeConvolution &operator=(const DerivativeConvolution &oper) = delete; ///< Non-assignable }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/DerivativeKernel.h b/src/operators/DerivativeKernel.h index 1888c2e0d..11557d5d1 100644 --- a/src/operators/DerivativeKernel.h +++ b/src/operators/DerivativeKernel.h @@ -31,14 +31,82 @@ namespace mrcpp { +/** + * @class DerivativeKernel + * @ingroup operators + * + * @brief One–dimensional *derivative-of-Gaussian* (DoG) kernel packaged as a + * Gaussian expansion of rank 1, suitable for building separable + * convolution-based derivative operators in D dimensions. + * + * @tparam D Spatial dimensionality of the *target* operator that will use this kernel. + * The class itself stores a 1D kernel (inherits from @ref GaussExp\<1\>), but + * uses @p D to choose a normalization consistent with a D-fold separable tensor + * product (see notes below). + * + * @details + * The constructor creates a single 1D Gaussian + * \f[ + * g(x) \;=\; c \, e^{-\alpha x^2},\qquad + * \alpha \equiv \frac{1}{\varepsilon}, + * \f] + * then analytically differentiates it once in @c x to obtain a polynomial–Gaussian + * (a @ref GaussPoly) and appends that single term to this expansion. + * + * ### Normalization and separability + * - The coefficient is chosen as + * \f[ + * c \;=\; \Big(\tfrac{\alpha}{\pi}\Big)^{D/2}, + * \f] + * which corresponds to the *D-dimensional* unit-charge normalization of the isotropic + * Gaussian \f$ c \exp(-\alpha \lvert \mathbf r \rvert^2) \f$. + * - When this 1D kernel is lifted to D dimensions as a separable product, + * MRCPP’s convolution machinery (@ref ConvolutionOperator) rescales each 1D factor + * by the D-th root of the magnitude of its coefficient so that the tensor product + * has the intended overall normalization. In effect, each axis receives + * \f$ (\alpha/\pi)^{1/2} \f$ and the product recovers \f$ (\alpha/\pi)^{D/2} \f$. + * + * ### Width control + * The user-provided @p epsilon controls the width via \f$ \alpha = 1/\varepsilon \f$: + * - Small \f$ \varepsilon \Rightarrow \alpha \gg 1 \Rightarrow \f$ very narrow kernel, + * closer to a distributional derivative, but harder to resolve numerically. + * - Large \f$ \varepsilon \Rightarrow \alpha \ll 1 \Rightarrow \f$ broad kernel, + * smoother but less localized derivative approximation. + * + * ### Usage + * Typically constructed internally by derivative-style convolution operators + * (e.g., @ref DerivativeConvolution) and not used directly. If used directly, + * pass it to a @ref ConvolutionOperator builder which will project, lift, and + * cache the corresponding multiwavelet operator blocks. + */ template class DerivativeKernel final : public GaussExp<1> { public: + /** + * @brief Construct a rank-1 1D derivative-of-Gaussian kernel. + * + * @param epsilon Width control parameter; the Gaussian exponent is set to + * \f$ \alpha = 1/\varepsilon \f$. + * + * @post The expansion contains a single @ref GaussPoly term equal to + * \f$ \frac{d}{dx}\big[c \exp(-\alpha x^2)\big] \f$ with + * \f$ c = (\alpha/\pi)^{D/2} \f$. + */ DerivativeKernel(double epsilon) : GaussExp<1>() { + // Exponent (narrowness): alpha = 1 / epsilon double alpha = 1.0 / epsilon; + + // D-dimensional normalization chosen up-front. + // ConvolutionOperator later redistributes this across dimensions (D-th root per axis). double coef = std::pow(alpha / mrcpp::pi, D / 2.0); + + // Start from a pure 1D Gaussian g(x) = coef * exp(-alpha x^2) GaussFunc<1> g(alpha, coef); + + // Differentiate analytically to obtain a polynomial–Gaussian (DoG) and store it GaussPoly<1> dg = g.differentiate(0); + + // Single-term expansion: { dg } this->append(dg); } }; diff --git a/src/operators/DerivativeOperator.h b/src/operators/DerivativeOperator.h index 8c902581c..b2f339bca 100644 --- a/src/operators/DerivativeOperator.h +++ b/src/operators/DerivativeOperator.h @@ -29,18 +29,62 @@ namespace mrcpp { +/** + * @class DerivativeOperator + * @ingroup operators + * + * @brief Common base for derivative-type multiwavelet operators. + * + * @tparam D Spatial dimension of the operator (1, 2, or 3). + * + * @details + * This abstract helper stores metadata and provides a thin interface for + * operators that represent spatial derivatives in the multiwavelet (MW) + * framework. It derives from @ref MWOperator and adds a single piece of + * state: the derivative @ref order, which subclasses set appropriately + * (e.g., 1 for first derivative, 2 for Laplacian-like second derivative + * components, etc.). + * + * The constructor simply forwards the *scale window* to the base: + * - @p root : the coarsest scale at which the operator is anchored, + * - @p reach : the number of levels (half-width) the operator spans + * around @p root (default = 1). + * + * Concrete implementations such as @ref ABGVOperator and @ref BSOperator + * specialize initialization, bandwidth, and stencil construction, while + * reusing this small common interface. + * + * @see MWOperator, ABGVOperator, BSOperator + */ template class DerivativeOperator : public MWOperator { public: + /** + * @brief Construct a derivative operator shell on a given scale window. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis that defines the + * domain and scaling basis for the operator. + * @param root Root scale (coarsest level) of the operator. + * @param reach Scale reach around @p root (default = 1). A reach of @c r + * allows interaction across \f$ 2r+1 \f$ adjacent levels. + * + * @note This constructor does not assemble any stencil; subclasses call + * their own initialization routines and may update @ref order. + */ DerivativeOperator(const MultiResolutionAnalysis &mra, int root, int reach = 1) : MWOperator(mra, root, reach) {} - DerivativeOperator(const DerivativeOperator &oper) = delete; - DerivativeOperator &operator=(const DerivativeOperator &oper) = delete; + DerivativeOperator(const DerivativeOperator &oper) = delete; ///< Non-copyable + DerivativeOperator &operator=(const DerivativeOperator &oper) = delete; ///< Non-assignable ~DerivativeOperator() override = default; + /** + * @brief Return the derivative order encoded by this operator. + * @returns Integer derivative order (1 by default; subclasses may set 2, 3, ...). + */ int getOrder() const { return order; } protected: + /** @brief Derivative order metadata (default = 1). Subclasses should set this. */ int order{1}; }; diff --git a/src/operators/HeatKernel.h b/src/operators/HeatKernel.h index bc5a8adba..173ae3171 100644 --- a/src/operators/HeatKernel.h +++ b/src/operators/HeatKernel.h @@ -30,36 +30,78 @@ namespace mrcpp { -/** @class HeatKernel. +/** + * @class HeatKernel + * @ingroup functions * - * @brief Heat kernel in \f$ \mathbb R^D \f$. + * @brief Single-term Gaussian expansion that represents the \(D\)-dimensional + * heat kernel \(K_t(\mathbf x)\) at diffusion time \(t>0\). * - * @details In \f$ \mathbb R^D \f$ the heat kernel has the form + * @tparam D Spatial dimension of the kernel to be modeled (1, 2, or 3). + * + * @details + * The continuous heat kernel in \(\mathbb R^D\) is + * \f[ + * K_t(\mathbf x) + * \;=\; + * \frac{1}{(4\pi t)^{D/2}} + * \exp\!\left(-\frac{\lVert \mathbf x\rVert^2}{4t}\right), + * \qquad t>0. + * \f] + * + * In MRCPP, separable operators are commonly assembled from 1D building blocks. + * This class therefore inherits from @ref GaussExp "GaussExp<1>" and stores a + * *single* 1D Gaussian term whose exponent and coefficient are chosen so that, + * when used inside separable constructions (e.g., @ref ConvolutionOperator), + * the resulting operator corresponds to the \(D\)-dimensional heat kernel. + * + * Concretely, with + * \f$ \beta = \frac{1}{4t} \f$ and \f$ \alpha = \big(\beta/\pi\big)^{D/2} \f$, + * we append the 1D Gaussian * \f[ - * K_t(x) - * = - * \frac 1{ (4 \pi t)^{D/2} } - * \exp - * \left( - * - \frac{ |x|^2 }{4t} - * \right) - * , \quad - * x \in \mathbb R^D - * \text{ and } - * t > 0 - * . + * g(x) = \alpha\, e^{-\beta x^2}, * \f] + * and the higher-dimensional operator logic (tensor products across coordinates) + * recovers the isotropic \(D\)-dimensional kernel. + * + * ### Notes + * - The constructor does not enforce \(t>0\) at runtime; pass a strictly positive + * value to avoid nonsensical parameters. + * - The class is intentionally minimal: it only sets up the Gaussian parameters + * and leaves projection/assembly to the caller (e.g., convolution operators). * + * ### Example + * @code + * MultiResolutionAnalysis<3> mra(box, basis); + * HeatKernel<3> Kt(0.05); // 3D heat kernel at t = 0.05 + * // Use Kt as a kernel for a ConvolutionOperator<3>, etc. + * @endcode */ template class HeatKernel final : public GaussExp<1> { public: + /** + * @brief Construct a heat kernel at diffusion time @p t. + * + * @param t Diffusion time (\f$ t>0 \f$). Smaller values yield narrower + * Gaussians (more localized kernels). + * + * @details + * Sets the Gaussian exponent to \f$ \beta = \frac{1}{4t} \f$ and the + * coefficient to \f$ \alpha = \big(\beta/\pi\big)^{D/2} \f$, then appends a + * single @ref GaussFunc "GaussFunc<1>" to this @ref GaussExp "GaussExp<1>". + */ HeatKernel(double t) : GaussExp<1>() { + // Exponent β = 1/(4t) double expo = 0.25 / t; + + // Amplitude α = (β/π)^{D/2} so that the separable product matches (4πt)^{-D/2} double coef = std::pow(expo / mrcpp::pi, D / 2.0); + + // Build the 1D Gaussian term and register it in the expansion GaussFunc<1> gFunc(expo, coef); this->append(gFunc); } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/HeatOperator.h b/src/operators/HeatOperator.h index f96560a81..b54f00575 100644 --- a/src/operators/HeatOperator.h +++ b/src/operators/HeatOperator.h @@ -29,39 +29,85 @@ namespace mrcpp { -/** @class HeatOperator semigroup +/** + * @file HeatOperator.h + * @brief Declaration of a separable convolution operator that realizes the heat + * semigroup \( e^{t\Delta} \) in \(D\) dimensions. * - * @brief Convolution with a heat kernel - * - * @details The exponential heat operator - * \f$ - * \exp \left( t \partial_x^2 \right) - * \f$ - * can be regarded as a convolution operator in \f$ L^2(\mathbb R) \f$ - * of the form + * @details + * In \f$\mathbb{R}^D\f$, the heat propagator at time \(t>0\) is a Gaussian + * convolution * \f[ - * \exp \left( t \partial_x^2 \right) - * f(x) + * (e^{t\Delta} f)(\mathbf x) + * = + * \int_{\mathbb{R}^D} + * K_t(\mathbf x-\mathbf y)\, f(\mathbf y)\, d\mathbf y, + * \qquad + * K_t(\mathbf r) * = - * \frac 1{ \sqrt{4 \pi t} } - * \int_{ \mathbb R } - * \exp - * \left( - * - \frac{ (x - y)^2 }{4t} - * \right) - * f(y) dy - * , \quad - * t > 0 - * . + * \frac{1}{(4\pi t)^{D/2}} + * \exp\!\left(-\frac{\|\mathbf r\|^2}{4t}\right). * \f] + * This class builds a rank-1 separable @ref ConvolutionOperator using a single 1D + * Gaussian kernel in each coordinate and assembles the \(D\)-dimensional operator + * as their tensor product. The amplitude/exponent are chosen so the overall kernel + * matches \(K_t\). + * + * Construction delegates to the base class to: + * - project the 1D kernel to a function tree on a 1D MRA, + * - lift it to operator trees via cross-correlation, + * - transform/caches the result in the multiwavelet domain. + * + * The overload with explicit @p root and @p reach is useful for periodic boundary + * conditions (PBC) or when the operator must be confined to a specific scale window. + * + * @see ConvolutionOperator, HeatKernel + */ + +/** + * @class HeatOperator + * @ingroup operators + * @brief D-dimensional heat semigroup as a separable Gaussian convolution. + * + * @tparam D Spatial dimension (1, 2, or 3). * + * @note The kernel is normalized so that \(\int_{\mathbb{R}^D} K_t = 1\) and + * the map \(f \mapsto e^{t\Delta}f\) is positivity-preserving and + * \(L^1\)-contractive in the continuous setting. */ template class HeatOperator final : public ConvolutionOperator { public: + /** + * @brief Construct the heat operator \(e^{t\Delta}\) on the default scale window. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis defining domain and basis. + * @param t Diffusion time; must be strictly positive. Smaller @p t yields a + * narrower Gaussian and requires finer resolution. + * @param prec Target build precision used while projecting the kernel and assembling + * the operator. + * + * @pre @p t > 0. + * @see ConvolutionOperator + */ HeatOperator(const MultiResolutionAnalysis &mra, double t, double prec); + + /** + * @brief Construct the heat operator with an explicit operator scale window. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis. + * @param t Diffusion time; must be strictly positive. + * @param prec Target build precision. + * @param root Operator root (coarsest) scale. + * @param reach Operator bandwidth (half-width in levels) at @p root; useful for + * periodic boundary conditions or domain tiling. Defaults to 1. + * + * @pre @p t > 0. + * @see MWOperator, ConvolutionOperator + */ HeatOperator(const MultiResolutionAnalysis &mra, double t, double prec, int root, int reach = 1); - HeatOperator(const HeatOperator &oper) = delete; - HeatOperator &operator=(const HeatOperator &oper) = delete; + + HeatOperator(const HeatOperator &oper) = delete; ///< Non-copyable + HeatOperator &operator=(const HeatOperator &oper) = delete; ///< Non-assignable }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/HelmholtzKernel.h b/src/operators/HelmholtzKernel.h index f5092d4a6..642158a7c 100644 --- a/src/operators/HelmholtzKernel.h +++ b/src/operators/HelmholtzKernel.h @@ -23,14 +23,68 @@ * */ +/** + * @file HelmholtzKernel.h + * @brief Declaration of a Gaussian expansion approximating the 3D Helmholtz (screened Coulomb) kernel. + * + * @details + * This header declares @c HelmholtzKernel, a convenience wrapper that builds a + * separable Gaussian expansion + * \f[ + * K_\mu(r) \;\approx\; \sum_{m=1}^{M} \beta_m\, e^{-\alpha_m r^2}, + * \f] + * that approximates the radial 3D Helmholtz/Yukawa kernel on a finite interval + * \f$[r_\min,r_\max]\f$ with a target relative accuracy \f$\varepsilon\f$. + * The class derives from @ref mrcpp::GaussExp "GaussExp<1>" and therefore can be + * used anywhere a one–dimensional Gaussian expansion is expected (e.g. to form + * convolution operators). + */ + #pragma once #include "functions/GaussExp.h" namespace mrcpp { +/** + * @class HelmholtzKernel + * @brief Gaussian expansion of the 3D Helmholtz (screened Coulomb / Yukawa) kernel. + * + * @details + * Constructs a 1D Gaussian expansion (in the radial variable) by sampling an + * integral representation of the Helmholtz kernel in a logarithmic parameter and + * applying a trapezoidal quadrature. The resulting set of Gaussian terms + * \f$\{\alpha_m,\beta_m\}\f$ is rescaled to the requested physical interval + * \f$[r_\min,r_\max]\f$. + * + * Typical usage: + * @code + * double mu = 1.0; // screening parameter + * double eps = 1e-8; // target relative accuracy + * double rmin = 1e-3, rmax = 10.0; + * mrcpp::HelmholtzKernel kernel(mu, eps, rmin, rmax); + * // 'kernel' is a GaussExp<1> and can be used to build convolution operators + * @endcode + * + * @note The actual separation rank @f$M@f$ depends on @p epsilon and the interval + * size. Extremely tight tolerances or very wide intervals may require a rank + * larger than the internal limit (see @c MaxSepRank in the implementation). + */ class HelmholtzKernel final : public GaussExp<1> { public: + /** + * @brief Build a Gaussian expansion of the Helmholtz kernel on \f$[r_\min,r_\max]\f$. + * + * @param mu Screening parameter \f$\mu > 0\f$ (Yukawa wavenumber). + * @param epsilon Target relative accuracy \f$0 < \varepsilon < 1\f$ for the expansion. + * @param r_min Lower radius bound (must satisfy \f$0 < r_\min < r_\max\f$). + * @param r_max Upper radius bound. + * + * @details + * The constructor fills the underlying @ref GaussExp "GaussExp<1>" with + * \f$M\f$ Gaussian terms determined by a trapezoidal discretization in a + * logarithmic variable. Endpoints are weighted with half-quadrature weights. + */ HelmholtzKernel(double mu, double epsilon, double r_min, double r_max); }; diff --git a/src/operators/HelmholtzOperator.h b/src/operators/HelmholtzOperator.h index fe5cf1d8b..b3e056bec 100644 --- a/src/operators/HelmholtzOperator.h +++ b/src/operators/HelmholtzOperator.h @@ -23,28 +23,93 @@ * */ +/** + * @file HelmholtzOperator.h + * @brief Declaration of a 3D separable convolution operator for the Helmholtz/Yukawa kernel. + * + * @details + * This header declares @ref mrcpp::HelmholtzOperator, a specialized + * @ref ConvolutionOperator that applies the screened Coulomb (Yukawa) Green's function + * in three spatial dimensions via a Gaussian expansion. The radial kernel + * \f$ e^{-\mu r}/r \f$ is approximated as a finite sum of 1D Gaussians, enabling + * separated application across Cartesian coordinates in the MRCPP multiwavelet basis. + */ + #pragma once #include "ConvolutionOperator.h" namespace mrcpp { -/** @class HelmholtzOperator +/** + * @class HelmholtzOperator + * @ingroup operators + * + * @brief Separable 3D convolution approximating the Helmholtz (Yukawa) Green's function. * - * @brief Convolution with the Helmholtz Green's function kernel + * @details + * The continuous kernel + * \f[ + * H(\mathbf r - \mathbf r') = \frac{e^{-\mu \lvert \mathbf r - \mathbf r' \rvert}} + * {\lvert \mathbf r - \mathbf r' \rvert} + * \f] + * is approximated by a Gaussian sum + * \f[ + * H(\mathbf r - \mathbf r') + * \;\approx\; + * \sum_{m=1}^{M} \alpha_m \exp\!\big( -\beta_m \lvert \mathbf r - \mathbf r' \rvert^2 \big), + * \f] + * which admits a *separable* representation in Cartesian coordinates, allowing the + * operator to be assembled as a tensor product of 1D convolution blocks within the + * MRCPP framework. The expansion coefficients \f$ \alpha_m, \beta_m \f$ and the + * separation rank \f$ M \f$ are chosen internally based on the requested build + * precision. * - * @details The Helmholtz kernel is approximated as a sum of gaussian functions - * in order to allow for separated application of the operator in the Cartesian - * directions: - * \f$ H(r-r') = \frac{e^{-\mu|r-r'|}}{|r-r'|} \approx \sum_m^M \alpha_m e^{-\beta_m (r-r')^2} \f$ + * ### Usage notes + * - This class is a convenience wrapper that constructs the Gaussian expansion and + * the corresponding multiwavelet operator trees; application is handled by the + * @ref ConvolutionOperator base. + * - For periodic worlds and explicit scale control, use the constructor that accepts + * @p root and @p reach (see below). + * + * @see ConvolutionOperator, HelmholtzKernel */ - class HelmholtzOperator final : public ConvolutionOperator<3> { public: + /** + * @brief Build a Helmholtz (Yukawa) convolution operator on the default scale window. + * + * @param mra 3D @ref MultiResolutionAnalysis defining domain and basis. + * @param m Screening parameter \f$ \mu > 0 \f$ of the Yukawa kernel. + * @param prec Target build precision controlling the Gaussian expansion accuracy + * and operator assembly tolerances. + * + * @details + * Internally: + * 1. Estimates admissible radial bounds from @p mra. + * 2. Constructs a Gaussian expansion for \f$ e^{-\mu r}/r \f$ at the requested accuracy. + * 3. Lifts the 1D kernels to separable operator blocks and caches them. + */ HelmholtzOperator(const MultiResolutionAnalysis<3> &mra, double m, double prec); + + /** + * @brief Build a Helmholtz operator with explicit root scale and reach (useful for PBC). + * + * @param mra 3D @ref MultiResolutionAnalysis. + * @param m Screening parameter \f$ \mu \f$ (same as above). + * @param prec Target build precision. + * @param root Operator root level (coarsest scale at which the operator is defined). + * @param reach Operator half-bandwidth at @p root (controls extent; relevant for periodic worlds). + * + * @details + * This overload confines the operator to a specified scale window and adjusts the + * radial extent accordingly—suitable for periodic boundary conditions and scenarios + * requiring strict bandwidth control. + */ HelmholtzOperator(const MultiResolutionAnalysis<3> &mra, double m, double prec, int root, int reach = 1); - HelmholtzOperator(const HelmholtzOperator &oper) = delete; - HelmholtzOperator &operator=(const HelmholtzOperator &oper) = delete; + + HelmholtzOperator(const HelmholtzOperator &oper) = delete; ///< Non-copyable + HelmholtzOperator &operator=(const HelmholtzOperator &oper) = delete; ///< Non-assignable }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/IdentityConvolution.h b/src/operators/IdentityConvolution.h index 8fd5f0956..db661f1e6 100644 --- a/src/operators/IdentityConvolution.h +++ b/src/operators/IdentityConvolution.h @@ -23,27 +23,88 @@ * */ +/** + * @file IdentityConvolution.h + * @brief Separable convolution operator that approximates the identity (Dirac delta) + * using a narrow Gaussian kernel. + * + * @details + * This header declares @ref mrcpp::IdentityConvolution, a thin convenience wrapper + * around @ref mrcpp::ConvolutionOperator that realizes an identity-like operator via + * a single Gaussian kernel, separably assembled in D dimensions. + * + * The kernel approximation is + * \f[ + * \delta(\mathbf r - \mathbf r') + * \;\approx\; + * \alpha \exp\!\bigl(-\beta\lVert \mathbf r - \mathbf r' \rVert^2\bigr), + * \f] + * which, in the MRCPP framework, is projected to a 1D function tree and lifted to + * D-dimensional operator blocks by cross-correlation. The resulting operator is + * bandwidth-limited and numerically stable for use in multiresolution workflows. + * + * The constructor takes a *build precision* that governs the kernel’s narrowness + * and the tolerances used during projection and assembly. Tighter precision yields + * a Gaussian closer to a true delta (hence a better identity approximation), at the + * cost of higher resolution demands. + */ + #pragma once #include "ConvolutionOperator.h" namespace mrcpp { -/** @class IdentityConvolution +/** + * @class IdentityConvolution + * @ingroup operators + * @brief Convolution with an identity (delta-like) kernel. * - * @brief Convolution with an identity kernel + * @tparam D Spatial dimension of the target operator (1, 2, or 3). * - * @details The identity kernel (Dirac's delta function) is approximated by a - * narrow Gaussian function: - * \f$ I(r-r') = \delta(r-r') \approx \alpha e^{-\beta (r-r')^2} \f$ + * @details + * The operator is represented as a separable sum (rank-1 in the default realization) + * of 1D Gaussian convolutions identical along each Cartesian direction. It is mainly + * intended for diagnostics and algorithmic baselines; for strict identity action, + * prefer direct coefficient transfers when applicable. + * + * The underlying kernel is the Gaussian surrogate of the Dirac delta, + * \f$ I(\mathbf r-\mathbf r') \approx \alpha e^{-\beta \lVert \mathbf r-\mathbf r' \rVert^2} \f$, + * with parameters chosen from the requested build precision. + * + * @see ConvolutionOperator */ - template class IdentityConvolution final : public ConvolutionOperator { public: + /** + * @brief Build an identity-like convolution operator on the default root/reach. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis defining domain and basis. + * @param prec Target build precision controlling the closeness to the delta function + * (narrowness of the Gaussian) and assembly tolerances. + * + * @details + * Internally constructs a single-term Gaussian kernel and invokes + * @ref ConvolutionOperator::initialize to assemble the separable operator blocks. + */ IdentityConvolution(const MultiResolutionAnalysis &mra, double prec); + + /** + * @brief Build an identity-like convolution operator with explicit scale window. + * + * @param mra D-dimensional @ref MultiResolutionAnalysis. + * @param prec Target build precision (as above). + * @param root Operator root level (coarsest scale at which the operator resides). + * @param reach Operator half-bandwidth at @p root (useful for periodic domains). + * + * @details + * Use this overload to confine the operator to a specific scale window—particularly + * helpful under periodic boundary conditions or when coordinating multiple operators. + */ IdentityConvolution(const MultiResolutionAnalysis &mra, double prec, int root, int reach = 1); - IdentityConvolution(const IdentityConvolution &oper) = delete; - IdentityConvolution &operator=(const IdentityConvolution &oper) = delete; + + IdentityConvolution(const IdentityConvolution &oper) = delete; ///< Non-copyable + IdentityConvolution &operator=(const IdentityConvolution &oper) = delete; ///< Non-assignable }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/IdentityKernel.h b/src/operators/IdentityKernel.h index 962ef2e7b..757566fa4 100644 --- a/src/operators/IdentityKernel.h +++ b/src/operators/IdentityKernel.h @@ -23,6 +23,32 @@ * */ +/** + * @file IdentityKernel.h + * @brief Gaussian surrogate of the Dirac delta kernel for use in separable + * convolution operators. + * + * @details + * This header declares @ref mrcpp::IdentityKernel, a convenience wrapper that + * builds a one-term @ref mrcpp::GaussExp "Gaussian expansion" approximating the + * identity (Dirac delta) kernel in \f$ \mathbb{R}^D \f$: + * \f[ + * \delta(x) \;\approx\; \alpha \, e^{-\beta x^2}, + * \f] + * with parameters chosen from a requested *narrowness* (precision) \f$ \varepsilon \f$. + * Concretely, + * \f[ + * \beta = \sqrt{\tfrac{1}{\varepsilon}}, \qquad + * \alpha = \left( \frac{\beta}{\pi} \right)^{D/2}. + * \f] + * + * The resulting object is a rank-1 @ref GaussExp<1> suitable for constructing + * separable, bandwidth-limited identity-like convolution operators; see + * @ref mrcpp::IdentityConvolution. + * + * @see IdentityConvolution, ConvolutionOperator, GaussExp, GaussFunc + */ + #pragma once #include "functions/GaussExp.h" @@ -30,15 +56,34 @@ namespace mrcpp { +/** + * @class IdentityKernel + * @ingroup kernels + * @brief Single-term Gaussian expansion approximating the Dirac delta in \f$ \mathbb{R}^D \f$. + * + * @tparam D Spatial dimension for the normalization of the Gaussian surrogate. + * + * @details + * Constructs a one-dimensional Gaussian \f$ \alpha e^{-\beta x^2} \f$ with + * \f$ \beta=\sqrt{1/\varepsilon} \f$ and + * \f$ \alpha=(\beta/\pi)^{D/2} \f$, + * then appends it to the underlying @ref GaussExp container. The parameter + * \p epsilon controls the narrowness of the surrogate: smaller values yield + * narrower Gaussians (closer to a true delta) but demand more resolution. + */ template class IdentityKernel final : public GaussExp<1> { public: + /** + * @brief Build a delta-like Gaussian kernel from a target narrowness \p epsilon. + * @param epsilon Positive parameter controlling the kernel width; smaller ⇒ narrower. + */ IdentityKernel(double epsilon) : GaussExp<1>() { - double expo = std::sqrt(1.0 / epsilon); - double coef = std::pow(expo / mrcpp::pi, D / 2.0); + double expo = std::sqrt(1.0 / epsilon); // β + double coef = std::pow(expo / mrcpp::pi, D / 2.0); // α = (β/π)^{D/2} GaussFunc<1> gFunc(expo, coef); this->append(gFunc); } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/MWOperator.h b/src/operators/MWOperator.h index 2dcad2b32..61c2db1cb 100644 --- a/src/operators/MWOperator.h +++ b/src/operators/MWOperator.h @@ -32,51 +32,144 @@ namespace mrcpp { -/** @class MWOperator +/** + * @class MWOperator + * @brief Base class for multiwavelet (MW) operators with separated expansions. * - * @brief Fixme + * @tparam D Spatial dimension of the function space (1, 2, or 3). * - * @details Fixme + * @details + * An MW operator is represented as a (typically low-rank) separated expansion + * whose per-term, per-dimension components are stored as pointers to + * @ref OperatorTree objects. This class provides: + * - bookkeeping for the operator’s *root* scale and *reach* (bandwidth), + * - storage for the raw operator terms and their per-dimension assignments, + * - utilities for bandwidth analysis and component access, and + * - construction of the 2D operator-domain @ref MultiResolutionAnalysis used + * by @ref OperatorTree. * + * Derived classes are responsible for building/populating @c raw_exp and then + * calling @ref initOperExp() to map raw terms into directional components. */ template class MWOperator { public: + /** + * @brief Construct an MW operator wrapper. + * + * @param mra D-dimensional analysis describing the function space/domain. + * @param root Operator root level (coarsest level at which the operator lives). + * @param reach Operator reach (half-width in levels at the root). Negative values + * can be interpreted by implementations as “auto”. + */ MWOperator(const MultiResolutionAnalysis &mra, int root, int reach) : oper_root(root) , oper_reach(reach) , MRA(mra) {} - MWOperator(const MWOperator &oper) = delete; - MWOperator &operator=(const MWOperator &oper) = delete; + + MWOperator(const MWOperator &oper) = delete; ///< Non-copyable + MWOperator &operator=(const MWOperator &oper) = delete; ///< Non-assignable virtual ~MWOperator() = default; + /** + * @brief Number of separated terms currently active in the operator. + */ int size() const { return this->oper_exp.size(); } + + /** + * @brief Maximum effective bandwidth at a given depth (scale). + * @param depth Depth index; if negative, returns the maximum over all depths. + * @return The maximum bandwidth, or -1 if @p depth is invalid. + */ int getMaxBandWidth(int depth = -1) const; + + /** + * @brief Vector of maximum effective bandwidths per depth. + * @return Reference to internal cache of maximum bandwidths. + */ const std::vector &getMaxBandWidths() const { return this->band_max; } + /** + * @brief Compute effective bandwidths for all components at all depths. + * @param prec Numeric tolerance used in bandwidth estimation. + */ void calcBandWidths(double prec); + + /** + * @brief Clear cached bandwidth information in all components. + */ void clearBandWidths(); + /** + * @brief Root level (coarsest scale) of the operator domain. + */ int getOperatorRoot() const { return this->oper_root; } + + /** + * @brief Operator reach (half-width at the root level). + */ int getOperatorReach() const { return this->oper_reach; } + /** + * @brief Mutable access to the @p i-th separated term, @p d-th dimension component. + * @param i Separated term index. + * @param d Cartesian direction index (0..D-1). + * @return Reference to the requested @ref OperatorTree. + */ OperatorTree &getComponent(int i, int d); + + /** + * @brief Const access to the @p i-th separated term, @p d-th dimension component. + * @param i Separated term index. + * @param d Cartesian direction index (0..D-1). + * @return Const reference to the requested @ref OperatorTree. + */ const OperatorTree &getComponent(int i, int d) const; + /** + * @brief Direct access to the array of D components for the @p i-th term. + */ std::array &operator[](int i) { return this->oper_exp[i]; } + + /** + * @brief Const direct access to the array of D components for the @p i-th term. + */ const std::array &operator[](int i) const { return this->oper_exp[i]; } protected: - int oper_root; - int oper_reach; - MultiResolutionAnalysis MRA; - std::vector> oper_exp; - std::vector> raw_exp; - std::vector band_max; + /** @name Operator geometry */ + ///@{ + int oper_root; ///< Operator root level (coarsest scale). + int oper_reach; ///< Operator reach (half-width in levels at the root). + MultiResolutionAnalysis MRA; ///< Function-space analysis (domain and basis). + ///@} + /** @name Operator storage */ + ///@{ + std::vector> oper_exp; ///< Active separated terms by dimension. + std::vector> raw_exp; ///< Owned raw operator terms (before assignment). + std::vector band_max; ///< Maximum bandwidth per depth. + ///@} + + /** + * @brief Build the 2D operator-domain MRA used by @ref OperatorTree. + * @details Operators act on a 2D lattice (row/column) even for D-D function spaces. + */ MultiResolutionAnalysis<2> getOperatorMRA() const; + /** + * @brief Initialize @ref oper_exp with @p M separated terms. + * @details By default, assigns the first @p M raw terms isotropically across all D dimensions. + * @param M Number of separated terms to activate. + */ void initOperExp(int M); + + /** + * @brief Assign a particular operator component for term @p i and direction @p d. + * @param i Term index in the separated expansion. + * @param d Cartesian direction index (0..D-1). + * @param oper Pointer to the @ref OperatorTree to be used for this component. + */ void assign(int i, int d, OperatorTree *oper) { this->oper_exp[i][d] = oper; } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/OperatorState.h b/src/operators/OperatorState.h index 677375632..f6602da12 100644 --- a/src/operators/OperatorState.h +++ b/src/operators/OperatorState.h @@ -23,10 +23,23 @@ * */ -/** OperatorState is a simple helper class for operator application. - * It keeps track of various state dependent variables and memory - * regions. We cannot have some of this information directly in OperatorFunc - * because of multi-threading issues. +/** + * @file + * @brief Lightweight state holder used during operator application. + * + * @details + * The operator application kernels (e.g., convolution/derivative calculators) + * are performance-critical and multi-threaded. To avoid sharing mutable + * state between threads, this helper encapsulates: + * - pointers to the current *source* (g) and *destination* (f) MW nodes, + * - precomputed size/stride quantities (`kp1`, `kp1_d`, …), + * - addresses of coefficient blocks for selected components (ft/gt), + * - temporary scratch buffers laid out for cache-friendly sweeps, + * - and small per-call metadata such as the maximum index offset + * (`maxDeltaL`) between the active nodes. + * + * It is deliberately simple (POD-like) and header-only to enable aggressive + * inlining by the compiler. */ #pragma once @@ -40,21 +53,66 @@ namespace mrcpp { +/** + * @def GET_OP_IDX(FT, GT, ID) + * @brief Build a 2-bit operator index for dimension @p ID from component flags. + * + * @details + * Encodes the *from* (FT) and *to* (GT) component bits at position @p ID into + * an index in the set {0,1,2,3}: + * \f[ + * \mathrm{idx} = 2 \cdot \big( (GT \gg ID) \& 1 \big) + * + \big( (FT \gg ID) \& 1 \big). + * \f] + * This compact index is used to select per-dimension operator blocks. + */ #define GET_OP_IDX(FT, GT, ID) (2 * ((GT >> ID) & 1) + ((FT >> ID) & 1)) +/** + * @class OperatorState + * @brief Thread-local state for applying an MW operator to node data. + * + * @tparam D Spatial dimension of the node (1–3). + * @tparam T Coefficient value type (e.g., double or std::complex). + * + * @details + * The class provides: + * - Binding of a *g-node* (source) at construction time. + * - Late binding of an *f-node* (destination) and its @ref NodeIndex. + * - Selection of component blocks (ft/gt) via @ref setFComponent and + * @ref setGComponent, exposing the corresponding coefficient slices. + * - Access to alternating scratch buffers arranged as + * `aux[0] = f-comp`, `aux[1..D-1]` ping-pong across `scr1`/`scr2`, + * and `aux[D] = g-comp`. + * + * The scratch layout avoids reallocation and reduces cache conflicts during + * dimension-by-dimension tensor sweeps. + */ template class OperatorState final { public: + /** + * @brief Construct with a source node and a raw scratch buffer. + * + * @param gn Source (g) node whose coefficients are read. + * @param scr1 Pointer to a scratch buffer of at least `kp1_d` elements. + * + * @details + * Two scratch regions are interleaved: `scr1` and `scr2 = scr1 + kp1_d`. + * For each interior dimension `i=1..D-1`, the buffer alternates between + * these two regions by parity of `i` to enable out-of-place 1D transforms. + */ OperatorState(MWNode &gn, T *scr1) : gNode(&gn) { - this->kp1 = this->gNode->getKp1(); - this->kp1_d = this->gNode->getKp1_d(); - this->kp1_2 = math_utils::ipow(this->kp1, 2); - this->kp1_dm1 = math_utils::ipow(this->kp1, D - 1); - this->gData = this->gNode->getCoefs(); + this->kp1 = this->gNode->getKp1(); // basis points per dim + this->kp1_d = this->gNode->getKp1_d(); // total points (kp1^D) + this->kp1_2 = math_utils::ipow(this->kp1, 2); // kp1^2 + this->kp1_dm1 = math_utils::ipow(this->kp1, D - 1); // kp1^(D-1) + this->gData = this->gNode->getCoefs(); this->maxDeltaL = -1; T *scr2 = scr1 + this->kp1_d; + // Assign alternating aux buffers for interior dimensions for (int i = 1; i < D; i++) { if (IS_ODD(i)) { this->aux[i] = scr2; @@ -64,56 +122,116 @@ template class OperatorState final { } } + /** + * @brief Convenience ctor: scratch storage provided as a std::vector. + * @param gn Source (g) node. + * @param scr1 Vector whose data pointer is used as scratch. + * + * @warning The vector must outlive the OperatorState. + */ OperatorState(MWNode &gn, std::vector scr1) : OperatorState(gn, scr1.data()) {} + + /** + * @brief Bind the destination (f) node and cache its coefficient pointer. + */ void setFNode(MWNode &fn) { this->fNode = &fn; this->fData = this->fNode->getCoefs(); } + + /** + * @brief Bind the destination node index and update @ref maxDeltaL. + * @param idx Destination node index in the tree. + * + * @details + * The maximum level shift \f$\max_d |f_l[d] - g_l[d]|\f$ is used to + * select scale-dependent operator stencils/bandwidths. + */ void setFIndex(NodeIndex &idx) { this->fIdx = &idx; calcMaxDeltaL(); } + + /** + * @brief Select the source (g) component and expose its coefficient slice. + * @param gt Component bitfield (typically 0/1 per dimension). + * + * @details Offsets the base pointer by `gt * kp1_d` and stores it in + * `aux[D]`, which operator kernels read as the final stage input. + */ void setGComponent(int gt) { this->aux[D] = this->gData + gt * this->kp1_d; this->gt = gt; } + + /** + * @brief Select the destination (f) component and expose its coefficient slice. + * @param ft Component bitfield (typically 0/1 per dimension). + * + * @details Offsets the base pointer by `ft * kp1_d` and stores it in + * `aux[0]`, which operator kernels use as the first stage buffer. + */ void setFComponent(int ft) { this->aux[0] = this->fData + ft * this->kp1_d; this->ft = ft; } + /** + * @brief Maximum level difference between the bound f/g nodes. + */ int getMaxDeltaL() const { return this->maxDeltaL; } + + /** + * @brief Build a compact operator index for dimension @p i (0..D-1). + * + * @details Uses @ref GET_OP_IDX on the currently bound @ref ft and @ref gt. + */ int getOperIndex(int i) const { return GET_OP_IDX(this->ft, this->gt, i); } + /** + * @brief Access the array of auxiliary data pointers used by kernels. + * @return `aux[0] = f-comp`, `aux[1..D-1]` scratch, `aux[D] = g-comp`. + */ T **getAuxData() { return this->aux; } + + /** + * @brief Access per-dimension operator data blocks (set by calculators). + */ double **getOperData() { return this->oData; } + // Calculator kernels are declared as friends to allow fast access. friend class ConvolutionCalculator; friend class DerivativeCalculator; private: - int ft; - int gt; - - int maxDeltaL; - double fThreshold; - double gThreshold; - // Shorthands - int kp1; - int kp1_2; - int kp1_d; - int kp1_dm1; - - MWNode *gNode; - MWNode *fNode; - NodeIndex *fIdx; - - T *aux[D + 1]; - T *gData; - T *fData; - double *oData[D]; - + // Current component selectors (bitfields) + int ft{0}; + int gt{0}; + + // Geometry / thresholds + int maxDeltaL; ///< max_d |f_l[d] - g_l[d]|; computed in calcMaxDeltaL() + double fThreshold; ///< (optional) threshold for f (may be set by calculators) + double gThreshold; ///< (optional) threshold for g (may be set by calculators) + + // Shorthands derived from the bound node + int kp1; ///< #points per dimension + int kp1_2; ///< kp1^2 + int kp1_d; ///< kp1^D (total points in a component block) + int kp1_dm1; ///< kp1^(D-1) + + // Bound nodes and indices + MWNode *gNode{nullptr}; + MWNode *fNode{nullptr}; + NodeIndex *fIdx{nullptr}; + + // Data pointers + T *aux[D + 1]{}; ///< [0]=f-comp, [1..D-1]=scratch, [D]=g-comp + T *gData{nullptr}; + T *fData{nullptr}; + double *oData[D]{}; ///< Per-dimension operator-specific metadata + + /// @brief Compute @ref maxDeltaL from the currently bound f/g nodes. void calcMaxDeltaL() { const auto &gl = this->gNode->getNodeIndex(); const auto &fl = *this->fIdx; @@ -126,4 +244,4 @@ template class OperatorState final { } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/OperatorStatistics.h b/src/operators/OperatorStatistics.h index 9a51728c0..4e8fc6313 100644 --- a/src/operators/OperatorStatistics.h +++ b/src/operators/OperatorStatistics.h @@ -23,6 +23,27 @@ * */ +/** + * @file OperatorStatistics.h + * @brief Thread-aware counters and summaries for multiwavelet operator application. + * + * @details + * This helper aggregates lightweight statistics collected while applying + * operators to multiwavelet nodes. For performance and thread-safety, counts + * are first accumulated in per-thread storage and later merged into global + * totals using @ref flushNodeCounters(). + * + * Tracked quantities: + * - Total number of destination (*f*) nodes where an operator was applied. + * - Total number of source (*g*) nodes evaluated. + * - Total number of destination nodes marked as “generalized”. + * - An 8×8 histogram of component-pair usages (indexed by `(ft, gt)`). + * + * The class intentionally avoids synchronization primitives inside hot loops; + * callers should invoke @ref flushNodeCounters() at safe points to consolidate + * results and reset per-thread buffers. + */ + #pragma once #include @@ -32,29 +53,75 @@ namespace mrcpp { +/** + * @class OperatorStatistics + * @brief Collects and reports counters during operator application. + * + * @note + * - Per-thread counters are sized using @c mrcpp_get_max_threads(). + * - Use the stream operator to print a human-readable summary. + */ class OperatorStatistics final { public: + /// Construct an empty statistics object with per-thread accumulators. OperatorStatistics(); + + /// Release all dynamically allocated per-thread buffers and histograms. ~OperatorStatistics(); + /** + * @brief Consolidate per-thread counters into global totals and reset locals. + * + * @details + * After calling this, @c totFCount, @c totGCount, @c totGenCount and + * @c totCompCount reflect all work since the previous flush, and the + * per-thread buffers are zeroed. + */ void flushNodeCounters(); - template void incrementFNodeCounters(const MWNode &fNode, int ft, int gt); - template void incrementGNodeCounters(const MWNode &gNode); + /** + * @brief Increment destination (*f*)-node counters for the current thread. + * @tparam D Spatial dimension of the node. + * @tparam T Coefficient/value type stored by the node. + * @param fNode Destination node being updated. + * @param ft Destination component bitfield (0–7). + * @param gt Source component bitfield (0–7). + * + * @details + * Increments the per-thread f-node count, updates the (ft,gt) entry of the + * per-thread 8×8 histogram, and increments the generalized-node count if + * @c fNode.isGenNode() returns true. + */ + template + void incrementFNodeCounters(const MWNode &fNode, int ft, int gt); + + /** + * @brief Increment source (*g*)-node counters for the current thread. + * @tparam D Spatial dimension of the node. + * @tparam T Coefficient/value type stored by the node. + * @param gNode Source node being processed (unused; for interface symmetry). + */ + template + void incrementGNodeCounters(const MWNode &gNode); + + /// Print a summary of accumulated totals and the component histogram. friend std::ostream &operator<<(std::ostream &o, const OperatorStatistics &os) { return os.print(o); } protected: - int nThreads; - int totFCount; - int totGCount; - int totGenCount; - int *fCount; - int *gCount; - int *genCount; - Eigen::Matrix *totCompCount; - Eigen::Matrix **compCount; + int nThreads; ///< Number of worker threads. + int totFCount; ///< Global total of applied *f*-nodes. + int totGCount; ///< Global total of processed *g*-nodes. + int totGenCount; ///< Global total of applied generalized nodes. + + int *fCount; ///< Per-thread *f*-node counters (size = nThreads). + int *gCount; ///< Per-thread *g*-node counters (size = nThreads). + int *genCount; ///< Per-thread generalized-node counters (size = nThreads). + + Eigen::Matrix *totCompCount; ///< Global (ft,gt) 8×8 usage histogram. + Eigen::Matrix **compCount; ///< Per-thread 8×8 usage histograms. + /// Internal pretty-printer used by the stream operator. std::ostream &print(std::ostream &o) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/PHOperator.h b/src/operators/PHOperator.h index d5f58bf29..a041a4f11 100644 --- a/src/operators/PHOperator.h +++ b/src/operators/PHOperator.h @@ -23,33 +23,75 @@ * */ +/** + * @file PHOperator.h + * @brief Declaration of a Holoborodko-style smoothing derivative operator. + * + * @details + * This header declares @ref mrcpp::PHOperator, a lightweight derivative operator + * constructed from the smooth, low-noise differentiators introduced by + * Pavel Holoborodko (see + * + * reference link). + * + * The operator is assembled in the multiwavelet framework and is intended + * primarily for experimentation/validation with smoothing differentiators. + * For robust production work: + * - use @ref mrcpp::ABGVOperator for functions with cusps/discontinuities, + * - or @ref mrcpp::BSOperator for sufficiently smooth functions. + * + * @see mrcpp::DerivativeOperator, mrcpp::ABGVOperator, mrcpp::BSOperator + */ + #pragma once #include "DerivativeOperator.h" namespace mrcpp { -/** @class PHOperator +/** + * @class PHOperator + * @ingroup operators * - * @brief Derivative operator based on the smoothing derivative of - * - * Pavel Holoborodko - * . + * @brief Derivative operator based on Holoborodko’s smooth, low-noise differentiators. * - * NOTE: This is _not_ the recommended derivative operator for practial calculations, it's - * a proof-of-concept operator. Use the ABGVOperator for "cuspy" functions and the - * BSOperator for smooth functions. + * @tparam D Spatial dimension (1, 2, or 3). + * + * @details + * This class derives from @ref DerivativeOperator and provides a separable, + * single-component derivative approximation whose stencil is defined by the + * Holoborodko differentiators. Internally, the concrete operator blocks are + * produced by a PH-specific calculator and stored in an @ref OperatorTree. + * + * @note This is **not** the recommended operator for general calculations. Prefer + * @ref ABGVOperator for non-smooth data and @ref BSOperator for smooth data. */ - template class PHOperator final : public DerivativeOperator { public: + /** + * @brief Construct a PH-based derivative operator. + * + * @param mra MultiResolutionAnalysis defining the domain and basis. + * @param order Derivative order (typically 1 or 2). + * + * @warning Orders beyond those implemented by the underlying calculator + * are not supported. + */ PHOperator(const MultiResolutionAnalysis &mra, int order); - PHOperator(const PHOperator &oper) = delete; - PHOperator &operator=(const PHOperator &oper) = delete; + + PHOperator(const PHOperator &oper) = delete; ///< Non-copyable + PHOperator &operator=(const PHOperator &oper) = delete; ///< Non-assignable protected: + /** + * @brief Build and cache the internal operator representation. + * + * @details + * Creates the PH calculator for the current scaling basis and requested order, + * assembles an @ref OperatorTree with bandwidth control, transforms it to the + * multiwavelet domain, and initializes the separable expansion. + */ void initialize(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/PoissonKernel.h b/src/operators/PoissonKernel.h index 725a6d68b..88718e678 100644 --- a/src/operators/PoissonKernel.h +++ b/src/operators/PoissonKernel.h @@ -23,15 +23,60 @@ * */ +/** + * @file PoissonKernel.h + * @brief Declaration of a Gaussian-expansion approximation to the 3D Poisson kernel. + */ + #pragma once #include "functions/GaussExp.h" namespace mrcpp { +/** + * @class PoissonKernel + * @brief Gaussian expansion of the radial Poisson kernel \f$ 1/r \f$ on a bounded interval. + * + * @details + * Builds a separated, finite Gaussian expansion that approximates the 3D Poisson kernel + * \f[ + * \frac{1}{\lvert \mathbf r \rvert} \;\approx\; \sum_{m=1}^{M} \beta_m \, e^{-\alpha_m r^2}, + * \f] + * valid for radii \f$ r \in [r_{\min},\, r_{\max}] \f$. The coefficients + * \f$ \{\alpha_m,\beta_m\}_{m=1}^M \f$ are produced by truncating and discretizing + * (via a trapezoidal rule in logarithmic variables) a continuous representation of + * \f$ 1/r \f$, with the truncation window and step size chosen to meet a target + * relative tolerance \p epsilon on the *normalized* interval + * \f$ [r_{\min}/r_{\max},\,1] \f$ and then rescaled back to \f$ [r_{\min}, r_{\max}] \f$. + * + * The resulting object is a one-dimensional @ref GaussExp "GaussExp<1>" whose entries + * can be used by separable convolution operators to assemble higher-dimensional + * kernels and operators. + * + * @note + * - Requires \f$ r_{\min} > 0 \f$ and \f$ r_{\max} > r_{\min} \f$. + * - The number of Gaussian terms \f$ M \f$ is bounded internally (see `MaxSepRank`); + * exceeding this bound will abort construction in the implementation. + * + * @see GaussExp, GaussFunc + */ class PoissonKernel final : public GaussExp<1> { public: + /** + * @brief Construct a Gaussian expansion of \f$ 1/r \f$ on \f$ [r_{\min}, r_{\max}] \f$. + * + * @param epsilon Target relative accuracy (heuristic; smaller ⇒ more terms). + * @param r_min Lower radius of validity, must satisfy \f$ r_{\min} > 0 \f$. + * @param r_max Upper radius of validity, must satisfy \f$ r_{\max} > r_{\min} \f$. + * + * @details + * Populates this @ref GaussExp with terms \f$ (\alpha_m,\beta_m) \f$ so that + * \f$ \sum_m \beta_m e^{-\alpha_m r^2} \approx 1/r \f$ over the requested interval. + * Coefficients are ordered according to the underlying quadrature and include + * standard endpoint weighting for the trapezoidal rule. + */ PoissonKernel(double epsilon, double r_min, double r_max); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/PoissonOperator.h b/src/operators/PoissonOperator.h index ad65e9cae..c03b4e85f 100644 --- a/src/operators/PoissonOperator.h +++ b/src/operators/PoissonOperator.h @@ -23,28 +23,72 @@ * */ +/** + * @file PoissonOperator.h + * @brief Separable multiwavelet convolution operator for the 3D Poisson kernel. + * + * The operator realizes a fast approximation of the Green's function + * \f$ P(\mathbf r-\mathbf r') = 1/\lvert \mathbf r-\mathbf r'\rvert \f$ + * by expanding it into a finite sum of Gaussians, + * \f[ + * \frac{1}{\lvert \mathbf r-\mathbf r'\rvert} + * \;\approx\; + * \sum_{m=1}^{M} \alpha_m \exp\!\big(-\beta_m \lvert \mathbf r-\mathbf r'\rvert^2\big), + * \f] + * which enables a tensor–separable application along Cartesian axes in the + * multiwavelet framework. See @ref ConvolutionOperator for assembly details and + * @ref PoissonOperator (implementation) for construction mechanics. + */ + #pragma once #include "ConvolutionOperator.h" namespace mrcpp { -/** @class PoissonOperator +/** + * @class PoissonOperator + * @brief Convolution with the Poisson Green's function kernel in 3D. + * + * @details + * The Poisson kernel is approximated by a Gaussian expansion, allowing the operator + * to be applied as a separated product over Cartesian directions: + * \f[ + * P(\mathbf r-\mathbf r') + * = \frac{1}{\lvert \mathbf r-\mathbf r'\rvert} + * \;\approx\; \sum_{m=1}^{M} \alpha_m \exp\!\big(-\beta_m \lvert \mathbf r-\mathbf r'\rvert^2\big). + * \f] + * Each 1D Gaussian term is projected to a function tree and lifted to operator blocks + * via cross-correlation; the full 3D operator is then cached for efficient application. * - * @brief Convolution with the Poisson Green's function kernel + * The expansion accuracy and kernel width are controlled by the requested build precision. + * An overload with explicit @p root/@p reach can confine the operator to a chosen scale window + * (useful for periodic-style setups or domain-decomposition experiments). * - * @details The Poisson kernel is approximated as a sum of Gaussian - * functions in order to allow for separated application of the operator - * in the Cartesian directions: - * \f$ P(r-r') = \frac{1}{|r-r'|} \approx \sum_m^M \alpha_m e^{-\beta_m (r-r')^2} \f$ + * @see ConvolutionOperator, PoissonKernel */ - class PoissonOperator final : public ConvolutionOperator<3> { public: + /** + * @brief Construct a Poisson operator on the default root/reach of the provided MRA. + * + * @param mra 3D @ref MultiResolutionAnalysis defining the domain and scaling basis. + * @param prec Target build precision controlling the Gaussian expansion (smaller ⇒ tighter/longer rank). + */ PoissonOperator(const MultiResolutionAnalysis<3> &mra, double prec); + + /** + * @brief Construct a Poisson operator with an explicit scale window. + * + * @param mra 3D @ref MultiResolutionAnalysis. + * @param prec Target build precision. + * @param root Operator root level (coarsest scale where the operator resides). + * @param reach Operator reach (half-width, in levels) around @p root; affects bandwidth/PBC-like extent. + */ PoissonOperator(const MultiResolutionAnalysis<3> &mra, double prec, int root, int reach = 1); - PoissonOperator(const PoissonOperator &oper) = delete; - PoissonOperator &operator=(const PoissonOperator &oper) = delete; + + PoissonOperator(const PoissonOperator &oper) = delete; ///< Non-copyable + PoissonOperator &operator=(const PoissonOperator &oper) = delete; ///< Non-assignable }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/operators/TimeEvolutionOperator.h b/src/operators/TimeEvolutionOperator.h index 839ba7b40..a33f4745f 100644 --- a/src/operators/TimeEvolutionOperator.h +++ b/src/operators/TimeEvolutionOperator.h @@ -23,6 +23,25 @@ * */ +/** + * @file TimeEvolutionOperator.h + * @brief Interface for a separable multiwavelet representation of the + * free-particle Schrödinger time-evolution semigroup. + * + * The operator approximates (real or imaginary parts of) + * \f[ + * U(t) \;=\; e^{\, i\,t\,\Delta} + * \f] + * by building an operator tree via cross-correlations between scaling functions + * and a kernel whose coefficients are expressed through power integrals + * \f$ \widetilde J_m \f$. Two construction modes are exposed: + * - **Uniform** to a user-specified finest scale. + * - **Adaptive** down to a fixed scale (bounded work in power integrals). + * + * See the .cpp for build details and post-processing steps (MW transform, + * rough-scale filtering, caching, etc.). + */ + #pragma once #include "ConvolutionOperator.h" @@ -31,43 +50,118 @@ namespace mrcpp { -/** @class TimeEvolutionOperator +/** + * @class TimeEvolutionOperator + * @ingroup operators + * + * @brief Multiwavelet operator for the free-particle Schrödinger semigroup. * - * @brief Semigroup of the free-particle Schrodinger equation + * @tparam D Spatial dimensionality (1, 2, or 3). * - * @details Represents the semigroup - * \f$ - * \exp \left( i t \partial_x^2 \right) - * . - * \f$ - * Matrix elements (actual operator tree) of the operator can be obtained by calling getComponent(0, 0). + * @details + * Provides a separable @ref ConvolutionOperator-like interface that assembles + * the matrix elements of + * \f$ U(t) = e^{\, i\,t\,\Delta} \f$ + * (or its real/imaginary part) in a multi-resolution setting. The actual + * operator blocks can be accessed via + * @code + * getComponent(0, 0) + * @endcode + * after construction (rank-1 expansion in current implementation). * - * @note So far implementation is done for Legendre scaling functions in 1d. + * Internally, coefficients are generated from per-scale power integrals + * \f$ \widetilde J_m \f$ and a dedicated cross-correlation calculator suited + * for the Schrödinger kernel. * - * \todo: Extend to D dimensinal on a general interval [a, b] in the future. + * @note Current implementation targets Legendre scaling functions; practical + * use has primarily focused on 1D, but the interface is templated in @p D. * + * @todo Extend to general dimension on arbitrary intervals \f$[a,b]\f$. */ template class TimeEvolutionOperator : public ConvolutionOperator // One can use ConvolutionOperator instead as well { public: - TimeEvolutionOperator(const MultiResolutionAnalysis &mra, double prec, double time, int finest_scale, bool imaginary, int max_Jpower = 30); - TimeEvolutionOperator(const MultiResolutionAnalysis &mra, double prec, double time, bool imaginary, int max_Jpower = 30); - TimeEvolutionOperator(const TimeEvolutionOperator &oper) = delete; - TimeEvolutionOperator &operator=(const TimeEvolutionOperator &oper) = delete; + /** + * @brief Construct a **uniform** time-evolution operator. + * + * @param mra Target @ref MultiResolutionAnalysis (domain/basis). + * @param prec Build precision controlling pruning and tolerances. + * @param time Time parameter \f$ t \f$. + * @param finest_scale Finest (uniform) scale to which the operator is built. + * @param imaginary If `true` build the imaginary part; otherwise real part. + * @param max_Jpower Maximum number of power-integral terms (default: 30). + */ + TimeEvolutionOperator(const MultiResolutionAnalysis &mra, + double prec, + double time, + int finest_scale, + bool imaginary, + int max_Jpower = 30); + + /** + * @brief Construct an **adaptive** time-evolution operator. + * + * @param mra Target @ref MultiResolutionAnalysis (domain/basis). + * @param prec Build precision controlling pruning and tolerances. + * @param time Time parameter \f$ t \f$. + * @param imaginary If `true` build the imaginary part; otherwise real part. + * @param max_Jpower Maximum number of power-integral terms (default: 30). + * + * @details + * The adaptive build proceeds down to a fixed scale to bound the number of + * required power integrals; see the source for the current depth choice. + */ + TimeEvolutionOperator(const MultiResolutionAnalysis &mra, + double prec, + double time, + bool imaginary, + int max_Jpower = 30); + + TimeEvolutionOperator(const TimeEvolutionOperator &oper) = delete; ///< Non-copyable + TimeEvolutionOperator &operator=(const TimeEvolutionOperator &oper) = delete; ///< Non-assignable virtual ~TimeEvolutionOperator() = default; + /// @return The build precision used to assemble the operator. double getBuildPrec() const { return this->build_prec; } protected: + /** @name Builder entry points (implementation detail) + * Internal construction routines used by the public constructors. + */ + ///@{ + /** + * @brief Uniform build to @p finest_scale. + * @param time Time parameter \f$ t \f$. + * @param finest_scale Finest scale to which the tree is constructed. + * @param imaginary Build imaginary (true) or real (false) part. + * @param max_Jpower Maximum number of power-integral terms. + */ void initialize(double time, int finest_scale, bool imaginary, int max_Jpower); + + /** + * @brief Adaptive build (fixed maximum depth). + * @param time Time parameter \f$ t \f$. + * @param imaginary Build imaginary (true) or real (false) part. + * @param max_Jpower Maximum number of power-integral terms. + */ void initialize(double time, bool imaginary, int max_Jpower); + + /** + * @brief Semi-uniform prototype (not implemented). + * @param time Time parameter \f$ t \f$. + * @param imaginary Build imaginary (true) or real (false) part. + * @param max_Jpower Maximum number of power-integral terms. + * @warning This method is a placeholder and aborts if called. + */ void initializeSemiUniformly(double time, bool imaginary, int max_Jpower); + ///@} + /// Set the build precision recorded by this operator. void setBuildPrec(double prec) { this->build_prec = prec; } - double build_prec{-1.0}; - SchrodingerEvolution_CrossCorrelation *cross_correlation{nullptr}; + double build_prec{-1.0}; ///< Build precision (assembly/pruning). + SchrodingerEvolution_CrossCorrelation *cross_correlation{nullptr}; ///< Per-dimension cross-correlation engine. }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/AdditionCalculator.h b/src/treebuilders/AdditionCalculator.h index 9223f1ae6..b820810cd 100644 --- a/src/treebuilders/AdditionCalculator.h +++ b/src/treebuilders/AdditionCalculator.h @@ -24,46 +24,130 @@ */ #pragma once +/** + * @file AdditionCalculator.h + * @brief Node-wise accumulator used during adaptive construction to sum + * multiresolution (MW) functions with optional conjugation. + * + * @details + * This header defines #mrcpp::AdditionCalculator, a lightweight + * #mrcpp::TreeCalculator that, for each target node, fetches the + * corresponding node from every input function in a + * #mrcpp::FunctionTreeVector and accumulates a weighted sum of their + * coefficients. No refinement policy is implemented here; pair this + * calculator with a #mrcpp::TreeBuilder and a suitable #mrcpp::TreeAdaptor. + * + * Complex handling: + * - For complex `T`, each term uses either the raw coefficients or their + * complex conjugate according to the XOR of the input tree's own + * `conjugate()` flag and the calculator-wide `conj` flag. + */ + +#include // std::is_same +#include // std::conj #include "TreeCalculator.h" #include "trees/FunctionTreeVector.h" namespace mrcpp { -template class AdditionCalculator final : public TreeCalculator { +/** + * @class AdditionCalculator + * @brief Node-wise accumulator for adaptive sums of multiresolution functions. + * + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient type (e.g., `double`, `ComplexDouble`). + * + * @details + * For each target node \p node_o (identified by its NodeIndex), this calculator + * gathers the corresponding node from every input function in a + * #mrcpp::FunctionTreeVector and accumulates the weighted coefficients into + * \p node_o: + * + * \f[ + * \mathbf{c}_o \;=\; \sum_i \alpha_i \,\mathbf{c}_i . + * \f] + * + * If \p T is complex, optional conjugation is applied according to the XOR of + * the per-tree conjugation flag and the calculator-wide @ref conj flag; i.e. + * a term uses \f$\overline{\mathbf{c}_i}\f$ iff exactly one of the two flags is set. + * + * This class performs **no grid refinement** or transforms; it only writes + * coefficients, marks presence, and updates node norms. It is intended to be + * used inside the adaptive loop driven by #mrcpp::TreeBuilder together with an + * appropriate adaptor. + */ +template +class AdditionCalculator final : public TreeCalculator { public: + /** + * @brief Construct an addition calculator over a set of input trees. + * + * @param[in] inp Vector of (coefficient, tree) pairs to be summed. + * @param[in] conjugate Global conjugation toggle for complex types. For + * each input tree, the effective conjugation applied + * is `tree.conjugate() XOR conjugate`. + * + * @note All input trees are assumed to share an MRA compatible with the + * output tree provided to the builder. + */ AdditionCalculator(const FunctionTreeVector &inp, bool conjugate = false) : sum_vec(inp) , conj(conjugate) {} private: + /// Vector of weighted input trees to sum. FunctionTreeVector sum_vec; + /// Global conjugation toggle for complex accumulation (see ctor docs). bool conj; + /** + * @brief Accumulate coefficients for a single output node. + * + * @param[in,out] node_o Target node whose coefficients are overwritten + * by the weighted sum of matching input nodes. + * + * @details + * Steps: + * 1. Zero \p node_o coefficients. + * 2. For each entry \f$(\alpha_i, f_i)\f$ in @ref sum_vec: + * - Fetch (and create if needed) the input node with the same index as \p node_o. + * - Accumulate \f$\alpha_i \cdot \mathbf{c}_i\f$ (or its conjugate for complex, + * following the XOR rule) into \p node_o. + * 3. Mark coefficients present and update norms. + * + * No transforms are performed here; coefficients are assumed to be in the + * same representation across all trees. + */ void calcNode(MWNode &node_o) override { node_o.zeroCoefs(); const NodeIndex &idx = node_o.getNodeIndex(); T *coefs_o = node_o.getCoefs(); + for (int i = 0; i < this->sum_vec.size(); i++) { T c_i = get_coef(this->sum_vec, i); FunctionTree &func_i = get_func(this->sum_vec, i); - // This generates missing nodes + + // This generates the node if missing in func_i const MWNode &node_i = func_i.getNode(idx); const T *coefs_i = node_i.getCoefs(); int n_coefs = node_i.getNCoefs(); + if constexpr (std::is_same::value) { - if (func_i.conjugate() xor conj) { - for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * std::conj(coefs_i[j]); } + const bool use_conj = (func_i.conjugate() xor conj); + if (use_conj) { + for (int j = 0; j < n_coefs; j++) coefs_o[j] += c_i * std::conj(coefs_i[j]); } else { - for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * coefs_i[j]; } + for (int j = 0; j < n_coefs; j++) coefs_o[j] += c_i * coefs_i[j]; } } else { - for (int j = 0; j < n_coefs; j++) { coefs_o[j] += c_i * coefs_i[j]; } + for (int j = 0; j < n_coefs; j++) coefs_o[j] += c_i * coefs_i[j]; } } + node_o.setHasCoefs(); node_o.calcNorms(); } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/AnalyticAdaptor.h b/src/treebuilders/AnalyticAdaptor.h index 3e9ca0613..a7359a2e9 100644 --- a/src/treebuilders/AnalyticAdaptor.h +++ b/src/treebuilders/AnalyticAdaptor.h @@ -30,15 +30,69 @@ namespace mrcpp { -template class AnalyticAdaptor final : public TreeAdaptor { +/** + * @class AnalyticAdaptor + * @brief Refinement policy that consults an analytic (representable) function. + * + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient/value type (e.g., `double`, `ComplexDouble`). + * + * @details + * This adaptor requests refinement of a node when the provided analytic + * function is **not yet visible** at the node's current scale and is **not + * identically zero** on the node's cell. Concretely: + * + * - If `func->isVisibleAtScale(scale, kp1)` returns **true**, the node is + * considered sufficiently resolved at this scale → **do not split**. + * - Else, if `func->isZeroOnInterval(lb, ub)` returns **true**, the function + * vanishes on the cell → **do not split**. + * - Otherwise, the feature likely requires more resolution → **split**. + * + * The visibility test uses the node’s polynomial order `k+1` (via `getKp1()`) + * as the quadrature/collocation count hint for the analytic oracle. + * + * ### Requirements on the analytic function + * The `RepresentableFunction` passed in must implement: + * - `bool isVisibleAtScale(int scale, int nQuadPts) const;` + * - `bool isZeroOnInterval(const double* lower, const double* upper) const;` + * + * ### Typical usage + * @code{.cpp} + * AnalyticFunction<3,double> f(...); // implements the required interface + * AnalyticAdaptor<3,double> adapt(f, mra.getMaxScale()); + * TreeBuilder<3,double> builder; + * DefaultCalculator<3,double> calc; + * builder.build(tree, calc, adapt, -1); // maxIter: unbounded + * @endcode + */ +template +class AnalyticAdaptor final : public TreeAdaptor { public: + /** + * @brief Construct an analytic-driven adaptor. + * @param f Analytic (representable) function used as refinement oracle. + * @param ms Maximum allowed scale for splitting (forwarded to TreeAdaptor). + */ AnalyticAdaptor(const RepresentableFunction &f, int ms) : TreeAdaptor(ms) , func(&f) {} private: + /// Pointer to the refinement oracle (not owned). const RepresentableFunction *func; + /** + * @brief Decide whether a node should be split. + * + * @param node Candidate node to test. + * @return `true` if refinement is requested; `false` otherwise. + * + * @details + * Uses the two-step logic described in the class documentation: + * 1) skip split if visible at current scale, + * 2) skip split if identically zero on the node's interval, + * 3) otherwise split. + */ bool splitNode(const MWNode &node) const override { int scale = node.getScale(); int nQuadPts = node.getKp1(); @@ -50,4 +104,4 @@ template class AnalyticAdaptor final : public TreeAdaptor class ConvolutionCalculator final : public TreeCalculator { +/** + * @class ConvolutionCalculator + * @brief Performs adaptive convolution of a function tree with a convolution operator. + * + * @tparam D Spatial dimensionality (1–3). + * @tparam T Coefficient scalar type (`double` or `ComplexDouble`). + * + * @details + * The calculator traverses the output tree (owned by the base + * @ref TreeCalculator) and, for each node, applies the convolution operator to + * the relevant neighborhood (an operator *band*). Band sizes are derived from + * the operator bandwidth and the current tree depth, and can be further tuned + * by a user-supplied per-node precision function @ref setPrecFunction. + * + * The implementation records timing and operator statistics per band/component + * to aid profiling, and can optionally manipulate the operator prior to + * application (see @ref startManipulateOperator). + * + * ### Lifetime / ownership + * - @ref ConvolutionCalculator does **not** own the operator nor the input + * function tree; it stores non-owning pointers. + * - Timers and internal matrices are allocated and cleared by + * @ref initTimers / @ref clearTimers . + */ +template +class ConvolutionCalculator final : public TreeCalculator { public: + /** + * @brief Construct a calculator for \f$ g = \mathcal{O}\{f\} \f$. + * + * @param p Target accuracy (relative or absolute depending on usage). + * @param o Convolution operator to apply. + * @param f Input function tree \f$ f \f$. + * @param depth Maximum traversal depth for the output tree + * (defaults to @c MaxDepth for the MRA). + * + * @pre @p o and @p f must remain valid for the lifetime of the calculator. + */ ConvolutionCalculator(double p, ConvolutionOperator &o, FunctionTree &f, int depth = MaxDepth); + + /// @brief Destructor. Releases timers and internal band-size tables. ~ConvolutionCalculator() override; - MWNodeVector *getInitialWorkVector(MWTree &tree) const override; + /** + * @brief Produce the initial work vector of nodes for the output tree. + * + * @param tree Output tree that will receive the convolution result. + * @return Pointer to a heap-allocated vector of nodes to start from. + * + * @details + * The initial set typically includes end nodes (or generator nodes in + * banded neighborhoods) where the operator action is non-zero. + * The caller (base class) assumes ownership of the returned vector. + */ + MWNodeVector* getInitialWorkVector(MWTree &tree) const override; + /** + * @brief Set a per-node precision function. + * + * @param prec_func A functor returning the local tolerance for a node index. + * + * @details + * When provided, the calculator uses @p prec_func(idx) to refine the target + * precision locally (e.g., tighter near singularities), typically in + * conjunction with the global precision passed to the constructor. + */ void setPrecFunction(const std::function &idx)> &prec_func) { this->precFunc = prec_func; } + + /** + * @brief Enable operator manipulation prior to application. + * + * @param excUnit If `true`, manipulate on the unit cell (periodic contexts). + * + * @details + * When enabled the operator may be preconditioned, symmetrized, or mapped + * to a fundamental domain before application. Exact behavior depends on + * the associated @ref ConvolutionOperator. + */ void startManipulateOperator(bool excUnit) { this->manipulateOperator = true; this->onUnitcell = excUnit; } private: + // ---- Configuration / inputs ------------------------------------------------ + + /// @brief Maximum output depth to visit. int maxDepth; + + /// @brief Global target precision (interpreted by implementation). double prec; + + /// @brief Toggle for pre-application manipulation of the operator. bool manipulateOperator{false}; + + /// @brief Toggle for unit-cell manipulation in periodic problems. bool onUnitcell{false}; + + /// @brief Non-owning pointer to the convolution operator. ConvolutionOperator *oper; + + /// @brief Non-owning pointer to the input function tree f(r). FunctionTree *fTree; - std::vector band_t; - std::vector calc_t; - std::vector norm_t; + // ---- Instrumentation ------------------------------------------------------- + + /// @brief Per-band timers for operator band building. + std::vector band_t; + + /// @brief Per-band timers for the main convolution kernels. + std::vector calc_t; + + /// @brief Per-band timers for norm/threshold checks. + std::vector norm_t; + + /// @brief Aggregate operator statistics (bandwidths, touches, flops estimates). OperatorStatistics operStat; - std::vector bandSizes; + + // ---- Band-size modeling ---------------------------------------------------- + + /** + * @brief Precomputed band-size factors per depth/component. + * + * @details + * Each matrix has shape `(maxDepth+1) × nComp2`, where `nComp = 2^D` + * and `nComp2 = nComp * nComp`. Linearized index + * `k = gt * nComp + ft` maps from generator (`gt`) and father (`ft`) + * component pairs to a band-size factor at a given depth. + */ + std::vector bandSizes; + + /** + * @brief Optional local precision override. + * + * @details + * Defaults to a neutral functor returning 1.0. When set by + * @ref setPrecFunction, it scales or replaces the global precision on a + * per-node basis. + */ std::function &idx)> precFunc = [](const NodeIndex &idx) { return 1.0; }; + /// @brief Number of component blocks (2^D) in a multiwavelet tensor. static const int nComp = (1 << D); + + /// @brief Number of component-pair interactions ( (2^D) × (2^D) ). static const int nComp2 = (1 << D) * (1 << D); - MWNodeVector *makeOperBand(const MWNode &gNode, std::vector> &idx_band); + // ---- Band construction helpers -------------------------------------------- + + /** + * @brief Build an operator band (list of neighbor nodes) around @p gNode. + * + * @param gNode Generator node (output-space anchor). + * @param idx_band Output: collected node indices forming the band. + * @return Heap-allocated node vector corresponding to @p idx_band. + * + * @details + * The band is determined by the operator bandwidth at the scale of + * @p gNode and the precomputed band-size factors. The returned vector + * contains concrete node handles in traversal order. + */ + MWNodeVector* makeOperBand(const MWNode &gNode, std::vector> &idx_band); + + /** + * @brief Recursive fill of an operator band. + * + * @param band Destination node vector to append to. + * @param idx_band Indices to materialize. + * @param idx Current index under construction. + * @param nbox Periodic-box replication vector per dimension. + * @param dim Current dimension (0..D-1) being expanded. + */ void fillOperBand(MWNodeVector *band, std::vector> &idx_band, NodeIndex &idx, const int *nbox, int dim); + // ---- Timing / statistics lifecycle ----------------------------------------- + + /// @brief Allocate and start per-band timers. void initTimers(); + + /// @brief Stop and free per-band timers. void clearTimers(); + + /// @brief Print a compact timing breakdown per component/band. void printTimers() const; + // ---- Band-size factors ----------------------------------------------------- + + /// @brief Allocate @ref bandSizes tables. void initBandSizes(); + + /** + * @brief Lookup the band-size factor for a component pair at a depth. + * + * @param i Which table (implementation-defined band decomposition). + * @param depth Tree depth. + * @param os Current operator-state (provides `gt` and `ft`). + * @return Precomputed size factor. + */ int getBandSizeFactor(int i, int depth, const OperatorState &os) const { int k = os.gt * this->nComp + os.ft; return (*this->bandSizes[i])(depth, k); } + /** + * @brief Compute the band-size factors for all component pairs at a depth. + * + * @param bs Destination matrix (size `(maxDepth+1) × nComp2`). + * @param depth Target depth to (re)compute. + * @param bw Operator bandwidth descriptor. + */ void calcBandSizeFactor(Eigen::MatrixXi &bs, int depth, const BandWidth &bw); + // ---- Core calculation hooks (TreeCalculator overrides) --------------------- + + /** + * @brief Compute the output contribution for a single node. + * + * @param node Output node to update. + * + * @details + * Builds the relevant operator band around @p node, applies the operator to + * the input tree restricted to that band, and accumulates the result into + * @p node's coefficients. Precision is controlled by @ref prec and + * @ref precFunc. + */ void calcNode(MWNode &node) override; + + /** + * @brief Post-processing after a full tree sweep. + * + * @details + * Prints per-band timing information, clears timers, and re-initializes + * the timing infrastructure for possible subsequent sweeps. + */ void postProcess() override { printTimers(); clearTimers(); initTimers(); } + // ---- Operator application kernels ------------------------------------------ + + /** + * @brief Apply a single operator component to the current band. + * + * @param os Operator state (component indices, buffers, thresholds, etc.). + */ void applyOperComp(OperatorState &os); + + /** + * @brief Apply the full operator (all components) for a given band index. + * + * @param i Band table index / decomposition slot. + * @param os Operator state for the current output node. + */ void applyOperator(int i, OperatorState &os); + + /** + * @brief Tensor-kernel variant of @ref applyOperComp (blocked/tensor form). + * + * @param os Operator state for the current output node. + * + * @details + * May use batched/tensorized multiply-adds for better cache locality when + * the component layout allows it. + */ void tensorApplyOperComp(OperatorState &os); + // ---- Tree maintenance ------------------------------------------------------- + + /** + * @brief Ensure parent nodes are materialized/touched before children writes. + * + * @param tree Output tree to touch/wake parents in. + * + * @details + * Some backends require parent nodes to exist to safely commit child + * contributions (e.g., for allocation, normalization, or boundary handling). + */ void touchParentNodes(MWTree &tree) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/CopyAdaptor.h b/src/treebuilders/CopyAdaptor.h index a7825cca0..f26fbf354 100644 --- a/src/treebuilders/CopyAdaptor.h +++ b/src/treebuilders/CopyAdaptor.h @@ -25,22 +25,113 @@ #pragma once +/** + * @file + * @brief Adaptor that copies data from one or more source trees into a target tree. + * + * @details + * Declares @ref mrcpp::CopyAdaptor, a lightweight @ref TreeAdaptor that + * drives adaptive traversal/refinement for copy operations. The adaptor + * decides whether to split/visit nodes based on a per-dimension band–width + * window around the regions populated in the source tree(s) and on a + * max-scale constraint provided at construction. + */ + #include "TreeAdaptor.h" #include "trees/FunctionTreeVector.h" namespace mrcpp { -template class CopyAdaptor final : public TreeAdaptor { +/** + * @class CopyAdaptor + * @brief Adaptor for reproducing (copying) function-tree structure/coefficients. + * + * @tparam D Spatial dimensionality (1–3). + * @tparam T Coefficient scalar type. + * + * @details + * A @ref TreeAdaptor used by generic tree algorithms to: + * - restrict refinement to a maximum scale (`ms`), and + * - gate splitting to nodes that fall within a per-dimension *band width* + * neighborhood around the non-empty region of one or more **source trees**. + * + * This enables efficient copying/subsetting of a function tree (or a + * @ref FunctionTreeVector) into a new tree while avoiding unnecessary + * refinement outside the area of interest. + */ +template +class CopyAdaptor final : public TreeAdaptor { public: + /** + * @brief Construct an adaptor using a single source tree. + * + * @param t Source tree to mirror/copy from. + * @param ms Maximum scale (depth) allowed for refinement in the target. + * @param bw Pointer to an array of length @c D with per-dimension band + * half-widths (in node/grid units). Values control how far + * from the source support we keep refining; non-positive + * entries are treated as zero. + * + * @note The adaptor stores an internal vector view that contains @p t. + */ CopyAdaptor(FunctionTree &t, int ms, int *bw); + + /** + * @brief Construct an adaptor using multiple source trees. + * + * @param t Collection of source trees whose union of supports guides + * refinement/visitation. + * @param ms Maximum scale (depth) allowed for refinement in the target. + * @param bw Pointer to an array of length @c D with per-dimension band + * half-widths (in node/grid units). See the single-tree + * constructor for interpretation. + */ CopyAdaptor(FunctionTreeVector &t, int ms, int *bw); private: + /** + * @brief Per-dimension refinement band half-widths. + * + * @details + * For dimension @c d, only nodes whose index lies within + * @c bandWidth[d] boxes of the populated region of the source will be + * considered for splitting. A value of zero limits refinement strictly to + * the currently populated footprint. + */ int bandWidth[D]; + + /** + * @brief Source tree collection used to drive the copy operation. + * + * @details + * When constructed from a single tree, this vector contains exactly one + * entry referencing that tree; otherwise it aliases the user-provided + * vector. No ownership transfer takes place. + */ FunctionTreeVector tree_vec; + /** + * @brief Initialize the @ref bandWidth array from a user buffer. + * + * @param bw Pointer to an array of length @c D; negative values are clamped + * to zero. + */ void setBandWidth(int *bw); + + /** + * @brief Decide whether a node should be split during traversal. + * + * @param node Node under consideration in the *target* tree. + * @return `true` if the node lies within the refinement window and the + * max-scale policy permits further subdivision; `false` otherwise. + * + * @details + * The decision combines: + * - the maximum allowed scale passed at construction, and + * - the per-dimension band width around the union support of the source + * tree(s) stored in @ref tree_vec. + */ bool splitNode(const MWNode &node) const override; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/CrossCorrelationCalculator.h b/src/treebuilders/CrossCorrelationCalculator.h index b97e5e6da..af80cdaa7 100644 --- a/src/treebuilders/CrossCorrelationCalculator.h +++ b/src/treebuilders/CrossCorrelationCalculator.h @@ -25,22 +25,109 @@ #pragma once +/** + * @file + * @brief Cross-correlation tree calculator for 2D functions with a 1D kernel. + * + * @details + * Declares @ref mrcpp::CrossCorrelationCalculator, a concrete + * @ref TreeCalculator that evaluates a (discrete) cross–correlation between a + * two–dimensional multiresolution function and a one–dimensional kernel, + * typically along one axis of each 2D node. The implementation leverages a + * @ref CrossCorrelationCache to reuse banded operator data and reduce + * per–node setup costs across the traversal. + */ + #include "TreeCalculator.h" #include "core/CrossCorrelationCache.h" namespace mrcpp { +/** + * @class CrossCorrelationCalculator + * @brief Applies a cached cross–correlation with a 1D kernel to a 2D tree. + * + * @details + * This calculator specializes @ref TreeCalculator for 2D nodes + * (`TreeCalculator<2>`). During traversal, @ref calcNode pulls the relevant + * coefficient band(s) from the current node, applies a cross–correlation with + * the supplied 1D @ref kernel, and writes the result back to the destination + * tree/state managed by the base calculator. + * + * ### Design notes + * - The kernel is provided as a `FunctionTree<1>` and is **not owned** by the + * calculator (the caller must guarantee its lifetime). + * - Internally, @ref applyCcc parametrizes on the scalar coefficient type + * (`double`, `std::complex`, …) via the template parameter `T` of + * @ref CrossCorrelationCache, enabling reuse for both real and complex trees. + * - A @ref CrossCorrelationCache is used to memoize structure- and + * bandwidth-dependent intermediates (e.g., band shapes, transforms) so that + * repeated applications across many nodes are efficient. + */ class CrossCorrelationCalculator final : public TreeCalculator<2> { public: + /** + * @brief Construct a calculator using a given 1D kernel. + * + * @param k Reference to a 1D function tree representing the correlation + * kernel. The pointer is stored; the object must outlive the + * calculator. + * + * @note No ownership is transferred; `k` must remain valid for the entire + * calculation. + */ CrossCorrelationCalculator(FunctionTree<1> &k) : kernel(&k) {} private: + /** + * @brief Non-owning pointer to the 1D kernel used in the cross–correlation. + */ FunctionTree<1> *kernel; + /** + * @brief Compute the cross–correlated output for a single 2D node. + * + * @param node The node to process within the current output tree. The node's + * scale/index determine the coefficient bands to read/write. + * + * @details + * This override fetches the node's input coefficients (from the source tree + * configured in the base @ref TreeCalculator), prepares/cache-reuses the + * operator band via a @ref CrossCorrelationCache, and applies the + * correlation along the appropriate axis. The resulting coefficients are + * accumulated into the node's output buffer. + * + * @warning The method assumes that the base calculator has already + * orchestrated any required refinement and that input/output tree + * storage is valid for @p node. + */ void calcNode(MWNode<2> &node) override; - template void applyCcc(MWNode<2> &node, CrossCorrelationCache &ccc); + /** + * @brief Apply the cached cross–correlation to a node with a concrete scalar type. + * + * @tparam T Coefficient scalar type of the node (e.g., `double`, + * `std::complex`). + * @param node The target node to which the operator is applied. + * @param ccc A cross–correlation cache specialized for @p T that provides + * band sizes, temporary buffers, and any preassembled operator + * pieces required for efficient application. + * + * @details + * Performs the type-specific core computation: + * - obtains the relevant coefficient band(s) from @p node, + * - uses @p ccc to assemble or retrieve the required operator slice + * derived from @ref kernel, + * - applies the correlation (with the proper band width and alignment), + * - writes/accumulates the result into the node's output coefficients. + * + * The separation into this templated helper allows the public + * @ref calcNode to dispatch based on the underlying node coefficient type + * without duplicating logic. + */ + template + void applyCcc(MWNode<2> &node, CrossCorrelationCache &ccc); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/DefaultCalculator.h b/src/treebuilders/DefaultCalculator.h index 4a1a4ce54..90074ccdf 100644 --- a/src/treebuilders/DefaultCalculator.h +++ b/src/treebuilders/DefaultCalculator.h @@ -25,23 +25,77 @@ #pragma once +/** + * @file + * @brief Trivial calculator that clears coefficients and norms on each node. + * + * @details + * Declares @ref mrcpp::DefaultCalculator, a minimal + * @ref TreeCalculator implementation that: + * - iterates over a vector of nodes **sequentially** (no OpenMP), + * - for each node, clears coefficient/flag state via + * @ref MWNode::clearHasCoefs and resets stored norms via + * @ref MWNode::clearNorms. + * + * This is useful as a baseline or as a final cleanup pass when no numerical + * operator needs to be applied but tree state must be normalized. + */ + #include "TreeCalculator.h" namespace mrcpp { -template class DefaultCalculator final : public TreeCalculator { +/** + * @class DefaultCalculator + * @brief Minimal calculator that performs per-node cleanup. + * + * @tparam D Spatial dimension of the tree. + * @tparam T Scalar coefficient type (`double`, `std::complex`, …). + * + * @details + * The calculator eschews OpenMP parallelism for its node-vector traversal + * because the work per node is trivial and sequential iteration is typically + * faster (lower overhead). If parallel traversal is desired, use or derive + * from an alternative calculator that enables OpenMP in + * `calcNodeVector`. + */ +template +class DefaultCalculator final : public TreeCalculator { public: - // Reimplementation without OpenMP, the default is faster this way + /** + * @brief Process a vector of nodes sequentially. + * + * @param nodeVec Container of node pointers to process. + * + * @details + * Calls @ref calcNode on each entry in order. The method deliberately + * avoids OpenMP to minimize overhead for very small, constant-time work. + * + * @complexity Linear in `nodeVec.size()`. + */ void calcNodeVector(MWNodeVector &nodeVec) override { int nNodes = nodeVec.size(); for (int n = 0; n < nNodes; n++) { calcNode(*nodeVec[n]); } } private: + /** + * @brief Clear coefficient presence flags and stored norms for a node. + * + * @param node The node whose local state will be reset. + * + * @details + * - @ref MWNode::clearHasCoefs marks that the node no longer has valid + * coefficients. + * - @ref MWNode::clearNorms removes any cached norm values. + * + * This does **not** modify topology (no splitting/merging) and does not + * change coefficient arrays beyond clearing the "has coefs" state. + */ void calcNode(MWNode &node) override { node.clearHasCoefs(); node.clearNorms(); } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/DerivativeCalculator.h b/src/treebuilders/DerivativeCalculator.h index 347554a46..b89d97b4d 100644 --- a/src/treebuilders/DerivativeCalculator.h +++ b/src/treebuilders/DerivativeCalculator.h @@ -25,45 +25,187 @@ #pragma once +/** + * @file + * @brief Derivative calculator on multiresolution function trees. + * + * @details + * Declares @ref mrcpp::DerivativeCalculator, a @ref TreeCalculator that applies a + * directional differential operator to an input @ref FunctionTree and writes the + * result into the calculator's target tree. The implementation constructs a + * scale-aware operator “band” around each output node and evaluates the + * derivative using tensorized component applications while optionally collecting + * timing and bandwidth statistics. + */ + #include "TreeCalculator.h" #include "operators/OperatorStatistics.h" namespace mrcpp { -template class DerivativeCalculator final : public TreeCalculator { +/** + * @class DerivativeCalculator + * @brief Applies a directional derivative operator to a function tree. + * + * @tparam D Spatial dimension of the tree (1–3 typical). + * @tparam T Scalar coefficient type (e.g., `double`, `std::complex`). + * + * @details + * The calculator computes \f$ g = \partial_{x_{dir}}(f) \f$ where: + * - `dir` selects the Cartesian direction \f$0 \le dir < D\f$, + * - `oper` encapsulates the discretized derivative stencils/filters, + * - `fTree` is the source tree and the calculator’s target is the destination. + * + * The traversal is driven by the base @ref TreeCalculator; for each output + * node, a localized operator band is constructed (see @ref makeOperBand) and + * the operator is applied in a tensorized fashion (see + * @ref tensorApplyOperComp). Optional timing is gathered per phase (band + * building, application, norm updates) and summarized on completion. + */ +template +class DerivativeCalculator final : public TreeCalculator { public: + /** + * @brief Construct a derivative calculator. + * + * @param dir Direction index of the derivative (0-based, `< D`). + * @param o Reference to the derivative operator to apply. + * @param f Reference to the **input/source** function tree. + * + * @note The destination/output tree is owned by the base + * @ref TreeCalculator (`this->outTree()`). + */ DerivativeCalculator(int dir, DerivativeOperator &o, FunctionTree &f); + + /// @brief Virtual destructor; prints and clears timers in @ref postProcess. ~DerivativeCalculator() override; + /** + * @brief Provide the initial work vector for traversal. + * + * @param tree Output tree on which work will be scheduled. + * @return Pointer to a newly created vector of nodes to process first. + * + * @details + * The default strategy is to populate the initial vector with those nodes + * of @p tree that require operator application (implementation-dependent). + */ MWNodeVector *getInitialWorkVector(MWTree &tree) const override; + + /** + * @brief Compute the derivative for a node pair. + * + * @param fNode Source node from @ref fTree (input). + * @param gNode Destination node on the output tree (result of \f$\partial_{dir} f\f$). + * + * @details + * Builds a local band around @p gNode, gathers contributions from @p fNode + * within the operator bandwidth, then accumulates into @p gNode. + */ void calcNode(MWNode &fNode, MWNode &gNode); private: - int applyDir; - FunctionTree *fTree; - DerivativeOperator *oper; + // --- Configuration and inputs ------------------------------------------------- + + int applyDir; ///< Direction index along which to differentiate. + FunctionTree *fTree{nullptr};///< Source function tree (input). + DerivativeOperator *oper{nullptr}; ///< Differential operator to apply. - std::vector band_t; - std::vector calc_t; - std::vector norm_t; - OperatorStatistics operStat; + // --- Timing/statistics -------------------------------------------------------- + std::vector band_t; ///< Timers for band construction per depth or phase. + std::vector calc_t; ///< Timers for operator application per depth or phase. + std::vector norm_t; ///< Timers for norm/cleanup updates per depth or phase. + OperatorStatistics operStat;///< Aggregate operator and bandwidth statistics. + + // --- Work preparation --------------------------------------------------------- + + /** + * @brief Build the operator "band" (neighborhood) for an output node. + * + * @param gNode Output (destination) node. + * @param idx_band Output: list of source node indices involved by bandwidth. + * @return A vector of node pointers representing the band to be processed. + * + * @details + * The band captures the set of input nodes that may contribute to @p gNode + * under the derivative operator’s bandwidth model across scales and + * spatial adjacency. + */ MWNodeVector makeOperBand(const MWNode &gNode, std::vector> &idx_band); + /// @brief Initialize per-phase timers based on tree depth/layout. void initTimers(); + + /// @brief Stop/clear all timers and release related resources. void clearTimers(); + + /// @brief Print a concise timing summary and collected operator statistics. void printTimers() const; + // --- TreeCalculator interface ------------------------------------------------- + + /** + * @brief Per-node callback from the traversal engine. + * + * @param node Output node to compute; pulls required input contributions. + * + * @details + * For each output @p node, constructs the operator band (via + * @ref makeOperBand), then delegates the actual stencil application to + * @ref applyOperator / @ref tensorApplyOperComp. Norm and flag updates are + * performed as needed. + */ void calcNode(MWNode &node) override; + + /** + * @brief Hook invoked after a traversal pass. + * + * @details Prints timing statistics, clears timers, and re-initializes + * them to be ready for subsequent passes. + */ void postProcess() override { printTimers(); clearTimers(); initTimers(); } + // --- Operator application ----------------------------------------------------- + + /** + * @brief Apply the derivative operator to the current band/state. + * + * @param os Operator state for the current output node and component pair. + * + * @details + * Chooses an application path depending on operator bandwidth and node + * configuration, then accumulates results into the output node. + */ void applyOperator(OperatorState &os); + + /** + * @brief Specialized path for zero-bandwidth (local) derivative application. + * + * @param os Operator state for the current output node and component pair. + * + * @details + * When the derivative is strictly local in the discretization (bandwidth 0), + * this fast path avoids neighborhood assembly and directly applies the + * local stencil. + */ void applyOperator_bw0(OperatorState &os); + + /** + * @brief Tensorized component application of the operator. + * + * @param os Operator state (contains gt/ft component ids, indices, buffers). + * + * @details + * Performs dimension-wise application of the derivative operator using + * separable tensor components, respecting the grid/scale layout in + * @ref OperatorState. + */ void tensorApplyOperComp(OperatorState &os); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/MapCalculator.h b/src/treebuilders/MapCalculator.h index 33f799ee9..8af401713 100644 --- a/src/treebuilders/MapCalculator.h +++ b/src/treebuilders/MapCalculator.h @@ -25,29 +25,129 @@ #pragma once +/** + * @file + * @brief Node-wise nonlinear mapping calculator for multiresolution trees. + * + * @details + * This header defines @ref mrcpp::MapCalculator, a concrete + * @ref mrcpp::TreeCalculator that applies a *pointwise* value mapping + * (nonlinear allowed) to the coefficients of @ref mrcpp::FunctionTree nodes. + * + * The calculator: + * 1. Locates the corresponding input node (creating a *temporary copy* when + * needed) via `FunctionTree::getNode(idx)`. + * 2. Brings that input node to **value space**: inverse multi-wavelet (MW) + * reconstruction followed by a forward coefficient transform (CV) to obtain + * per-point values. + * 3. Applies the user-supplied functor `fmap : T → T` elementwise. + * 4. Transforms the output node back to the compressed representation + * (CV backward, MW compression), sets flags, and updates norms. + * + * The calculator operates node-by-node and is typically orchestrated by the + * surrounding @ref TreeCalculator driver (which handles traversal, work queues, + * and adaptive refinement). + */ + #include "TreeCalculator.h" namespace mrcpp { -template class MapCalculator final : public TreeCalculator { +/** + * @class MapCalculator + * @brief Node-local nonlinear mapping (value transform) on a function tree. + * + * @tparam D Spatial dimension (e.g., 1, 2, 3). + * @tparam T Coefficient/value scalar type (e.g., `double`, `ComplexDouble`). + * + * @details + * `MapCalculator` evaluates a user-provided mapping functor `fmap` on the node + * samples corresponding to the input tree @p inp and writes the transformed + * values into the output tree managed by the base @ref TreeCalculator. + * + * ### Transform pipeline per node + * For a given output node `node_o` at index `idx`: + * - Acquire an **input** node copy: `MWNode node_i = func->getNode(idx)`. + * - Convert coefficients to point samples: + * - `node_i.mwTransform(Reconstruction);` + * - `node_i.cvTransform(Forward);` + * - Apply `fmap` to each sample: `coefs_o[j] = fmap(coefs_i[j]);` + * - Convert back to the compressed representation: + * - `node_o.cvTransform(Backward);` + * - `node_o.mwTransform(Compression);` + * - Finalize bookkeeping: `setHasCoefs()` and `calcNorms()`. + * + * @note + * The calculator assumes the **input** and **output** trees are compatible + * (same MRA order/domain and node layouts as scheduled by the driver). Any + * missing input nodes are generated on-the-fly by `getNode(idx)` as a *copy*. + * + * @warning + * The mapping functor @p fm **must** be thread-safe and side-effect free, + * as nodes can be processed in parallel by the base calculator. + */ +template +class MapCalculator final : public TreeCalculator { public: + /** + * @brief Construct a node-mapping calculator. + * + * @param fm Elementwise mapping functor `T → T` (copied/moved in). + * @param inp Input function tree providing source coefficients. + * + * @pre @p inp is initialized on a valid MRA compatible with the target + * output tree managed by the driver. + */ MapCalculator(FMap fm, FunctionTree &inp) : func(&inp) , fmap(std::move(fm)) {} private: + /// Pointer to the input function tree (non-owning). FunctionTree *func; + + /// Elementwise mapping functor applied to node samples. FMap fmap; + + /** + * @brief Compute mapped coefficients for one output node. + * + * @param node_o Output node to fill (topology assumed prepared by driver). + * + * @details + * - Fetch input node at the same index (creates a temporary node copy). + * - Perform MW reconstruction and CV forward transforms on the input copy. + * - Apply @ref fmap to each sample and write into @p node_o. + * - Restore compressed representation of @p node_o (CV backward, MW compress). + * - Mark coefficients present and update node norms. + * + * @complexity + * \f$O(n_\text{coef})\f$ per node (excluding transform costs), where + * \f$n_\text{coef}\f$ is the number of local coefficients. + * + * @thread_safety + * Independent across nodes when the driver runs in parallel. The functor + * @ref fmap must be reentrant. + */ void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); - int n_coefs = node_o.getNCoefs(); + const int n_coefs = node_o.getNCoefs(); T *coefs_o = node_o.getCoefs(); - // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + + // Obtain a temporary input node copy at the same index. + MWNode node_i = func->getNode(idx); + + // Bring input node to value space (reconstruction → forward CV). node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); + + // Apply the non-linear map pointwise. const T *coefs_i = node_i.getCoefs(); - for (int j = 0; j < n_coefs; j++) { coefs_o[j] = fmap(coefs_i[j]); } + for (int j = 0; j < n_coefs; ++j) { + coefs_o[j] = fmap(coefs_i[j]); + } + + // Return to compressed representation and finalize bookkeeping. node_o.cvTransform(Backward); node_o.mwTransform(Compression); node_o.setHasCoefs(); @@ -55,4 +155,4 @@ template class MapCalculator final : public TreeCalculator class MultiplicationAdaptor : public TreeAdaptor { +/** + * @class MultiplicationAdaptor + * @brief Refinement rule for the product of two function trees. + * + * @tparam D Spatial dimension (e.g., 1, 2, 3). + * @tparam T Coefficient/value scalar type (e.g., `double`, `ComplexDouble`). + * + * @details + * The adaptor is typically used to drive construction of an output grid for + * \f$f_0 \cdot f_1\f$. At each node index it: + * 1. Retrieves the corresponding nodes from both input trees. + * 2. Computes the square-rooted maximum scaling and wavelet norms + * \f$(S_i, W_i)\f$. + * 3. Forms the estimate + * \f$\text{multNorm} = W_0 S_1 + W_1 S_0 + W_0 W_1\f$. + * 4. Requests a split if `multNorm > prec` and at least one input node is not + * a leaf. This effectively avoids refining deeper than either input grid, + * because when both inputs have zero wavelet contribution at a node, + * `multNorm == 0` and no further refinement is triggered. + * + * @note The input vector @ref trees must contain **exactly two** trees; a + * runtime error is emitted otherwise. + * + * @warning The adaptor reads norms from the input trees during `splitNode`. + * The member @ref trees is `mutable` to allow this from a `const` context. + */ +template +class MultiplicationAdaptor : public TreeAdaptor { public: + /** + * @brief Construct the multiplication refinement rule. + * + * @param pr Refinement threshold \f$\text{prec}\f$ for the multNorm estimate. + * @param ms Maximum scale/depth hint passed to @ref TreeAdaptor base. + * @param t The pair of input trees as a @ref FunctionTreeVector. + * + * @pre `t.size() == 2` + */ MultiplicationAdaptor(double pr, int ms, FunctionTreeVector &t) : TreeAdaptor(ms) , prec(pr) , trees(t) {} + ~MultiplicationAdaptor() override = default; protected: + /// Refinement threshold used against `multNorm`. double prec; + + /** + * @brief Input trees used to estimate the product's local complexity. + * + * @details + * Marked `mutable` so that `splitNode` can retrieve node views from a + * `const` context without implying logical modification. + */ mutable FunctionTreeVector trees; + /** + * @brief Decide whether an output node should be split. + * + * @param node The (prospective) output node whose index determines + * which input nodes are inspected. + * @return `true` if `multNorm > prec` **and** at least one of the input + * nodes is not a leaf; otherwise `false`. + * + * @details + * - Retrieves the corresponding nodes from both input trees at the same + * @ref NodeIndex. + * - Computes + * \f$S_i=\sqrt{\text{max scaling square norm}},\; + * W_i=\sqrt{\text{max wavelet square norm}}\f$. + * - Forms \f$\text{multNorm}=W_0 S_1 + W_1 S_0 + W_0 W_1\f$. + * - Triggers refinement when the estimate exceeds @ref prec, except when + * both inputs are already leaf nodes at this index. + * + * @throws Emits `MSG_ERROR` if `trees.size() != 2`. + */ bool splitNode(const MWNode &node) const override { if (this->trees.size() != 2) MSG_ERROR("Invalid tree vec size: " << this->trees.size()); + auto &pNode0 = get_func(trees, 0).getNode(node.getNodeIndex()); auto &pNode1 = get_func(trees, 1).getNode(node.getNodeIndex()); + + // Square roots convert stored square norms to norms. double maxW0 = std::sqrt(pNode0.getMaxWSquareNorm()); double maxW1 = std::sqrt(pNode1.getMaxWSquareNorm()); double maxS0 = std::sqrt(pNode0.getMaxSquareNorm()); double maxS1 = std::sqrt(pNode1.getMaxSquareNorm()); - // The wavelet contribution (in the product of node0 and node1) can be approximated as + // Estimated wavelet contribution in the product node. double multNorm = maxW0 * maxS1 + maxW1 * maxS0 + maxW0 * maxW1; - // Note: this never refine deeper than one scale more than input tree grids, because when wavelets are zero - // for both input trees, multPrec=0 In addition, we force not to refine deeper than input tree grids - if (multNorm > this->prec and not(pNode0.isLeafNode() and pNode1.isLeafNode())) { + // Never refine beyond both input grids' leaf level. + if (multNorm > this->prec && !(pNode0.isLeafNode() && pNode1.isLeafNode())) { return true; } else { return false; @@ -65,4 +159,4 @@ template class MultiplicationAdaptor : public TreeAdaptor class MultiplicationCalculator final : public TreeCalculator { +/** + * @class MultiplicationCalculator + * @brief Computes the pointwise product of several input trees into an output tree. + * + * @tparam D Spatial dimension (e.g., 1, 2, 3). + * @tparam T Coefficient scalar type (`double` or `ComplexDouble`). + * + * @details + * Let \f$\{f_i\}\f$ be the input trees (with optional scalar prefactors + * provided externally via `get_coef(prod_vec, i)`) and let + * \f$g = \prod_i f_i\f$ denote the pointwise product. This calculator fills + * the coefficients of the output node corresponding to a given + * @ref NodeIndex by: + * + * \f[ + * \mathbf{c}^{(g)} \leftarrow + * \prod_i \left( c_i \; \mathbf{c}^{(f_i)} \right), + * \f] + * + * where \f$c_i\f$ is the scalar multiplier returned by `get_coef` and + * \f$\mathbf{c}^{(f_i)}\f$ are the (reconstructed, forward-transformed) + * coefficients of the input node. When `T` is complex, \f$\mathbf{c}^{(f_i)}\f$ + * may be conjugated as described below. + * + * ### Conjugation rules (complex case) + * - If `func_i.conjugate()` is true, that input’s coefficients are conjugated. + * - Additionally, if the calculator is constructed with `conjugate=true`, + * the **first** input (index 0) is conjugated. The two conditions are XOR’d + * (`xor`), so a per-tree conjugation flag can cancel the global one. + * + * @note Missing input nodes are generated on demand by `FunctionTree::getNode`. + */ +template +class MultiplicationCalculator final : public TreeCalculator { public: + /** + * @brief Construct a product calculator. + * + * @param inp Vector of input trees to be multiplied. + * @param conjugate If `true`, apply complex conjugation to the **first** + * input factor (useful for ⟨bra|ket⟩-like operations). + * Ignored for real types. + */ MultiplicationCalculator(const FunctionTreeVector &inp, bool conjugate = false) : prod_vec(inp) , conj(conjugate) {} private: + /// Collection of input trees and (optionally) their scalar prefactors. FunctionTreeVector prod_vec; + + /// Global conjugation switch for the first factor (complex types only). bool conj; + /** + * @brief Compute coefficients for one output node by multiplying inputs. + * + * @param node_o Output node to be filled/updated. + * + * @details + * Steps performed: + * 1. Initialize output coefficients to unity. + * 2. For each input tree `i`: + * - Fetch scalar factor `c_i = get_coef(prod_vec, i)`. + * - Materialize copy of matching input node (`getNode(idx)`), + * reconstruct (`mwTransform(Reconstruction)`), + * and forward transform to coefficient space (`cvTransform(Forward)`). + * - Multiply output coefficients element-wise by + * `c_i * coefs_i[j]` (or `c_i * conj(coefs_i[j])` per rules above). + * 3. Transform output node back (`cvTransform(Backward)`), + * compress (`mwTransform(Compression)`), mark as having coefficients, + * and update norms. + * + * @note Uses helper functions `get_func(prod_vec, i)` and + * `get_coef(prod_vec, i)` provided by the @ref FunctionTreeVector API. + */ void calcNode(MWNode &node_o) { const NodeIndex &idx = node_o.getNodeIndex(); T *coefs_o = node_o.getCoefs(); - for (int j = 0; j < node_o.getNCoefs(); j++) { coefs_o[j] = 1.0; } + + // 1) Initialize output coefficients to multiplicative identity. + for (int j = 0; j < node_o.getNCoefs(); j++) { coefs_o[j] = static_cast(1.0); } + + // 2) Multiply contributions from each input factor. for (int i = 0; i < this->prod_vec.size(); i++) { T c_i = get_coef(this->prod_vec, i); FunctionTree &func_i = get_func(this->prod_vec, i); - // This generates missing nodes - MWNode node_i = func_i.getNode(idx); // Copy node + + // Materialize and prepare input node coefficients. + MWNode node_i = func_i.getNode(idx); // copy/materialize node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); + const T *coefs_i = node_i.getCoefs(); int n_coefs = node_i.getNCoefs(); + if constexpr (std::is_same::value) { - if (func_i.conjugate() xor (conj and i == 0)) { // NB: take complex conjugate of "bra" + // Conjugate rule: per-tree flag XOR global-first-factor flag. + bool do_conj = func_i.conjugate() xor (conj && i == 0); + if (do_conj) { for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * std::conj(coefs_i[j]); } } else { for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * coefs_i[j]; } @@ -63,6 +168,8 @@ template class MultiplicationCalculator final : public TreeC for (int j = 0; j < n_coefs; j++) { coefs_o[j] *= c_i * coefs_i[j]; } } } + + // 3) Finalize output node state. node_o.cvTransform(Backward); node_o.mwTransform(Compression); node_o.setHasCoefs(); @@ -70,4 +177,4 @@ template class MultiplicationCalculator final : public TreeC } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/OperatorAdaptor.h b/src/treebuilders/OperatorAdaptor.h index d4bd0d1d8..524fcf7a3 100644 --- a/src/treebuilders/OperatorAdaptor.h +++ b/src/treebuilders/OperatorAdaptor.h @@ -25,27 +25,85 @@ #pragma once +/** + * @file + * @brief Adaptor that targets operator-sensitive regions for refinement. + * + * @details + * This adaptor specializes @c WaveletAdaptor in 2D to refine only those nodes + * that are (i) aligned with a coordinate axis (translation index 0 along @f$x@f$ + * or @f$y@f$) and (ii) exhibit non-zero **wavelet** content. This pattern is + * useful when applying kernels/operators whose singular support or strongest + * variation is concentrated along axes (e.g., banded/operator stencils), so + * we avoid refining benign regions. + * + * The refinement trigger is implemented in @ref OperatorAdaptor::splitNode. + */ + #include "WaveletAdaptor.h" namespace mrcpp { +/** + * @class OperatorAdaptor + * @brief Wavelet-driven 2D refinement biased to axis-aligned nodes. + * + * @details + * A node is marked for splitting iff: + * - **Axis proximity:** its translation index satisfies @c idx[0]==0 or + * @c idx[1]==0 (i.e., the node touches the @f$x@f- or @f$y@f-axis at its scale). + * - **Wavelet energy present:** at least one non-scaling component has a + * positive norm. In 2D, component indices are conventionally: + * - 0: scaling (S), + * - 1..3: wavelet components (W). + * + * Combining these filters refines only the regions that are both “near” the + * axes and relevant to operator action (non-zero wavelet content), keeping the + * mesh compact elsewhere. + * + * @see WaveletAdaptor + */ class OperatorAdaptor final : public WaveletAdaptor<2> { public: + /** + * @brief Construct an adaptor with precision and depth controls. + * + * @param pr Target precision/tolerance forwarded to the base adaptor. + * @param ms Maximum scale (upper bound on refinement depth) forwarded to the base adaptor. + * @param ap If @c true, enable the base adaptor's optional parent-aware + * behavior (propagation specifics depend on @ref WaveletAdaptor). + */ OperatorAdaptor(double pr, int ms, bool ap = false) : WaveletAdaptor<2>(pr, ms, ap) {} protected: + /** + * @brief Decide whether a node should be split. + * + * @param node The candidate node. + * @return @c true if the node lies on an axis (either translation index + * is zero) **and** has non-zero wavelet component norm; otherwise + * @c false. + * + * @details + * - **Axis check:** @c idx[0]==0 || idx[1]==0. + * - **Wavelet check:** any component @c i in {1,2,3} has + * @c node.getComponentNorm(i) > 0.0. + * + * Component 0 (scaling) is intentionally ignored in the wavelet check to + * avoid refining nodes that carry only low-frequency/scaling content. + */ bool splitNode(const MWNode<2> &node) const override { const auto &idx = node.getNodeIndex(); - bool chkTransl = (idx[0] == 0 or idx[1] == 0); + bool chkTransl = (idx[0] == 0 || idx[1] == 0); bool chkCompNorm = false; for (int i = 1; i < 4; i++) { if (node.getComponentNorm(i) > 0.0) chkCompNorm = true; } - return chkTransl and chkCompNorm; + return chkTransl && chkCompNorm; } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/PHCalculator.h b/src/treebuilders/PHCalculator.h index 2c3b18510..affb1848f 100644 --- a/src/treebuilders/PHCalculator.h +++ b/src/treebuilders/PHCalculator.h @@ -25,24 +25,132 @@ #pragma once +/** + * @file + * @brief 2D calculator that applies a three-point stencil in a scaling basis. + * + * @details + * This component specializes the generic @ref TreeCalculator for 2D trees to + * perform operations that can be expressed through **shifted overlap/stencil + * matrices** of a scaling basis. The operator is represented by three + * precomputed banded blocks + * @f$S_{-1}, S_{0}, S_{+1}@f$, which correspond to interactions with the + * left, center, and right neighbor positions along a given axis in the + * multiresolution grid. These blocks are assembled from a provided + * @ref ScalingBasis and then applied node-wise in @ref calcNode. + * + * Typical use-cases include discrete differential operators (e.g., first/second + * derivatives) or narrow-band filters that can be written as a three-point + * stencil in the scaling-function coefficient space. The effective stencil + * "width" (derivative order) is indicated by @ref diff_order. + */ + #include #include "TreeCalculator.h" namespace mrcpp { +/** + * @class PHCalculator + * @brief Applies a scaling-basis three-point stencil on a 2D multiresolution tree. + * + * @details + * The calculator preloads three overlap/stencil matrices derived from a + * @ref ScalingBasis: + * - @ref S_m1 : shifted block for offset @f$-1@f$, + * - @ref S_0 : central block for offset @f$0@f$, + * - @ref S_p1 : shifted block for offset @f$+1@f$. + * + * During @ref calcNode, these blocks are combined to transform the node's + * coefficients according to the chosen stencil (e.g., a centered finite + * difference of order @ref diff_order). The exact algebra depends on the + * basis; see the implementation of @ref readSMatrix. + * + * The class is marked @c final because it provides a complete node-level + * implementation tailored to a 3-band stencil and is not intended for further + * subclassing. + */ class PHCalculator final : public TreeCalculator<2> { public: + /** + * @brief Construct the calculator and preload stencil blocks. + * + * @param basis Scaling basis from which the banded overlap/stencil + * matrices are derived. The basis determines support, + * moments, and thus the actual entries of the @f$S@f$ + * blocks. + * @param n Nominal stencil/derivative order (e.g., 1 for first + * derivative, 2 for second). The value is stored in + * @ref diff_order and may influence how the three + * blocks are combined inside @ref calcNode. + * + * @post + * - @ref diff_order is set from @p n. + * - @ref S_m1, @ref S_0, @ref S_p1 are populated via @ref readSMatrix(). + */ PHCalculator(const ScalingBasis &basis, int n); private: + /** + * @brief Logical order of the stencil/differential operator to apply. + * + * @details + * This does not change the *size* of the precomputed blocks, but can alter + * how @ref S_m1, @ref S_0, and @ref S_p1 are linearly combined inside + * @ref calcNode (e.g., centered first vs. second derivative weights). + */ const int diff_order; + + /// @name Precomputed banded overlap/stencil blocks + /// @{ + /// Block corresponding to a shift of @f$-1@f$ grid unit(s). Eigen::MatrixXd S_m1; + /// Central (unshifted) block. Eigen::MatrixXd S_0; + /// Block corresponding to a shift of @f$+1@f$ grid unit(s). Eigen::MatrixXd S_p1; + /// @} + /** + * @brief Node-level application of the three-point stencil. + * + * @param node Target 2D node whose coefficient vector is transformed + * in-place according to the assembled operator. The method + * is invoked by the traversal implemented in the base + * @ref TreeCalculator. + * + * @details + * Conceptually, this computes (in scaling space) + * @f[ + * \mathbf{c}_{\text{out}} \;\leftarrow\; + * w_{-1}\,S_{-1}\,\mathbf{c}_{-1} \;+\; + * w_{0}\, S_{0}\, \mathbf{c}_{0} \;+\; + * w_{+1}\,S_{+1}\,\mathbf{c}_{+1}, + * @f] + * where weights @f$w_{\cdot}@f$ depend on @ref diff_order and the chosen + * discrete scheme, and @f$\mathbf{c}_{k}@f$ denotes the coefficient vector + * at relative offset @f$k \in \{-1,0,+1\}@f$. The exact assembly is + * implementation-specific and consistent with the supplied @ref ScalingBasis. + */ void calcNode(MWNode<2> &node); + + /** + * @brief Populate one of the stencil matrices from the scaling basis. + * + * @param basis Scaling basis providing overlap and shift relations. + * @param n Selector for the matrix to load/build: + * chooses among @f$S_{-1}@f$, @f$S_{0}@f$, or @f$S_{+1}@f$. + * (The accepted values and encoding are implementation-defined, + * but conceptually map to offsets -1, 0, and +1.) + * + * @details + * Extracts or assembles the band-limited matrix corresponding to a given + * neighbor offset with respect to the scaling-function grid induced by + * @p basis. The resulting block is stored into one of @ref S_m1, @ref S_0, + * or @ref S_p1 depending on @p n. + */ void readSMatrix(const ScalingBasis &basis, char n); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/PowerCalculator.h b/src/treebuilders/PowerCalculator.h index 79147fc4b..e35b1d4a7 100644 --- a/src/treebuilders/PowerCalculator.h +++ b/src/treebuilders/PowerCalculator.h @@ -25,30 +25,128 @@ #pragma once +/** + * @file + * @brief Node-wise power transform calculator for multiresolution trees. + * + * @details + * This header defines @ref mrcpp::PowerCalculator, a concrete + * @ref TreeCalculator that raises the coefficients of an input function tree + * to a scalar power, writing results into an output tree during the standard + * calculator traversal. + * + * The transform is applied **locally per node** in scaling space: + * 1. Fetch/copy the corresponding input node (creating it if missing). + * 2. Apply multiresolution reconstruction (wavelet → scaling) and then a + * forward coefficient transform to obtain coefficient values suitable + * for pointwise operations. + * 3. Compute @c coefs_out[j] = pow(coefs_in[j], power) for each coefficient. + * 4. Apply the inverse coefficient transform and multiresolution compression. + * 5. Mark coefficients present and update node norms. + * + * The scalar @ref power is a real number. For complex-valued trees, + * @c std::pow(std::complex<>, double) is used. + */ + #include "TreeCalculator.h" namespace mrcpp { -template class PowerCalculator final : public TreeCalculator { +/** + * @class PowerCalculator + * @brief Raises node coefficients of an input tree to a fixed power. + * + * @tparam D Spatial dimension of the tree (e.g., 1, 2, or 3). + * @tparam T Coefficient scalar type (e.g., @c double or @c ComplexDouble). + * + * @details + * This calculator implements a **pointwise power** operation in coefficient + * space. For each visited node in the output tree, it pulls the corresponding + * node from the input tree (creating it if necessary), reconstructs to scaling + * space, and applies: + * @f[ + * \forall j:\quad c^{\text{out}}_j \leftarrow \big(c^{\text{in}}_j\big)^{\,p} + * @f] + * where @f$p=@ref power@f$. + * + * ### Precision & grid handling + * The class delegates traversal, splitting, and precision handling to the + * base @ref TreeCalculator. It performs only the node-local algebra and the + * required forward/backward transforms. + * + * ### Complex inputs + * For complex-valued @p T, the standard overload @c std::pow(T,double) is used. + * Note that if @p T is real and coefficients are negative while @ref power is + * non-integer, the result may be @c NaN; this calculator does not alter that + * behavior. + */ +template +class PowerCalculator final : public TreeCalculator { public: + /** + * @brief Construct a power calculator for a given input tree. + * + * @param inp Reference to the input function tree whose node coefficients + * are the base values in the power operation. + * @param pow Exponent @f$p@f$ to apply pointwise to all node coefficients. + * + * @note The calculator does not own @p inp; the caller must ensure that + * the referenced tree remains valid for the calculator's lifetime. + */ PowerCalculator(FunctionTree &inp, double pow) : power(pow) , func(&inp) {} private: + /** + * @brief Scalar exponent used in @c std::pow for all coefficients. + */ double power; + + /** + * @brief Non-owning pointer to the input function tree. + */ FunctionTree *func; + /** + * @brief Node-level power application. + * + * @param node_o Output node whose coefficients are overwritten with + * @f$\big(c^{\text{in}}\big)^{\,p}@f$ at the corresponding + * location in the input tree. + * + * @details + * Steps performed: + * - Retrieve the corresponding input node @c node_i from @ref func + * using the same @ref NodeIndex (this may create a missing node). + * - @c node_i.mwTransform(Reconstruction) to convert wavelet → scaling. + * - @c node_i.cvTransform(Forward) to access coefficient array in the + * appropriate local basis. + * - For each coefficient index @c j, compute: + * @code + * coefs_o[j] = std::pow(coefs_i[j], power); + * @endcode + * - Apply @c node_o.cvTransform(Backward) and + * @c node_o.mwTransform(Compression) to restore representation. + * - Set the "has coefficients" flag and refresh node norms. + * + * @warning If @p T is a real type and @ref power is non-integer, negative + * input coefficients can lead to @c NaN. This is the standard + * behavior of @c std::pow and is not intercepted here. + */ void calcNode(MWNode &node_o) override { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); T *coefs_o = node_o.getCoefs(); - // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + + // Generate/copy input node at the same index. + MWNode node_i = func->getNode(idx); node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); + const T *coefs_i = node_i.getCoefs(); for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::pow(coefs_i[j], this->power); } + node_o.cvTransform(Backward); node_o.mwTransform(Compression); node_o.setHasCoefs(); @@ -56,4 +154,4 @@ template class PowerCalculator final : public TreeCalculator } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/ProjectionCalculator.h b/src/treebuilders/ProjectionCalculator.h index 067c41422..a5d5708b1 100644 --- a/src/treebuilders/ProjectionCalculator.h +++ b/src/treebuilders/ProjectionCalculator.h @@ -25,20 +25,103 @@ #pragma once +/** + * @file + * @brief Node-wise projector from an analytic/representable function to a + * multiresolution function tree. + * + * @details + * This calculator implements the core **projection kernel** that takes a + * user-provided @ref RepresentableFunction and, node by node, produces + * multiresolution coefficients for an output tree managed by the surrounding + * @ref TreeCalculator pipeline. + * + * The calculator is agnostic to the scheduling/refinement policy: it only + * defines how a *single* node is computed (see @ref calcNode). The driving + * logic (initial work list, adaptors, termination) is handled by + * @ref TreeCalculator and its collaborators. + * + * ### Coordinate scaling + * A per-dimension @p scaling_factor is supplied at construction. It is applied + * consistently during node evaluation to support anisotropic grid scalings, + * unit conversions, or jacobian-like preconditioning. Use a vector of ones to + * disable scaling. + */ + #include "TreeCalculator.h" namespace mrcpp { -template class ProjectionCalculator final : public TreeCalculator { +// Forward declaration; the concrete definition is provided by MRCPP headers. +template class RepresentableFunction; + +/** + * @class ProjectionCalculator + * @brief Projects a @ref RepresentableFunction onto the active output tree. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type (e.g., `double`, `ComplexDouble`). + * + * @details + * For each node requested by the @ref TreeCalculator scheduler, this class: + * 1. Builds or fetches the corresponding input stencil/quadrature on the + * node’s support. + * 2. Evaluates the supplied @ref RepresentableFunction at those points, + * applying the provided per-axis @ref scaling_factor. + * 3. Computes the node’s scaling/wavelet coefficients, writes them to the + * output tree, and updates norms/metadata. + * + * The calculator itself does **not** decide where to refine; use an adaptor + * (e.g., wavelet- or operator-based) with @ref TreeCalculator to drive + * adaptivity from residuals or norm estimates. + */ +template +class ProjectionCalculator final : public TreeCalculator { public: - ProjectionCalculator(const RepresentableFunction &inp_func, const std::array &sf) + /** + * @brief Construct a projector from an analytic/representable function. + * + * @param[in] inp_func Function to be projected. The pointer is + * stored and must remain valid for the lifetime + * of the calculator. + * @param[in] sf Per-dimension scaling factors applied to local + * coordinates before evaluating @p inp_func. + * Set to `{1, …, 1}` for no scaling. + * + * @note The output target (tree) and traversal policy are provided by the + * owning @ref TreeCalculator context; this constructor only binds the + * callable and the evaluation scaling. + */ + ProjectionCalculator(const RepresentableFunction &inp_func, + const std::array &sf) : func(&inp_func) , scaling_factor(sf) {} private: + /// Source function to be sampled on each node’s stencil. const RepresentableFunction *func; + + /// Per-axis multiplicative coordinate scaling used during evaluation. const std::array scaling_factor; + + /** + * @brief Compute a single node of the output tree. + * + * @param[in,out] node Target node whose coefficients and norms are produced. + * + * @details + * The typical implementation flow is: + * - derive the node’s physical coordinates from its @ref NodeIndex and + * apply @ref scaling_factor, + * - evaluate @ref func at the required sample points, + * - assemble scaling/wavelet coefficients and write them into @p node, + * - set “has coefficients” flags and update per-component norms. + * + * Thread-safety: the method only mutates @p node and uses read-only access + * to @ref func and @ref scaling_factor, so it is safe under the usual + * per-node parallel scheduling employed by @ref TreeCalculator. + */ void calcNode(MWNode &node) override; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/SplitAdaptor.h b/src/treebuilders/SplitAdaptor.h index 7e81bbe8b..d1c73e679 100644 --- a/src/treebuilders/SplitAdaptor.h +++ b/src/treebuilders/SplitAdaptor.h @@ -25,19 +25,73 @@ #pragma once +/** + * @file + * @brief Minimal adaptor that unconditionally (or never) splits nodes. + * + * @details + * This adaptor implements the @ref TreeAdaptor interface with a constant + * split policy: depending on a boolean flag passed at construction, + * every node presented by the traversal will either: + * - be **always split** (when `split == true`), or + * - be **never split** (when `split == false`). + * + * This is useful for: + * - unit tests and benchmarks (forcing a fixed refinement pattern), + * - creating a uniform grid up to a maximum scale, + * - disabling refinement while still running a calculator over an existing grid. + */ + #include "TreeAdaptor.h" namespace mrcpp { -template class SplitAdaptor final : public TreeAdaptor { +/** + * @class SplitAdaptor + * @brief Constant split/no-split adaptor for tree refinement. + * + * @tparam D Spatial dimension of the tree. + * @tparam T Scalar coefficient type (defaults to `double`). + * + * @details + * The adaptor inherits the depth/scale controls from @ref TreeAdaptor (e.g., + * the *maximum scale* passed to the base constructor). The split decision + * itself is independent of node content and simply mirrors the `split` flag. + * + * ### Refinement semantics + * - If `split == true`, any node that the base class allows to be refined + * (e.g., below the maximum scale) will be marked for splitting. + * - If `split == false`, no node will be split by this adaptor (even if below + * the maximum scale). + */ +template +class SplitAdaptor final : public TreeAdaptor { public: + /** + * @brief Construct a constant split adaptor. + * + * @param[in] ms Maximum scale (or equivalent depth limit) forwarded to + * @ref TreeAdaptor. Nodes at or beyond this scale will not be + * refined by the base logic regardless of @p sp. + * @param[in] sp Split policy: `true` to always split (subject to base + * constraints), `false` to never split. + */ SplitAdaptor(int ms, bool sp) : TreeAdaptor(ms) , split(sp) {} private: + /// Constant split policy applied to every visited node. bool split; + /** + * @brief Decide whether to split a node. + * + * @param[in] node Node under consideration (unused). + * @return `true` if this adaptor is configured to split; otherwise `false`. + * + * @note The base class may still veto refinement (e.g., beyond max scale). + */ bool splitNode(const MWNode &node) const override { return this->split; } }; diff --git a/src/treebuilders/SquareCalculator.h b/src/treebuilders/SquareCalculator.h index 015b90f82..63c33f94d 100644 --- a/src/treebuilders/SquareCalculator.h +++ b/src/treebuilders/SquareCalculator.h @@ -25,46 +25,154 @@ #pragma once +/** + * @file + * @brief Element-wise squaring of function-tree coefficients (with optional complex conjugation). + * + * @details + * This calculator evaluates one of the following pointwise operations on an input function + * represented by a multiresolution @ref FunctionTree: + * + * - **Algebraic square**: \f$ g(\mathbf r) = f(\mathbf r)^2 \f$ + * - **Squared magnitude** (Hermitian square): \f$ g(\mathbf r) = f(\mathbf r)\,f(\mathbf r)^* = |f(\mathbf r)|^2 \f$ + * + * The choice is controlled by the constructor's `conjugate` flag (see below). For real + * coefficient types the two definitions coincide. + * + * Implementation sketch per node: + * 1. Pull (or generate) the input node at the same index as the output node. + * 2. Convert to scaling coefficients (multiwavelet reconstruction), then to coefficient + * vector space. + * 3. Apply the element-wise operation (square or squared magnitude). + * 4. Transform coefficients back, compress, mark as having coefficients, and refresh norms. + */ + #include "TreeCalculator.h" namespace mrcpp { -template class SquareCalculator final : public TreeCalculator { +/** + * @class SquareCalculator + * @brief Per-node square / squared-magnitude operator for function trees. + * + * @tparam D Spatial dimension of the tree. + * @tparam T Scalar coefficient type (e.g., `double` or `ComplexDouble`). + * + * @details + * Let \f$f\f$ be the input function represented by `func`. This calculator writes + * an output tree \f$g\f$ such that, for each node and each basis coefficient: + * + * - If `conjugate == false`: + * - Real `T`: \f$ g = f^2 \f$ + * - Complex `T`: \f$ g = f^2 \f$ + * - If `conjugate == true`: + * - Real `T`: \f$ g = f^2 \f$ (same as above) + * - Complex `T`: \f$ g = f\,\overline{f} = |f|^2 \f$ + * + * Additionally, if the input tree is marked internally as "conjugated" (via its + * soft-conjugation flag), the implementation respects that view such that + * `conjugate == true` still produces \f$|f|^2\f$ and `conjugate == false` produces + * \f$(\overline{f})^2\f$ for complex `T`. See the truth table in @ref calcNode. + * + * ### Transform pipeline + * Each output node is computed by: + * - reconstructing the corresponding input node to scaling space + * (`mwTransform(Reconstruction)`), + * - converting to coefficient space (`cvTransform(Forward)`), + * - applying the element-wise operation on the coefficient array, + * - mapping back (`cvTransform(Backward)`, `mwTransform(Compression)`), + * - finalizing (`setHasCoefs()`, `calcNorms()`). + * + * @note The calculator is stateless across nodes and can be scheduled in parallel by + * the tree execution engine as long as nodes are independent. + */ +template +class SquareCalculator final : public TreeCalculator { public: + /** + * @brief Construct a square (or squared-magnitude) calculator. + * + * @param[in] inp Input function tree \f$f\f$. + * @param[in] conjugate If `true` and `T` is complex, compute the squared magnitude + * \f$|f|^2\f$ (i.e., multiply by the complex conjugate). If + * `false`, compute the algebraic square \f$f^2\f$. + * + * @note For real `T`, `conjugate` has no effect; \f$|f|^2 = f^2\f$. + */ SquareCalculator(FunctionTree &inp, bool conjugate = false) : func(&inp) , conj(conjugate) {} private: + /// Pointer to the input function tree \f$f\f$. FunctionTree *func; + /// Operation switch: `false` ⇒ \f$f^2\f$; `true` ⇒ (for complex) \f$|f|^2\f$. bool conj; + /** + * @brief Compute one output node by squaring the corresponding input node. + * + * @param[in,out] node_o Output node to be written at the current index. + * + * @details + * Steps: + * 1. Acquire a copy of the input node at the same index: `node_i = func->getNode(idx)`. + * 2. Transform `node_i` to coefficient space (`mwTransform(Reconstruction)`, + * then `cvTransform(Forward)`). + * 3. For each coefficient \f$c_j\f$: + * - If `T` is real: \f$c_j \leftarrow c_j^2\f$. + * - If `T` is complex: + * - Respect the input tree's soft conjugation flag (`func->conjugate()`). + * - Apply the following table to compute `coefs_o[j]`: + * + * | `func->conjugate()` | `conj` | result | + * |:--------------------:|:------:|:---------------------------------------| + * | false | false | \f$c_j \cdot c_j = c_j^2\f$ | + * | false | true | \f$c_j \cdot \overline{c_j} = |c_j|^2\f$ | + * | true | false | \f$\overline{c_j}\cdot \overline{c_j} = (\overline{c_j})^2\f$ | + * | true | true | \f$\overline{c_j}\cdot c_j = |c_j|^2\f$ | + * + * 4. Map the result back (`cvTransform(Backward)`, `mwTransform(Compression)`), + * then finalize flags and norms. + * + * @complexity Linear in the number of coefficients of the node: \f$O(n_{\text{coefs}})\f$. + */ void calcNode(MWNode &node_o) { const NodeIndex &idx = node_o.getNodeIndex(); int n_coefs = node_o.getNCoefs(); T *coefs_o = node_o.getCoefs(); - // This generates missing nodes - MWNode node_i = func->getNode(idx); // Copy node + + // Acquire / materialize the input node at the same index + MWNode node_i = func->getNode(idx); // Copy node (may generate missing nodes) node_i.mwTransform(Reconstruction); node_i.cvTransform(Forward); + const T *coefs_i = node_i.getCoefs(); + if constexpr (std::is_same::value) { if (func->conjugate()) { if (conj) { + // |f|^2: conj(c) * c for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::conj(coefs_i[j]) * coefs_i[j]; } } else { + // (conj f)^2: conj(c) * conj(c) for (int j = 0; j < n_coefs; j++) { coefs_o[j] = std::conj(coefs_i[j]) * std::conj(coefs_i[j]); } } } else { if (conj) { + // |f|^2: c * conj(c) for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * std::conj(coefs_i[j]); } } else { + // f^2: c * c for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * coefs_i[j]; } } } } else { + // Real case: f^2 for (int j = 0; j < n_coefs; j++) { coefs_o[j] = coefs_i[j] * coefs_i[j]; } } + + // Map back and finalize node_o.cvTransform(Backward); node_o.mwTransform(Compression); node_o.setHasCoefs(); @@ -72,4 +180,4 @@ template class SquareCalculator final : public TreeCalculato } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h b/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h index f2a68295f..917f00fed 100644 --- a/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h +++ b/src/treebuilders/TimeEvolution_CrossCorrelationCalculator.h @@ -25,6 +25,33 @@ #pragma once +/** + * @file + * @brief Time-evolution calculator based on cross-correlation kernels. + * + * @details + * This header declares a node-local calculator that evaluates contributions + * required for (imaginary or real time) Schrödinger evolution using + * precomputed *J-power* integrals and a cross-correlation driver. + * + * Conceptually, for a two-dimensional function tree \f$f(\mathbf r)\f$ the + * calculator applies (per node) a correlation-type update of the form + * \f[ + * g(\mathbf r) \;=\; \big(K * f\big)(\mathbf r) + * \;=\; \int_{\mathbb R^2} K(\mathbf r - \mathbf r')\, f(\mathbf r')\, d\mathbf r' , + * \f] + * where the kernel \f$K\f$ and its (power) moments are provided through + * the cross-correlation infrastructure and the \c JpowerIntegrals table. + * + * The boolean switch #imaginary selects which component of the complex-valued + * kernel (or of the assembled integral) is used: + * - `imaginary == false` → **real** part; + * - `imaginary == true` → **imaginary** part. + * + * The class is invoked by the tree execution engine (see @ref TreeCalculator) + * and operates independently on each @ref MWNode. + */ + #include "TreeCalculator.h" #include "core/CrossCorrelationCache.h" #include "core/SchrodingerEvolution_CrossCorrelation.h" @@ -32,34 +59,111 @@ namespace mrcpp { -/** @class TimeEvolution_CrossCorrelationCalculator - * - * @brief An efficient way to calculate ... (work in progress) +/** + * @class TimeEvolution_CrossCorrelationCalculator + * @brief Node calculator for Schrödinger time evolution via cross-correlation. * - * @details An efficient way to calculate ... having the form - * \f$ \ldots = \ldots \f$ + * @details + * This calculator evaluates nodewise contributions needed for real- or + * imaginary-time propagation in 2D using a cross-correlation representation + * of the evolution operator. Precomputed integrals of the form + * \f$ J_m = \int x^m\,K(x)\,dx \f$ (and higher-dimensional analogs) are + * supplied through a map of @ref JpowerIntegrals instances indexed by + * the power/order. * + * ### Responsibilities + * - Pull required cross-correlation data from a + * @ref SchrodingerEvolution_CrossCorrelation instance. + * - Select **real** or **imaginary** contribution according to #imaginary. + * - Assemble the per-node update and write the result to the output node. * + * ### Threading / Parallelism + * The class itself holds only non-owning pointers and simple references + * to shared, read-only tables. It is thus re-entrant across nodes. + * Synchronization and scheduling are handled at the @ref TreeCalculator layer. * + * @note All pointer members are **non-owning**; the caller must ensure they + * remain valid for the lifetime of the calculator. */ class TimeEvolution_CrossCorrelationCalculator final : public TreeCalculator<2> { public: - TimeEvolution_CrossCorrelationCalculator(std::map &J, SchrodingerEvolution_CrossCorrelation *cross_correlation, bool imaginary) + /** + * @brief Construct the calculator with auxiliary integral tables and a driver. + * + * @param[in] J + * Map from power/order (e.g., \f$m\f$) to the corresponding + * @ref JpowerIntegrals table. The calculator does **not** take ownership. + * @param[in] cross_correlation + * Pointer to a @ref SchrodingerEvolution_CrossCorrelation driver that + * exposes kernel accessors / caches needed to assemble the correlation + * at node level. Non-owning. + * @param[in] imaginary + * If `true`, use the **imaginary part** of the accumulated contribution; + * otherwise use the **real part**. + * + * @warning The map and the driver pointer must outlive this calculator. + */ + TimeEvolution_CrossCorrelationCalculator(std::map &J, + SchrodingerEvolution_CrossCorrelation *cross_correlation, + bool imaginary) : J_power_inetgarls(J) , cross_correlation(cross_correlation) , imaginary(imaginary) {} - // private: - std::map J_power_inetgarls; - SchrodingerEvolution_CrossCorrelation *cross_correlation; - - /// @brief If False then the calculator is using th real part of integrals, otherwise - the imaginary part. - bool imaginary; + /** + * @brief Compute the contribution for one output node. + * + * @param[in,out] node + * The node to be written. The implementation typically: + * 1) gathers the necessary kernel moments / cache entries, + * 2) accumulates the cross-correlation at the node resolution, + * 3) commits coefficients and refreshes norms/flags. + * + * @note The exact algebra (e.g., reconstruction/compression steps) is + * implemented in the corresponding source file. + */ void calcNode(MWNode<2> &node) override; - // template + /** + * @brief Apply the cross-correlation operator at the granularity of a single node. + * + * @param[in,out] node The node to which the operator is applied. + * + * @details + * This helper encapsulates the node-local application of the correlation + * kernel using the caches provided by #cross_correlation and the moment + * tables from #J_power_inetgarls. The #imaginary flag governs whether + * the real or imaginary component of the final integral is extracted. + * + * @see calcNode + */ void applyCcc(MWNode<2> &node); - // template void applyCcc(MWNode<2> &node, CrossCorrelationCache &ccc); + + // --------------------------------------------------------------------- + // Public state (non-owning) — kept public to match existing interfaces. + // --------------------------------------------------------------------- + + /** + * @brief Precomputed kernel moment/integral tables, indexed by power. + * + * @note Non-owning pointers; the map must remain valid externally. + */ + std::map J_power_inetgarls; + + /** + * @brief Cross-correlation driver (non-owning). + * + * Provides access to kernel caches and auxiliary data needed to assemble + * the correlation at a given node. + */ + SchrodingerEvolution_CrossCorrelation *cross_correlation; + + /** + * @brief Component selector for complex contributions. + * + * If `false`, the **real** part is used; if `true`, the **imaginary** part. + */ + bool imaginary; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/TreeAdaptor.h b/src/treebuilders/TreeAdaptor.h index 80cecb09e..8698034cf 100644 --- a/src/treebuilders/TreeAdaptor.h +++ b/src/treebuilders/TreeAdaptor.h @@ -25,36 +25,137 @@ #pragma once +/** + * @file + * @brief Generic adapter that decides whether tree nodes should be refined (split). + * + * @details + * `TreeAdaptor` provides a lightweight, policy-style interface used by the tree + * execution engine to determine which nodes of a @ref MWTree should be split + * (refined). Concrete adaptors implement the decision rule in the protected + * pure-virtual @ref splitNode method. + * + * Typical workflow: + * 1. Construct a concrete adaptor (e.g., one that inspects norms, wavelet + * content, error estimates, etc.). + * 2. Optionally set a maximum refinement scale with @ref setMaxScale. + * 3. Call @ref splitNodeVector to examine an input list of nodes and append any + * newly-created children to an output list for further processing. + * + * The adaptor enforces two built-in guards in @ref splitNodeVector: + * - **Branch nodes** (internal/structural nodes without coefficients) are + * skipped. + * - Nodes deeper than the allowed scale threshold are skipped: + * `node.getScale() + 2 > maxScale`. + * + * @note The `+2` slack prevents overshooting the configured refinement ceiling + * when subsequent passes may still need headroom (implementation detail). + */ + #include "MRCPP/mrcpp_declarations.h" #include "trees/MWNode.h" namespace mrcpp { -template class TreeAdaptor { +/** + * @class TreeAdaptor + * @brief Abstract base class for node-refinement policies. + * + * @tparam D Spatial dimension of the tree. + * @tparam T Coefficient value type (e.g., `double`, `ComplexDouble`). + * + * @details + * Concrete adaptors derive from this class and implement @ref splitNode to + * express the criterion that decides whether a given leaf node should be + * refined. The base class owns the *maximum scale* guard and the helper that + * performs splitting and collects children. + */ +template +class TreeAdaptor { public: - TreeAdaptor(int ms) + /** + * @brief Construct with an initial maximum refinement scale. + * @param ms Maximum scale (depth) allowed for refinement. + * + * Nodes at scales for which `node.getScale() + 2 > ms` will **not** be + * split by @ref splitNodeVector, regardless of the policy decision. + */ + explicit TreeAdaptor(int ms) : maxScale(ms) {} + + /// Virtual destructor (polymorphic base). virtual ~TreeAdaptor() = default; + /** + * @brief Change the maximum refinement scale. + * @param ms New ceiling for refinement depth. + * + * @see maxScale + */ void setMaxScale(int ms) { this->maxScale = ms; } + /** + * @brief Apply the refinement policy to a batch of nodes and collect children. + * + * @param[out] out + * Vector that will receive pointers to the **newly created children** + * (across all nodes that are decided to be split). + * @param[in] inp + * Vector of candidate nodes to be tested for splitting. + * + * @details + * For each node in @p inp the routine: + * - skips the node if it is a **branch node** (see `MWNode::isBranchNode`); + * - enforces the scale guard `node.getScale() + 2 > maxScale`; + * - calls the policy @ref splitNode; if `true`, it creates the children + * (`node.createChildren(true)`) and appends them to @p out. + * + * @note Ownership of nodes remains with the tree; this function only pushes + * pointers to existing/newly created nodes into @p out. + */ void splitNodeVector(MWNodeVector &out, MWNodeVector &inp) const { for (int n = 0; n < inp.size(); n++) { MWNode &node = *inp[n]; - // Can be BranchNode in operator application + + // Skip structural nodes (no coefficients) if (node.isBranchNode()) continue; + + // Enforce maximum scale guard with a +2 safety margin if (node.getScale() + 2 > this->maxScale) continue; + + // Delegate the decision to the concrete adaptor if (splitNode(node)) { node.createChildren(true); - for (int i = 0; i < node.getNChildren(); i++) out.push_back(&node.getMWChild(i)); + for (int i = 0; i < node.getNChildren(); i++) { + out.push_back(&node.getMWChild(i)); + } } } } protected: + /** + * @brief Maximum allowed refinement scale (depth) for newly created nodes. + * + * Nodes for which `node.getScale() + 2 > maxScale` are not considered for + * splitting within @ref splitNodeVector. + */ int maxScale; + /** + * @brief Decide whether a given leaf node should be refined. + * + * @param node Candidate node (guaranteed non-branch and within scale guard). + * @return `true` if the node must be split, `false` otherwise. + * + * @details + * Derived classes implement this method to express an application-specific + * refinement criterion (e.g., wavelet-norm threshold, operator bandwidth, + * error estimator, etc.). This method must be **pure** (no side-effects) + * with respect to the tree topology; @ref splitNodeVector performs the + * actual splitting. + */ virtual bool splitNode(const MWNode &node) const = 0; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/TreeBuilder.h b/src/treebuilders/TreeBuilder.h index 81c32afe6..8e89590fe 100644 --- a/src/treebuilders/TreeBuilder.h +++ b/src/treebuilders/TreeBuilder.h @@ -29,16 +29,123 @@ namespace mrcpp { -template class TreeBuilder final { +/** + * @class TreeBuilder + * @brief Orchestrates adaptive construction and refinement of @ref MWTree objects. + * + * @tparam D Spatial dimension of the tree. + * @tparam T Coefficient value type (e.g., `double`, `ComplexDouble`). + * + * @details + * `TreeBuilder` coordinates three roles during adaptive computations: + * - a **calculator** (@ref TreeCalculator) that evaluates node data + * (coefficients, norms, metadata) on the current grid, + * - an **adaptor** (@ref TreeAdaptor) that decides which nodes should + * be refined (split), + * - the **tree** (@ref MWTree) that stores topology and coefficients. + * + * A typical adaptive build loop is: + * 1. @ref calc to populate coefficients/norms on the current grid, + * 2. @ref split to refine nodes selected by the adaptor, + * 3. repeat (1–2) until no more splits occur or `maxIter` is reached. + * + * Some calculators maintain internal statistics/timers and may need a final + * post-processing step; @ref build calls into the calculator appropriately. + */ +template +class TreeBuilder final { public: - void build(MWTree &tree, TreeCalculator &calculator, TreeAdaptor &adaptor, int maxIter) const; +/** + * @brief Adaptive build: iterate (calc → split) up to @p maxIter times. + * + * @param[in,out] tree Target tree to (re)build/refine. + * @param[in,out] calculator Calculator used to fill coefficients/norms on the current grid. + * @param[in,out] adaptor Refinement policy deciding which nodes to split. + * @param[in] maxIter Upper bound on calc/split passes (use a small integer; non-positive means 0 passes). + * + * @details + * The method performs: + * - an initial calc pass, + * - up to @p maxIter refinement passes, each performing: + * - split(tree, adaptor, passCoefs=true) + * - calc(tree, calculator) + * - any calculator post-processing hooks. + * + * Implementations typically stop early when split returns 0 (no new nodes). + */ + void build(MWTree &tree, + TreeCalculator &calculator, + TreeAdaptor &adaptor, + int maxIter) const; + + /** + * @brief Clear node data in @p tree using the provided @p calculator policy. + * + * @param[in,out] tree Tree whose nodes should be cleared. + * @param[in,out] calculator Calculator that defines how to reset per-node state. + * + * @details + * Resets coefficient flags and cached norms to a consistent "empty" state. + * This is useful before reusing a tree structure for another computation. + */ void clear(MWTree &tree, TreeCalculator &calculator) const; + + /** + * @brief Compute/refresh coefficients and norms on the current grid. + * + * @param[in,out] tree Tree to evaluate. + * @param[in,out] calculator Calculator that implements per-node computation. + * + * @details + * Traverses the active nodes (calculator-dependent strategy) and ensures + * each leaf has consistent coefficients (scaling/wavelet) and derived norms. + */ void calc(MWTree &tree, TreeCalculator &calculator) const; + + /** + * @brief Refine the tree topology according to @p adaptor policy. + * + * @param[in,out] tree Tree subject to refinement. + * @param[in,out] adaptor Adaptor deciding which nodes to split. + * @param[in] passCoefs + * If `true`, propagate or initialize child coefficients immediately + * (calculator-dependent behavior); if `false`, only topology changes + * are performed and coefficients are left for a subsequent @ref calc. + * + * @return Number of **new nodes** created (sum of all children inserted). + * + * @details + * The method collects candidate leaves, applies the adaptor’s + * `splitNodeVector`, and updates the tree topology. Implementations may + * perform light-weight coefficient seeding when `passCoefs==true` to + * improve the next calculation pass. + */ int split(MWTree &tree, TreeAdaptor &adaptor, bool passCoefs) const; private: + /** + * @brief Aggregate the total scaling-norm over a set of nodes. + * + * @param vec Vector of node pointers to be reduced. + * @return Sum (or calculator-defined aggregation) of scaling coefficients' norm. + * + * @details + * Utility used by build loops for convergence checks and diagnostics. + * The precise definition of “scaling norm” follows the node’s basis. + */ double calcScalingNorm(const MWNodeVector &vec) const; + + /** + * @brief Aggregate the total wavelet-norm over a set of nodes. + * + * @param vec Vector of node pointers to be reduced. + * @return Sum (or calculator-defined aggregation) of wavelet coefficients' norm. + * + * @details + * Utility used by build loops for refinement heuristics and stopping criteria. + * The precise definition of “wavelet norm” follows the node’s basis. + */ double calcWaveletNorm(const MWNodeVector &vec) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/TreeCalculator.h b/src/treebuilders/TreeCalculator.h index 1bf41f407..5d10e3c10 100644 --- a/src/treebuilders/TreeCalculator.h +++ b/src/treebuilders/TreeCalculator.h @@ -29,13 +29,69 @@ namespace mrcpp { -template class TreeCalculator { +/** + * @class TreeCalculator + * @brief Abstract base for per-node computations on @ref MWTree. + * + * @tparam D Spatial dimension of the multiwavelet tree. + * @tparam T Coefficient value type (e.g. `double`, `ComplexDouble`). + * + * @details + * A `TreeCalculator` defines how to **evaluate/update a single node** + * (via the pure-virtual @ref calcNode) and provides utilities to apply + * that logic over a set of nodes, possibly in parallel. + * + * Typical usage (in conjunction with @ref TreeBuilder): + * - derive a calculator and implement @ref calcNode, + * - obtain a worklist with @ref getInitialWorkVector, + * - call @ref calcNodeVector to process all nodes, + * - optionally override @ref postProcess for statistics/timers. + * + * ### Parallelism + * @ref calcNodeVector uses OpenMP (if enabled) with `schedule(guided)` + * and a thread count provided by the `mrcpp_get_num_threads()` macro. + * Implementations of @ref calcNode must be **thread-safe** w.r.t. other + * nodes in the worklist. Avoid shared mutable state unless properly + * synchronized. + */ +template +class TreeCalculator { public: + /// @brief Default constructor. TreeCalculator() = default; + + /// @brief Virtual destructor. virtual ~TreeCalculator() = default; - virtual MWNodeVector *getInitialWorkVector(MWTree &tree) const { return tree.copyEndNodeTable(); } + /** + * @brief Build the initial list of nodes to process. + * + * @param[in,out] tree The tree whose nodes should be evaluated. + * @return Heap-allocated vector of node pointers representing the initial + * work set (typically the **current leaf nodes**). + * + * @details + * The default implementation returns a copy of the tree's end-node table + * (`tree.copyEndNodeTable()`). Callers are responsible for deleting the + * returned container when done. + */ + virtual MWNodeVector* getInitialWorkVector(MWTree &tree) const { + return tree.copyEndNodeTable(); + } + /** + * @brief Evaluate all nodes in @p nodeVec (parallelized when available). + * + * @param[in,out] nodeVec Container of node pointers to be processed. + * + * @details + * Invokes @ref calcNode for each entry. Uses OpenMP with guided scheduling + * and `mrcpp_get_num_threads()` to determine the thread count. + * After processing all nodes, calls @ref postProcess once. + * + * @note The container is treated as read-only regarding its topology; + * implementations of @ref calcNode should not insert/remove nodes. + */ virtual void calcNodeVector(MWNodeVector &nodeVec) { #pragma omp parallel shared(nodeVec) num_threads(mrcpp_get_num_threads()) { @@ -50,8 +106,27 @@ template class TreeCalculator { } protected: + /** + * @brief Perform the calculator's core work on a single node. + * + * @param[in,out] node Target node. Implementations typically: + * - ensure transforms are in the correct space (MW/CV) as needed, + * - compute/update coefficients and derived norms/flags, + * - leave the node in a consistent state for subsequent passes. + * + * @warning This method is called concurrently on different nodes. + * Do not mutate shared global state without synchronization. + */ virtual void calcNode(MWNode &node) = 0; + + /** + * @brief Optional hook executed once after @ref calcNodeVector finishes. + * + * @details + * Override to flush accumulators, update statistics, print timers, etc. + * Default implementation is a no-op. + */ virtual void postProcess() {} }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/WaveletAdaptor.h b/src/treebuilders/WaveletAdaptor.h index 829039bf4..231c16f7d 100644 --- a/src/treebuilders/WaveletAdaptor.h +++ b/src/treebuilders/WaveletAdaptor.h @@ -31,27 +31,114 @@ namespace mrcpp { -template class WaveletAdaptor : public TreeAdaptor { +/** + * @class WaveletAdaptor + * @brief Refinement policy based on wavelet-norm error indicators. + * + * @tparam D Spatial dimension of the multiwavelet tree. + * @tparam T Coefficient value type (e.g., double, ComplexDouble). + * + * @details + * This adaptor decides whether a node should be *split* (refined) by + * comparing its wavelet contribution against a precision target. + * Internally it relies on @ref mrcpp::tree_utils::split_check, which + * examines a node's (accumulated) wavelet norm relative to: + * + * - a global precision @ref prec (optionally absolute via @ref absPrec), + * - a user-provided, index-dependent scaling @ref precFunc (defaults to 1), + * - an extra scale-dependent attenuation factor @ref splitFac + * (used to bias refinement with depth). + * + * If the threshold is exceeded, @ref splitNode requests refinement. + * + * @code{.cpp} + * // Typical usage: + * WaveletAdaptor<3, double> adapt(1e-6, 20); // prec, maxScale + * adapt.setPrecFunction([](const NodeIndex<3>&){ return 2.0; }); // tighten locally + * TreeBuilder<3, double> builder; + * builder.split(tree, adapt, false); // passCoefs = false + * @endcode + */ +template +class WaveletAdaptor : public TreeAdaptor { public: + /** + * @brief Construct a wavelet-based adaptor. + * + * @param pr Global target precision (relative unless @p ap is true). + * @param ms Maximum refinement scale (forwarded to @ref TreeAdaptor). + * @param ap If true, interpret @p pr as an **absolute** tolerance; + * otherwise use a **relative** tolerance w.r.t. function norm. + * @param sf Split-factor controlling depth bias (≥ 0). When > 0, + * the threshold is scaled by \f$2^{-0.5\,sf\,(s+1)}\f$ at scale s, + * encouraging deeper refinement only when warranted. + */ WaveletAdaptor(double pr, int ms, bool ap = false, double sf = 1.0) : TreeAdaptor(ms) , absPrec(ap) , prec(pr) , splitFac(sf) {} + + /// @brief Virtual destructor. ~WaveletAdaptor() override = default; - void setPrecFunction(const std::function &idx)> &prec_func) { this->precFunc = prec_func; } + /** + * @brief Provide a spatially varying precision multiplier. + * + * @param prec_func Function returning a factor (default 1.0) for + * a given node index. The effective threshold becomes + * `prec * prec_func(idx)` (plus depth scaling via @ref splitFac). + * + * @note Use this to tighten or relax refinement in specific regions, + * e.g. around features of interest. + */ + void setPrecFunction(const std::function &idx)> &prec_func) { + this->precFunc = prec_func; + } protected: + /// @brief If true, treat @ref prec as an absolute tolerance; otherwise relative. bool absPrec; + + /// @brief Base precision target used by the wavelet thresholding rule. double prec; + + /** + * @brief Scale-dependent attenuation of the threshold. + * + * @details A positive value reduces the threshold with depth, making + * refinement stricter at finer scales. Set to 0.0 to disable. + */ double splitFac; - std::function &idx)> precFunc = [](const NodeIndex &idx) { return 1.0; }; + /** + * @brief Per-node precision multiplier (defaults to identity). + * + * @details The effective threshold is `prec * precFunc(idx)` before + * applying the depth-dependent @ref splitFac scaling. + */ + std::function &idx)> precFunc = + [](const NodeIndex & /*idx*/) { return 1.0; }; + + /** + * @brief Decide whether a node should be split. + * + * @param node The candidate node. + * @return `true` if the node's wavelet norm exceeds the computed threshold. + * + * @details + * Computes a local tolerance as: + * \f[ + * \tau = \text{prec} \times \text{precFunc}(\text{idx}) + * \f] + * (relative to the function norm unless @ref absPrec is set), + * then applies an additional depth-dependent factor governed by + * @ref splitFac, and finally compares against the node's wavelet norm. + */ bool splitNode(const MWNode &node) const override { auto precFac = this->precFunc(node.getNodeIndex()); // returns 1.0 by default return tree_utils::split_check(node, this->prec * precFac, this->splitFac, this->absPrec); } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/add.h b/src/treebuilders/add.h index 1f94dbc9d..ecc38a07d 100644 --- a/src/treebuilders/add.h +++ b/src/treebuilders/add.h @@ -25,11 +25,116 @@ #pragma once +/** + * @file add.h + * @brief Adaptive linear combination of multiresolution (MW) function trees. + * + * @details + * These routines build an output MW function as a weighted sum of one or more + * input trees on an adaptively refined grid. The refinement loop is driven by a + * precision target: + * - **Relative precision** (default): refine while local wavelet details are + * not small compared to the local function norm. + * - **Absolute precision** (`absPrec = true`): refine until local details + * fall below a fixed absolute threshold. + * + * Unless noted otherwise, all input trees and the output must share the same + * `MultiResolutionAnalysis` (domain, basis, scales). The output grid is + * extended as needed; it is not cleared automatically. + */ + namespace mrcpp { +/** + * @brief Adaptive sum of two MW functions with scalar weights. + * + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient type (e.g., `double`, `ComplexDouble`). + * + * @param[in] prec Target build precision (relative by default; see @p absPrec). + * @param[out] out Output tree to construct (its grid is extended as needed). + * @param[in] a Scalar weight multiplying @p tree_a. + * @param[in] tree_a First input function tree. + * @param[in] b Scalar weight multiplying @p tree_b. + * @param[in] tree_b Second input function tree. + * @param[in] maxIter Maximum refinement iterations; negative means unbounded. + * @param[in] absPrec If true, interpret @p prec as absolute; else relative. + * @param[in] conjugate If true and @p T is complex, apply complex conjugation + * to the second operand (@p tree_b) during accumulation. + * + * @details + * Builds + * \f[ + * \text{out} \leftarrow a\,\text{tree\_a} \;+\; + * b\,(\,\text{conjugate}\ ?\ \overline{\text{tree\_b}}:\text{tree\_b}\,), + * \f] + * on a grid refined to meet @p prec under the chosen precision policy. + */ +template +void add(double prec, + FunctionTree &out, + T a, + FunctionTree &tree_a, + T b, + FunctionTree &tree_b, + int maxIter = -1, + bool absPrec = false, + bool conjugate = false); + +/** + * @brief Adaptive linear combination from a vector of (coefficient, tree) pairs. + * + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient type. + * + * @param[in] prec Target build precision (relative by default; see @p absPrec). + * @param[out] out Output tree to construct. + * @param[in] inp Vector of pairs \f$(\alpha_k, f_k)\f$ (type `FunctionTreeVector`). + * @param[in] maxIter Maximum refinement iterations; negative means unbounded. + * @param[in] absPrec If true, interpret @p prec as absolute; else relative. + * @param[in] conjugate If true and @p T is complex, apply complex conjugation + * to all trees except the first one during accumulation. + * + * @details + * Builds + * \f[ + * \text{out} \leftarrow \sum_k \alpha_k\, g_k, + * \f] + * where \f$g_k = \overline{f_k}\f$ if @p conjugate is true (and \f$k>0\f$ in the + * complex case), otherwise \f$g_k = f_k\f$. The grid is refined adaptively + * to satisfy @p prec. + */ +template +void add(double prec, + FunctionTree &out, + FunctionTreeVector &inp, + int maxIter = -1, + bool absPrec = false, + bool conjugate = false); + +/** + * @brief Convenience overload: adaptive sum of a list of trees with unit weights. + * + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient type. + * + * @param[in] prec Target build precision (relative by default; see @p absPrec). + * @param[out] out Output tree to construct. + * @param[in] inp List of tree pointers; each term is taken with weight 1. + * @param[in] maxIter Maximum refinement iterations; negative means unbounded. + * @param[in] absPrec If true, interpret @p prec as absolute; else relative. + * @param[in] conjugate If true and @p T is complex, apply complex conjugation + * to all trees except the first during accumulation. + * + * @details + * Equivalent to the `FunctionTreeVector` overload with all coefficients set to 1. + */ template -void add(double prec, FunctionTree &out, T a, FunctionTree &tree_a, T b, FunctionTree &tree_b, int maxIter = -1, bool absPrec = false, bool conjugate = false); -template void add(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); -template void add(double prec, FunctionTree &out, std::vector *> &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); +void add(double prec, + FunctionTree &out, + std::vector *> &inp, + int maxIter = -1, + bool absPrec = false, + bool conjugate = false); } // namespace mrcpp diff --git a/src/treebuilders/apply.h b/src/treebuilders/apply.h index fa5a43661..abdb030f6 100644 --- a/src/treebuilders/apply.h +++ b/src/treebuilders/apply.h @@ -24,6 +24,28 @@ */ #pragma once +/** + * @file apply.h + * @brief Adaptive application of convolution/derivative operators to + * multiresolution (MW) functions and composite (multi-component) functions. + * + * @details + * This header declares a family of routines that: + * - apply **separable convolution operators** (near-/far-field or full) to MW trees, + * - apply **derivative operators** to scalar or vector fields, + * - compute **divergence** of vector fields, and + * - compute **gradients**. + * + * Overloads exist for scalar MW trees (`FunctionTree`) and for composite + * multi-component fields (`CompFunction`). For composite variants a + * 4×4 complex **metric** can be supplied to define the componentwise inner + * product / mixing; by default the identity metric is used. + * + * Precision and adaptivity: + * - `prec` is the target build precision used by the adaptive refinement loop. + * - `absPrec = false` → relative criterion; `true` → absolute threshold. + * - `maxIter < 0` removes the iteration cap. + */ #include "trees/FunctionTreeVector.h" #include "utils/CompFunction.h" @@ -36,24 +58,313 @@ template class FunctionTree; template class DerivativeOperator; template class ConvolutionOperator; +/** + * @brief Default 4×4 complex metric (identity). + * + * @details + * Used by composite-function overloads to define component coupling / inner + * product when applying operators. The default is the identity, i.e., no + * cross-component mixing. + */ constexpr ComplexDouble defaultMetric [4][4] ={{1,0,0,0},{0,1,0,0},{0,0,1,0},{0,0,0,1}}; -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply(double prec, CompFunction &out, ConvolutionOperator &oper, const CompFunction &inp, const ComplexDouble (*metric)[4] = defaultMetric, int maxIter = -1, bool absPrec = false); -template void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, FunctionTreeVector &precTrees, int maxIter = -1, bool absPrec = false); -template void apply(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, FunctionTreeVector *precTrees, ComplexDouble (*metric)[4] = nullptr, int maxIter = -1, bool absPrec = false); -template void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply_far_field(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, const ComplexDouble (*metric)[4] = defaultMetric, int maxIter = -1, bool absPrec = false); -template void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, FunctionTree &inp, int maxIter = -1, bool absPrec = false); -template void apply_near_field(double prec, CompFunction &out, ConvolutionOperator &oper, CompFunction &inp, const ComplexDouble (*metric)[4] = defaultMetric, int maxIter = -1, bool absPrec = false); -template void apply(FunctionTree &out, DerivativeOperator &oper, FunctionTree &inp, int dir = -1); -template void apply(CompFunction &out, DerivativeOperator &oper, CompFunction &inp, int dir = -1, const ComplexDouble (*metric)[4] = defaultMetric); -template void divergence(FunctionTree &out, DerivativeOperator &oper, FunctionTreeVector &inp); -template void divergence(CompFunction &out, DerivativeOperator &oper, FunctionTreeVector *inp, const ComplexDouble (*metric)[4] = defaultMetric); -template void divergence(FunctionTree &out, DerivativeOperator &oper, std::vector *> &inp); -template void divergence(CompFunction &out, DerivativeOperator &oper, std::vector *> *inp, const ComplexDouble (*metric)[4] = defaultMetric); -template FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp); -std::vector*> gradient(DerivativeOperator<3> &oper, CompFunction<3> &inp, const ComplexDouble (*metric)[4] = defaultMetric); +/** + * @name Convolution application (scalar FunctionTree) + * @{ + */ + +/** + * @brief Apply a separable convolution operator adaptively. + * + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient type (e.g., double, ComplexDouble). + * + * @param[in] prec Target precision for the adaptive build. + * @param[out] out Output function tree (built/extended adaptively). + * @param[in] oper Convolution operator to apply. + * @param[in] inp Input function tree. + * @param[in] maxIter Maximum refinement iterations (-1 = unbounded). + * @param[in] absPrec Use absolute (true) or relative (false) precision. + */ +template +void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, + FunctionTree &inp, int maxIter = -1, bool absPrec = false); + +/** + * @brief Apply a convolution operator with **per-node precision modulation**. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[in] prec Base precision. + * @param[out] out Output function tree. + * @param[in] oper Convolution operator. + * @param[in] inp Input function tree. + * @param[in] precTrees Vector of trees used to modulate local precision + * (e.g., via node-wise scaling factors). + * @param[in] maxIter Maximum refinement iterations. + * @param[in] absPrec Absolute vs. relative precision. + */ +template +void apply(double prec, FunctionTree &out, ConvolutionOperator &oper, + FunctionTree &inp, FunctionTreeVector &precTrees, + int maxIter = -1, bool absPrec = false); + +/** + * @brief Apply only the **far-field** contribution of a convolution operator. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[in] prec Target precision. + * @param[out] out Output function tree. + * @param[in] oper Convolution operator (far-field path will be used). + * @param[in] inp Input function tree. + * @param[in] maxIter Maximum refinement iterations. + * @param[in] absPrec Absolute vs. relative precision. + */ +template +void apply_far_field(double prec, FunctionTree &out, ConvolutionOperator &oper, + FunctionTree &inp, int maxIter = -1, bool absPrec = false); + +/** + * @brief Apply only the **near-field** contribution of a convolution operator. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[in] prec Target precision. + * @param[out] out Output function tree. + * @param[in] oper Convolution operator (near-field path will be used). + * @param[in] inp Input function tree. + * @param[in] maxIter Maximum refinement iterations. + * @param[in] absPrec Absolute vs. relative precision. + */ +template +void apply_near_field(double prec, FunctionTree &out, ConvolutionOperator &oper, + FunctionTree &inp, int maxIter = -1, bool absPrec = false); + +/** @} */ + +/** + * @name Convolution application (composite CompFunction) + * @{ + */ + +/** + * @brief Apply a convolution operator to a composite function with a metric. + * + * @tparam D Spatial dimension. + * + * @param[in] prec Target precision. + * @param[out] out Output composite function. + * @param[in] oper Convolution operator. + * @param[in] inp Input composite function. + * @param[in] metric Optional 4×4 complex metric (defaults to identity). + * @param[in] maxIter Maximum refinement iterations (-1 = unbounded). + * @param[in] absPrec Absolute vs. relative precision. + * + * @note Components can be coupled via @p metric during accumulation. + */ +template +void apply(double prec, CompFunction &out, ConvolutionOperator &oper, + const CompFunction &inp, const ComplexDouble (*metric)[4] = defaultMetric, + int maxIter = -1, bool absPrec = false); + +/** + * @brief Apply a convolution operator to a composite function with + * precision-modulating trees and optional metric. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type used in @p precTrees. + * + * @param[in] prec Base precision. + * @param[out] out Output composite function. + * @param[in] oper Convolution operator. + * @param[in] inp Input composite function. + * @param[in] precTrees Optional per-node precision modulators (may be nullptr). + * @param[in] metric Optional 4×4 complex metric (may be nullptr → identity). + * @param[in] maxIter Maximum refinement iterations. + * @param[in] absPrec Absolute vs. relative precision. + */ +template +void apply(double prec, CompFunction &out, ConvolutionOperator &oper, + CompFunction &inp, FunctionTreeVector *precTrees, + ComplexDouble (*metric)[4] = nullptr, int maxIter = -1, bool absPrec = false); + +/** + * @brief Apply only the **far-field** part to a composite function. + * + * @tparam D Spatial dimension. + * + * @param[in] prec Target precision. + * @param[out] out Output composite function. + * @param[in] oper Convolution operator. + * @param[in] inp Input composite function. + * @param[in] metric Optional 4×4 complex metric (defaults to identity). + * @param[in] maxIter Maximum refinement iterations. + * @param[in] absPrec Absolute vs. relative precision. + */ +template +void apply_far_field(double prec, CompFunction &out, ConvolutionOperator &oper, + CompFunction &inp, const ComplexDouble (*metric)[4] = defaultMetric, + int maxIter = -1, bool absPrec = false); + +/** + * @brief Apply only the **near-field** part to a composite function. + * + * @tparam D Spatial dimension. + * + * @param[in] prec Target precision. + * @param[out] out Output composite function. + * @param[in] oper Convolution operator. + * @param[in] inp Input composite function. + * @param[in] metric Optional 4×4 complex metric (defaults to identity). + * @param[in] maxIter Maximum refinement iterations. + * @param[in] absPrec Absolute vs. relative precision. + */ +template +void apply_near_field(double prec, CompFunction &out, ConvolutionOperator &oper, + CompFunction &inp, const ComplexDouble (*metric)[4] = defaultMetric, + int maxIter = -1, bool absPrec = false); + +/** @} */ + +/** + * @name Derivative application + * @{ + */ + +/** + * @brief Apply a derivative operator to a scalar MW function. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[out] out Output tree (derivative result). + * @param[in] oper Derivative operator. + * @param[in] inp Input function. + * @param[in] dir Application direction (0..D-1). If negative, use the + * operator’s internal direction. + */ +template +void apply(FunctionTree &out, DerivativeOperator &oper, + FunctionTree &inp, int dir = -1); + +/** + * @brief Apply a derivative operator to a composite function with a metric. + * + * @tparam D Spatial dimension. + * + * @param[out] out Output composite function. + * @param[in] oper Derivative operator. + * @param[in] inp Input composite function. + * @param[in] dir Application direction (0..D-1). If negative, use operator’s default. + * @param[in] metric Optional 4×4 complex metric (defaults to identity). + */ +template +void apply(CompFunction &out, DerivativeOperator &oper, + CompFunction &inp, int dir = -1, + const ComplexDouble (*metric)[4] = defaultMetric); + +/** @} */ + +/** + * @name Divergence + * @{ + */ + +/** + * @brief Divergence of a vector field given as separate component trees. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[out] out Output scalar field (divergence). + * @param[in] oper Derivative operator (used per direction). + * @param[in] inp Vector of component trees (size D expected). + */ +template +void divergence(FunctionTree &out, DerivativeOperator &oper, + FunctionTreeVector &inp); + +/** + * @brief Divergence of a composite vector field with metric. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type used by the composite. + * + * @param[out] out Output scalar composite function (divergence). + * @param[in] oper Derivative operator (used per direction). + * @param[in] inp Pointer to vector of component composite functions. + * @param[in] metric Optional 4×4 complex metric. + */ +template +void divergence(CompFunction &out, DerivativeOperator &oper, + FunctionTreeVector *inp, + const ComplexDouble (*metric)[4] = defaultMetric); + +/** + * @brief Divergence of a vector field given as a raw list of component pointers. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[out] out Output scalar field (divergence). + * @param[in] oper Derivative operator. + * @param[in] inp Vector of pointers to component trees (size D expected). + */ +template +void divergence(FunctionTree &out, DerivativeOperator &oper, + std::vector *> &inp); + +/** + * @brief Divergence for composite fields given as raw component pointers with metric. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type used by components. + * + * @param[out] out Output scalar composite function. + * @param[in] oper Derivative operator. + * @param[in] inp Pointer to vector of component tree pointers. + * @param[in] metric Optional 4×4 complex metric. + */ +template +void divergence(CompFunction &out, DerivativeOperator &oper, + std::vector *> *inp, + const ComplexDouble (*metric)[4] = defaultMetric); + +/** @} */ + +/** + * @name Gradient + * @{ + */ + +/** + * @brief Gradient of a scalar field (returns D component trees). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * + * @param[in] oper Derivative operator (used per direction). + * @param[in] inp Input scalar field. + * @return Vector of D component trees with directional derivatives. + */ +template +FunctionTreeVector gradient(DerivativeOperator &oper, FunctionTree &inp); + +/** + * @brief Gradient (3D) for composite fields, returning heap-allocated components. + * + * @param[in] oper 3D derivative operator. + * @param[in] inp Input composite function. + * @param[in] metric Optional 4×4 complex metric (defaults to identity). + * @return Vector of pointers to newly allocated component composite functions + * representing the gradient. The caller owns and must delete them. + */ +std::vector*> gradient(DerivativeOperator<3> &oper, CompFunction<3> &inp, + const ComplexDouble (*metric)[4] = defaultMetric); // clang-format on -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/complex_apply.h b/src/treebuilders/complex_apply.h index 8ed9a0f17..3d634e5a9 100644 --- a/src/treebuilders/complex_apply.h +++ b/src/treebuilders/complex_apply.h @@ -29,28 +29,100 @@ namespace mrcpp { -/// @brief Stores pointers to real and imaginary parts of tree objects. -/// @tparam MWClass -template struct ComplexObject { - MWClass *real; - MWClass *imaginary; - - ComplexObject(MWClass &realPart, MWClass &imaginaryPart) - : real(&realPart) - , imaginary(&imaginaryPart) {} +/** + * @file + * @brief Complex wrapper utilities for multiwavelet trees and operators. + * + * @details + * This header declares a lightweight wrapper, @ref ComplexObject, that groups + * pointers to the real and imaginary parts of an object (e.g., a function + * tree or a convolution operator). It also declares an `apply` routine that + * applies a (possibly complex) convolution operator to a (possibly complex) + * function, writing the result to a (possibly complex) output. + * + * The pattern keeps real and imaginary parts as separate objects for memory + * locality and to reuse existing real-valued kernels, while allowing users to + * orchestrate complex arithmetic at a higher level. + */ + +/** + * @brief Aggregates pointers to the real and imaginary parts of an object. + * + * @tparam MWClass Underlying class type of the wrapped objects + * (e.g., `FunctionTree` or `ConvolutionOperator`). + * + * @details + * The struct is a non-owning pair of pointers. It does **not** manage + * lifetime—callers must ensure both referenced objects outlive the wrapper. + * + * @note + * The members are intentionally public for ergonomic access in kernels. + */ +template +struct ComplexObject { + /** @brief Pointer to the real component (non-owning). */ + MWClass* real; + /** @brief Pointer to the imaginary component (non-owning). */ + MWClass* imaginary; + + /** + * @brief Construct from lvalue references to the real and imaginary parts. + * @param realPart Reference to the real component. + * @param imaginaryPart Reference to the imaginary component. + */ + ComplexObject(MWClass& realPart, MWClass& imaginaryPart) + : real(&realPart) + , imaginary(&imaginaryPart) {} }; // clang-format off //template class FunctionTree; //template class ConvolutionOperator; +/** + * @brief Apply a (complex) convolution operator to a (complex) function. + * + * @tparam D Spatial dimensionality of the multiwavelet representation. + * + * @param prec Target accuracy. If `absPrec == false`, this is interpreted + * as a **relative** tolerance; otherwise as an **absolute** tolerance. + * @param out Destination complex function trees (real/imag). On return, + * contains \f$ \text{oper} \{\text{inp}\} \f$ within the requested + * accuracy. + * @param oper Complex convolution operator (real/imag components). + * @param inp Input complex function trees to be transformed. + * @param maxIter Optional cap on internal refinement/iteration steps. + * Use `-1` (default) for the implementation’s automatic choice. + * @param absPrec When `true`, treat `prec` as absolute; when `false`, as relative. + * + * @pre + * - `out.real`, `out.imaginary`, `inp.real`, `inp.imaginary`, + * `oper.real`, and `oper.imaginary` are non-null and represent + * consistent discretizations (same MRA/order/domain). + * + * @post + * - `out` holds the complex result. Implementations typically compute: + * \f[ + * \Re(\text{out}) = \Re(\text{oper})\Re(\text{inp}) + * - \Im(\text{oper})\Im(\text{inp}), + * \qquad + * \Im(\text{out}) = \Re(\text{oper})\Im(\text{inp}) + * + \Im(\text{oper})\Re(\text{inp}), + * \f] + * with adaptive refinement to honor `prec`. + * + * @note + * The exact refinement strategy and stopping criteria are backend-dependent. + * For reproducibility across runs/nodes, set the relevant MPI/OpenMP controls + * prior to calling. + */ template void apply ( - double prec, ComplexObject< FunctionTree > &out, - ComplexObject< ConvolutionOperator > &oper, ComplexObject< FunctionTree > &inp, + double prec, ComplexObject< FunctionTree >& out, + ComplexObject< ConvolutionOperator >& oper, ComplexObject< FunctionTree >& inp, int maxIter = -1, bool absPrec = false ); // clang-format on -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/grid.h b/src/treebuilders/grid.h index 1d7021f8b..be3e88376 100644 --- a/src/treebuilders/grid.h +++ b/src/treebuilders/grid.h @@ -25,24 +25,230 @@ #pragma once +/** + * @file + * @brief Grid construction, copying, clearing, and refinement helpers for multiresolution trees. + * + * @details + * This header declares a family of utilities to *construct* and *modify* the + * topology (grid) of @ref mrcpp::FunctionTree without necessarily computing or + * moving coefficients. The functions support several sources: + * analytic/representable functions, existing trees, vectors of trees, and + * explicit scale counts. + * + * ### Conventions + * - `D` is the spatial dimension (typically 1–3). + * - `T` is the coefficient scalar type (`double` or `std::complex`). + * - Functions named `build_grid` create (or enlarge) the *tree structure* + * of `out` to be adequate for representing the given input(s). + * - Functions named `copy_grid` copy only the *structure* (no coefficients). + * - `copy_func` copies *both* structure and coefficients. + * - Functions named `refine_grid` add resolution either explicitly by a scale + * count or adaptively by a precision criterion. + * - Functions return `int` indicate the number of newly created end-nodes + * (i.e., how many refinements were actually performed). + */ + #include "functions/RepresentableFunction.h" #include "trees/FunctionTree.h" #include "trees/FunctionTreeVector.h" #include "utils/CompFunction.h" namespace mrcpp { -template void build_grid(FunctionTree &out, int scales); -template void build_grid(FunctionTree &out, const GaussExp &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1); -template void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter = -1); -template void copy_func(FunctionTree &out, FunctionTree &inp); -template void copy_grid(FunctionTree &out, FunctionTree &inp); -template void copy_grid(CompFunction &out, CompFunction &inp); -template void clear_grid(FunctionTree &out); -template int refine_grid(FunctionTree &out, int scales); -template int refine_grid(FunctionTree &out, double prec, bool absPrec = false); -template int refine_grid(FunctionTree &out, FunctionTree &inp); -template int refine_grid(FunctionTree &out, const RepresentableFunction &inp); -} // namespace mrcpp + +/** + * @brief Create a uniform grid of fixed depth. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Target tree whose topology will be (re)built. + * @param[in] scales Number of refinement steps from the current state. + * + * @details + * Starting from the current `out` topology (typically roots), subdivide each + * active end-node `scales` times so that a regular grid of depth increased by + * `scales` is obtained. No coefficients are computed or modified. + */ +template +void build_grid(FunctionTree &out, int scales); + +/** + * @brief Build an adaptive grid suitable for a Gaussian expansion. + * + * @tparam D Spatial dimension. + * @param[out] out Target **real** tree to receive the grid. + * @param[in] inp Analytic Gaussian expansion used as refinement oracle. + * @param[in] maxIter Maximum refinement passes; negative means “unbounded” + * until convergence by the internal criterion. + * + * @details + * Iteratively refines the tree so that the structure can represent `inp` + * within the library’s default per-node criterion (e.g., band-limited model or + * local projection error). Coefficients are not guaranteed to be written. + */ +template +void build_grid(FunctionTree &out, const GaussExp &inp, int maxIter = -1); + +/** + * @brief Build an adaptive grid for a generic representable function. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Target tree. + * @param[in] inp Representable function serving as refinement oracle. + * @param[in] maxIter Maximum refinement passes; negative means unbounded. + * + * @details + * Uses evaluations/projections of `inp` to determine where refinement is + * needed so that the resulting grid can capture `inp` with the library’s + * default tolerance heuristic. + */ +template +void build_grid(FunctionTree &out, const RepresentableFunction &inp, int maxIter = -1); + +/** + * @brief Build a grid that can accommodate another tree’s resolution/support. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Target tree to be enlarged/refined. + * @param[in] inp Source tree whose structure (support + finest scales) is used. + * + * @details + * Ensures that `out` has at least the resolution present in `inp` wherever + * `inp` has support (a *grid union* operation). Coefficients are not copied. + */ +template +void build_grid(FunctionTree &out, FunctionTree &inp, int maxIter = -1); + +/** + * @brief Build a grid that is a union of a vector of trees. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Target tree. + * @param[in] inp Vector of trees whose supports/resolutions are merged. + * @param[in] maxIter Optional iteration cap for staged refinement strategies. + */ +template +void build_grid(FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1); + +/** + * @brief Build a grid that is a union of a list of tree pointers. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Target tree. + * @param[in] inp List of tree pointers to merge. + * @param[in] maxIter Optional iteration cap for staged refinement strategies. + */ +template +void build_grid(FunctionTree &out, std::vector *> &inp, int maxIter = -1); + +/** + * @brief Deep copy a tree structure *and* coefficients. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Destination tree (reallocated as needed). + * @param[in] inp Source tree. + */ +template +void copy_func(FunctionTree &out, FunctionTree &inp); + +/** + * @brief Copy only the tree topology (no coefficients). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Destination tree (structure rebuilt). + * @param[in] inp Source tree whose topology is replicated. + */ +template +void copy_grid(FunctionTree &out, FunctionTree &inp); + +/** + * @brief Copy only the topology for all components of a composite function. + * + * @tparam D Spatial dimension. + * @param[out] out Destination composite function (components allocated as needed). + * @param[in] inp Source composite function. + * + * @details + * For each component present in @p inp, ensure @p out has a corresponding + * component with identical tree structure. Coefficients are not copied. + */ +template +void copy_grid(CompFunction &out, CompFunction &inp); + +/** + * @brief Clear the grid topology (prune to roots, drop nodes). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[out] out Tree to clear; MRA association remains intact. + */ +template +void clear_grid(FunctionTree &out); + +/** + * @brief Refine uniformly by a fixed number of scales. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[in,out] out Tree to refine. + * @param[in] scales Number of subdivision steps to apply. + * @return Number of new end-nodes created by the refinement. + */ +template +int refine_grid(FunctionTree &out, int scales); + +/** + * @brief Adaptive refinement driven by a precision threshold. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[in,out] out Tree to refine. + * @param[in] prec Target precision (threshold). + * @param[in] absPrec If `true`, interpret @p prec as absolute tolerance; + * otherwise relative to a norm estimate. + * @return Number of new end-nodes created. + * + * @details + * Subdivides those nodes whose local error/indicator exceeds the requested + * threshold. The precise indicator depends on the library configuration + * (e.g., wavelet-norm-based splitting). + */ +template +int refine_grid(FunctionTree &out, double prec, bool absPrec = false); + +/** + * @brief Refine `out` so that its grid is at least as fine as `inp`. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[in,out] out Destination tree to refine. + * @param[in] inp Source tree providing the target finest scales. + * @return Number of new end-nodes created. + */ +template +int refine_grid(FunctionTree &out, FunctionTree &inp); + +/** + * @brief Adaptive refinement using a representable function as oracle. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param[in,out] out Tree to refine. + * @param[in] inp Representable function guiding refinement. + * @return Number of new end-nodes created. + * + * @details + * Samples or projects @p inp on candidate nodes and refines where the + * estimated local error is above the internal criterion, creating a grid + * appropriate for subsequently projecting @p inp. + */ +template +int refine_grid(FunctionTree &out, const RepresentableFunction &inp); + +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/map.h b/src/treebuilders/map.h index d1f86e201..696c91ff5 100644 --- a/src/treebuilders/map.h +++ b/src/treebuilders/map.h @@ -25,11 +25,113 @@ #pragma once +/** + * @file + * @brief Nonlinear mapping utilities for multiresolution trees. + * + * @details + * Declares an adaptive routine that applies a user-supplied scalar mapping + * to a multiresolution function represented by a @ref mrcpp::FunctionTree. + * The routine produces an output tree whose grid is refined as needed to meet + * a requested precision. + * + * ### What “map” does + * Given an input scalar field \( f(\mathbf{r}) \) encoded by `inp`, and a + * scalar-to-scalar function `fmap : ℝ → ℝ`, this routine builds (or refines) + * the topology of `out` and computes coefficients so that + * \f[ + * g(\mathbf{r}) = \mathrm{fmap}\big(f(\mathbf{r})\big) + * \f] + * is represented to within the requested tolerance. + * + * The mapping is *pointwise* in value space (nonlinear allowed) and the grid + * refinement is *adaptive*: nodes are split where approximation error indicates + * additional resolution is required. + * + * ### Typical uses + * - Envelope shaping (e.g., clamp, softplus, \f$x^p\f$). + * - Nonlinearities inside iterative solvers. + * - Post-processing fields (e.g., magnitude, thresholding). + * + * @note Only the **real** scalar case (`double` coefficients) is declared here. + * Complex-valued mappings typically require splitting real/imag components + * explicitly or using dedicated complex routines elsewhere in the library. + */ + #include "trees/FunctionTreeVector.h" namespace mrcpp { + template class FunctionTree; -template void map(double prec, FunctionTree &out, FunctionTree &inp, FMap fmap, int maxIter = -1, bool absPrec = false); +/** + * @brief Apply a scalar mapping to a function tree with adaptive refinement. + * + * @tparam D Spatial dimension (1–3 typical). + * + * @param[in] prec + * Target precision threshold used to control adaptive refinement. + * See @p absPrec for interpretation. + * @param[out] out + * Destination tree. Its topology will be enlarged/refined as needed and its + * coefficients overwritten with the mapped result. + * The tree must be associated with a valid MRA compatible with @p inp. + * @param[in] inp + * Source tree that encodes the input function \( f(\mathbf{r}) \). + * Logically read-only (will not be modified by a correct implementation). + * @param[in] fmap + * Scalar mapping functor of type `FMap` (typically equivalent + * to `std::function` or any callable with signature + * `double(double)`). It is applied pointwise to sample values of `inp`. + * @param[in] maxIter + * Maximum number of refinement passes. A negative value (default) requests + * unbounded passes until the internal convergence criterion is satisfied + * (e.g., no new nodes created or estimated error below @p prec everywhere). + * @param[in] absPrec + * If `true`, interpret @p prec as an **absolute** tolerance on the local + * error indicator. If `false` (default), use a **relative** tolerance, + * typically scaled by an estimate of \f$\|f\|\f$ (implementation-defined). + * + * @pre + * - `out` and `inp` share compatible MRAs (same domain, basis order, etc.). + * - `fmap` must be pure (side-effect free) and thread-safe. + * + * @post + * - `out`’s topology and coefficients represent + * \( g(\mathbf{r}) = \mathrm{fmap}(f(\mathbf{r})) \) to within the requested + * tolerance, subject to the library’s split criterion. + * + * @par Precision semantics + * - *Absolute mode* (`absPrec=true`): the error indicator is compared directly + * to @p prec. + * - *Relative mode* (`absPrec=false`): the indicator is scaled by a norm of + * the input (e.g., tree square-norm), so @p prec represents a relative + * threshold. + * + * @par Parallelization + * The routine may exploit OpenMP/MPI internally. Supplying a thread-safe + * `fmap` is required. + * + * @par Exception safety + * Strong guarantee for `inp`. `out` is modified during execution; in case of + * failure it may be left partially updated. + * + * @par Example + * @code + * using Tree = mrcpp::FunctionTree<3,double>; + * Tree fout(mra), fin(mra); + * // ... build/project fin ... + * + * auto square = [](double x){ return x*x; }; + * mrcpp::map<3>(1e-6, fout, fin, square); // fout ≈ (fin)^2 + * @endcode + */ +template +void map(double prec, + FunctionTree &out, + FunctionTree &inp, + FMap fmap, + int maxIter = -1, + bool absPrec = false); } // namespace mrcpp diff --git a/src/treebuilders/multiply.h b/src/treebuilders/multiply.h index 316066483..75e9cf316 100644 --- a/src/treebuilders/multiply.h +++ b/src/treebuilders/multiply.h @@ -25,18 +25,125 @@ #pragma once +/** + * @file + * @brief High-level algebra on multiresolution function trees. + * + * @details + * This header declares scalar and field operations on + * @ref mrcpp::FunctionTree objects: + * - continuous inner products (dot products), + * - pointwise products of two or many trees, + * - powers/squares (element-wise). + * + * Unless stated otherwise, functions honor a target accuracy parameter + * `prec` (see each overload). Implementations typically refine/coarsen + * grids adaptively using wavelet norms and multiresolution estimates until + * the requested tolerance is met (or a `maxIter` cap is reached). + * + * ### Precision semantics + * - `prec` is interpreted as a **relative** tolerance in an L2-like sense + * by default; set `absPrec=true` to treat it as an **absolute** tolerance. + * - `maxIter < 0` means “iterate as needed”; otherwise it limits refinement + * passes (the function may exit early with a looser error). + * + * ### Conjugation semantics (complex trees) + * When `T` is complex, some overloads accept `conjugate=true` to apply + * complex conjugation to the first factor (bra–ket convention), yielding + * products like \f$f \cdot \overline{g}\f$ or \f{|f|^2}\f. + */ + #include "trees/FunctionTreeVector.h" namespace mrcpp { + template class RepresentableFunction; template class FunctionTree; -template () * std::declval())> V dot(FunctionTree &bra, FunctionTree &ket); +/** + * @brief Continuous inner product \f$\langle \text{bra} \mid \text{ket} \rangle\f$ over \f$\mathbb{R}^D\f$. + * + * @tparam D Spatial dimension. + * @tparam T Scalar type of the bra tree (e.g., `double`, `ComplexDouble`). + * @tparam U Scalar type of the ket tree. + * @tparam V Return type deduced as `decltype(T{} * U{})`. + * + * @param bra Multiresolution function tree acting as the bra. + * @param ket Multiresolution function tree acting as the ket. + * @return The scalar inner product value. For complex types, the bra is + * conjugated (i.e., \f$\int \overline{bra(x)}\,ket(x)\,dx\f$). + * + * @pre Both trees must be compatible (same MRA/grid conventions). + * @note Implementations usually reconstruct to consistent representations + * before integration; they may refine adaptively to ensure accuracy. + * @warning For poorly overlapping/aliased grids, the routine may refine + * meshes internally, which can be expensive. + */ +template () * std::declval())> +V dot(FunctionTree &bra, FunctionTree &ket); -template void dot(double prec, FunctionTree &out, FunctionTreeVector &inp_a, FunctionTreeVector &inp_b, int maxIter = -1, bool absPrec = false); +/** + * @brief Contract two vectors of trees into a scalar field: + * \f$out(x) = \sum_i a_i(x)\,\overline{b_i(x)}\f$ (complex) or \f$\sum_i a_i(x)\,b_i(x)\f$ (real). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type (`double` or `ComplexDouble`). + * @param prec Target accuracy for the constructed field (see “Precision semantics” above). + * @param out Output scalar field tree receiving the contraction. + * @param inp_a Vector of factor trees \f$\{a_i\}\f$. + * @param inp_b Vector of factor trees \f$\{b_i\}\f$; must have the same size and compatible grids as `inp_a`. + * @param maxIter Maximum refinement passes; `-1` for unlimited. + * @param absPrec If `true`, interpret `prec` as absolute tolerance. + * + * @details + * Builds the *pointwise* contraction of two equally sized collections of trees, + * summing products component-wise. This is often used to assemble densities + * or overlaps distributed over space. + * + * The routine adaptively refines `out` to meet `prec`. Input nodes may be + * transiently reconstructed to compatible representations. + */ +template +void dot(double prec, FunctionTree &out, FunctionTreeVector &inp_a, FunctionTreeVector &inp_b, int maxIter = -1, bool absPrec = false); -template double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact = false); +/** + * @brief Fast contraction based on node norms (cheap estimate/upper bound). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param bra First tree. + * @param ket Second tree. + * @param exact If `true`, request exact inner product instead of a norm-based estimate (implementation-dependent). + * @return A scalar quantity derived from node-wise norms; commonly used + * as a quick upper bound or cheap similarity measure. + * + * @note When `exact=true`, implementations may fall back to the same + * evaluation as @ref dot(bra, ket). If exact evaluation is not + * available, `exact` may be ignored. + */ +template +double node_norm_dot(FunctionTree &bra, FunctionTree &ket, bool exact = false); +/** + * @brief Pointwise product of two trees with a scalar prefactor: + * \f$out \leftarrow c \cdot a \cdot (\mathrm{conj}\,b \text{ if requested})\f$. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param prec Target accuracy for `out`. + * @param out Output tree receiving the product. + * @param c Scalar prefactor applied to the product. + * @param inp_a First factor. + * @param inp_b Second factor. + * @param maxIter Maximum refinement passes; `-1` for unlimited. + * @param absPrec If `true`, interpret `prec` as absolute tolerance. + * @param useMaxNorms If `true`, use max-norm heuristics to guide refinement (may be faster, slightly more conservative). + * @param conjugate If `true` and `T` is complex, conjugate the **first** factor (bra–ket convention). + * + * @details + * Produces an adaptively refined tree such that the representation error of + * the pointwise product does not exceed `prec` under the chosen policy. + */ template void multiply(double prec, FunctionTree &out, @@ -48,14 +155,76 @@ void multiply(double prec, bool useMaxNorms = false, bool conjugate = false); +/** + * @brief Pointwise product of an arbitrary number of trees: + * \f$out \leftarrow \prod_{i} f_i\f$ (optional conjugation of the first factor). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param prec Target accuracy for `out`. + * @param out Output tree receiving the product. + * @param inp List of input tree pointers (non-null, compatible MRAs). + * @param maxIter Maximum refinement passes; `-1` for unlimited. + * @param absPrec If `true`, interpret `prec` as absolute tolerance. + * @param useMaxNorms If `true`, enable max-norm driven refinement. + * @param conjugate If `true` and `T` is complex, conjugate the **first** factor only. + * + * @note The algorithm typically multiplies factors incrementally with + * intermediate refinement; ordering can affect performance. + */ template void multiply(double prec, FunctionTree &out, std::vector *> &inp, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); +/** + * @brief Pointwise product of a vector of trees: + * \f$out \leftarrow \prod_{i} f_i\f$ (optional conjugation of the first factor). + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param prec Target accuracy for `out`. + * @param out Output tree receiving the product. + * @param inp Vector wrapper containing input trees (and possibly per-tree scalars). + * @param maxIter Maximum refinement passes; `-1` for unlimited. + * @param absPrec If `true`, interpret `prec` as absolute tolerance. + * @param useMaxNorms If `true`, enable max-norm driven refinement. + * @param conjugate If `true` and `T` is complex, conjugate the **first** factor only. + */ template void multiply(double prec, FunctionTree &out, FunctionTreeVector &inp, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); -template void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter = -1, bool absPrec = false); +/** + * @brief Element-wise power: \f$out(x) = \big(inp(x)\big)^{p}\f$. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param prec Target accuracy for `out`. + * @param out Output tree. + * @param inp Input tree. + * @param p Real exponent. + * @param maxIter Maximum refinement passes; `-1` for unlimited. + * @param absPrec If `true`, interpret `prec` as absolute tolerance. + * + * @warning For real `T`, negative bases with non-integer `p` are undefined. + * For complex `T`, the principal branch is typically used. + */ +template +void power(double prec, FunctionTree &out, FunctionTree &inp, double p, int maxIter = -1, bool absPrec = false); -template void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); +/** + * @brief Element-wise square. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param prec Target accuracy for `out`. + * @param out Output tree. + * @param inp Input tree. + * @param maxIter Maximum refinement passes; `-1` for unlimited. + * @param absPrec If `true`, interpret `prec` as absolute tolerance. + * @param conjugate If `true` and `T` is complex, compute squared magnitude: + * \f$out = inp \cdot \overline{inp}\f$; otherwise compute + * \f$out = inp \cdot inp\f$. + */ +template +void square(double prec, FunctionTree &out, FunctionTree &inp, int maxIter = -1, bool absPrec = false, bool conjugate = false); -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/treebuilders/project.h b/src/treebuilders/project.h index f9e070ef2..ba3466412 100644 --- a/src/treebuilders/project.h +++ b/src/treebuilders/project.h @@ -25,12 +25,119 @@ #pragma once +/** + * @file + * @brief Projection helpers to expand analytic/representable functions on + * multiresolution bases (function trees). + * + * @details + * These overloads build or refine an output @ref FunctionTree (or a vector of + * trees) so that the supplied function(s) are represented to within a target + * precision. The projection is adaptive: nodes are split where the estimated + * local error exceeds the tolerance, and coefficients are (re)computed only + * where needed. + * + * **Precision semantics** + * - If @p absPrec is `false` (default), @p prec is interpreted as a + * *relative* tolerance with respect to a suitable global/aggregate norm of + * the function (typical L²-relative stopping criterion). + * - If @p absPrec is `true`, @p prec is treated as an *absolute* tolerance + * for local/node-wise thresholds. + * + * **Iteration control** + * - @p maxIter limits the number of refinement passes. Use `-1` for the + * default behavior (iterate until the tolerance is reached or the internal + * refiner deems the grid converged). + * + * **Preconditions and side effects** + * - @p out is modified in-place (grid may be refined/coarsened; coefficients + * are (re)computed). + * - The @p RepresentableFunction or callable provided by the user must be + * well-defined on the domain of @p out’s @ref MultiResolutionAnalysis. + * + * @note Implementations typically perform, per node: + * 1) evaluate the input function on the node’s quadrature/stencil, + * 2) compute scaling/wavelet coefficients, + * 3) estimate local error and decide on further splitting, + * 4) stop when global/local criteria satisfy @p prec or @p maxIter is hit. + */ + #include "MRCPP/mrcpp_declarations.h" #include "trees/FunctionTreeVector.h" #include namespace mrcpp { -template void project(double prec, FunctionTree &out, RepresentableFunction &inp, int maxIter = -1, bool absPrec = false); -template void project(double prec, FunctionTree &out, std::function &r)> func, int maxIter = -1, bool absPrec = false); -template void project(double prec, FunctionTreeVector &out, std::vector &r)>> func, int maxIter = -1, bool absPrec = false); -} // namespace mrcpp + +/** + * @brief Project a @ref RepresentableFunction onto an output function tree. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type (e.g., `double`, `ComplexDouble`). + * + * @param prec Target tolerance (relative by default, see @p absPrec). + * @param out Destination @ref FunctionTree; refined and filled in-place. + * @param inp Analytic / representable function to project. + * @param maxIter Maximum refinement passes (`-1` = default/unlimited). + * @param absPrec If `true`, interpret @p prec as an absolute tolerance. + * + * @details + * Builds an adaptive multiresolution representation of @p inp in @p out. + * Existing content of @p out may be reused and further refined. + */ +template +void project(double prec, + FunctionTree &out, + RepresentableFunction &inp, + int maxIter = -1, + bool absPrec = false); + +/** + * @brief Project a user-supplied callable onto an output function tree. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type (e.g., `double`, `ComplexDouble`). + * + * @param prec Target tolerance (relative by default, see @p absPrec). + * @param out Destination @ref FunctionTree; refined and filled in-place. + * @param func Callable (e.g., lambda) mapping @ref Coord to @p T. + * @param maxIter Maximum refinement passes (`-1` = default/unlimited). + * @param absPrec If `true`, interpret @p prec as an absolute tolerance. + * + * @details + * Equivalent to the @ref RepresentableFunction overload, but accepts any + * `std::function&)>` (or compatible lambda) as the source. + */ +template +void project(double prec, + FunctionTree &out, + std::function &r)> func, + int maxIter = -1, + bool absPrec = false); + +/** + * @brief Project multiple callables into a vector of function trees. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type (e.g., `double`, `ComplexDouble`). + * + * @param prec Target tolerance (relative by default, see @p absPrec). + * @param out Destination @ref FunctionTreeVector; each entry refined / + * filled in-place. It is expected to have the same length as + * @p func (one tree per callable). + * @param func Collection of callables, each mapping @ref Coord to @p T. + * @param maxIter Maximum refinement passes (`-1` = default/unlimited). + * @param absPrec If `true`, interpret @p prec as an absolute tolerance. + * + * @details + * Applies the single-tree callable projection to each element, pairing + * `out[i]` with `func[i]`. All trees should share a compatible + * @ref MultiResolutionAnalysis. + */ +template +void project(double prec, + FunctionTreeVector &out, + std::vector &r)>> func, + int maxIter = -1, + bool absPrec = false); + +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/BandWidth.h b/src/trees/BandWidth.h index b4ee49e8d..aa8976a4c 100644 --- a/src/trees/BandWidth.h +++ b/src/trees/BandWidth.h @@ -23,41 +23,88 @@ * */ -/* - * BandWidth.h - */ - #pragma once #include #include +#include namespace mrcpp { +/** + * @class BandWidth + * @brief Container for band widths over depths and components + */ class BandWidth final { public: + /** + * @brief Constructor with storage for @p depth + 1 rows + * @param depth Maximum depth to allocate (inclusive) + * @details All entries are initialized to -1 (empty). + */ BandWidth(int depth = 0) : widths(depth + 1, 5) { this->clear(); } + + /** + * @brief Copy-constructor + * @param bw Instance to copy from + */ BandWidth(const BandWidth &bw) : widths(bw.widths) {} + + /// @brief Copy-assign from another instance. BandWidth &operator=(const BandWidth &bw); + /// @brief Reset all width entries to -1 (empty). void clear() { this->widths.setConstant(-1); } + /** + * @brief Check whether the row for @p depth is effectively empty + * @param depth Depth to test + * @return True if @p depth is out of range or the last values of the row is < 0 + */ bool isEmpty(int depth) const; + + /** + * @brief Highest valid depth index stored + * @return The maximum depth (rows - 1) + */ int getDepth() const { return this->widths.rows() - 1; } + + /** + * @brief Cached maximum width for a depth. + * @param depth Depth to query. + * @return Max width at @p depth, or -1 if @p depth is out of range. + */ int getMaxWidth(int depth) const { return (depth > getDepth()) ? -1 : this->widths(depth, 4); } + + /** + * @brief Component width accessor. + * @param depth Depth index + * @param index Component index + * @return Width for (@p depth, @p index), or -1 if @p depth is out of range. + */ int getWidth(int depth, int index) const { return (depth > getDepth()) ? -1 : this->widths(depth, index); } + + /** + * @brief Set component width and update the cached per-depth maximum. + * @param depth Depth to modify (0..getDepth()). + * @param index Component in {0, 1, 2, 3}. + * @param wd Non-negative band width. + */ void setWidth(int depth, int index, int wd); + /// @brief Formatted printing of the BandWidth instance. friend std::ostream &operator<<(std::ostream &o, const BandWidth &bw) { return bw.print(o); } private: - Eigen::MatrixXi widths; /// column 5 stores max width at depth + /// Matrix of widths; columns 0..3 = components, column 4 = cached max per depth. + Eigen::MatrixXi widths; + /// @brief Formatted printing of the BandWidth instance. std::ostream &print(std::ostream &o) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/BoundingBox.cpp b/src/trees/BoundingBox.cpp index ff86abf2e..e0bc63f91 100644 --- a/src/trees/BoundingBox.cpp +++ b/src/trees/BoundingBox.cpp @@ -32,18 +32,6 @@ namespace mrcpp { -/** @brief Constructor for BoundingBox object. - * - * @param[in] box: [lower, upper] bound in all dimensions. - * @returns New BoundingBox object. - * - * @details Creates a box with appropriate root scale and scaling - * factor to fit the given bounds, which applies to _all_ dimensions. - * Root scale is chosen such that the scaling factor becomes 1 < sfac < 2. - * - * Limitations: Box must be _either_ [0,L] _or_ [-L,L], with L a positive integer. - * This is the most general constructor, which will create a world with no periodic boundary conditions. - */ template BoundingBox::BoundingBox(std::array box) { if (box[1] < 1) { MSG_ERROR("Invalid upper bound: " << box[1]); @@ -79,23 +67,6 @@ template BoundingBox::BoundingBox(std::array box) { setDerivedParameters(); } -/** @brief Constructor for BoundingBox object. - * - * @param[in] n: Length scale, default 0. - * @param[in] l: Corner translation, default [0, 0, ...]. - * @param[in] nb: Number of boxes, default [1, 1, ...]. - * @param[in] sf: Scaling factor, default [1.0, 1.0, ...]. - * @param[in] pbc: Periodic boundary conditions, default false. - * @returns New BoundingBox object. - * - * @details Creates a box with given parameters. The parameter n defines the length scale, which, together with sf, determines the unit length of each side of the boxes by \f$ [2^{-n}]^D \f$. - * The parameter l defines the corner translation of the lower corner of the box relative to the world origin. - * The parameter nb defines the number of boxes in each dimension. - * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. - * The parameter pbc defines whether the world is periodic or not. In this constructor this value makes all dimensions periodic. - * This constructor is used for work in periodic systems. - * - */ template BoundingBox::BoundingBox(int n, const std::array &l, const std::array &nb, const std::array &sf, bool pbc) : cornerIndex(n, l) { @@ -105,19 +76,6 @@ BoundingBox::BoundingBox(int n, const std::array &l, const std::array setDerivedParameters(); } -/** @brief Constructor for BoundingBox object. - * - * @param[in] idx: index of the lower corner of the box. - * @param[in] nb: Number of boxes, default [1, 1, ...]. - * @param[in] sf: Scaling factor, default [1.0, 1.0, ...]. - * @returns New BoundingBox object. - * - * @details Creates a box with given parameters. - * The parameter idx defines the index of the lower corner of the box relative to the world origin. - * The parameter nb defines the number of boxes in each dimension. - * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. - * This constructor creates a world with no periodic boundary conditions. - */ template BoundingBox::BoundingBox(const NodeIndex &idx, const std::array &nb, const std::array &sf) : cornerIndex(idx) { @@ -127,16 +85,6 @@ BoundingBox::BoundingBox(const NodeIndex &idx, const std::array &n setDerivedParameters(); } -/** @brief Constructor for BoundingBox object. - * - * @param[in] sf: Scaling factor, default [1.0, 1.0, ...]. - * @param[in] pbc: Periodic boundary conditions, default true. - * - * @details Creates a box with given parameters. - * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. - * The parameter pbc defines whether the world is periodic or not. In this constructor this value makes all dimensions periodic. - * This construtor is used for work in periodic systems. - */ template BoundingBox::BoundingBox(const std::array &sf, bool pbc) : cornerIndex() { @@ -146,17 +94,6 @@ BoundingBox::BoundingBox(const std::array &sf, bool pbc) setDerivedParameters(); } -/** @brief Constructor for BoundingBox object. - * - * @param[in] sf: Scaling factor, default [1.0, 1.0, ...]. - * @param[in] pbc: Periodic boundary conditions, default true. - * @returns New BoundingBox object. - * - * @details Creates a box with given parameters. - * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. - * The parameter pbc defines whether the world is periodic or not. In this constructor this value makes specific dimensions periodic. - * This is used for work in periodic systems. - */ template BoundingBox::BoundingBox(const std::array &sf, std::array pbc) : cornerIndex() { @@ -166,14 +103,6 @@ BoundingBox::BoundingBox(const std::array &sf, std::array setDerivedParameters(); } -/** @brief Constructor for BoundingBox object. - * - * @param[in] box: Other BoundingBox object. - * @returns New BoundingBox object. - * - * @details Creates a box identical to the input box paramter. - * This constructor uses all parameters from the other BoundingBox to create a new one. - */ template BoundingBox::BoundingBox(const BoundingBox &box) : cornerIndex(box.cornerIndex) { @@ -183,13 +112,6 @@ BoundingBox::BoundingBox(const BoundingBox &box) setDerivedParameters(); } -/** @brief Assignment operator overload for BoundingBox object. - * - * @returns New BoundingBox object. - * @param[in] box: Other BoundingBox object. - * - * @details Allocates all parameters in this BoundingBox to be that of the other BoundingBox. - */ template BoundingBox &BoundingBox::operator=(const BoundingBox &box) { if (&box != this) { this->cornerIndex = box.cornerIndex; @@ -201,13 +123,6 @@ template BoundingBox &BoundingBox::operator=(const BoundingBox return *this; } -/** @brief Sets the number of boxes in each dimension. - * - * @param[in] nb: Number of boxes, default [1, 1, ...]. - * - * @details For each dimentions D it sets the number of boxes in that dimension in the nBoxes array and the total amount of boxes in the world in the totBoxes variable. - * This just sets counters for the number of boxes in each dimension. - */ template void BoundingBox::setNBoxes(const std::array &nb) { this->totBoxes = 1; for (int d = 0; d < D; d++) { @@ -216,16 +131,6 @@ template void BoundingBox::setNBoxes(const std::array &nb) { } } -/** @brief Computes and sets all derived parameters. - * - * @details For all parameters that have been initialized in the constructor, - * this function will compute the necessary derived parameters in each dimension. - * The unit length is set to \a sfac \f$ \cdot 2^{-n} \f$ where \a sfac is the scaling factor (default 1.0) and n is the length scale. - * The unit length is the base unit which is used for the size and positioning of the boxes around origin. - * The boxLength is the total length of the box in each dimension, which is the unit length times the number of boxes in that dimension. - * The lowerBound is computed from the index of the lower corner of the box and the unit length. - * The upperBound is computed to be the lower corner plus the total length in that dimension. - */ template void BoundingBox::setDerivedParameters() { assert(this->totBoxes > 0); const NodeIndex &cIdx = this->cornerIndex; @@ -238,12 +143,6 @@ template void BoundingBox::setDerivedParameters() { } } -/** @brief Sets the number of boxes in each dimension. - * - * @param[in] sf: Scaling factor, default [1.0, 1.0, ...]. - * - * @details This checks that the sf variable has sane values before assigning it to the member variable scalingFactor. - */ template void BoundingBox::setScalingFactors(const std::array &sf) { assert(this->totBoxes > 0); for (auto &x : sf) @@ -252,36 +151,14 @@ template void BoundingBox::setScalingFactors(const std::array{}) scalingFactor.fill(1.0); } -/** @brief Sets which dimensions are periodic. - * - * @param[in] pbc: Boolean which is used to set all dimension to either periodic or not - * - * @details this fills in the periodic array with the values from the input. - */ template void BoundingBox::setPeriodic(bool pbc) { this->periodic.fill(pbc); } -/** @brief Sets which dimensions are periodic. - * - * @param[in] pbs: D-dimensional array holding boolean values for each dimension. - * - * @details This fills in the periodic array with the values from the input array. - */ template void BoundingBox::setPeriodic(std::array pbc) { this->periodic = pbc; } -/** @brief Fetches a NodeIndex object from a given box index. - * - * @param[in] bIdx: Box index, the index of the box we want to fetch the cell index from. - * @returns The NodeIndex object of the index given as it is in the Multiresolutoin analysis. - * - * @details During the adaptive refinement, each original box will contain an increasing number of smaller cells, - * each of which will be part of a specific node in the tree. These cells are divided adaptivelly. This function returns the NodeIndex - * object of the cell at the lower back corner of the box object indexed by bIdx. - * Specialized for D=1 below - */ template NodeIndex BoundingBox::getNodeIndex(int bIdx) const { assert(bIdx >= 0 and bIdx <= this->totBoxes); std::array l; @@ -300,13 +177,6 @@ template NodeIndex BoundingBox::getNodeIndex(int bIdx) const { return NodeIndex(getScale(), l); } -/** @brief Fetches the index of a box from a given coordinate. - * - * @param[in] r: D-dimensional array representaing a coordinate in the simulation box - * @returns The index value of the boxes in the position given as it is in the generated world. - * - * @details Specialized for D=1 below - */ template int BoundingBox::getBoxIndex(Coord r) const { if (this->isPeriodic()) { periodic::coord_manipulation(r, this->getPeriodic()); } @@ -334,15 +204,6 @@ template int BoundingBox::getBoxIndex(Coord r) const { return bIdx; } -/** @brief Fetches the index of a box from a given NodeIndex. - * - * @param[in] nIdx: NodeIndex object, representing the node and its index in the adaptive tree. - * @returns The index value of the boxes in which the NodeIndex object is mapping to. - * - * @details During the multiresolution analysis the boxes will be divided into smaller boxes, which means that each individual box will be part of a specific node in the tree. - * Each node will get its own index value, but will still be part of one of the original boxes of the world. - * Specialized for D=1 below - */ template int BoundingBox::getBoxIndex(NodeIndex nIdx) const { if (this->isPeriodic()) { periodic::index_manipulation(nIdx, this->getPeriodic()); }; @@ -366,13 +227,6 @@ template int BoundingBox::getBoxIndex(NodeIndex nIdx) const { return bIdx; } -/** @brief Prints information about the BoundinBox object. - * - * @param[in] o: Output stream variable which will be used to print the information - * @returns The output stream variable. - * - * @details A function which prints information about the BoundingBox object. - */ template std::ostream &BoundingBox::print(std::ostream &o) const { int oldprec = Printer::setPrecision(5); o << std::fixed; @@ -401,15 +255,6 @@ template std::ostream &BoundingBox::print(std::ostream &o) const { return o; } -/** @brief Fetches a NodeIndex object from a given box index, specialiced for 1-D. - * - * @param[in] bIdx: Box index, the index of the box we want to fetch the cell index from. - * @returns The NodeIndex object of the index given as it is in the Multiresolutoin analysis. - * - * @details During the adaptive refinement, each original box will contain an increasing number of smaller cells, - * each of which will be part of a specific node in the tree. These cells are divided adaptivelly. This function returns the NodeIndex - * object of the cell at the lower back corner of the box object indexed by bIdx. - */ template <> NodeIndex<1> BoundingBox<1>::getNodeIndex(int bIdx) const { const NodeIndex<1> &cIdx = this->cornerIndex; int n = cIdx.getScale(); @@ -417,11 +262,6 @@ template <> NodeIndex<1> BoundingBox<1>::getNodeIndex(int bIdx) const { return NodeIndex<1>(n, {l}); } -/** @brief Fetches the index of a box from a given coordinate, specialized for 1D. - * - * @param[in] r: 1-dimensional array representaing a coordinate in the simulation box - * @returns The index value of the boxes in the position given as it is in the generated world. - */ template <> int BoundingBox<1>::getBoxIndex(Coord<1> r) const { if (this->isPeriodic()) { periodic::coord_manipulation<1>(r, this->getPeriodic()); } @@ -435,14 +275,6 @@ template <> int BoundingBox<1>::getBoxIndex(Coord<1> r) const { return static_cast(iint); } -/** @brief Fetches the index of a box from a given NodeIndex specialized for 1-D. - * - * @param[in] nIdx: NodeIndex object, representing the node and its index in the adaptive tree. - * @returns The index value of the boxes in which the NodeIndex object is mapping to. - * - * @details During the multiresolution analysis the boxes will be divided into smaller boxes, which means that each individual box will be part of a specific node in the tree. - * Each node will get its own index value, but will still be part of one of the original boxes of the world. - */ template <> int BoundingBox<1>::getBoxIndex(NodeIndex<1> nIdx) const { if (this->isPeriodic()) { periodic::index_manipulation<1>(nIdx, this->getPeriodic()); }; diff --git a/src/trees/BoundingBox.h b/src/trees/BoundingBox.h index 28ad9c052..84cfa83d8 100644 --- a/src/trees/BoundingBox.h +++ b/src/trees/BoundingBox.h @@ -27,6 +27,7 @@ #include #include +#include #include "NodeIndex.h" #include "utils/details.h" @@ -35,7 +36,9 @@ namespace mrcpp { -/** @class BoundingBox +/** + * @class BoundingBox + * @tparam D Spatial dimension (1, 2, or 3) * * @brief Class defining the computational domain * @@ -47,67 +50,273 @@ namespace mrcpp { * Box translations relative to the world origin _must_ be an integer * multiple of the given scale size \f$ 2^{-n} \f$. */ - template class BoundingBox { public: + /** + * @brief Constructor for BoundingBox object + * @param box [lower, upper] bound in all dimensions + * + * @details Creates a box with appropriate root scale and scaling + * factor to fit the given bounds, which applies to _all_ dimensions. + * Root scale is chosen such that the scaling factor becomes 1 < sfac < 2. + * + * @note Limitations: Box must be _either_ [0,L] _or_ [-L,L], with L a positive integer. + * This is the most general constructor, which will create a world with no periodic boundary conditions. + */ explicit BoundingBox(std::array box); - explicit BoundingBox(int n = 0, const std::array &l = {}, const std::array &nb = {}, const std::array &sf = {}, bool pbc = false); - explicit BoundingBox(const NodeIndex &idx, const std::array &nb = {}, const std::array &sf = {}); + + /** + * @brief Constructor for BoundingBox object + * @param n Length scale, default 0 + * @param l Corner translation, default [0, 0, ...] + * @param nb Number of boxes, default [1, 1, ...] + * @param sf Scaling factor, default [1.0, 1.0, ...] + * @param pbc Periodic boundary conditions, default false + * + * @details Creates a box with given parameters. The parameter n defines the length scale, which, together with sf, determines the unit length of each side of the boxes by \f$ [2^{-n}]^D \f$. + * The parameter l defines the corner translation of the lower corner of the box relative to the world origin. + * The parameter nb defines the number of boxes in each dimension. + * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. + * The parameter pbc defines whether the world is periodic or not. In this constructor this value makes all dimensions periodic. + * This constructor is used for work in periodic systems. + */ + explicit BoundingBox(int n = 0, + const std::array &l = {}, + const std::array &nb = {}, + const std::array &sf = {}, + bool pbc = false); + + /** + * @brief Constructor for BoundingBox objec + * @param idx index of the lower corner of the box + * @param nb Number of boxes, default [1, 1, ...] + * @param sf Scaling factor, default [1.0, 1.0, ...] + * + * @details Creates a box with given parameters + * The parameter idx defines the index of the lower corner of the box relative to the world origin. + * The parameter nb defines the number of boxes in each dimension. + * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. + * This constructor creates a world with no periodic boundary conditions. + */ + explicit BoundingBox(const NodeIndex &idx, + const std::array &nb = {}, + const std::array &sf = {}); + + /** + * @brief Constructor for BoundingBox object + * @param sf Scaling factor, default [1.0, 1.0, ...] + * @param pbc Periodic boundary conditions, default true + * + * @details Creates a box with given parameters. + * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. + * The parameter pbc defines whether the world is periodic or not. In this constructor this value makes all dimensions periodic. + * This construtor is used for work in periodic systems. + */ explicit BoundingBox(const std::array &sf, bool pbc = true); + + /** + * @brief Constructor for BoundingBox object + * @param sf Scaling factor, default [1.0, 1.0, ...] + * @param pbc Periodic boundary conditions, default true + * + * @details Creates a box with given parameters. + * The parameter sf defines the scaling factor, which determines the box translations around the origin, i.e. the amount of boxes around origin. + * The parameter pbc defines whether the world is periodic or not. In this constructor this value makes specific dimensions periodic. + * This is used for work in periodic systems. + */ BoundingBox(const std::array &sf, std::array pbc); + + /** + * @brief Copy constructor for BoundingBox object + * @param box Other BoundingBox object + * + * @details Creates a box identical to the input box paramter. + * This constructor uses all parameters from the other BoundingBox to create a new one. + */ BoundingBox(const BoundingBox &box); + + /** + * @brief Assignment operator overload for BoundingBox object + * @param box Other BoundingBox object + * @details Allocates all parameters in this BoundingBox to be that of the other BoundingBox. + * @return New BoundingBox object + */ BoundingBox &operator=(const BoundingBox &box); + virtual ~BoundingBox() = default; + /** + * @brief Equality: same corner index and per-dimension box counts + * @param box Other BoundingBox object + * @return True if equal, false otherwise + */ inline bool operator==(const BoundingBox &box) const; + /** + * @brief Inequality: differs in corner index or in any per-dimension box count + * @param box Other BoundingBox object + * @return True if not equal, false otherwise + */ inline bool operator!=(const BoundingBox &box) const; + /** + * @brief Fetches a NodeIndex object from a given box index + * @param bIdx The index of the box we want to fetch the cell index from + * + * @details During the adaptive refinement, each original box will contain an increasing number of smaller cells, + * each of which will be part of a specific node in the tree. These cells are divided adaptivelly. This function returns the NodeIndex + * object of the cell at the lower back corner of the box object indexed by bIdx. + * + * @return The NodeIndex object of the index given as it is in the Multiresolutoin analysis + * @note Specialized for D=1 below + */ NodeIndex getNodeIndex(int bIdx) const; + /** + * @brief Fetches the index of a box from a given coordinate + * @param r D-dimensional array representaing a coordinate in the simulation box + * @return The index value of the boxes in the position given as it is in the generated world + * @note Specialized for D=1 below + */ int getBoxIndex(Coord r) const; + + /** + * @brief Fetches the index of a box from a given NodeIndex + * @param nIdx NodeIndex object, representing the node and its index in the adaptive tree + * + * @details During the multiresolution analysis the boxes will be divided into smaller boxes, which means that each individual box will be part of a specific node in the tree. + * Each node will get its own index value, but will still be part of one of the original boxes of the world. + * + * @return The index value of the boxes in which the NodeIndex object is mapping to + * @note Specialized for D=1 below + */ int getBoxIndex(NodeIndex nIdx) const; - int size() const { return this->totBoxes; } + int size() const { return this->totBoxes; } ///< @return Total number of boxes + /** + * @param d Dimension index + * @return Number of boxes along dimension @p d + */ int size(int d) const { return this->nBoxes[d]; } - int getScale() const { return this->cornerIndex.getScale(); } + int getScale() const { return this->cornerIndex.getScale(); } ///< @return Root scale \(n\) + + /** + * @param d Dimension index + * @return Scaling factor to scale this box by along dimension @p d + */ double getScalingFactor(int d) const { return this->scalingFactor[d]; } + /** + * @param d Dimension index + * @return Unit length along dimension @p d + */ double getUnitLength(int d) const { return this->unitLengths[d]; } + /** + * @param d Dimension index + * @return Box length along dimension @p d + */ double getBoxLength(int d) const { return this->boxLengths[d]; } + /** + * @param d Dimension index + * @return Lower bound of this box coordinates along dimension @p d + */ double getLowerBound(int d) const { return this->lowerBounds[d]; } + /** + * @param d Dimension index + * @return Upper bound of this box coordinates along dimension @p d + */ double getUpperBound(int d) const { return this->upperBounds[d]; } - bool isPeriodic() const { return details::are_any(this->periodic, true); } - const std::array &getPeriodic() const { return this->periodic; } - const Coord &getUnitLengths() const { return this->unitLengths; } - const Coord &getBoxLengths() const { return this->boxLengths; } - const Coord &getLowerBounds() const { return this->lowerBounds; } - const Coord &getUpperBounds() const { return this->upperBounds; } - const NodeIndex &getCornerIndex() const { return this->cornerIndex; } - const std::array &getScalingFactors() const { return this->scalingFactor; } + + bool isPeriodic() const { return details::are_any(this->periodic, true); } ///< @return Is any dimension periodic? + const std::array &getPeriodic() const { return this->periodic; } ///< @return Periodicity flags per dimension + + const Coord &getUnitLengths() const { return this->unitLengths; } ///< @return The unit lengths + const Coord &getBoxLengths() const { return this->boxLengths; } ///< @return The box lengths + const Coord &getLowerBounds() const { return this->lowerBounds; } ///< @return The lower bounds of the coordinates of this box + const Coord &getUpperBounds() const { return this->upperBounds; } ///< @return The upper bounds of the coordinates of this box + const NodeIndex &getCornerIndex() const { return this->cornerIndex; } ///< @return The corner index + const std::array &getScalingFactors() const { return this->scalingFactor; } ///< @return The scaling factors to scale this box by + + /** + * @brief Stream output operator + * @param o Output stream + * @param box BoundingBox object + * @return Reference to output stream + */ friend std::ostream &operator<<(std::ostream &o, const BoundingBox &box) { return box.print(o); } protected: // Fundamental parameters - NodeIndex cornerIndex; ///< Index defining the lower corner of the box - std::array nBoxes{}; ///< Number of boxes in each dim, last entry total - std::array scalingFactor{}; - std::array periodic{}; ///< Sets which dimension has Periodic boundary conditions. + NodeIndex cornerIndex; ///< Index defining the lower corner of the box + std::array nBoxes{}; ///< Number of boxes in each dim, last entry total + std::array scalingFactor{}; ///< Scaling factors to scale this box by, per dimension + std::array periodic{}; ///< Sets which dimension has Periodic boundary conditions // Derived parameters - int totBoxes{1}; - Coord unitLengths; ///< 1/2^initialScale - Coord boxLengths; ///< Total length (unitLength times nBoxes) - Coord lowerBounds; ///< Box lower bound (not real) - Coord upperBounds; ///< Box upper bound (not real) + int totBoxes{1}; ///< Number of total boxes + Coord unitLengths; ///< 1/2^initialScale + Coord boxLengths; ///< Total length (unitLength times nBoxes) + Coord lowerBounds; ///< Box lower bound (not real) + Coord upperBounds; ///< Box upper bound (not real) + + /** + * @brief Sets the number of boxes in each dimension + * @param nb Number of boxes, default [1, 1, ...] + * + * @details For each dimentions D it sets the number of boxes in that dimension in the nBoxes array and the total amount of boxes in the world in the totBoxes variable. + * This just sets counters for the number of boxes in each dimension. + */ void setNBoxes(const std::array &nb = {}); + + /** + * @brief Computes and sets all derived parameters + * + * @details For all parameters that have been initialized in the constructor, + * this function will compute the necessary derived parameters in each dimension. + * The unit length is set to \a sfac \f$ \cdot 2^{-n} \f$ where \a sfac is the scaling factor (default 1.0) and n is the length scale. + * The unit length is the base unit which is used for the size and positioning of the boxes around origin. + * The boxLength is the total length of the box in each dimension, which is the unit length times the number of boxes in that dimension. + * The lowerBound is computed from the index of the lower corner of the box and the unit length. + * The upperBound is computed to be the lower corner plus the total length in that dimension. + */ void setDerivedParameters(); + + /** + * @brief Sets the scaling factors in each dimension + * @param sf Scaling factor, default [1.0, 1.0, ...] + * + * @details This checks that the sf variable has sane values before assigning it to the member variable scalingFactor. + */ void setScalingFactors(const std::array &sf); + + /** + * @brief Sets whether all dimensions are periodic + * @param pbc Boolean which is used to set all dimension to either periodic or not + * + * @details This fills in the periodic array with the values from the input. + */ void setPeriodic(std::array periodic); + + /** + * @brief Sets which dimensions are periodic + * @param pbc D-dimensional array holding boolean values for each dimension + * + * @details This fills in the periodic array with the values from the input array. + */ void setPeriodic(bool periodic); + /** + * @brief Prints information about the BoundinBox object + * @param o Output stream variable which will be used to print the information + * + * @details A function which prints information about the BoundingBox object. + * + * @return The output stream variable + */ std::ostream &print(std::ostream &o) const; }; +// Inline comparisons + template bool BoundingBox::operator==(const BoundingBox &box) const { if (getCornerIndex() != box.getCornerIndex()) return false; for (int d = 0; d < D; d++) { @@ -124,4 +333,4 @@ template bool BoundingBox::operator!=(const BoundingBox &box) cons return false; } -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/CornerOperatorTree.cpp b/src/trees/CornerOperatorTree.cpp index 6de235dd3..9d4fa817c 100644 --- a/src/trees/CornerOperatorTree.cpp +++ b/src/trees/CornerOperatorTree.cpp @@ -32,15 +32,6 @@ using namespace Eigen; namespace mrcpp { -/** @brief Calculates band widths of the non-standard form matrices. - * - * @param[in] prec: Precision used for thresholding - * - * @details It is starting from \f$ l = 2^n \f$ and updating the band width value each time we encounter - * considerable value while keeping decreasing down to \f$ l = 0 \f$, that stands for the distance to the diagonal. - * This procedure is repeated for each matrix \f$ A, B \f$ and \f$ C \f$. - * - */ void CornerOperatorTree::calcBandWidth(double prec) { if (this->bandWidth == nullptr) clearBandWidth(); this->bandWidth = new BandWidth(getDepth()); @@ -71,15 +62,6 @@ void CornerOperatorTree::calcBandWidth(double prec) { println(100, "\nOperator BandWidth" << *this->bandWidth); } -/** @brief Checks if the distance to diagonal is lesser than the operator band width. - * - * @param[in] oTransl: distance to diagonal - * @param[in] o_depth: scaling order - * @param[in] idx: index corresponding to one of the matrices \f$ A, B, C \f$ or \f$ T \f$. - * - * @returns True if \b oTransl is outside of the corner band (close to diagonal) and False otherwise. - * - */ bool CornerOperatorTree::isOutsideBand(int oTransl, int o_depth, int idx) { return abs(oTransl) < this->bandWidth->getWidth(o_depth, idx); } diff --git a/src/trees/CornerOperatorTree.h b/src/trees/CornerOperatorTree.h index 0ac2ad5bd..545fd6813 100644 --- a/src/trees/CornerOperatorTree.h +++ b/src/trees/CornerOperatorTree.h @@ -35,17 +35,35 @@ namespace mrcpp { * * @details Tree structure of operators having corner matrices * \f$ A, B, C \f$ in the non-standard form. - * */ class CornerOperatorTree final : public OperatorTree { public: - using OperatorTree::OperatorTree; // Import the single valid constructor from OperatorTree + /// Inherit the valid constructorfrom OperatorTree. + using OperatorTree::OperatorTree; + CornerOperatorTree(const CornerOperatorTree &tree) = delete; CornerOperatorTree &operator=(const CornerOperatorTree &tree) = delete; ~CornerOperatorTree() override = default; + /** + * @brief Calculates band widths of the non-standard form matrices + * @param prec Precision used for thresholding + * + * @details It is starting from \f$ l = 2^n \f$ and updating the band width value each time we encounter + * considerable value while keeping decreasing down to \f$ l = 0 \f$, that stands for the distance to the diagonal. + * This procedure is repeated for each matrix \f$ A, B \f$ and \f$ C \f$. + */ void calcBandWidth(double prec = -1.0) override; + + /** + * @brief Checks if the distance to diagonal is lesser than the operator band width + * @param oTransl distance to diagonal + * @param o_depth scaling order + * @param idx index corresponding to one of the matrices \f$ A, B, C \f$ or \f$ T \f$ + * + * @returns True if @p oTransl is outside of the corner band (close to diagonal) and False otherwise. + */ bool isOutsideBand(int oTransl, int o_depth, int idx) override; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/FunctionNode.cpp b/src/trees/FunctionNode.cpp index ff23fb394..0233fd1fb 100644 --- a/src/trees/FunctionNode.cpp +++ b/src/trees/FunctionNode.cpp @@ -42,8 +42,6 @@ using namespace Eigen; namespace mrcpp { -/** Function evaluation. - * Evaluate all polynomials defined on the node. */ template T FunctionNode::evalf(Coord r) { if (not this->hasCoefs()) MSG_ERROR("Evaluating node without coefs"); @@ -87,11 +85,6 @@ template T FunctionNode::evalScaling(const Coord &r return two_n * result; } -/** Function integration. - * - * Wrapper for function integration, that requires different methods depending - * on scaling type. Integrates the function represented on the node on the - * full support of the node. */ template T FunctionNode::integrate() const { if (not this->hasCoefs()) { return 0.0; } switch (this->getScalingType()) { @@ -106,15 +99,6 @@ template T FunctionNode::integrate() const { } } -/** Function integration, Legendre basis. - * - * Integrates the function represented on the node on the full support of the - * node. The Legendre basis is particularly easy to integrate, as the work is - * already done when calculating its coefficients. The coefficients of the - * node is defined as the projection integral - * s_i = int f(x)phi_i(x)dx - * and since the first Legendre function is the constant 1, the first - * coefficient is simply the integral of f(x). */ template T FunctionNode::integrateLegendre() const { double n = (D * this->getScale()) / 2.0; double two_n = std::pow(2.0, -n); @@ -234,10 +218,9 @@ template void FunctionNode::getValues(Matrix void FunctionNode::getAbsCoefs(T *absCoefs) { @@ -381,9 +364,6 @@ template void FunctionNode::dealloc() { } } -/** Update the coefficients of the node by a mw transform of the scaling - * coefficients of the children. Option to overwrite or add up existing - * coefficients. Specialized for D=3 below. */ template void FunctionNode::reCompress() { MWNode::reCompress(); } @@ -408,14 +388,6 @@ template <> void FunctionNode<3>::reCompress() { } } -/** Inner product of the functions represented by the scaling basis of the nodes. - * - * Integrates the product of the functions represented by the scaling basis on - * the node on the full support of the nodes. The scaling basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { assert(bra.hasCoefs()); assert(ket.hasCoefs()); @@ -433,14 +405,6 @@ template double dot_scaling(const FunctionNode &bra, const Fu #endif } -/** Inner product of the functions represented by the scaling basis of the nodes. - * - * Integrates the product of the functions represented by the scaling basis on - * the node on the full support of the nodes. The scaling basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { assert(bra.hasCoefs()); assert(ket.hasCoefs()); @@ -467,14 +431,6 @@ template ComplexDouble dot_scaling(const FunctionNode return result; } -/** Inner product of the functions represented by the scaling basis of the nodes. - * - * Integrates the product of the functions represented by the scaling basis on - * the node on the full support of the nodes. The scaling basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { assert(bra.hasCoefs()); assert(ket.hasCoefs()); @@ -493,14 +449,6 @@ template ComplexDouble dot_scaling(const FunctionNode return result; } -/** Inner product of the functions represented by the scaling basis of the nodes. - * - * Integrates the product of the functions represented by the scaling basis on - * the node on the full support of the nodes. The scaling basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket) { assert(bra.hasCoefs()); assert(ket.hasCoefs()); @@ -519,14 +467,6 @@ template ComplexDouble dot_scaling(const FunctionNode &bra, c return result; } -/** Inner product of the functions represented by the wavelet basis of the nodes. - * - * Integrates the product of the functions represented by the wavelet basis on - * the node on the full support of the nodes. The wavelet basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { if (bra.isGenNode() or ket.isGenNode()) return 0.0; @@ -547,14 +487,6 @@ template double dot_wavelet(const FunctionNode &bra, const Fu #endif } -/** Inner product of the functions represented by the wavelet basis of the nodes. - * - * Integrates the product of the functions represented by the wavelet basis on - * the node on the full support of the nodes. The wavelet basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { if (bra.isGenNode() or ket.isGenNode()) return 0.0; @@ -583,14 +515,6 @@ template ComplexDouble dot_wavelet(const FunctionNode return result; } -/** Inner product of the functions represented by the wavelet basis of the nodes. - * - * Integrates the product of the functions represented by the wavelet basis on - * the node on the full support of the nodes. The wavelet basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { if (bra.isGenNode() or ket.isGenNode()) return 0.0; @@ -611,14 +535,6 @@ template ComplexDouble dot_wavelet(const FunctionNode return result; } -/** Inner product of the functions represented by the wavelet basis of the nodes. - * - * Integrates the product of the functions represented by the wavelet basis on - * the node on the full support of the nodes. The wavelet basis is fully - * orthonormal, and the inner product is simply the dot product of the - * coefficient vectors. Assumes the nodes have identical support. - * NB: will take conjugate of bra in case of complex values. - */ template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket) { if (bra.isGenNode() or ket.isGenNode()) return 0.0; diff --git a/src/trees/FunctionNode.h b/src/trees/FunctionNode.h index d1bfaaa31..b3d02f213 100644 --- a/src/trees/FunctionNode.h +++ b/src/trees/FunctionNode.h @@ -32,63 +32,365 @@ namespace mrcpp { +/** + * @class FunctionNode + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @brief Node of a @ref FunctionTree that stores coefficients and implements + * function-specific operations + * + * @details A FunctionNode is a concrete @ref MWNode specialized for function + * representations. It holds scaling and wavelet coefficients, provides + * allocation and refinement helpers, and exposes utilities for evaluation, + * coefficient access and basic per-node operations such as integration and + * local dot products. + * + * @note FunctionNodes are managed by @ref FunctionTree and @ref NodeAllocator. + * Most users should not construct FunctionNode directly. + */ template class FunctionNode final : public MWNode { public: - FunctionTree &getFuncTree() { return static_cast &>(*this->tree); } - FunctionNode &getFuncParent() { return static_cast &>(*this->parent); } + FunctionTree &getFuncTree() { return static_cast &>(*this->tree); } ///< @return A reference to the tree this node belongs to, cast to a non-const @ref FunctionTree + FunctionNode &getFuncParent() { return static_cast &>(*this->parent); } ///< @return A reference to the parent of this node, cast to a non-const @ref FunctionNode + + /** + * @param i The index of the child + * @return A reference to the child at the given index, cast to a non-const @ref FunctionNode + */ FunctionNode &getFuncChild(int i) { return static_cast &>(*this->children[i]); } - const FunctionTree &getFuncTree() const { return static_cast &>(*this->tree); } - const FunctionNode &getFuncParent() const { return static_cast &>(*this->parent); } + const FunctionTree &getFuncTree() const { return static_cast &>(*this->tree); } ///< @return A reference to the tree this node belongs to, cast to a const @ref FunctionTree + const FunctionNode &getFuncParent() const { return static_cast &>(*this->parent); } ///< @return A reference to the parent of this node, cast to a const @ref FunctionNode + + /** + * @param i The index of the child + * @return A reference to the child at the given index, cast to a const @ref FunctionNode + */ const FunctionNode &getFuncChild(int i) const { return static_cast &>(*this->children[i]); } + /** + * @brief Create child nodes for this node and abort if it is already a + * branch node (see @ref FlagBranchNode) + * @param coefs If `true`, allocate coefficient chunk for child nodes + * + * @details This routine allocates child nodes via the tree's @ref NodeAllocator. + * The tree's node counter is incremented by @ref MWTree::incrementNodeCount. + * Finally, this node is marked as both a branch node (see @ref FlagBranchNode) + * and a non-end node (see @ref FlagEndNode). + */ void createChildren(bool coefs) override; + + /** + * @brief Generates child nodes with the @ref FlagGenNode bit flag set, and + * abort if this node is already a branch node (see @ref FlagBranchNode) + * + * @details This routine creates general or redundant child nodes for + * temporary use. As a result, the tree's node counter remains unchanged, + * and this node is marked only as a branch node (see @ref FlagBranchNode). + */ void genChildren() override; + + /** + * @brief Generate a parent for this node and abort if it already has one + * + * @details This routine allocates the parent node via the tree's @ref + * NodeAllocator and links this node into the parent children array. The + * tree's node counter is incremented by @ref MWTree::incrementNodeCount. + */ void genParent() override; + + /** + * @brief Recursive deallocation of children and all their descendants. + * + * @details This routine uses base class function @ref MWTree::deleteChildren + * for the deallocation. Finally, this node is marked as an end node (see + * @ref FlagEndNode). + */ void deleteChildren() override; + /** + * @brief Function integration + * @return The integral of type @p T + * + * @details Wrapper for function integration, that requires different + * methods depending on scaling type @ref FuncType. Integrates the function + * represented on the node on the full support of the node. This routine + * will return zero if the node does not have coeffcients, and abort if the + * node has invalid type of scaling basis (Legendre or Interpol; see + * MRCPP/constants.h). + */ T integrate() const; + /** + * @brief Set values from a vector to the node's coefficients, and update + * metadata of the node + * @param vec Column vector + * + * @details This routine calls @ref MWTree::setCoefBlock to set + * values from the vector, and update metadata of the node by caling + * @ref MWTree::cvTransform, and @ref MWTree::mwTransform. The + * node is marked as having coefficients, and its square norm and component + * norms are also computed by @ref MWTree::calcNorms. + */ void setValues(const Eigen::Matrix &vec); + + /** + * @brief Extract the node's coefficients into a vector + * @param[out] vec Column vector resized to the number of coefficients of + * the node, see @ref MWTree::getNCoefs + */ void getValues(Eigen::Matrix &vec); + + /** + * @brief Get coefficients corresponding to absolute value of function + * @param[out] absCoefs Coefficients of type @p T + * + * @note The absolute value of function is computed using std::norm(). + */ void getAbsCoefs(T *absCoefs); friend class FunctionTree; friend class NodeAllocator; protected: + /** + * @brief FunctionNode constructor + * + * @note This routine uses @ref MWNode default constructor. + */ FunctionNode() : MWNode() {} - FunctionNode(MWTree *tree, int rIdx) + + /** + * @brief FunctionNode constructor + * @param[in] tree The MWTree the root node belongs to + * @param[in] rIdx The integer specifying the corresponding root node + * + * @details Constructor for root nodes. It actually calls @ref MWNode + * contructor MWNode(tree, rIdx). + */ + explicit FunctionNode(MWTree *tree, int rIdx) : MWNode(tree, rIdx) {} - FunctionNode(MWNode *parent, int cIdx) - : MWNode(parent, cIdx) {} + + /** + * @brief FunctionNode constructor + * @param[in] tree The MWTree the root node belongs to + * @param[in] idx The NodeIndex defining scale and translation of the node + * + * @details Constructor for an empty node, which calls @ref MWNode + * contructor MWNode(tree, idx). + */ FunctionNode(MWTree *tree, const NodeIndex &idx) : MWNode(tree, idx) {} + + /** + * @brief FunctionNode constructor + * @param[in] parent Parent node + * @param[in] cIdx Child index of the current node + * + * @details Constructor for leaf nodes. It invokes @ref MWNode constructor + * MWNode(parent, cIdx). + */ + FunctionNode(MWNode *parent, int cIdx) + : MWNode(parent, cIdx) {} + FunctionNode(const FunctionNode &node) = delete; FunctionNode &operator=(const FunctionNode &node) = delete; + + /// @brief Default destructor of FunctionNode ~FunctionNode() = default; + /** + * @brief Evaluate function at a point + * @param[in,out] r The point in space + * @return The evaluated result of type @p T + * + * @details Evaluate all polynomials defined on the node found by + * @ref MWTree::getChildIndex. Trigger an error if the node does not + * have coefficients. For periodic systems, the coordinate r will be mapped + * to the [-1, 1] periodic cell if it is outside the unit cell, see + * @ref periodic::coord_manipulation. + */ T evalf(Coord r); + + /** + * @brief Function evaluation + * @param[in] r Coordinate where the evaluation is performed at + * @return The evaluated result of type @p T + */ T evalScaling(const Coord &r) const; + /// @brief Deallocate the node from the tree void dealloc() override; + + /** + * @brief Update the coefficients of the node by an MW transform of the + * scaling coefficients of the children + * @note There is a specialization for @p D = 3, + * see @ref FunctionNode<3>::reCompress. + */ void reCompress() override; + /** + * @brief Function integration, Legendre basis + * @return The integral of type @p T + * + * @details Integrate the function represented on the node on the full + * support of the node. The Legendre basis is particularly easy to + * integrate, as the work is already done when calculating its + * coefficients. The coefficients of the node is defined as the projection + * integral \f$ s_i = \int f(x)\phi_i(x)\mathrm{d}x \f$ and since the first + * Legendre function is the constant 1, the first coefficient is simply the + * integral of \f$ f(x) \f$. + */ T integrateLegendre() const; + + /** + * @brief Function integration, Interpolating basis + * @return The integral of type @p T + * + * @details Integrate the function represented on the node on the full + * support of the node. A bit more involved than in the Legendre basis, as + * is requires some coupling of quadrature weights. + */ T integrateInterpolating() const; + + /** + * @brief Function integration, Interpolating basis + * @return The integral of type @p T + * + * @details Integrate the function represented on the node on the full + * support of the node. A bit more involved than in the Legendre basis, as + * is requires some coupling of quadrature weights. + */ + //FIXME This routine has exactly the same documentation comment as + // integrateInterpolating() in FunctionNode.cpp. T integrateValues() const; }; + +/** + * @brief Inner product of the functions represented by the scaling basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the scaling + * basis on the node on the full support of the nodes. The scaling basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ template double dot_scaling(const FunctionNode &bra, const FunctionNode &ket); -template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); +/** + * @brief Inner product of the functions represented by the scaling basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the scaling + * basis on the node on the full support of the nodes. The scaling basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket); -template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); +/** + * @brief Inner product of the functions represented by the scaling basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the scaling + * basis on the node on the full support of the nodes. The scaling basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket); -template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); +/** + * @brief Inner product of the functions represented by the scaling basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the scaling + * basis on the node on the full support of the nodes. The scaling basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ template ComplexDouble dot_scaling(const FunctionNode &bra, const FunctionNode &ket); + + +/** + * @brief Inner product of the functions represented by the wavelet basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the wavelet + * basis on the node on the full support of the nodes. The wavelet basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ +template double dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); + +/** + * @brief Inner product of the functions represented by the wavelet basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the wavelet + * basis on the node on the full support of the nodes. The wavelet basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); + +/** + * @brief Inner product of the functions represented by the wavelet basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the wavelet + * basis on the node on the full support of the nodes. The wavelet basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ +template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); + +/** + * @brief Inner product of the functions represented by the wavelet basis of + * the nodes + * @param[in] bra FunctionNode on bra + * @param[in] ket FunctionNode on ket + * @return The computed inner product + * + * @details Integrates the product of the functions represented by the wavelet + * basis on the node on the full support of the nodes. The wavelet basis is + * fully orthonormal, and the inner product is simply the dot product of the + * coefficient vectors. Assumes the nodes have identical support. + * @note Conjugate of bra will be taken in case of complex values. + */ template ComplexDouble dot_wavelet(const FunctionNode &bra, const FunctionNode &ket); +//FIXME There are template specializations in FunctionNode.cpp, do we +// need to document them as well? + } // namespace mrcpp diff --git a/src/trees/FunctionTree.cpp b/src/trees/FunctionTree.cpp index 72ca68e39..39a73b506 100644 --- a/src/trees/FunctionTree.cpp +++ b/src/trees/FunctionTree.cpp @@ -42,15 +42,7 @@ using namespace Eigen; namespace mrcpp { -/** @returns New FunctionTree object - * - * @param[in] mra: Which MRA the function is defined - * @param[in] sh_mem: Pointer to MPI shared memory block - * - * @details Constructs an uninitialized tree, containing only empty root nodes. - * If a shared memory pointer is provided the tree will be allocated in this - * shared memory window, otherwise it will be local to each MPI process. - */ + template FunctionTree::FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem, const std::string &name) : MWTree(mra, name) @@ -107,11 +99,7 @@ template FunctionTree::~FunctionTree() { if (this->getNNodes() > 0) this->deleteRootNodes(); } -/** @brief Read a previously stored tree assuming text/ASCII format, - * in a representation using MADNESS conventions for n, l and index order. - * @param[in] file: File name - * @note This tree must have the exact same MRA the one that was saved(?) - */ + template void FunctionTree::loadTreeTXT(const std::string &file) { std::ifstream in(file); int NDIM, k; @@ -285,10 +273,7 @@ template void FunctionTree::loadTreeTXT(const std::str this->calcSquareNorm(); } -/** @brief Write the tree to disk in text/ASCII format in a representation - * using MADNESS conventions for n, l and index order. - * @param[in] file: File name - */ + template void FunctionTree::saveTreeTXT(const std::string &fname) { int nRoots = this->getRootBox().size(); MWNode **roots = this->getRootBox().getNodes(); @@ -357,9 +342,8 @@ template void FunctionTree::saveTreeTXT(const std::str } out.close(); } -/** @brief Write the tree structure to disk, for later use - * @param[in] file: File name, will get ".tree" extension - */ + + template void FunctionTree::saveTree(const std::string &file) { Timer t1; @@ -383,10 +367,7 @@ template void FunctionTree::saveTree(const std::string print::time(10, "Time write", t1); } -/** @brief Read a previously stored tree structure from disk - * @param[in] file: File name, will get ".tree" extension - * @note This tree must have the exact same MRA the one that was saved - */ + template void FunctionTree::loadTree(const std::string &file) { Timer t1; @@ -438,7 +419,7 @@ template T FunctionTree::integrate() const { return jacobian * result; } -/** @returns Integral of a representable function over the grid given by the tree */ + template <> double FunctionTree<3, double>::integrateEndNodes(RepresentableFunction_M &f) { // traverse tree, and treat end nodes only std::vector *> stack; // node from this @@ -473,20 +454,7 @@ template <> double FunctionTree<3, double>::integrateEndNodes(RepresentableFunct return jacobian * result; } -/** @returns Function value in a point, out of bounds returns zero - * - * @param[in] r: Cartesian coordinate - * - * @note This will only evaluate the _scaling_ part of the - * leaf nodes in the tree, which means that the function - * values will not be fully accurate. - * This is done to allow a fast and const function evaluation - * that can be done in OMP parallel. If you want to include - * also the _final_ wavelet part you can call the corresponding - * evalf_precise function, _or_ you can manually extend - * the MW grid by one level before evaluating, using - * `mrcpp::refine_grid(tree, 1)` - */ + template T FunctionTree::evalf(const Coord &r) const { // Handle potential scaling const auto scaling_factor = this->getMRA().getWorldBox().getScalingFactors(); @@ -511,16 +479,7 @@ template T FunctionTree::evalf(const Coord &r) cons return coef * result; } -/** @returns Function value in a point, out of bounds returns zero - * - * @param[in] r: Cartesian coordinate - * - * @note This will evaluate the _true_ value (scaling + wavelet) of the - * leaf nodes in the tree. This requires an on-the-fly MW transform - * on the node which makes this function slow and non-const. If you - * need fast evaluation, use refine_grid(tree, 1) first, and then - * evalf. - */ + template T FunctionTree::evalf_precise(const Coord &r) { // Handle potential scaling const auto scaling_factor = this->getMRA().getWorldBox().getScalingFactors(); @@ -546,12 +505,7 @@ template T FunctionTree::evalf_precise(const Coord return coef * result; } -/** @brief In-place square of MW function representations, fixed grid - * - * @details The leaf node point values of the function will be in-place - * squared, no grid refinement. - * - */ + template void FunctionTree::square() { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); @@ -575,14 +529,7 @@ template void FunctionTree::square() { this->calcSquareNorm(); } -/** @brief In-place power of MW function representations, fixed grid - * - * @param[in] p: Numerical power - * - * @details The leaf node point values of the function will be in-place raised - * to the given power, no grid refinement. - * - */ + template void FunctionTree::power(double p) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); @@ -606,14 +553,7 @@ template void FunctionTree::power(double p) { this->calcSquareNorm(); } -/** @brief In-place multiplication by a scalar, fixed grid - * - * @param[in] c: Scalar coefficient - * - * @details The leaf node point values of the function will be - * in-place multiplied by the given coefficient, no grid refinement. - * - */ + template void FunctionTree::rescale(T c) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) num_threads(mrcpp_get_num_threads()) @@ -633,7 +573,7 @@ template void FunctionTree::rescale(T c) { this->calcSquareNorm(); } -/** @brief In-place rescaling by a function norm \f$ ||f||^{-1} \f$, fixed grid */ + template void FunctionTree::normalize() { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); double sq_norm = this->getSquareNorm(); @@ -641,15 +581,7 @@ template void FunctionTree::normalize() { this->rescale(1.0 / std::sqrt(sq_norm)); } -/** @brief In-place addition with MW function representations, fixed grid - * - * @param[in] c: Numerical coefficient of input function - * @param[in] inp: Input function to add - * - * @details The input function will be added in-place on the current grid of - * the function, i.e. no further grid refinement. - * - */ + template void FunctionTree::add(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); @@ -670,15 +602,7 @@ template void FunctionTree::add(T c, FunctionTreecalcSquareNorm(); inp.deleteGenerated(); } -/** @brief In-place addition with MW function representations, fixed grid - * - * @param[in] c: Numerical coefficient of input function - * @param[in] inp: Input function to add - * - * @details The input function will be added to the union of the current grid of - * and input the function grid. - * - */ + template void FunctionTree::add_inplace(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); @@ -701,15 +625,7 @@ template void FunctionTree::add_inplace(T c, FunctionT inp.deleteGenerated(); } -/** @brief In-place addition of absolute values of MW function representations - * - * @param[in] c Numerical coefficient of input function - * @param[in] inp Input function to add - * - * The absolute value of input function will be added in-place on the current grid of the output - * function, i.e. no further grid refinement. - * - */ + template void FunctionTree::absadd(T c, FunctionTree &inp) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); #pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads()) @@ -736,15 +652,7 @@ template void FunctionTree::absadd(T c, FunctionTree void FunctionTree::multiply(T c, FunctionTree &inp) { if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA"); if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); @@ -772,14 +680,7 @@ template void FunctionTree::multiply(T c, FunctionTree inp.deleteGenerated(); } -/** @brief In-place mapping with a predefined function f(x), fixed grid - * - * @param[in] fmap: mapping function - * - * @details The input function will be mapped in-place on the current grid - * of the function, i.e. no further grid refinement. - * - */ + template void FunctionTree::map(FMap fmap) { if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared"); { @@ -839,20 +740,7 @@ template std::ostream &FunctionTree::print(std::ostrea return MWTree::print(o); } -/** @brief Reduce the precision of the tree by deleting nodes - * - * @param prec: New precision criterion - * @param splitFac: Splitting factor: 1, 2 or 3 - * @param absPrec: Use absolute precision - * - * @details This will run the tree building algorithm in "reverse", starting - * from the leaf nodes, and perform split checks on each node based on the given - * precision and the local wavelet norm. - * - * @note The splitting factor appears in the threshold for the wavelet norm as - * \f$ ||w|| < 2^{-sn/2} ||f|| \epsilon \f$. In principal, `s` should be equal - * to the dimension; in practice, it is set to `s=1`. - */ + template int FunctionTree::crop(double prec, double splitFac, bool absPrec) { for (int i = 0; i < this->rootBox.size(); i++) { MWNode &root = this->getRootMWNode(i); @@ -864,10 +752,7 @@ template int FunctionTree::crop(double prec, double sp return nChunks; } -/** Traverse tree using BFS and returns an array with the address of the coefs. - * Also returns an array with the corresponding indices defined as the - * values of serialIx in refTree, and an array with the indices of the parent. - * Set index -1 for nodes that are not present in refTree */ + template void FunctionTree::makeCoeffVector(std::vector &coefs, std::vector &indices, @@ -918,10 +803,7 @@ void FunctionTree::makeCoeffVector(std::vector &coefs, } } -/** Traverse tree using DFS and reconstruct it using node info from the - * reference tree and a list of coefficients. - * It is the reference tree (refTree) which is traversed, but one does not descend - * into children if the norm of the tree is smaller than absPrec. */ + template void FunctionTree::makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode) { std::vector *> stack; std::map *> ix2node; // gives the nodes in this tree for a given ix @@ -998,9 +880,7 @@ template void FunctionTree::makeTreefromCoeff(MWTree void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { std::vector *> instack; // node from inTree std::vector *> thisstack; // node from this Tree @@ -1039,7 +919,7 @@ template void FunctionTree::appendTreeNoCoeff(MWTree void FunctionTree::appendTreeNoCoeff(MWTree &inTree) { std::vector *> instack; // node from inTree std::vector *> thisstack; // node from this Tree @@ -1131,17 +1011,13 @@ template <> int FunctionTree<3, ComplexDouble>::saveNodesAndRmCoeff() { return this->NodeIndex2serialIx.size(); } -/** @brief Deep copy of tree - * - * @details Exact copy without any binding between old and new tree - */ + template void FunctionTree::deep_copy(FunctionTree *out) { copy_grid(*out, *this); copy_func(*out, *this); } -/** @brief New tree with only real part - */ + template FunctionTree *FunctionTree::Real() { FunctionTree *out = new FunctionTree(this->getMRA(), this->getName()); out->setZero(); @@ -1165,8 +1041,7 @@ template FunctionTree *FunctionTree::Real() return out; } -/** @brief New tree with only imaginary part - */ + template FunctionTree *FunctionTree::Imag() { FunctionTree *out = new FunctionTree(this->getMRA(), this->getName()); out->setZero(); @@ -1188,10 +1063,6 @@ template FunctionTree *FunctionTree::Imag() return out; } -/* - * From real to complex tree. Copy everything, and convert double to ComplexDouble for the coefficents. - * Should use a deep_copy if generalized in the future. - */ template <> void FunctionTree<3, double>::CopyTreeToComplex(FunctionTree<3, ComplexDouble> *&outTree) { delete outTree; diff --git a/src/trees/FunctionTree.h b/src/trees/FunctionTree.h index 9d976d6be..792313bf8 100644 --- a/src/trees/FunctionTree.h +++ b/src/trees/FunctionTree.h @@ -33,77 +33,322 @@ namespace mrcpp { -/** @class FunctionTree - * - * @brief Function representation in MW basis +/** + * @class FunctionTree + * @tparam D Spatial dimension (1, 2, or 3). + * @tparam T Coefficient type (e.g. double, ComplexDouble). + * @brief Function representation in the MW basis with adaptive topology. * * @details - * Constructing a full grown FunctionTree involves a number of steps, - * including setting up a memory allocator, constructing root nodes according - * to the given MRA, building an adaptive tree structure and computing MW - * coefficients. The FunctionTree constructor does only half of these steps: - * It takes an MRA argument, which defines the computational domain and scaling - * basis (these are fixed parameters that cannot be changed after construction). - * The tree is initialized with a memory allocator and a set of root nodes, but - * it does not compute any coefficients and the function is initially - * *undefined*. An undefined FunctionTree will have a well defined tree - * structure (at the very least the root nodes of the given MRA, but possibly - * with additional refinement) and its MW coefficient will be allocated but - * uninitialized, and its square norm will be negative (minus one). + * The class derives from MWTree (topology and node management) and + * RepresentableFunction (evaluation interface). Typical workflows build + * or refine the tree via calculators/adaptors and then apply algebraic + * transforms in place. */ - -template class FunctionTree final : public MWTree, public RepresentableFunction { +template +class FunctionTree final : public MWTree, public RepresentableFunction { public: + /** + * @brief Construct a tree bound to an MRA with a user label. + * @param mra Multi-resolution analysis (domain and basis function used). + * @param name Optional textual name of the function. + * + * @note This will only create the object. To compute coefficients, + * use a tree builder or calculator afterwards (e.g., projectors). + */ FunctionTree(const MultiResolutionAnalysis &mra, const std::string &name) : FunctionTree(mra, nullptr, name) {} - FunctionTree(const MultiResolutionAnalysis &mra, SharedMemory *sh_mem = nullptr, const std::string &name = "nn"); + + /** + * @brief Construct a tree bound to an MRA with optional shared memory and name. + * @param[in] mra: Which MRA the function is defined + * @param[in] sh_mem: Pointer to MPI shared memory block + * @param name Optional textual name of the function + * + * @returns New FunctionTree object + * + * @details Constructs an uninitialized tree, containing only empty root nodes. + * If a shared memory pointer is provided the tree will be allocated in this + * shared memory window, otherwise it will be local to each MPI process. + */ + FunctionTree(const MultiResolutionAnalysis &mra, + SharedMemory *sh_mem = nullptr, + const std::string &name = "nn"); + + FunctionTree(const FunctionTree &tree) = delete; FunctionTree &operator=(const FunctionTree &tree) = delete; + + /// FunctionTree destructor ~FunctionTree() override; + /** + * @brief Integrate the represented function over the MRA box + * @returns Integral of the function over the entire computational domain + */ T integrate() const; + + /** + * @brief Integrate a representable function using the this tree's grid + * @param[in] f RepresentableFunction used as integrand partner + * @returns Integral of the representable function + * + * @details You can evaluate the integral of any representable function + * over the most refined scale of the 'this' FunctionTree's grid. + */ double integrateEndNodes(RepresentableFunction_M &f); - T evalf_precise(const Coord &r); + + /** + * @brief Evaluate with high accuracy at a given coordinate. + * @param r Physical coordinate. + * @return Function value. + * + * @details May be more expensive than evalf due to stricter handling. + */ + + + /** + * @brief Fast but approximate evaluation of this FunctionTree at a given coordinate. + * @param[in] r: Cartesian coordinate to be evaluated + * @return Approximate Function value + * + * @note This will only evaluate the _scaling_ part of the + * leaf nodes in the tree, which means that the function + * values will not be fully accurate. + * This is done to allow a fast and const function evaluation + * that can be done in OMP parallel. If you want to include + * also the _final_ wavelet part you can call the corresponding + * evalf_precise function, _or_ you can manually extend + * the MW grid by one level before evaluating, using + * `mrcpp::refine_grid(tree, 1)` + */ T evalf(const Coord &r) const override; + /** + * @brief Slow but high-accuracy , evaluation of the function at a given coordinate + * @param[in] r: Cartesian coordinate to be evaluated + * @returns Exact value of this FunctionTree in the point r + * @note This will evaluate the _true_ value (scaling + wavelet) of the + * leaf nodes in the tree. This requires an on-the-fly MW transform + * on the node which makes this function slow and non-const. If you + * need fast evaluation, use refine_grid(tree, 1) first, and then + * evalf. + */ + T evalf_precise(const Coord &r); + + /** + * @brief Number of generated (non-root) nodes currently alive + * @return Count of nodes (managed by the generated-node allocator) + */ int getNGenNodes() const { return getGenNodeAllocator().getNNodes(); } + /** + * @brief Collect values on end nodes into a dense Eigen type vector + * @param[out] data Column vector sized to the total number of end-node values + */ void getEndValues(Eigen::Matrix &data); + + /** + * @brief Set end-node values as the components of the dense Eigen type vector + * @param[in] data Column vector holding values; its size must match + */ void setEndValues(Eigen::Matrix &data); - void saveTree(const std::string &file); + /** + * @brief Write the tree to disk in text/ASCII format in a representation + * using MADNESS conventions for n, l and index order. + * @param file Output filename. + */ void saveTreeTXT(const std::string &file); + + /** + * @brief Write the tree structure to disk, for later use + * @param[in] file: File name, will get ".tree" extension + */ + void saveTree(const std::string &file); + + /** + * @brief Read a previously stored tree structure from disk + * @param[in] file File name, will get ".tree" extension + * @note This tree must have the exact same MRA the one that was saved + */ void loadTree(const std::string &file); + + /** + * @brief Read a previously stored tree assuming text/ASCII format, using MADNESS conventions (n, l and index order) + * @param[in] file Input filename + * @note Make sure that the MRA of this tree matches the one used to create the file + */ void loadTreeTXT(const std::string &file); - // In place operations - void square(); + + /** @brief In-place square of MW function representations, fixed grid + * + * @details The leaf node point values of the function will be in-place + * squared, no grid refinement. + * + */ + void square(); + /// Raise the function to power p pointwise. + /** + * @brief In-place power of MW function representations, fixed grid + * + * @param p Exponent + * + * @details The leaf node point values of the function will be in-place raised + * to the given power, no grid refinement. + */ void power(double p); - void rescale(T c); - void normalize(); + /** + * @brief In-place multiplication by a scalar, fixed grid + * + * @param c Scaling factor (with the same data type as the coefficients) + * + * @details The leaf node point values of the function will be + * in-place multiplied by the given coefficient, no grid refinement. + */ + void rescale(T c); + void normalize(); ///< In-place rescaling by a function norm \f$ ||f||^{-1} \f$, fixed grid + + /** + * @brief this + inp (fixed grid) + * + * @param c: Numerical coefficient of input function + * @param[in] inp: Input function to be added on this FunctionTree + * + * @details The input function will be added in-place on the current grid of + * the function, i.e. no further grid refinement. Addition done within the MW representations. + */ void add(T c, FunctionTree &inp); + + /** + * @brief this + inp (uniting the two grids) + + * @param c: Numerical coefficient of input function + * @param[in] inp: Input function to be added on this FunctionTree + * + * @details The input function will be added to the union of the current grid of + * and input the function grid. Addition done within the MW representations. + */ void add_inplace(T c, FunctionTree &inp); + + /** + * @brief this + abs(inp) (fixed grid) + + * @param c: Numerical coefficient of input function + * @param[in] inp: Input function to be added on this FunctionTree + * + * @details The absolute value of input function will be added in-place on the current grid of the output + * function, i.e. no further grid refinement. Addition done within the MW representations. + */ void absadd(T c, FunctionTree &inp); + + /** + * @brief this * (c * inp), fixed grid + * @param c: Numerical coefficient of input function + * @param[in] inp: Input function to be multiplied with this FunctionTree + * + * @details The input function will be multiplied in-place on the current grid + * of the function, i.e. no further grid refinement. + */ void multiply(T c, FunctionTree &inp); + + /** + * @brief In-place mapping with a predefined function f(x), fixed grid + * + * @param[in] fmap: mapping function + * + * @details The input function will be mapped in-place on the current grid + * of the function, i.e. no further grid refinement. + */ void map(FMap fmap); + + + + /** + * @brief Number of memory chunks reserved for nodes. + * @return Total chunk count. + */ int getNChunks() { return this->getNodeAllocator().getNChunks(); } + + /** + * @brief Number of memory chunks currently in use. + * @return Used chunk count. + */ int getNChunksUsed() { return this->getNodeAllocator().getNChunksUsed(); } + /** @brief Reduce the precision of the tree by deleting nodes + * + * @param prec: New precision criterion + * @param splitFac: Splitting factor: 1, 2 or 3 + * @param absPrec: Use absolute precision + * + * @details This will run the tree building algorithm in "reverse", starting + * from the leaf nodes, and perform split checks on each node based on the given + * precision and the local wavelet norm. + * + * @note The splitting factor appears in the threshold for the wavelet norm as + * \f$ ||w|| < 2^{-sn/2} ||f|| \epsilon \f$. In principal, `s` should be equal + * to the dimension; in practice, it is set to `s=1`. + */ int crop(double prec, double splitFac = 1.0, bool absPrec = true); + /** @name Typed access to nodes */ + ///@{ + + /// @return i-th end node cast to FunctionNode FunctionNode &getEndFuncNode(int i) { return static_cast &>(this->getEndMWNode(i)); } + + /// @return i-th root node cast to FunctionNode FunctionNode &getRootFuncNode(int i) { return static_cast &>(this->rootBox.getNode(i)); } + /// @return Allocator for generated nodes NodeAllocator &getGenNodeAllocator() { return *this->genNodeAllocator_p; } + + /// @return Allocator for generated nodes ìì const NodeAllocator &getGenNodeAllocator() const { return *this->genNodeAllocator_p; } + /// @return i-th end node cast to FunctionNode const FunctionNode &getEndFuncNode(int i) const { return static_cast &>(this->getEndMWNode(i)); } + + /// @return i-th root node cast to FunctionNode const FunctionNode &getRootFuncNode(int i) const { return static_cast &>(this->rootBox.getNode(i)); } + ///@} + + /** + * @brief Delete nodes that were generated during the last build/refine step. + * @details Restores the tree to the pre-generation state without touching + * persisted nodes and data, lowering the resolution to the previous scale + */ void deleteGenerated(); + + /** + * @brief Delete generated nodes and their generated parents if they became empty. + */ void deleteGeneratedParents(); + /** + * @brief Will fill the first 4 vectors with the coefficient pointers, indices, parent indices and scale factors + * + * @param[out] coefs Pointers to coefficient blocks per node. + * @param[out] indices Node indices mapped to a compact integer id. + * @param[out] parent_indices Parent ids matching indices. + * @param[out] scalefac Per-node scale factors (e.g. for normalization). + * @param[out] max_index Maximum assigned compact id. + * @param[in] refTree Reference tree defining traversal order. + * @param[in] refNodes Optional explicit node list to follow. + * + * @details Traverse tree using BFS and returns an array with the address of the coefs. + * Also returns an array with the corresponding indices defined as the + * values of serialIx in refTree, and an array with the indices of the parent. + * Set index -1 for nodes that are not present in refTree + * Set parent_indices as -2 for nodes that are not present in refTree + * Set scalefac as 1.0 for nodes that are not present in refTree + * Intended for exporting the tree into custom linear algebra + * back-ends or checkpoint formats. + */ void makeCoeffVector(std::vector &coefs, std::vector &indices, std::vector &parent_indices, @@ -111,25 +356,123 @@ template class FunctionTree final : public MWTree, pub int &max_index, MWTree &refTree, std::vector *> *refNodes = nullptr); - void makeTreefromCoeff(MWTree &refTree, std::vector coefpVec, std::map &ix2coef, double absPrec, const std::string &mode = "adaptive"); + + /** + * @brief Reconstruct a tree topology from a coefficient vector. + * @param[out] refTree Reference topology to follow. + * @param coefpVec Pointers to coefficient blocks. + * @param ix2coef Mapping from node compact id to coefpVec index. + * @param absPrec Threshold for adaptive creation. + * @param mode Creation mode: "adaptive" or fixed variants. + * + * @details Traverse tree using DFS and reconstruct it using node info from the + * reference tree and a list of coefficients. + * It is the reference tree (refTree) which is traversed, but one does not descend + * into children if the norm of the tree is smaller than absPrec. + */ + void makeTreefromCoeff(MWTree &refTree, + std::vector coefpVec, + std::map &ix2coef, + double absPrec, + const std::string &mode = "adaptive"); + + + /** + * @brief Append topology from another tree with real-type coefficients (no coefficients copied) + * @param[in] inTree Input tree. + * + * @note It will append only the nodes structure, without copying any coefficient data, + * therefore it won't matter the datatype of the input tree for the result. + */ void appendTreeNoCoeff(MWTree &inTree); + + /** + * @brief Append topology from another tree with real-type coefficients (no coefficients copied) + * @param[in] inTree Input tree. + * + * @note It will append only the nodes structure, without copying any coefficient data, + * therefore it won't matter the datatype of the input tree for the result. + */ void appendTreeNoCoeff(MWTree &inTree); + + /** + * @brief Copy topology AND coefficients from a real-valued tree + * @param[in] inTree Source tree. + * + * @note The copy process is a shallow copy for the coefficients, i.e. + * the new tree nodes will point to the same coefficient blocks as the input tree. + * Therefore, modifying the coefficients in one tree will affect the other. + */ void CopyTree(FunctionTree &inTree); - // tools for use of local (nodes are stored in Bank) representation - int saveNodesAndRmCoeff(); // put all nodes coefficients in Bank and delete all coefficients + + /** + * @brief Move all node coefficients to a bank and remove them from nodes + * @return Number of nodes affected. + */ + int saveNodesAndRmCoeff(); + + /** + * @brief Deep-copy entire tree into out (topology and data) + * @param[out] out Destination tree pointer (must be non-null and compatible) + * + * @details Exact copy without any binding between old and new tree + */ void deep_copy(FunctionTree *out); + + /** + * @brief Extract real part into a newly allocated real tree + * @return Pointer to a new FunctionTree of type double + */ FunctionTree *Real(); + + /** + * @brief Extract imaginary part into a newly allocated real tree + * @return Pointer to a new FunctionTree of type double + */ FunctionTree *Imag(); + + /** @name Real/complex conversion helpers */ + ///@{ + /** + * @brief Deep-copy this tree into a complex-valued tree. + * + * @param[out] out Destination tree pointer (must be non-null). + * @details Exact copy into a complex tree, with imaginary parts set to zero + */ void CopyTreeToComplex(FunctionTree<3, ComplexDouble> *&out); + /** + * @brief Deep-copy this tree into a complex-valued tree. + * + * @param[out] out Destination tree pointer (must be non-null). + * @details Exact copy into a complex tree, with imaginary parts set to zero + */ void CopyTreeToComplex(FunctionTree<2, ComplexDouble> *&out); + /** + * @brief Deep-copy this tree into a complex-valued tree. + * + * @param[out] out Destination tree pointer (must be non-null). + * @details Exact copy into a complex tree, with imaginary parts set to zero + */ void CopyTreeToComplex(FunctionTree<1, ComplexDouble> *&out); + + /** + * @brief Deep-copy this tree into a real-valued tree. + * + * @param[out] out Destination tree pointer (must be non-null). + * @details Exact copy into a real tree, taking only the real parts + */ void CopyTreeToReal(FunctionTree<3, double> *&out); // for testing + ///@} protected: + /// Allocator for generated nodes. std::unique_ptr> genNodeAllocator_p{nullptr}; + + /// Print a short, human-readable description of the tree. std::ostream &print(std::ostream &o) const override; + /// Allocate and initialize root nodes according to the MRA. void allocRootNodes(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/FunctionTreeVector.h b/src/trees/FunctionTreeVector.h index 142113e1f..675def8e6 100644 --- a/src/trees/FunctionTreeVector.h +++ b/src/trees/FunctionTreeVector.h @@ -32,14 +32,38 @@ namespace mrcpp { -template using CoefsFunctionTree = std::tuple *>; -template using FunctionTreeVector = std::vector>; + +/** + * @brief Alias for a weighted FunctionTree pointer. + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @details + * The tuple layout is: + * - element 0: numeric coefficient of type @p T, + * - element 1: pointer to a @c FunctionTree. + * + * Ownership of the pointer is not implied by the alias; see @ref clear(). + */ +template +using CoefsFunctionTree = std::tuple *>; + +/** + * @brief Alias for a vector of weighted FunctionTree pointers. + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + */ +template +using FunctionTreeVector = std::vector>; /** @brief Remove all entries in the vector * @param[in] fs: Vector to clear * @param[in] dealloc: Option to free FunctionTree pointer before clearing */ -template void clear(FunctionTreeVector &fs, bool dealloc = false) { +template +void clear(FunctionTreeVector &fs, bool dealloc = false) { if (dealloc) { for (auto &t : fs) { auto f = std::get<1>(t); @@ -50,10 +74,17 @@ template void clear(FunctionTreeVector &fs, bool deall fs.clear(); } -/** @returns Total number of nodes of all trees in the vector - * @param[in] fs: Vector to fetch from +/** + * @brief Compute the total number of nodes across all trees in the vector. + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * @param[in] fs Vector to fetch from + * + * @returns Total number of nodes of all trees in the vector */ -template int get_n_nodes(const FunctionTreeVector &fs) { +template +int get_n_nodes(const FunctionTreeVector &fs) { int nNodes = 0; for (const auto &t : fs) { auto f = std::get<1>(t); @@ -62,10 +93,17 @@ template int get_n_nodes(const FunctionTreeVector &fs) return nNodes; } -/** @returns Total size of all trees in the vector, in kB - * @param[in] fs: Vector to fetch from +/** + * @brief Compute the total size of all trees in the vector (in kilobytes). + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * @param[in] fs Vector to fetch from. + * + * @returns Total size of all trees in the vector, in kB */ -template int get_size_nodes(const FunctionTreeVector &fs) { +template +int get_size_nodes(const FunctionTreeVector &fs) { int sNodes = 0; for (const auto &t : fs) { auto f = std::get<1>(t); @@ -74,27 +112,55 @@ template int get_size_nodes(const FunctionTreeVector & return sNodes; } -/** @returns Numerical coefficient at given position in vector - * @param[in] fs: Vector to fetch from - * @param[in] i: Position in vector +/** + * @brief Access the numeric coefficient at a given position. + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * @param[in] fs Vector to fetch from + * @param[in] i Position in vector (zero-based) + * + * @returns Numerical coefficient at given position in vector + * + * @pre @p i must be a valid index in @p fs. */ -template T get_coef(const FunctionTreeVector &fs, int i) { +template +T get_coef(const FunctionTreeVector &fs, int i) { return std::get<0>(fs[i]); } -/** @returns FunctionTree at given position in vector - * @param[in] fs: Vector to fetch from - * @param[in] i: Position in vector +/** + * @brief Access the FunctionTree at a given position (non-const). + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * @param[in] fs Vector to fetch from + * @param[in] i Position in vector (zero-based) + * + * @return FunctionTree at given position in vector + * + * @pre The pointer stored at position @p i must be non-null. */ -template FunctionTree &get_func(FunctionTreeVector &fs, int i) { +template +FunctionTree &get_func(FunctionTreeVector &fs, int i) { return *(std::get<1>(fs[i])); } -/** @returns FunctionTree at given position in vector - * @param[in] fs: Vector to fetch from - * @param[in] i: Position in vector +/** + * @brief Access the FunctionTree at a given position (non-const). + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * @param[in] fs Vector to fetch from + * @param[in] i Position in vector (zero-based) + * + * @return FunctionTree at given position in vector + * + * @pre The pointer stored at position @p i must be non-null. */ -template const FunctionTree &get_func(const FunctionTreeVector &fs, int i) { +template +const FunctionTree &get_func(const FunctionTreeVector &fs, int i) { return *(std::get<1>(fs[i])); } -} // namespace mrcpp + +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/HilbertPath.h b/src/trees/HilbertPath.h index 519e7ac73..c6c005a14 100644 --- a/src/trees/HilbertPath.h +++ b/src/trees/HilbertPath.h @@ -27,31 +27,95 @@ namespace mrcpp { -template class HilbertPath final { +/** + * @class HilbertPath + * @tparam D Spatial dimension (e.g., 2 for quadtree, 3 for octree). + * + * @brief Traverse the nodes of a tree following the Hilbert space-filling curve. + * + * @details The Hilbert curve is a continuous fractal space-filling curve that + * has good locality properties. We use it to traverse the nodes of a tree. + * Each node visit in a Hilbert traversal has an associated **state** that + * determines how the children are ordered. Alternativly a Z-ordering can be used. + * - @ref getZIndex maps a Hilbert child index to the corresponding Morton + * (Z-order) child index; + * - @ref getHIndex performs the inverse mapping (Morton to Hilbert); and + * - @ref getChildPath returns the orientation state to use after descending + * to a specific Hilbert child. + * + * The mappings are implemented via static lookup tables. + */ +template +class HilbertPath final { public: + /** + * @brief Default constructor + */ HilbertPath() = default; + /** + * @brief Copy constructor + */ HilbertPath(const HilbertPath &p) : path(p.path) {} + /** + * @brief Construct a child path from a parent path and a child index. + * + * @param[in] p Parent @ref HilbertPath state. + * @param[in] cIdx Child index expressed in **Morton (Z-order)** for this parent. + * + * @details + * The provided @p cIdx is first converted to the corresponding **Hilbert** + * index for the parent state, then the next orientation state is selected + * via the transition table. + */ HilbertPath(const HilbertPath &p, int cIdx) { int hIdx = p.getHIndex(cIdx); this->path = p.getChildPath(hIdx); } + /** + * @brief Assignment operator + */ HilbertPath &operator=(const HilbertPath &p) { this->path = p.path; return *this; } - - short int getPath() const { return this->path; } + short int getPath() const { return this->path; } ///< @return the current path */ + /** + * @brief Get path index of selected child + * + * @param hIdx Child index in **Hilbert** order for the current state. + * @return Path index for the selected child + */ short int getChildPath(int hIdx) const { return this->pTable[this->path][hIdx]; } - + /** + * @brief Map Hilbert child index to Morton (Z-order) child index + * + * @param hIdx Child index in **Hilbert** order + * @return **Morton** child index. + */ int getZIndex(int hIdx) const { return this->zTable[this->path][hIdx]; } + + /** + * @brief Map Morton (Z-order) child index to Hilbert child index. + * + * @param zIdx Child index in **Morton** order + * @return **Hilbert** child index + */ int getHIndex(int zIdx) const { return this->hTable[this->path][zIdx]; } private: + /// Current Hilbert orientation state (table row selector). short int path{0}; - static const short int pTable[][8]; - static const int zTable[][8]; - static const int hTable[][8]; + + /** + * @name Lookup tables (declared in header, defined in the .cpp) + * Each table has 2^D columns (up to 8 for D=3) and one row per state. + * + */ + static const short int pTable[][8]; ///< Next-state table: state × h -> state' + static const int zTable[][8]; ///< Mapping: state × h -> z + static const int hTable[][8]; ///< Mapping: state × z -> h + /** @} */ }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/MWNode.cpp b/src/trees/MWNode.cpp index 2d521b468..7c952d8a6 100644 --- a/src/trees/MWNode.cpp +++ b/src/trees/MWNode.cpp @@ -41,10 +41,6 @@ using namespace Eigen; namespace mrcpp { -/** @brief MWNode default constructor. - * - * @details Should be used only by NodeAllocator to obtain - * virtual table pointers for the derived classes. */ template MWNode::MWNode() : tree(nullptr) @@ -59,13 +55,6 @@ MWNode::MWNode() MRCPP_INIT_OMP_LOCK(); } -/** @brief MWNode constructor. - * - * @param[in] tree: the MWTree the root node belongs to - * @param[in] idx: the NodeIndex defining scale and translation of the node - * - * @details Constructor for an empty node, given the corresponding MWTree and NodeIndex - */ template MWNode::MWNode(MWTree *tree, const NodeIndex &idx) : tree(tree) @@ -79,14 +68,6 @@ MWNode::MWNode(MWTree *tree, const NodeIndex &idx) MRCPP_INIT_OMP_LOCK(); } -/** @brief MWNode constructor. - * - * @param[in] tree: the MWTree the root node belongs to - * @param[in] rIdx: the integer specifying the corresponding root node - * - * @details Constructor for root nodes. It requires the corresponding - * MWTree and an integer to fetch the right NodeIndex - */ template MWNode::MWNode(MWTree *tree, int rIdx) : tree(tree) @@ -100,14 +81,6 @@ MWNode::MWNode(MWTree *tree, int rIdx) MRCPP_INIT_OMP_LOCK(); } -/** @brief MWNode constructor. - * - * @param[in] parent: parent node - * @param[in] cIdx: child index of the current node - * - * @details Constructor for leaf nodes. It requires the corresponding - * parent and an integer to identify the correct child. - */ template MWNode::MWNode(MWNode *parent, int cIdx) : tree(parent->tree) @@ -121,15 +94,6 @@ MWNode::MWNode(MWNode *parent, int cIdx) MRCPP_INIT_OMP_LOCK(); } -/** @brief MWNode copy constructor. - * - * @param[in] node: the original node - * @param[in] allocCoef: if true MW coefficients are allocated and copied from the original node - * - * @details Creates loose nodes and optionally copy coefs. The node - * does not "belong" to the tree: it cannot be accessed by traversing - * the tree. - */ template MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) : tree(node.tree) @@ -159,18 +123,14 @@ MWNode::MWNode(const MWNode &node, bool allocCoef, bool SetCoef) MRCPP_INIT_OMP_LOCK(); } -/** @brief MWNode destructor. - * - * @details Recursive deallocation of a node and all its decendants - */ template MWNode::~MWNode() { if (this->isLooseNode()) this->freeCoefs(); MRCPP_DESTROY_OMP_LOCK(); } -/** @brief Dummy deallocation of MWNode coefficients. +/* Dummy deallocation of MWNode coefficients. * - * @details This is just to make sure this method never really gets + * This is just to make sure this method never really gets * called (derived classes must implement their own version). This was * to avoid having pure virtual methods in the base class. */ @@ -178,12 +138,6 @@ template void MWNode::dealloc() { NOT_REACHED_ABORT; } -/** @brief Allocate the coefs vector. - * - * @details This is only used by loose nodes, because the loose nodes - * are not treated by the NodeAllocator class. - * - */ template void MWNode::allocCoefs(int n_blocks, int block_size) { if (this->n_coefs != 0) MSG_ABORT("n_coefs should be zero"); if (this->isAllocated()) MSG_ABORT("Coefs already allocated"); @@ -196,12 +150,6 @@ template void MWNode::allocCoefs(int n_blocks, int blo this->setIsAllocated(); } -/** @brief Deallocate the coefs vector. - * - * @details This is only used by loose nodes, because the loose nodes - * are not treated by the NodeAllocator class. - * - */ template void MWNode::freeCoefs() { if (not this->isLooseNode()) MSG_ABORT("Only loose nodes here!"); @@ -214,8 +162,6 @@ template void MWNode::freeCoefs() { this->clearIsAllocated(); } -/** @brief Printout of node coefficients - */ template void MWNode::printCoefs() const { if (not this->isAllocated()) MSG_ABORT("Node is not allocated"); println(0, "\nMW coefs"); @@ -226,8 +172,6 @@ template void MWNode::printCoefs() const { } } -/** @brief wraps the MW coefficients into an eigen vector object - */ template void MWNode::getCoefs(Eigen::Matrix &c) const { if (not this->isAllocated()) MSG_ABORT("Node is not allocated"); if (not this->hasCoefs()) MSG_ABORT("Node has no coefs"); @@ -236,9 +180,6 @@ template void MWNode::getCoefs(Eigen::Matrix::Map(this->coefs, this->n_coefs); } -/** @brief sets all MW coefficients and the norms to zero - * - */ template void MWNode::zeroCoefs() { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated " << *this); @@ -247,68 +188,26 @@ template void MWNode::zeroCoefs() { this->setHasCoefs(); } -/** @brief Attach a set of coefs to this node. Only used locally (the tree is not aware of this). - */ template void MWNode::attachCoefs(T *coefs) { this->coefs = coefs; this->setHasCoefs(); } -/** @brief assigns values to a block of coefficients - * - * @param[in] c: the input coefficients - * @param[in] block: the block index - * @param[in] block_size: size of the block - * - * @details a block is typically containing one kind of coefficients - * (given scaling/wavelet in each direction). Its size is then \f$ - * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. - */ template void MWNode::setCoefBlock(int block, int block_size, const T *c) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] = c[i]; } } -/** @brief adds values to a block of coefficients - * - * @param[in] c: the input coefficients - * @param[in] block: the block index - * @param[in] block_size: size of the block - * - * @details a block is typically containing one kind of coefficients - * (given scaling/wavelet in each direction). Its size is then \f$ - * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. - */ template void MWNode::addCoefBlock(int block, int block_size, const T *c) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] += c[i]; } } -/** @brief sets values of a block of coefficients to zero - * - * @param[in] block: the block index - * @param[in] block_size: size of the block - * - * @details a block is typically containing one kind of coefficients - * (given scaling/wavelet in each direction). Its size is then \f$ - * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. - */ template void MWNode::zeroCoefBlock(int block, int block_size) { if (not this->isAllocated()) MSG_ABORT("Coefs not allocated"); for (int i = 0; i < block_size; i++) { this->coefs[block * block_size + i] = 0.0; } } -/** @brief forward MW transform from this node to its children - * - * @param[in] overwrite: if true the coefficients of the children are - * overwritten. If false the values are summed to the already present - * ones. - * - * @details it performs forward MW transform inserting the result - * directly in the right place for each child node. The children must - * already be present and its memory allocated for this to work - * properly. - */ template void MWNode::giveChildrenCoefs(bool overwrite) { assert(this->isBranchNode()); if (not this->isAllocated()) MSG_ABORT("Not allocated!"); @@ -334,17 +233,6 @@ template void MWNode::giveChildrenCoefs(bool overwrite } } -/** @brief forward MW transform to compute scaling coefficients of a single child - * - * @param[in] cIdx: child index - * @param[in] overwrite: if true the coefficients of the children are - * overwritten. If false the values are summed to the already present - * ones. - * - * @details it performs forward MW transform in place on a loose - * node. The scaling coefficients of the selected child are then - * copied/summed in the correct child node. - */ template void MWNode::giveChildCoefs(int cIdx, bool overwrite) { MWNode node_i = *this; @@ -365,12 +253,6 @@ template void MWNode::giveChildCoefs(int cIdx, bool ov child.calcNorms(); } -/** Takes a MWParent and generates coefficients, reverse operation from - * giveChildrenCoefs */ -/** @brief backward MW transform to compute scaling/wavelet coefficients of a parent - * - * \warning This routine is only used in connection with Periodic Boundary Conditions - */ template void MWNode::giveParentCoefs(bool overwrite) { MWNode node = *this; MWNode &parent = getMWParent(); @@ -387,12 +269,6 @@ template void MWNode::giveParentCoefs(bool overwrite) parent.calcNorms(); } -/** @brief Copy scaling coefficients from children to parent - * - * @details Takes the scaling coefficients of the children and stores - * them consecutively in the corresponding block of the parent, - * following the usual bitwise notation. - */ template void MWNode::copyCoefsFromChildren() { int kp1_d = this->getKp1_d(); int nChildren = this->getTDim(); @@ -403,14 +279,6 @@ template void MWNode::copyCoefsFromChildren() { } } -/** @brief Generates scaling coefficients of children - * - * @details If the node is a leafNode, it takes the scaling&wavelet - * coefficients of the parent and it generates the scaling - * coefficients for the children and stores - * them consecutively in the corresponding block of the parent, - * following the usual bitwise notation. - */ template void MWNode::threadSafeGenChildren() { if (tree->isLocal) { NOT_IMPLEMENTED_ABORT; } MRCPP_SET_OMP_LOCK(); @@ -421,14 +289,6 @@ template void MWNode::threadSafeGenChildren() { MRCPP_UNSET_OMP_LOCK(); } -/** @brief Creates scaling coefficients of children - * - * @details If the node is a leafNode, it takes the scaling&wavelet - * coefficients of the parent and it generates the scaling - * coefficients for the children and stores - * them consecutively in the corresponding block of the parent, - * following the usual bitwise notation. The new node is permanently added to the tree. - */ template void MWNode::threadSafeCreateChildren() { if (tree->isLocal) { NOT_IMPLEMENTED_ABORT; } MRCPP_SET_OMP_LOCK(); @@ -439,16 +299,6 @@ template void MWNode::threadSafeCreateChildren() { MRCPP_UNSET_OMP_LOCK(); } -/** @brief Coefficient-Value transform - * - * @details This routine transforms the scaling coefficients of the node to the - * function values in the corresponding quadrature roots (of its children). - * - * @param[in] operation: forward (coef->value) or backward (value->coef). - * - * NOTE: this routine assumes a 0/1 (scaling on child 0 and 1) - * representation, instead of s/d (scaling and wavelet). - */ template void MWNode::cvTransform(int operation, bool firstchild) { int kp1 = this->getKp1(); int kp1_dm1 = math_utils::ipow(kp1, D - 1); @@ -539,25 +389,6 @@ void MWNode::cvTransform(int operation) { } */ -/** @brief Multiwavelet transform - * - * @details Application of the filters on one node to pass from a 0/1 (scaling - * on child 0 and 1) representation to an s/d (scaling and - * wavelet) representation. Bit manipulation is used in order to - * determine the correct filters and whether to apply them or just - * pass to the next couple of indexes. The starting coefficients are - * preserved until the application is terminated, then they are - * overwritten. With minor modifications this code can also be used - * for the inverse mw transform (just use the transpose filters) or - * for the application of an operator (using A, B, C and T parts of an - * operator instead of G1, G0, H1, H0). This is the version where the - * three directions are operated one after the other. Although this - * is formally faster than the other algorithm, the separation of the - * three dimensions prevent the possibility to use the norm of the - * operator in order to discard a priori negligible contributions. - * - * * @param[in] operation: compression (s0,s1->s,d) or reconstruction (s,d->s0,s1). - */ template void MWNode::mwTransform(int operation) { int kp1 = this->getKp1(); int kp1_dm1 = math_utils::ipow(kp1, D - 1); @@ -597,19 +428,16 @@ template void MWNode::mwTransform(int operation) { } } -/** @brief Set all norms to Undefined. */ template void MWNode::clearNorms() { this->squareNorm = -1.0; for (int i = 0; i < this->getTDim(); i++) { this->componentNorms[i] = -1.0; } } -/** @brief Set all norms to zero. */ template void MWNode::zeroNorms() { this->squareNorm = 0.0; for (int i = 0; i < this->getTDim(); i++) { this->componentNorms[i] = 0.0; } } -/** @brief Calculate and store square norm and component norms, if allocated. */ template void MWNode::calcNorms() { this->squareNorm = 0.0; for (int i = 0; i < this->getTDim(); i++) { @@ -619,7 +447,6 @@ template void MWNode::calcNorms() { } } -/** @brief Calculate and return the squared scaling norm. */ template double MWNode::getScalingNorm() const { double sNorm = this->getComponentNorm(0); if (sNorm >= 0.0) { @@ -629,7 +456,6 @@ template double MWNode::getScalingNorm() const { } } -/** @brief Calculate and return the squared wavelet norm. */ template double MWNode::getWaveletNorm() const { double wNorm = 0.0; for (int i = 1; i < this->getTDim(); i++) { @@ -643,7 +469,6 @@ template double MWNode::getWaveletNorm() const { return wNorm; } -/** @brief Calculate the norm of one component (NOT the squared norm!). */ template double MWNode::calcComponentNorm(int i) const { if (this->isGenNode() and i != 0) return 0.0; assert(this->isAllocated()); @@ -658,9 +483,6 @@ template double MWNode::calcComponentNorm(int i) const return std::sqrt(sq_norm); } -/** @brief Update the coefficients of the node by a mw transform of the scaling - * coefficients of the children. - */ template void MWNode::reCompress() { if (this->isGenNode()) NOT_IMPLEMENTED_ABORT; if (this->isBranchNode()) { @@ -672,12 +494,6 @@ template void MWNode::reCompress() { } } -/** @brief Recurse down until an EndNode is found, and then crop children below the given precision threshold - * - * @param[in] prec: precision required - * @param[in] splitFac: factor used in the split check (larger factor means tighter threshold for finer nodes) - * @param[in] absPrec: flag to switch from relative (false) to absolute (true) precision. - */ template bool MWNode::crop(double prec, double splitFac, bool absPrec) { if (this->isEndNode()) { return true; @@ -707,11 +523,6 @@ template void MWNode::genParent() { NOT_REACHED_ABORT; } -/** @brief Recursive deallocation of children and all their descendants. - * - * @details - * Leaves node as LeafNode and children[] as null pointer. - */ template void MWNode::deleteChildren() { if (this->isLeafNode()) return; for (int cIdx = 0; cIdx < getTDim(); cIdx++) { @@ -726,7 +537,6 @@ template void MWNode::deleteChildren() { this->setIsLeafNode(); } -/** @brief Recursive deallocation of parent and all their forefathers. */ template void MWNode::deleteParent() { if (this->parent == nullptr) return; MWNode &parent = getMWParent(); @@ -736,7 +546,6 @@ template void MWNode::deleteParent() { this->parent = nullptr; } -/** @brief Deallocation of all generated nodes . */ template void MWNode::deleteGenerated() { if (this->isBranchNode()) { if (this->isEndNode()) { @@ -747,7 +556,6 @@ template void MWNode::deleteGenerated() { } } -/** @brief returns the coordinates of the centre of the node */ template Coord MWNode::getCenter() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); @@ -757,7 +565,6 @@ template Coord MWNode::getCenter() const { return r; } -/** @brief returns the upper bounds of the D-interval defining the node */ template Coord MWNode::getUpperBounds() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); @@ -767,7 +574,6 @@ template Coord MWNode::getUpperBounds() const { return ub; } -/** @brief returns the lower bounds of the D-interval defining the node */ template Coord MWNode::getLowerBounds() const { auto two_n = std::pow(2.0, -getScale()); auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors(); @@ -777,14 +583,6 @@ template Coord MWNode::getLowerBounds() const { return lb; } -/** @brief Routine to find the path along the tree. - * - * @param[in] nIdx: the sought after node through its NodeIndex - * - * @details Given the translation indices at the final scale, computes the child m - * to be followed at the current scale in oder to get to the requested - * node at the final scale. The result is the index of the child needed. - * The index is obtained by bit manipulation of of the translation indices. */ template int MWNode::getChildIndex(const NodeIndex &nIdx) const { assert(isAncestor(nIdx)); int cIdx = 0; @@ -799,12 +597,6 @@ template int MWNode::getChildIndex(const NodeIndex return cIdx; } -/** @brief Routine to find the path along the tree. - * - * @param[in] r: the sought after node through the coordinates of a point in space - * - * @detailsGiven a point in space, determines which child should be followed - * to get to the corresponding terminal node. */ template int MWNode::getChildIndex(const Coord &r) const { assert(hasCoord(r)); int cIdx = 0; @@ -818,18 +610,6 @@ template int MWNode::getChildIndex(const Coord &r) return cIdx; } -/** @brief Returns the quadrature points in a given node - * - * @param[in,out] pts: quadrature points in a \f$ d \times (k+1) \f$ matrix form. - * - * @details The original quadrature points are fetched and then - * dilated and translated. For each cartesian direction \f$ \alpha = - * x,y,z... \f$ the set of quadrature points becomes \f$ x^\alpha_i = - * 2^{-n} (x_i + l^\alpha \f$. By taking all possible - * \f$(k+1)^d\f$ combinations, they will then define a d-dimensional - * grid of quadrature points. - * - */ template void MWNode::getPrimitiveQuadPts(MatrixXd &pts) const { int kp1 = this->getKp1(); pts = MatrixXd::Zero(D, kp1); @@ -842,19 +622,6 @@ template void MWNode::getPrimitiveQuadPts(MatrixXd &pt for (int d = 0; d < D; d++) pts.row(d) = sFac * (roots.array() + static_cast(l[d])); } -/** @brief Returns the quadrature points in a given node - * - * @param[in,out] pts: quadrature points in a \f$ d \times (k+1) \f$ matrix form. - * - * @details The original quadrature points are fetched and then - * dilated and translated to match the quadrature points in the - * children of the given node. For each cartesian direction \f$ \alpha = x,y,z... \f$ - * the set of quadrature points becomes \f$ x^\alpha_i = 2^{-n-1} (x_i + 2 l^\alpha + t^\alpha) \f$, where \f$ t^\alpha = - * 0,1 \f$. By taking all possible \f$(k+1)^d\combinations \f$, they will - * then define a d-dimensional grid of quadrature points for the child - * nodes. - * - */ template void MWNode::getPrimitiveChildPts(MatrixXd &pts) const { int kp1 = this->getKp1(); pts = MatrixXd::Zero(D, 2 * kp1); @@ -870,16 +637,6 @@ template void MWNode::getPrimitiveChildPts(MatrixXd &p } } -/** @brief Returns the quadrature points in a given node - * - * @param[in,out] pts: expanded quadrature points in a \f$ d \times - * (k+1)^d \f$ matrix form. - * - * @details The primitive quadrature points are used to obtain a - * tensor-product representation collecting all \f$ (k+1)^d \f$ - * vectors of quadrature points. - * - */ template void MWNode::getExpandedQuadPts(Eigen::MatrixXd &pts) const { MatrixXd prim_pts; getPrimitiveQuadPts(prim_pts); @@ -894,16 +651,6 @@ template void MWNode::getExpandedQuadPts(Eigen::Matrix if (D >= 4) NOT_IMPLEMENTED_ABORT; } -/** @brief Returns the quadrature points in a given node - * - * @param[in,out] pts: expanded quadrature points in a \f$ d \times - * 2^d(k+1)^d \f$ matrix form. - * - * @details The primitive quadrature points of the children are used to obtain a - * tensor-product representation collecting all \f$ 2^d (k+1)^d \f$ - * vectors of quadrature points. - * - */ template void MWNode::getExpandedChildPts(MatrixXd &pts) const { MatrixXd prim_pts; getPrimitiveChildPts(prim_pts); @@ -928,16 +675,6 @@ template void MWNode::getExpandedChildPts(MatrixXd &pt } } -/** @brief Const version of node retriever that NEVER generates. - * - * @param[in] idx: the requested NodeIndex - * - * @details - * Recursive routine to find and return the node with a given NodeIndex. - * This routine returns the appropriate Node, or a NULL pointer if - * the node does not exist, or if it is a GenNode. Recursion starts at at this - * node and ASSUMES the requested node is in fact decending from this node. - */ template const MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) const { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); @@ -952,16 +689,6 @@ template const MWNode *MWNode::retrieveNodeNoGen return this->children[cIdx]->retrieveNodeNoGen(idx); } -/** @brief Node retriever that NEVER generates. - * - * @param[in] idx: the requested NodeIndex - * - * @details - * Recursive routine to find and return the node with a given NodeIndex. - * This routine returns the appropriate Node, or a NULL pointer if - * the node does not exist, or if it is a GenNode. Recursion starts at at this - * node and ASSUMES the requested node is in fact decending from this node. - */ template MWNode *MWNode::retrieveNodeNoGen(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); @@ -976,18 +703,6 @@ template MWNode *MWNode::retrieveNodeNoGen(const return this->children[cIdx]->retrieveNodeNoGen(idx); } -/** @brief Node retriever that returns requested Node or EndNode (const version). - * - * @param[in] r: the coordinates of a point in the node - * @param[in] depth: the depth which one needs to descend - * - * @details Recursive routine to find and return the node given the - * coordinates of a point in space. This routine returns the - * appropriate Node, or the EndNode on the path to the requested node, - * and will never create or return GenNodes. Recursion starts at at - * this node and ASSUMES the requested node is in fact decending from - * this node. - */ template const MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) const { if (getDepth() == depth or this->isEndNode()) { return this; } int cIdx = getChildIndex(r); @@ -995,18 +710,6 @@ template const MWNode *MWNode::retrieveNodeOrEnd return this->children[cIdx]->retrieveNodeOrEndNode(r, depth); } -/** @brief Node retriever that returns requested Node or EndNode. - * - * @param[in] r: the coordinates of a point in the node - * @param[in] depth: the depth which one needs to descend - * - * @details Recursive routine to find and return the node given the - * coordinates of a point in space. This routine returns the - * appropriate Node, or the EndNode on the path to the requested node, - * and will never create or return GenNodes. Recursion starts at at - * this node and ASSUMES the requested node is in fact decending from - * this node. - */ template MWNode *MWNode::retrieveNodeOrEndNode(const Coord &r, int depth) { if (getDepth() == depth or this->isEndNode()) { return this; } int cIdx = getChildIndex(r); @@ -1014,17 +717,6 @@ template MWNode *MWNode::retrieveNodeOrEndNode(c return this->children[cIdx]->retrieveNodeOrEndNode(r, depth); } -/** @brief Node retriever that returns requested Node or EndNode (const version). - * - * @param[in] idx: the NodeIndex of the requested node - * - * @details Recursive routine to find and return the node given the - * coordinates of a point in space. This routine returns the - * appropriate Node, or the EndNode on the path to the requested node, - * and will never create or return GenNodes. Recursion starts at at - * this node and ASSUMES the requested node is in fact decending from - * this node. - */ template const MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) const { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); @@ -1039,18 +731,6 @@ template const MWNode *MWNode::retrieveNodeOrEnd return this->children[cIdx]->retrieveNodeOrEndNode(idx); } -/** @brief Node retriever that returns requested Node or EndNode. - * - * @param[in] idx: the NodeIndex of the requested node - * - * @details - * Recursive routine to find and return the node given the - * coordinates of a point in space. This routine returns the - * appropriate Node, or the EndNode on the path to the requested node, - * and will never create or return GenNodes. Recursion starts at at - * this node and ASSUMES the requested node is in fact decending from - * this node. - */ template MWNode *MWNode::retrieveNodeOrEndNode(const NodeIndex &idx) { if (getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); @@ -1065,17 +745,6 @@ template MWNode *MWNode::retrieveNodeOrEndNode(c return this->children[cIdx]->retrieveNodeOrEndNode(idx); } -/** @brief Node retriever that ALWAYS returns the requested node. - * - * @param[in] r: the coordinates of a point in the node - * @param[in] depth: the depth which one needs to descend - * - * @details - * Recursive routine to find and return the node with a given NodeIndex. - * This routine always returns the appropriate node, and will generate nodes - * that does not exist. Recursion starts at this node and ASSUMES the - * requested node is in fact decending from this node. - */ template MWNode *MWNode::retrieveNode(const Coord &r, int depth) { if (depth < 0) MSG_ABORT("Invalid argument"); @@ -1087,17 +756,6 @@ template MWNode *MWNode::retrieveNode(const Coor return this->children[cIdx]->retrieveNode(r, depth); } -/** @brief Node retriever that ALWAYS returns the requested node, possibly without coefs. - * - * @param[in] idx: the NodeIndex of the requested node - * - * @details - * Recursive routine to find and return the node with a given NodeIndex. This - * routine always returns the appropriate node, and will generate nodes that - * does not exist. Recursion starts at this node and ASSUMES the requested - * node is in fact descending from this node. - * If create = true, the nodes are permanently added to the tree. - */ template MWNode *MWNode::retrieveNode(const NodeIndex &idx, bool create) { if (getScale() == idx.getScale()) { // we're done if (tree->isLocal) { @@ -1123,18 +781,6 @@ template MWNode *MWNode::retrieveNode(const Node return this->children[cIdx]->retrieveNode(idx, create); } -/** Node retriever that ALWAYS returns the requested node. - * - * WARNING: This routine is NOT thread safe! Must be used within omp critical. - * - * @param[in] idx: the NodeIndex of the requested node - * - * @details - * Recursive routine to find and return the node with a given NodeIndex. This - * routine always returns the appropriate node, and will generate nodes that - * does not exist. Recursion starts at this node and ASSUMES the requested - * node is in fact related to this node. - */ template MWNode *MWNode::retrieveParent(const NodeIndex &idx) { if (getScale() < idx.getScale()) MSG_ABORT("Scale error") if (getScale() == idx.getScale()) return this; @@ -1145,15 +791,6 @@ template MWNode *MWNode::retrieveParent(const No return this->parent->retrieveParent(idx); } -/** @brief Gives the norm (absolute value) of the node at the given NodeIndex. - * - * @param[in] idx: the NodeIndex of the requested node - * - * @details - * Recursive routine to find the node with a given NodeIndex. When an EndNode is - * found, do not generate any new node, but rather give the value of the norm - * assuming the function is uniformly distributed within the node. - */ template double MWNode::getNodeNorm(const NodeIndex &idx) const { if (this->getScale() == idx.getScale()) { // we're done assert(getNodeIndex() == idx); @@ -1168,10 +805,6 @@ template double MWNode::getNodeNorm(const NodeIndex return this->children[cIdx]->getNodeNorm(idx); } -/** @brief Test if a given coordinate is within the boundaries of the node. - * - * @param[in] r: point coordinates - */ template bool MWNode::hasCoord(const Coord &r) const { double sFac = std::pow(2.0, -getScale()); const NodeIndex &l = getNodeIndex(); @@ -1203,11 +836,6 @@ template bool MWNode::isCompatible(const MWNode // return true; } -/** @brief Test if the node is decending from a given NodeIndex, that is, if they have - * overlapping support. - * - * @param[in] idx: the NodeIndex of the requested node - */ template bool MWNode::isAncestor(const NodeIndex &idx) const { int relScale = idx.getScale() - getScale(); if (relScale < 0) return false; @@ -1223,10 +851,6 @@ template bool MWNode::isDecendant(const NodeIndex & NOT_IMPLEMENTED_ABORT; } -/** @brief printout ofm the node content. - * - * @param[in] o: the output stream - */ template std::ostream &MWNode::print(std::ostream &o) const { std::string flags = " "; o << getNodeIndex(); @@ -1252,12 +876,6 @@ template std::ostream &MWNode::print(std::ostream &o) return o; } -/** @brief recursively set maxSquaredNorm and maxWSquareNorm of parent and descendants - * - * @details - * normalization is such that a constant function gives constant value, - * i.e. *not* same normalization as a squareNorm - */ template void MWNode::setMaxSquareNorm() { auto n = this->getScale(); this->maxWSquareNorm = calcScaledWSquareNorm(); @@ -1272,8 +890,7 @@ template void MWNode::setMaxSquareNorm() { } } } -/** @brief recursively reset maxSquaredNorm and maxWSquareNorm of parent and descendants to value -1 - */ + template void MWNode::resetMaxSquareNorm() { auto n = this->getScale(); this->maxSquareNorm = -1.0; diff --git a/src/trees/MWNode.h b/src/trees/MWNode.h index f86313846..8b032698e 100644 --- a/src/trees/MWNode.h +++ b/src/trees/MWNode.h @@ -37,125 +37,337 @@ namespace mrcpp { -/** @class MWNode +/** + * @class MWNode + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) * * @brief Base class for Multiwavelet nodes * * @details A MWNode will contain the scaling and wavelet coefficients * to represent functions or operators within a Multiwavelet - * framework. The nodes are in multidimensional. The dimensionality is - * set thoucgh the template parameter D=1,2,3. In addition to the - * coefficients the node contains metadata such as the scale, the + * framework. The nodes are multidimensional. The dimensionality is + * set through the template parameter D=1,2,3. In addition to the + * coefficients, the node contains metadata such as the scale, the * translation index, the norm, pointers to parent node and child * nodes, pointer to the corresponding MWTree etc... See member and * data descriptions for details. * + * @note Nodes are created and managed by MWTree and specialized trees + * (e.g., FunctionTree). Most users should not instantiate nodes + * directly; instead, operate at the tree level. */ -template class MWNode { +template +class MWNode { public: + /** + * @brief MWNode copy constructor + * @param[in] node The original node + * @param allocCoef If true, allocate MW coefficients and copy from the original node + * @param SetCoef If true and @p allocCoef is true, copy coefficients + * + * @details Creates loose nodes and optionally copy coefs. The node + * does not "belong" to the tree: It cannot be accessed by traversing + * the tree. + */ MWNode(const MWNode &node, bool allocCoef = true, bool SetCoef = true); - MWNode &operator=(const MWNode &node) = delete; - virtual ~MWNode(); - - int getKp1() const { return getMWTree().getKp1(); } - int getKp1_d() const { return getMWTree().getKp1_d(); } - int getOrder() const { return getMWTree().getOrder(); } - int getScalingType() const { return getMWTree().getMRA().getScalingBasis().getScalingType(); } - int getTDim() const { return (1 << D); } - int getDepth() const { return getNodeIndex().getScale() - getMWTree().getRootScale(); } - int getScale() const { return getNodeIndex().getScale(); } - int getNChildren() const { return (isBranchNode()) ? getTDim() : 0; } - int getSerialIx() const { return this->serialIx; } - void setSerialIx(int Ix) { this->serialIx = Ix; } - const NodeIndex &getNodeIndex() const { return this->nodeIndex; } - const HilbertPath &getHilbertPath() const { return this->hilbertPath; } + MWNode &operator=(const MWNode &node) = delete; - Coord getCenter() const; - Coord getUpperBounds() const; - Coord getLowerBounds() const; + /// @brief Recursive deallocation of a node and all its decendants + virtual ~MWNode(); + /* + * Getters and setters + */ + int getOrder() const { return getMWTree().getOrder(); } ///< @return Polynomial order k + int getKp1() const { return getMWTree().getKp1(); } ///< @return k+1 + int getKp1_d() const { return getMWTree().getKp1_d(); } ///< @return (k+1)^D + int getScalingType() const { return getMWTree().getMRA().getScalingBasis().getScalingType(); } ///< @return The type of scaling basis (Legendre or Interpol; see MRCPP/constants.h) + int getTDim() const { return (1 << D); } ///< @return 2^D (number of children per internal node) + int getDepth() const { return getNodeIndex().getScale() - getMWTree().getRootScale(); } ///< @return The depth of this node + int getScale() const { return getNodeIndex().getScale(); } ///< @return The scale of this node + int getNChildren() const { return (isBranchNode()) ? getTDim() : 0; } ///< @return The number of children of this node + int getSerialIx() const { return this->serialIx; } ///< @return The index in the serial tree + void setSerialIx(int Ix) { this->serialIx = Ix; } ///< @param Ix The index in the serial tree + + const NodeIndex &getNodeIndex() const { return this->nodeIndex; } ///< @return The index (scale and translation) for this node + const HilbertPath &getHilbertPath() const { return this->hilbertPath; } // TODO document this + + Coord getCenter() const; ///< @return The coordinates of the centre of the node + Coord getUpperBounds() const; ///< @return The upper bounds of the D-interval defining the node + Coord getLowerBounds() const; ///< @return The lower bounds of the D-interval defining the node + + /** + * @brief Test if a given coordinate is within the boundaries of the node + * @param[in] r Point coordinates + */ bool hasCoord(const Coord &r) const; + + /// @warning This method is currently not implemented. bool isCompatible(const MWNode &node); + + /** + * @brief Test if the node is decending from a given NodeIndex, that is, if they have + * overlapping support + * @param[in] idx the NodeIndex of the requested node + */ bool isAncestor(const NodeIndex &idx) const; + + /// @warning This method is currently not implemented. bool isDecendant(const NodeIndex &idx) const; - double getSquareNorm() const { return this->squareNorm; } - double getMaxSquareNorm() const { return (maxSquareNorm > 0.0) ? maxSquareNorm : calcScaledSquareNorm(); } - double getMaxWSquareNorm() const { return (maxWSquareNorm > 0.0) ? maxWSquareNorm : calcScaledWSquareNorm(); } + double getSquareNorm() const { return this->squareNorm; } ///< @return Squared norm of all 2^D (k+1)^D coefficients + double getMaxSquareNorm() const { return (maxSquareNorm > 0.0) ? maxSquareNorm : calcScaledSquareNorm(); } ///< @return Largest squared norm among itself and descendants. + double getMaxWSquareNorm() const { return (maxWSquareNorm > 0.0) ? maxWSquareNorm : calcScaledWSquareNorm(); } ///< @return Largest wavelet squared norm among itself and descendants. + /** + * @brief Calculate and return the squared scaling norm + * @return The scaling norm + */ double getScalingNorm() const; + /** + * @brief Calculate and return the squared wavelet norm + * @return The squared wavelet norm + */ virtual double getWaveletNorm() const; + /** + * @param i The component index + * @return The squared norm of the component at the given index + */ double getComponentNorm(int i) const { return this->componentNorms[i]; } - int getNCoefs() const { return this->n_coefs; } + int getNCoefs() const { return this->n_coefs; } ///< @return The number of coefficients + /** + * @brief Wraps the MW coefficients into an Eigen vector object + * @param[out] c The coefficient matrix + */ void getCoefs(Eigen::Matrix &c) const; - void printCoefs() const; - - T *getCoefs() { return this->coefs; } - const T *getCoefs() const { return this->coefs; } + void printCoefs() const; ///< @brief Printout of node coefficients + + T *getCoefs() { return this->coefs; } ///< @return The 2^D (k+1)^D MW coefficients + const T *getCoefs() const { return this->coefs; } ///< @return The 2^D (k+1)^D MW coefficients + + /** + * @brief Returns the quadrature points of this node + * + * @param[out] pts Quadrature points in a \f$ d \times (k+1) \f$ matrix form + * + * @details The original quadrature points are fetched and then + * dilated and translated. For each cartesian direction \f$ \alpha = + * x,y,z... \f$ the set of quadrature points becomes \f$ x^\alpha_i = + * 2^{-n} (x_i + l^\alpha \f$. By taking all possible + * \f$(k+1)^d\f$ combinations, they will then define a d-dimensional + * grid of quadrature points. + */ void getPrimitiveQuadPts(Eigen::MatrixXd &pts) const; + + /** + * @brief Returns the quadrature points of this node + * + * @param[out] pts Quadrature points in a \f$ d \times (k+1) \f$ matrix form + * + * @details The original quadrature points are fetched and then + * dilated and translated to match the quadrature points in the + * children of this node. For each cartesian direction \f$ \alpha = x,y,z... \f$ + * the set of quadrature points becomes \f$ x^\alpha_i = 2^{-n-1} (x_i + 2 l^\alpha + t^\alpha) \f$, where \f$ t^\alpha = + * 0,1 \f$. By taking all possible \f$(k+1)^d\f$ combinations, they will + * then define a d-dimensional grid of quadrature points for the child + * nodes. + */ void getPrimitiveChildPts(Eigen::MatrixXd &pts) const; + + /** + * @brief Returns the quadrature points of this node + * + * @param[out] pts Expanded quadrature points in a \f$ d \times + * (k+1)^d \f$ matrix form + * + * @details The primitive quadrature points are used to obtain a + * tensor-product representation collecting all \f$ (k+1)^d \f$ + * vectors of quadrature points. + */ void getExpandedQuadPts(Eigen::MatrixXd &pts) const; + + /** + * @brief Returns the quadrature points of this node + * + * @param[out] pts Expanded quadrature points in a \f$ d \times + * 2^d(k+1)^d \f$ matrix form + * + * @details The primitive quadrature points of the children are used to obtain a + * tensor-product representation collecting all \f$ 2^d (k+1)^d \f$ + * vectors of quadrature points. + */ void getExpandedChildPts(Eigen::MatrixXd &pts) const; - MWTree &getMWTree() { return static_cast &>(*this->tree); } - MWNode &getMWParent() { return static_cast &>(*this->parent); } + MWTree &getMWTree() { return static_cast &>(*this->tree); } ///< @return The tree this node belongs to + MWNode &getMWParent() { return static_cast &>(*this->parent); } ///< @return The parent of this node + + /** + * @param i The index of the child + * @return The child at the given index + */ MWNode &getMWChild(int i) { return static_cast &>(*this->children[i]); } - const MWTree &getMWTree() const { return static_cast &>(*this->tree); } - const MWNode &getMWParent() const { return static_cast &>(*this->parent); } + const MWTree &getMWTree() const { return static_cast &>(*this->tree); } ///< @return The tree this node belongs to + const MWNode &getMWParent() const { return static_cast &>(*this->parent); } ///< @return The parent of this node + + /** + * @param i The index of the child + * @return The child at the given index + */ const MWNode &getMWChild(int i) const { return static_cast &>(*this->children[i]); } + /// @brief Sets all MW coefficients and the norms to zero void zeroCoefs(); + + /** + * @brief Assigns values to a block of coefficients + * @param block The block index + * @param block_size Size of the block + * @param[in] c The input coefficients + * + * @details A block is typically containing one kind of coefficients + * (given scaling/wavelet in each direction). Its size is then \f$ + * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. + */ void setCoefBlock(int block, int block_size, const T *c); + + /** + * @brief Adds values to a block of coefficients + * @param block The block index + * @param block_size Size of the block + * @param[in] c The input coefficients + * + * @details A block is typically containing one kind of coefficients + * (given scaling/wavelet in each direction). Its size is then \f$ + * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. + */ void addCoefBlock(int block, int block_size, const T *c); + + /** + * @brief Sets values of a block of coefficients to zero + * @param[in] block The block index + * @param[in] block_size Size of the block + * + * @details A block is typically containing one kind of coefficients + * (given scaling/wavelet in each direction). Its size is then \f$ + * (k+1)^D \f$ and the index is between 0 and \f$ 2^D-1 \f$. + */ void zeroCoefBlock(int block, int block_size); + + /** + * @brief Attach a set of coefficients to this node. Only used locally (the tree is not aware of this). + * @param[in] coefs The coefficients to attach + * + * @note The number of coefficients must remain the same. + */ void attachCoefs(T *coefs); - void calcNorms(); - void zeroNorms(); - void clearNorms(); + void calcNorms(); ///< @brief Calculate and store square norm and component norms, if allocated. + void zeroNorms(); ///< @brief Set all norms to zero. + void clearNorms(); ///< @brief Set all norms to Undefined. + /* + * Implemented in child classes + */ virtual void createChildren(bool coefs); virtual void genChildren(); virtual void genParent(); + + /** + * @brief Recursive deallocation of children and all their descendants + * + * @details Leaves node as LeafNode and children[] as null pointer + */ virtual void deleteChildren(); - virtual void deleteParent(); - virtual void cvTransform(int kind, bool firstchild = false); - virtual void mwTransform(int kind); + /// @brief Recursive deallocation of parent and all their forefathers. + virtual void deleteParent(); + /** + * @brief Coefficient-Value transform + * @param operation Forward (coef->value) or backward (value->coef) + * + * @details This routine transforms the scaling coefficients of the node to the + * function values in the corresponding quadrature roots (of its children). + * + * @note This routine assumes a 0/1 (scaling on child 0 and 1) + * representation, instead of s/d (scaling and wavelet). + */ + virtual void cvTransform(int operation, bool firstchild = false); // TODO document firstchild parameter + + /** + * @brief Multiwavelet transform + * @param operation Compression (s0,s1->s,d) or reconstruction (s,d->s0,s1) + * + * @details Application of the filters on one node to pass from a 0/1 (scaling + * on child 0 and 1) representation to an s/d (scaling and + * wavelet) representation. Bit manipulation is used in order to + * determine the correct filters and whether to apply them or just + * pass to the next couple of indexes. The starting coefficients are + * preserved until the application is terminated, then they are + * overwritten. With minor modifications this code can also be used + * for the inverse mw transform (just use the transpose filters) or + * for the application of an operator (using A, B, C and T parts of an + * operator instead of G1, G0, H1, H0). This is the version where the + * three directions are operated one after the other. Although this + * is formally faster than the other algorithm, the separation of the + * three dimensions prevent the possibility to use the norm of the + * operator in order to discard a priori negligible contributions. + */ + virtual void mwTransform(int operation); + + /** + * @brief Gives the norm (absolute value) of the node at the given NodeIndex + * @param[in] idx the NodeIndex of the requested node + * + * @details + * Recursive routine to find the node with a given NodeIndex. When an EndNode is + * found, do not generate any new node, but rather give the value of the norm + * assuming the function is uniformly distributed within the node. + */ double getNodeNorm(const NodeIndex &idx) const; - bool hasParent() const { return (parent != nullptr) ? true : false; } - bool hasCoefs() const { return (this->status & FlagHasCoefs); } - bool isEndNode() const { return (this->status & FlagEndNode); } - bool isGenNode() const { return (this->status & FlagGenNode); } - bool isRootNode() const { return (this->status & FlagRootNode); } - bool isLeafNode() const { return not(this->status & FlagBranchNode); } - bool isAllocated() const { return (this->status & FlagAllocated); } - bool isBranchNode() const { return (this->status & FlagBranchNode); } - bool isLooseNode() const { return (this->status & FlagLooseNode); } + /* + * Getters and setters + */ + bool hasParent() const { return (parent != nullptr) ? true : false; } ///< @return Whether the node hsa a parent + bool hasCoefs() const { return (this->status & FlagHasCoefs); } ///< @return Whether the node has coefficients + bool isEndNode() const { return (this->status & FlagEndNode); } ///< @return Whether the node is an end node + bool isGenNode() const { return (this->status & FlagGenNode); } ///< @return Whether the node is a generated node + bool isRootNode() const { return (this->status & FlagRootNode); } ///< @return Whether the node is a root node + bool isLeafNode() const { return not(this->status & FlagBranchNode); } ///< @return Whether the node is a leaf node + bool isAllocated() const { return (this->status & FlagAllocated); } ///< @return Whether the node is fully allocated + bool isBranchNode() const { return (this->status & FlagBranchNode); } ///< @return Whether the node is a leaf node + bool isLooseNode() const { return (this->status & FlagLooseNode); } ///< @return Whether the node is a loose node + + /** + * @brief Allows checking the state of a node against a state mask + * @param mask The status mask to compare against + * @return Whether the state of the node matches the given mask + */ bool checkStatus(unsigned char mask) const { return (mask == (this->status & mask)); } - void setHasCoefs() { SET_BITS(status, FlagHasCoefs | FlagAllocated); } - void setIsEndNode() { SET_BITS(status, FlagEndNode); } - void setIsGenNode() { SET_BITS(status, FlagGenNode); } - void setIsRootNode() { SET_BITS(status, FlagRootNode); } - void setIsLeafNode() { CLEAR_BITS(status, FlagBranchNode); } - void setIsAllocated() { SET_BITS(status, FlagAllocated); } - void setIsBranchNode() { SET_BITS(status, FlagBranchNode); } - void setIsLooseNode() { SET_BITS(status, FlagLooseNode); } - void clearHasCoefs() { CLEAR_BITS(status, FlagHasCoefs); } - void clearIsEndNode() { CLEAR_BITS(status, FlagEndNode); } - void clearIsGenNode() { CLEAR_BITS(status, FlagGenNode); } - void clearIsRootNode() { CLEAR_BITS(status, FlagRootNode); } - void clearIsAllocated() { CLEAR_BITS(status, FlagAllocated); } + void setHasCoefs() { SET_BITS(status, FlagHasCoefs | FlagAllocated); } ///< @brief Marks the node as having coefficients + void setIsEndNode() { SET_BITS(status, FlagEndNode); } ///< @brief Marks the node as an end node + void setIsGenNode() { SET_BITS(status, FlagGenNode); } ///< @brief Marks the node as a generated node + void setIsRootNode() { SET_BITS(status, FlagRootNode); } ///< @brief Marks the node as a root node + void setIsLeafNode() { CLEAR_BITS(status, FlagBranchNode); } ///< @brief Marks the node as a leaf node + void setIsAllocated() { SET_BITS(status, FlagAllocated); } ///< @brief Marks the node as allocated + void setIsBranchNode() { SET_BITS(status, FlagBranchNode); } ///< @brief Marks the node as a leaf node + void setIsLooseNode() { SET_BITS(status, FlagLooseNode); } ///< @brief Marks the node as a loose node + void clearHasCoefs() { CLEAR_BITS(status, FlagHasCoefs); } ///< @brief Clears the mark for having coefficients + void clearIsEndNode() { CLEAR_BITS(status, FlagEndNode); } ///< @brief Clears the mark for being an end node + void clearIsGenNode() { CLEAR_BITS(status, FlagGenNode); } ///< @brief Clears the mark for being a generated node + void clearIsRootNode() { CLEAR_BITS(status, FlagRootNode); } ///< @brief Clears the mark for being a root node + void clearIsAllocated() { CLEAR_BITS(status, FlagAllocated); } ///< @brief Clears the mark for being allocated friend std::ostream &operator<<(std::ostream &o, const MWNode &nd) { return nd.print(o); } + // Friend classes that are allowed to operate on internals. friend class TreeBuilder; friend class MultiplicationCalculator; friend class NodeAllocator; @@ -165,98 +377,377 @@ template class MWNode { friend class FunctionNode; friend class OperatorNode; friend class DerivativeCalculator; - bool isComplex = false; // TODO put as one of the flags - friend class FunctionTree; // required if a ComplexDouble tree access a double node from another tree! + bool isComplex = false; // TODO put as one of the flags + friend class FunctionTree; // required if a ComplexDouble tree access a double node from another tree! friend class FunctionTree; - int childSerialIx{-1}; ///< index of first child in serial Tree, or -1 for leafnodes/endnodes + int childSerialIx{-1}; ///< index of first child in a serial tree, or -1 for leaf nodes/end nodes protected: - MWTree *tree{nullptr}; ///< Tree the node belongs to - MWNode *parent{nullptr}; ///< Parent node - MWNode *children[1 << D]; ///< 2^D children - - double squareNorm{-1.0}; ///< Squared norm of all 2^D (k+1)^D coefficients - double componentNorms[1 << D]; ///< Squared norms of the separeted 2^D components - double maxSquareNorm{-1.0}; ///< Largest squared norm among itself and descendants. - double maxWSquareNorm{-1.0}; ///< Largest wavelet squared norm among itself and descendants. - ///< NB: must be set before used. - T *coefs{nullptr}; ///< the 2^D (k+1)^D MW coefficients - ///< For example, in case of a one dimensional function \f$ f \f$ - ///< this array equals \f$ s_0, \ldots, s_k, d_0, \ldots, d_k \f$, - ///< where scaling coefficients \f$ s_j = s_{jl}^n(f) \f$ - ///< and wavelet coefficients \f$ d_j = d_{jl}^n(f) \f$. - ///< Here \f$ n, l \f$ are unique for every node. - int n_coefs{0}; - - int serialIx{-1}; ///< index in serial Tree - int parentSerialIx{-1}; ///< index of parent in serial Tree, or -1 for roots - - NodeIndex nodeIndex; ///< Scale and translation of the node - HilbertPath hilbertPath; ///< To be documented - + MWTree *tree{nullptr}; ///< Tree the node belongs to + MWNode *parent{nullptr}; ///< Parent node (nullptr for root nodes) + MWNode *children[1 << D]; ///< Array of 2^D children (valid if branch node) + + double squareNorm{-1.0}; ///< Squared norm of all 2^D (k+1)^D coefficients + double componentNorms[1 << D]; ///< Squared norms of the separated 2^D components + double maxSquareNorm{-1.0}; ///< Maximum squared norm among the node and descendants + double maxWSquareNorm{-1.0}; ///< Maximum wavelet squared norm among the node and descendants + ///< NB: must be set before used. + T *coefs{nullptr}; ///< The 2^D (k+1)^D MW coefficients + ///< For example, in case of a one dimensional function \f$ f \f$ + ///< this array equals \f$ s_0, \ldots, s_k, d_0, \ldots, d_k \f$, + ///< where scaling coefficients \f$ s_j = s_{jl}^n(f) \f$ + ///< and wavelet coefficients \f$ d_j = d_{jl}^n(f) \f$. + ///< Here \f$ n, l \f$ are unique for every node. + int n_coefs{0}; ///< Number of coefficients in @ref coefs. + + int serialIx{-1}; ///< Index in the serial tree + int parentSerialIx{-1}; ///< Index of the parent in the serial tree, or -1 for root nodes + + NodeIndex nodeIndex; ///< Scale and translation of this node. + HilbertPath hilbertPath; ///< Current Hilbert path state for child ordering. + + /** + * @brief MWNode default constructor + * + * @details Should be used only by NodeAllocator to obtain + * virtual table pointers for the derived classes + */ MWNode(); + + /** + * @brief MWNode constructor + * @param[in] tree The MWTree the root node belongs to + * @param[in] rIdx The integer specifying the corresponding root node + * + * @details Constructor for root nodes. It requires the corresponding + * MWTree and an integer to fetch the right NodeIndex. + */ MWNode(MWTree *tree, int rIdx); + + /** + * @brief MWNode constructor + * @param[in] tree The MWTree the root node belongs to + * @param[in] idx The NodeIndex defining scale and translation of the node + * + * @details Constructor for an empty node, given the corresponding MWTree and NodeIndex + */ MWNode(MWTree *tree, const NodeIndex &idx); + + /** + * @brief MWNode constructor + * @param[in] parent Parent node + * @param[in] cIdx Child index of the current node + * + * @details Constructor for leaf nodes. It requires the corresponding + * parent and an integer to identify the correct child. + */ MWNode(MWNode *parent, int cIdx); + // Implemented in child classes virtual void dealloc(); + /** + * @brief Recurse down until an EndNode is found, and then crop children below the given precision threshold + * @param prec The required precision + * @param splitFac Factor used in the split check (larger factor means tighter threshold for finer nodes) + * @param absPrec Flag to switch from relative (false) to absolute (true) precision. + * @return Whether the crop was successful + */ bool crop(double prec, double splitFac, bool absPrec); + /// @brief Initialize thread lock (when OpenMP is enabled). void initNodeLock() { MRCPP_INIT_OMP_LOCK(); } + + /** + * @brief Allocate the coefs vector + * @param n_blocks The number of blocks + * @param block_size The size of a block + * + * @details This is only used by loose nodes, because the loose nodes + * are not treated by the NodeAllocator class. + */ virtual void allocCoefs(int n_blocks, int block_size); + + /** + * @brief Deallocate the coefs vector + * + * @details This is only used by loose nodes, because the loose nodes + * are not treated by the NodeAllocator class. + */ virtual void freeCoefs(); + /** + * @brief recursively set maxSquaredNorm and maxWSquareNorm of parent and descendants + * + * @details + * normalization is such that a constant function gives constant value, + * i.e. *not* same normalization as a squareNorm + */ void setMaxSquareNorm(); + + /// @brief Recursively reset maxSquaredNorm and maxWSquareNorm of parent and descendants to value -1 void resetMaxSquareNorm(); + + /// @return The scaled square norm. double calcScaledSquareNorm() const { return std::pow(2.0, D * getScale()) * getSquareNorm(); } + + /// @return The scaled wavelet square norm. double calcScaledWSquareNorm() const { return std::pow(2.0, D * getScale()) * getWaveletNorm(); } + + /** + * @brief Calculate the norm of one component (NOT the squared norm!) + * @param i The component index + * @return The single component norm + */ virtual double calcComponentNorm(int i) const; + /** + * @brief Update the coefficients of the node by a MW transform of the scaling + * coefficients of the children. + */ virtual void reCompress(); + + /** + * @brief Forward MW transform from this node to its children + * @param overwrite If true, the coefficients of the children are + * overwritten. If false, the values are summed to the already present + * ones. + * + * @details It performs forward MW transform inserting the result + * directly in the right place for each child node. The children must + * already be present and its memory allocated for this to work + * properly. + */ virtual void giveChildrenCoefs(bool overwrite = true); + + /** + * @brief Forward MW transform to compute scaling coefficients of a single child + * @param[in] cIdx The child index + * @param[in] overwrite If true, the coefficients of the children are + * overwritten. If false, the values are summed to the already present + * ones. + * + * @details It performs forward MW transform in place on a loose + * node. The scaling coefficients of the selected child are then + * copied/summed in the correct child node. + */ virtual void giveChildCoefs(int cIdx, bool overwrite = true); + + /** @brief Backward MW transform to compute scaling/wavelet coefficients of a parent + * + * @details Takes a MWParent and generates coefficients, reverse operation from + * giveChildrenCoefs. + * + * @note This routine is only used in connection with Periodic Boundary Conditions + */ virtual void giveParentCoefs(bool overwrite = true); + + /** + * @brief Copy scaling coefficients from children to parent + * + * @details Takes the scaling coefficients of the children and stores + * them consecutively in the corresponding block of the parent, + * following the usual bitwise notation. + */ virtual void copyCoefsFromChildren(); + /** + * @brief Routine to find the path along the tree + * @param[in] nIdx The sought after node through its NodeIndex + * + * @details Given the translation indices at the final scale, computes the child m + * to be followed at the current scale in oder to get to the requested + * node at the final scale. The result is the index of the child needed. + * The index is obtained by bit manipulation of of the translation indices. + */ int getChildIndex(const NodeIndex &nIdx) const; + + /** + * @brief Routine to find the path along the tree + * @param[in] r The sought after node through the coordinates of a point in space + * + * @details Given a point in space, determines which child should be followed + * to get to the corresponding terminal node. + */ int getChildIndex(const Coord &r) const; + /** + * @brief Fast check whether two nodes lie in different branches + * @param rhs The node to compare against + * @return true if two nodes lie in different branches + */ bool diffBranch(const MWNode &rhs) const; + /** + * @brief Node retriever that ALWAYS returns the requested node + * + * @param[in] r The coordinates of a point in the node + * @param depth The depth to descend + * @return The node at the given coordinates + * + * @details Recursive routine to find and return the node with a given NodeIndex. + * This routine always returns the appropriate node, and will generate nodes + * that does not exist. Recursion starts at this node and ASSUMES the + * requested node is in fact decending from this node. + */ MWNode *retrieveNode(const Coord &r, int depth); + + /** + * @brief Node retriever that ALWAYS returns the requested node, possibly without coefs + * @param[in] idx The NodeIndex of the requested node + * @return The node at the given node index + * + * @details Recursive routine to find and return the node with a given NodeIndex. This + * routine always returns the appropriate node, and will generate nodes that + * does not exist. Recursion starts at this node and ASSUMES the requested + * node is in fact descending from this node. + * If create = true, the nodes are permanently added to the tree. + */ MWNode *retrieveNode(const NodeIndex &idx, bool create = false); + + /** + * @brief Node retriever that ALWAYS returns the requested node + * @param[in] idx The NodeIndex of the requested node + * @return The node at the given node index + * + * @details Recursive routine to find and return the node with a given NodeIndex. This + * routine always returns the appropriate node, and will generate nodes that + * does not exist. Recursion starts at this node and ASSUMES the requested + * node is in fact related to this node. + * + * @warning This routine is NOT thread safe! Must be used within omp critical. + */ MWNode *retrieveParent(const NodeIndex &idx); + /** + * @brief Const version of node retriever that NEVER generates + * @param[in] idx The requested NodeIndex + * @returns The requested node + * + * @details Recursive routine to find and return the node with a given NodeIndex. + * This routine returns the appropriate Node, or a NULL pointer if + * the node does not exist, or if it is a GenNode. Recursion starts at at this + * node and ASSUMES the requested node is in fact decending from this node. + */ const MWNode *retrieveNodeNoGen(const NodeIndex &idx) const; + + /** + * @brief Node retriever that NEVER generates. + * @param[in] idx The requested NodeIndex + * @returns The requested node + * + * @details Recursive routine to find and return the node with a given NodeIndex. + * This routine returns the appropriate Node, or a NULL pointer if + * the node does not exist, or if it is a GenNode. Recursion starts at at this + * node and ASSUMES the requested node is in fact decending from this node. + */ MWNode *retrieveNodeNoGen(const NodeIndex &idx); + /** + * @brief Node retriever that returns requested Node or EndNode (const version) + * @param[in] r The coordinates of a point in the node + * @param depth The depth to descend + * @return The node at the given coordinates + * + * @details Recursive routine to find and return the node given the + * coordinates of a point in space. This routine returns the + * appropriate Node, or the EndNode on the path to the requested node, + * and will never create or return GenNodes. Recursion starts at at + * this node and ASSUMES the requested node is in fact decending from + * this node. + */ const MWNode *retrieveNodeOrEndNode(const Coord &r, int depth) const; + + /** + * @brief Node retriever that returns requested Node or EndNode + * @param[in] r The coordinates of a point in the node + * @param depth The depth to descend + * @return The node at the given coordinates + * + * @details Recursive routine to find and return the node given the + * coordinates of a point in space. This routine returns the + * appropriate Node, or the EndNode on the path to the requested node, + * and will never create or return GenNodes. Recursion starts at at + * this node and ASSUMES the requested node is in fact decending from + * this node. + */ MWNode *retrieveNodeOrEndNode(const Coord &r, int depth); + /** + * @brief Node retriever that returns requested Node or EndNode (const version) + * @param[in] idx The NodeIndex of the requested node + * @return The requested node + * + * @details Recursive routine to find and return the node given the + * coordinates of a point in space. This routine returns the + * appropriate Node, or the EndNode on the path to the requested node, + * and will never create or return GenNodes. Recursion starts at at + * this node and ASSUMES the requested node is in fact decending from + * this node. + */ const MWNode *retrieveNodeOrEndNode(const NodeIndex &idx) const; + + /** + * @brief Node retriever that returns requested Node or EndNode + * @param[in] idx The NodeIndex of the requested node + * @return The requested node + * + * @details Recursive routine to find and return the node given the + * coordinates of a point in space. This routine returns the + * appropriate Node, or the EndNode on the path to the requested node, + * and will never create or return GenNodes. Recursion starts at at + * this node and ASSUMES the requested node is in fact decending from + * this node. + */ MWNode *retrieveNodeOrEndNode(const NodeIndex &idx); + /** + * @brief Creates scaling coefficients of children + * + * @details If the node is a leaf node, it takes the scaling&wavelet + * coefficients of the parent and it generates the scaling + * coefficients for the children and stores + * them consecutively in the corresponding block of the parent, + * following the usual bitwise notation. The new node is permanently added to the tree. + */ void threadSafeCreateChildren(); + + /** + * @brief Generates scaling coefficients of children + * + * @details If the node is a leaf node, it takes the scaling&wavelet + * coefficients of the parent and it generates the scaling + * coefficients for the children and stores + * them consecutively in the corresponding block of the parent, + * following the usual bitwise notation. + */ void threadSafeGenChildren(); + + + /// @brief Deallocation of all generated nodes void deleteGenerated(); + /** + * @brief Prints of the node content + * @param[in,out] o The output stream + */ virtual std::ostream &print(std::ostream &o) const; + // Bit flags describing node state static const unsigned char FlagBranchNode = B8(00000001); - static const unsigned char FlagGenNode = B8(00000010); - static const unsigned char FlagHasCoefs = B8(00000100); - static const unsigned char FlagAllocated = B8(00001000); - static const unsigned char FlagEndNode = B8(00010000); - static const unsigned char FlagRootNode = B8(00100000); - static const unsigned char FlagLooseNode = B8(01000000); + static const unsigned char FlagGenNode = B8(00000010); + static const unsigned char FlagHasCoefs = B8(00000100); + static const unsigned char FlagAllocated = B8(00001000); + static const unsigned char FlagEndNode = B8(00010000); + static const unsigned char FlagRootNode = B8(00100000); + static const unsigned char FlagLooseNode = B8(01000000); private: - unsigned char status{0}; + unsigned char status{0}; ///< Bit mask of @ref FlagBranchNode, @ref FlagGenNode, etc. #ifdef MRCPP_HAS_OMP - omp_lock_t omp_lock; + omp_lock_t omp_lock; ///< Per-node lock for thread-safe edits (OpenMP). #endif }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/MWTree.cpp b/src/trees/MWTree.cpp index 6a646d33f..4f26b9a48 100644 --- a/src/trees/MWTree.cpp +++ b/src/trees/MWTree.cpp @@ -40,15 +40,6 @@ using namespace Eigen; namespace mrcpp { -/** @brief MWTree constructor. - * - * @param[in] mra: the multiresolution analysis object - * @param[in] n: the name of the tree (only for printing purposes) - * - * @details Creates an empty tree object, containing only the set of - * root nodes. The information for the root node configuration to use - * is in the mra object which is passed to the constructor. - */ template MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n) : MRA(mra) @@ -60,19 +51,12 @@ MWTree::MWTree(const MultiResolutionAnalysis &mra, const std::string &n this->nodesAtDepth.push_back(0); } -/** @brief MWTree destructor. */ template MWTree::~MWTree() { this->endNodeTable.clear(); if (this->nodesAtDepth.size() != 1) MSG_ERROR("Nodes at depth != 1 -> " << this->nodesAtDepth.size()); if (this->nodesAtDepth[0] != 0) MSG_ERROR("Nodes at depth 0 != 0 -> " << this->nodesAtDepth[0]); } -/** @brief Deletes all the nodes in the tree - * - * @details This method will recursively delete all the nodes, - * including the root nodes. Derived classes will call this method - * when the object is deleted. - */ template void MWTree::deleteRootNodes() { for (int i = 0; i < this->rootBox.size(); i++) { MWNode &root = this->getRootMWNode(i); @@ -82,14 +66,6 @@ template void MWTree::deleteRootNodes() { } } -/** @brief Remove all nodes in the tree - * - * @details Leaves the tree in the same state as after construction, - * i.e. undefined tree structure containing only root nodes without - * coefficients. The assigned memory, including branch and leaf - * nodes, (nodeChunks in NodeAllocator) is NOT released, but is - * immediately available to the new function. - */ template void MWTree::clear() { for (int i = 0; i < this->rootBox.size(); i++) { MWNode &root = this->getRootMWNode(i); @@ -101,11 +77,6 @@ template void MWTree::clear() { this->clearSquareNorm(); } -/** @brief Calculate the squared norm \f$ ||f||^2_{\ldots} \f$ of a function represented as a tree. - * - * @details The norm is calculated using endNodes only. The specific - * type of norm which is computed will depend on the derived class - */ template void MWTree::calcSquareNorm(bool deep) { double treeNorm = 0.0; for (int n = 0; n < this->getNEndNodes(); n++) { @@ -117,29 +88,6 @@ template void MWTree::calcSquareNorm(bool deep) { this->squareNorm = treeNorm; } -/** @brief Full Multiwavelet transform of the tree in either directions - * - * @param[in] type: TopDown (from roots to leaves) or BottomUp (from - * leaves to roots) which specifies the direction of the MW transform - * @param[in] overwrite: if true, the result will overwrite - * preexisting coefficients. - * - * @details It performs a Multiwavlet transform of the whole tree. The - * input parameters will specify the direction (upwards or downwards) - * and whether the result is added to the coefficients or it - * overwrites them. See the documentation for the #mwTransformUp - * and #mwTransformDown for details. - * \f[ - * \pmatrix{ - * s_{nl}\\ - * d_{nl} - * } - * \rightleftarrows \pmatrix{ - * s_{n+1,2l}\\ - * s_{n+1,2l+1} - * } - * \f] - */ template void MWTree::mwTransform(int type, bool overwrite) { switch (type) { case TopDown: @@ -154,15 +102,6 @@ template void MWTree::mwTransform(int type, bool overw } } -/** @brief Regenerates all s/d-coeffs by backtransformation - * - * @details It starts at the bottom of the tree (scaling coefficients - * of the leaf nodes) and it generates the scaling and wavelet - * coefficients of the parent node. It then proceeds recursively all the - * way up to the root nodes. This is generally used after a function - * projection to purify the coefficients obtained by quadrature at - * coarser scales which are therefore not precise enough. - */ template void MWTree::mwTransformUp() { std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); @@ -180,17 +119,6 @@ template void MWTree::mwTransformUp() { } } -/** @brief Regenerates all scaling coeffs by MW transformation of existing s/w-coeffs - * on coarser scales - * - * @param[in] overwrite: if true the preexisting coefficients are overwritten - * - * @details The transformation starts at the rootNodes and proceeds - * recursively all the way to the leaf nodes. The existing scaling - * coefficeints will either be overwritten or added to. The latter - * operation is generally used after the operator application. - * - */ template void MWTree::mwTransformDown(bool overwrite) { std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); @@ -215,12 +143,6 @@ template void MWTree::mwTransformDown(bool overwrite) } } -/** @brief Set the MW coefficients to zero, keeping the same tree structure - * - * @details Keeps the node structure of the tree, even though the zero - * function is representable at depth zero. One should then use \ref cropTree to remove - * unnecessary nodes. - */ template void MWTree::setZero() { TreeIterator it(*this); while (it.next()) { @@ -230,13 +152,6 @@ template void MWTree::setZero() { this->squareNorm = 0.0; } -/** @brief Increments node counter by one for non-GenNodes. - * - * @details TO BE DOCUMENTED - * \warning: This routine is not thread - * safe, and must NEVER be called outside a critical region in parallel. - * It's way. way too expensive to lock the tree, so don't even think - * about it. */ template void MWTree::incrementNodeCount(int scale) { int depth = scale - getRootScale(); if (depth < 0) { @@ -254,14 +169,6 @@ template void MWTree::incrementNodeCount(int scale) { } } -/** @brief Decrements node counter by one for non-GenNodes. - * - * @details TO BE DOCUMENTED - * \warning: This routine is not thread - * safe, and must NEVER be called outside a critical region in parallel. - * It's way. way too expensive to lock the tree, so don't even think - * about it. - */ template void MWTree::decrementNodeCount(int scale) { int depth = scale - getRootScale(); if (depth < 0) { @@ -277,10 +184,6 @@ template void MWTree::decrementNodeCount(int scale) { } } -/** @returns Total number of nodes in the tree, at given depth (not in use) - * - * @param[in] depth: Tree depth (0 depth is the coarsest scale) to count. - */ template int MWTree::getNNodesAtDepth(int depth) const { int N = 0; if (depth < 0) { @@ -291,19 +194,11 @@ template int MWTree::getNNodesAtDepth(int depth) const return N; } -/** @returns Size of all MW coefs in the tree, in kB */ template int MWTree::getSizeNodes() const { auto nCoefs = 1ll * getNNodes() * getTDim() * getKp1_d(); return sizeof(T) * nCoefs / 1024; } -/** @brief Finds and returns the node pointer with the given \ref NodeIndex, const version. - * - * @details Recursive routine to find and return the node with a given - * NodeIndex. This routine returns the appropriate Node, or a NULL - * pointer if the node does not exist, or if it is a - * GenNode. Recursion starts at the appropriate rootNode. - */ template const MWNode *MWTree::findNode(NodeIndex idx) const { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } int rIdx = getRootBox().getBoxIndex(idx); @@ -313,13 +208,6 @@ template const MWNode *MWTree::findNode(NodeInde return root.retrieveNodeNoGen(idx); } -/** @brief Finds and returns the node pointer with the given \ref NodeIndex. - * - * @details Recursive routine to find and return the node with a given - * NodeIndex. This routine returns the appropriate Node, or a NULL - * pointer if the node does not exist, or if it is a - * GenNode. Recursion starts at the appropriate rootNode. - */ template MWNode *MWTree::findNode(NodeIndex idx) { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } int rIdx = getRootBox().getBoxIndex(idx); @@ -329,14 +217,6 @@ template MWNode *MWTree::findNode(NodeIndex i return root.retrieveNodeNoGen(idx); } -/** @brief Finds and returns the node reference with the given NodeIndex. - * - * @details This routine ALWAYS returns the node you ask for. If the - * node does not exist, it will be generated by MW - * transform. Recursion starts at the appropriate rootNode and descends - * from this. - * The nodes are permanently added to the tree if create = true - */ template MWNode &MWTree::getNode(NodeIndex idx, bool create) { if (getRootBox().isPeriodic()) periodic::index_manipulation(idx, getRootBox().getPeriodic()); @@ -351,14 +231,6 @@ template MWNode &MWTree::getNode(NodeIndex id return *out; } -/** @brief Finds and returns the node with the given NodeIndex. - * - * @details This routine returns the Node you ask for, or the EndNode - * on the path to the requested node, if the requested one is deeper - * than the leaf node ancestor. It will never create or return - * GenNodes. Recursion starts at the appropriate rootNode and decends - * from this. - */ template MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } MWNode &root = getRootBox().getNode(idx); @@ -366,13 +238,6 @@ template MWNode &MWTree::getNodeOrEndNode(NodeIn return *root.retrieveNodeOrEndNode(idx); } -/** @brief Finds and returns the node reference with the given NodeIndex. Const version. - * - * @details This routine ALWAYS returns the node you ask for. If the - * node does not exist, it will be generated by MW - * transform. Recursion starts at the appropriate rootNode and decends - * from this. - */ template const MWNode &MWTree::getNodeOrEndNode(NodeIndex idx) const { if (getRootBox().isPeriodic()) { periodic::index_manipulation(idx, getRootBox().getPeriodic()); } const MWNode &root = getRootBox().getNode(idx); @@ -380,15 +245,6 @@ template const MWNode &MWTree::getNodeOrEndNode( return *root.retrieveNodeOrEndNode(idx); } -/** @brief Finds and returns the node at a given depth that contains a given coordinate. - * - * @param[in] depth: requested node depth from root scale. - * @param[in] r: coordinates of an arbitrary point in space - * - * @details This routine ALWAYS returns the node you ask for, and will - * generate nodes that do not exist. Recursion starts at the - * appropriate rootNode and decends from this. - */ template MWNode &MWTree::getNode(Coord r, int depth) { MWNode &root = getRootBox().getNode(r); if (depth >= 0) { @@ -398,15 +254,6 @@ template MWNode &MWTree::getNode(Coord r, int } } -/** @brief Finds and returns the node at a given depth that contains a given coordinate. - * - * @param[in] depth: requested node depth from root scale. - * @param[in] r: coordinates of an arbitrary point in space - * - * @details This routine returns the Node you ask for, or the EndNode on - * the path to the requested node, and will never create or return GenNodes. - * Recursion starts at the appropriate rootNode and decends from this. - */ template MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) { if (getRootBox().isPeriodic()) { periodic::coord_manipulation(r, getRootBox().getPeriodic()); } @@ -415,15 +262,6 @@ template MWNode &MWTree::getNodeOrEndNode(Coord< return *root.retrieveNodeOrEndNode(r, depth); } -/** @brief Finds and returns the node at a given depth that contains a given coordinate. Const version - * - * @param[in] depth: requested node depth from root scale. - * @param[in] r: coordinates of an arbitrary point in space - * - * @details This routine returns the Node you ask for, or the EndNode on - * the path to the requested node, and will never create or return GenNodes. - * Recursion starts at the appropriate rootNode and decends from this. - */ template const MWNode &MWTree::getNodeOrEndNode(Coord r, int depth) const { if (getRootBox().isPeriodic()) { periodic::coord_manipulation(r, getRootBox().getPeriodic()); } @@ -431,11 +269,6 @@ template const MWNode &MWTree::getNodeOrEndNode( return *root.retrieveNodeOrEndNode(r, depth); } -/** @brief Returns the list of all EndNodes - * - * @details copies the list of all EndNode pointers into a new vector - * and returns it. - */ template MWNodeVector *MWTree::copyEndNodeTable() { auto *nVec = new MWNodeVector; for (int n = 0; n < getNEndNodes(); n++) { @@ -445,12 +278,6 @@ template MWNodeVector *MWTree::copyEndNodeTable( return nVec; } -/** @brief Recreate the endNodeTable - * - * @details the endNodeTable is first deleted and then rebuilt from - * scratch. It makes use of the TreeIterator to traverse the tree. - * - */ template void MWTree::resetEndNodeTable() { clearEndNodeTable(); TreeIterator it(*this, TopDown, Hilbert); @@ -514,8 +341,6 @@ template int MWTree::countAllocNodes(int depth) { // return count; } -/** @brief Prints a summary of the tree structure on the output file - */ template std::ostream &MWTree::print(std::ostream &o) const { o << " square norm: " << this->squareNorm << std::endl; o << " root scale: " << this->getRootScale() << std::endl; @@ -528,11 +353,6 @@ template std::ostream &MWTree::print(std::ostream &o) return o; } -/** @brief sets values for maxSquareNorm in all nodes - * - * @details it defines the upper bound of the squared norm \f$ - * ||f||^2_{\ldots} \f$ in this node or its descendents - */ template void MWTree::makeMaxSquareNorms() { NodeBox &rBox = this->getRootBox(); MWNode **roots = rBox.getNodes(); @@ -542,11 +362,6 @@ template void MWTree::makeMaxSquareNorms() { } } -/** @brief gives serialIx of a node from its NodeIndex - * - * @details gives a unique integer for each nodes corresponding to the position - * of the node in the serialized representation - */ template int MWTree::getIx(NodeIndex nIdx) { if (this->isLocal == false) MSG_ERROR("getIx only implemented in local representation"); if (NodeIndex2serialIx.count(nIdx) == 0) diff --git a/src/trees/MWTree.h b/src/trees/MWTree.h index b0261aca6..aae6ded6e 100644 --- a/src/trees/MWTree.h +++ b/src/trees/MWTree.h @@ -40,7 +40,10 @@ namespace mrcpp { class BankAccount; -/** @class MWTree +/** + * @class MWTree + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) * * @brief Base class for Multiwavelet tree structures, such as FunctionTree and OperatorTree * @@ -59,95 +62,332 @@ class BankAccount; * will generate the required node on the fly using the MW transform; * some methods will return an empty pointer if the node is not * present. See specific methods for details. - * */ template class MWTree { public: + /** + * @brief MWTree constructor. + * + * @param[in] mra The multiresolution analysis object + * @param[in] n The name of the tree (only for printing purposes) + * + * @details Creates an empty tree object, containing only the set of + * root nodes. The information for the root node configuration to use + * is in the mra object which is passed to the constructor. + */ MWTree(const MultiResolutionAnalysis &mra, const std::string &n); + MWTree(const MWTree &tree) = delete; MWTree &operator=(const MWTree &tree) = delete; + + /// @brief MWTree destructor. virtual ~MWTree(); + /** + * @brief Set the MW coefficients to zero, keeping the same tree structure + * + * @details Keeps the node structure of the tree, even though the zero + * function is representable at depth zero. One should then use \ref cropTree to remove + * unnecessary nodes. + */ void setZero(); + + /** @brief Remove all nodes in the tree + * + * @details Leaves the tree in the same state as after construction, + * i.e. undefined tree structure containing only root nodes without + * coefficients. The assigned memory, including branch and leaf + * nodes, (nodeChunks in NodeAllocator) is NOT released, but is + * immediately available to the new function. + */ void clear(); - /** @returns Squared L2 norm of the function */ - double getSquareNorm() const { return this->squareNorm; } + double getSquareNorm() const { return this->squareNorm; } ///< @return The squared L2 norm of the function + + /** @brief Calculate the squared norm \f$ ||f||^2_{\ldots} \f$ of a function represented as a tree. + * + * @details The norm is calculated using endNodes only. The specific + * type of norm which is computed will depend on the derived class. + */ void calcSquareNorm(bool deep = false); - void clearSquareNorm() { this->squareNorm = -1.0; } - - int getOrder() const { return this->order; } - int getKp1() const { return this->order + 1; } - int getKp1_d() const { return this->kp1_d; } - int getDim() const { return D; } - int getTDim() const { return (1 << D); } - /** @returns the total number of nodes in the tree */ - int getNNodes() const { return getNodeAllocator().getNNodes(); } - int getNNegScales() const { return this->nodesAtNegativeDepth.size(); } - int getRootScale() const { return this->rootBox.getScale(); } - int getDepth() const { return this->nodesAtDepth.size(); } - int getNNodesAtDepth(int i) const; - int getSizeNodes() const; - /** @returns */ - NodeBox &getRootBox() { return this->rootBox; } - const NodeBox &getRootBox() const { return this->rootBox; } - const MultiResolutionAnalysis &getMRA() const { return this->MRA; } + void clearSquareNorm() { this->squareNorm = -1.0; } //< @brief Mark the norm as undefined (sets it to -1) + + int getOrder() const { return this->order; } ///< @return Polynomial order k + int getKp1() const { return this->order + 1; } ///< @return k+1 + int getKp1_d() const { return this->kp1_d; } ///< @return (k+1)^D + int getDim() const { return D; } ///< @return The spatial dimension D + int getTDim() const { return (1 << D); } ///< @return 2^D (number of children per internal node) + int getNNodes() const { return getNodeAllocator().getNNodes(); } ///< @return The total number of nodes in this tree + int getNNegScales() const { return this->nodesAtNegativeDepth.size(); } ///< @return The number of negative scales in this tree + int getRootScale() const { return this->rootBox.getScale(); } ///< @return The root scale of this tree + int getDepth() const { return this->nodesAtDepth.size(); } ///< @return The maximum depth of this tree + int getSizeNodes() const; ///< @return The size of all MW coefficients in the tree (in kB) + /** + * @brief Returns the total number of nodes in the tree, at given depth (not in use) + * @param i Tree depth (0 depth is the coarsest scale) to count + * @return Number of nodes at depth i + */ + int getNNodesAtDepth(int i) const; + NodeBox &getRootBox() { return this->rootBox; } ///< @return The container of nodes + const NodeBox &getRootBox() const { return this->rootBox; } ///< @return The container of nodes + const MultiResolutionAnalysis &getMRA() const { return this->MRA; } ///< @return The MRA object used by this tree + + /** + * @brief Full Multiwavelet transform of the tree in either directions + * + * @param type TopDown (from roots to leaves) or BottomUp (from + * leaves to roots) which specifies the direction of the MW transform + * @param overwrite If true, the result will overwrite preexisting coefficients. + * + * @details It performs a Multiwavlet transform of the whole tree. The + * input parameters will specify the direction (upwards or downwards) + * and whether the result is added to the coefficients or it + * overwrites them. See the documentation for the #mwTransformUp + * and #mwTransformDown for details. + * \f[ + * \pmatrix{ + * s_{nl}\\ + * d_{nl} + * } + * \rightleftarrows \pmatrix{ + * s_{n+1,2l}\\ + * s_{n+1,2l+1} + * } + * \f] + */ void mwTransform(int type, bool overwrite = true); + /** + * @brief Set the name of the tree + * @param n The new name + */ void setName(const std::string &n) { this->name = n; } - const std::string &getName() const { return this->name; } + const std::string &getName() const { return this->name; } ///< @return The name of the tree + /** + * @param r Spatial coordinates + * @return The index of the root box containng r + */ int getRootIndex(Coord r) const { return this->rootBox.getBoxIndex(r); } + /** + * @param nIdx Index of a node + * @return The index of the root box containng nIdx + */ int getRootIndex(NodeIndex nIdx) const { return this->rootBox.getBoxIndex(nIdx); } + /** + * @brief Finds and returns the node pointer with the given NodeIndex + * @param nIdx The NodeIndex to search for + * + * @details Recursive routine to find and return the node with a given + * NodeIndex. This routine returns the appropriate Node, or a NULL + * pointer if the node does not exist, or if it is a + * GenNode. Recursion starts at the appropriate rootNode. + * + * @return Pointer to the required node. + */ MWNode *findNode(NodeIndex nIdx); + /** + * @brief Finds and returns the node pointer with the given NodeIndex + * @param nIdx The NodeIndex to search for + * + * @details Recursive routine to find and return the node with a given + * NodeIndex. This routine returns the appropriate Node, or a NULL + * pointer if the node does not exist, or if it is a + * GenNode. Recursion starts at the appropriate rootNode. + * + * @return Pointer to the required node. + */ const MWNode *findNode(NodeIndex nIdx) const; + /** + * @brief Finds and returns the node reference with the given NodeIndex. + * @param nIdx The NodeIndex to search for + * @param create If true, previously non-existing nodes will be stored permanently in the tree + * + * @details This routine ALWAYS returns the node you ask for. If the + * node does not exist, it will be generated by MW + * transform. Recursion starts at the appropriate rootNode and descends + * from this. + * + * @return Reference to the required node. + * @note The nodes are permanently added to the tree if create = true. + */ MWNode &getNode(NodeIndex nIdx, bool create = false); + + /** + * @brief Finds and returns the node (or EndNode) for the given NodeIndex. + * @param nIdx The NodeIndex to search for + * + * @details This routine returns the Node you ask for, or the EndNode + * on the path to the requested node, if the requested one is deeper + * than the leaf node ancestor. It will never create or return + * GenNodes. Recursion starts at the appropriate rootNode and decends + * from this. + * + * @return Reference to the required node or EndNode. + */ MWNode &getNodeOrEndNode(NodeIndex nIdx); + /** + * @brief Finds and returns the node (or EndNode) for the given NodeIndex. + * @param nIdx The NodeIndex to search for + * + * @details This routine returns the Node you ask for, or the EndNode + * on the path to the requested node, if the requested one is deeper + * than the leaf node ancestor. It will never create or return + * GenNodes. Recursion starts at the appropriate rootNode and decends + * from this. + * + * @return Reference to the required node or EndNode. + */ const MWNode &getNodeOrEndNode(NodeIndex nIdx) const; + /** + * @brief Finds and returns the node at a given depth that contains a given coordinate. + * + * @param r Coordinates of an arbitrary point in space + * @param depth Requested node depth from root scale + * + * @details This routine ALWAYS returns the node you ask for, and will + * generate nodes that do not exist. Recursion starts at the + * appropriate rootNode and decends from this. + * + * @return Reference to the required node. + */ MWNode &getNode(Coord r, int depth = -1); + + /** + * @brief Finds and returns the node at a given depth that contains a given coordinate. + * + * @param r Coordinates of an arbitrary point in space + * @param depth Requested node depth from root scale. + * + * @details This routine returns the Node you ask for, or the EndNode on + * the path to the requested node, and will never create or return GenNodes. + * Recursion starts at the appropriate rootNode and decends from this. + * + * @return Reference to the required node or EndNode. + */ MWNode &getNodeOrEndNode(Coord r, int depth = -1); + /** + * @brief Finds and returns the node at a given depth that contains a given coordinate. + * + * @param r Coordinates of an arbitrary point in space + * @param depth Requested node depth from root scale. + * + * @details This routine returns the Node you ask for, or the EndNode on + * the path to the requested node, and will never create or return GenNodes. + * Recursion starts at the appropriate rootNode and decends from this. + * + * @return Reference to the required node or EndNode. + */ const MWNode &getNodeOrEndNode(Coord r, int depth = -1) const; - int getNEndNodes() const { return this->endNodeTable.size(); } - int getNRootNodes() const { return this->rootBox.size(); } + int getNEndNodes() const { return this->endNodeTable.size(); } ///< @return The number of end nodes + int getNRootNodes() const { return this->rootBox.size(); } ///< @return The number of root nodes + + /** + * @param i Index of the end node + * @return Reference to the i-th end node + */ MWNode &getEndMWNode(int i) { return *this->endNodeTable[i]; } + /** + * @param i Index of the root node + * @return Reference to the i-th root node + */ MWNode &getRootMWNode(int i) { return this->rootBox.getNode(i); } + /** + * @param i Index of the end node + * @return Reference to the i-th end node + */ const MWNode &getEndMWNode(int i) const { return *this->endNodeTable[i]; } + /** + * @param i Index of the root node + * @return Reference to the i-th root node + */ const MWNode &getRootMWNode(int i) const { return this->rootBox.getNode(i); } - bool isPeriodic() const { return this->MRA.getWorldBox().isPeriodic(); } + bool isPeriodic() const { return this->MRA.getWorldBox().isPeriodic(); } ///< @return Whether the world is periodic + /** + * @brief Returns the list of all EndNodes + * @details Copies the list of all EndNode pointers into a new vector and returns it. + * @return The copied end-node table. + */ MWNodeVector *copyEndNodeTable(); - MWNodeVector *getEndNodeTable() { return &this->endNodeTable; } - + MWNodeVector *getEndNodeTable() { return &this->endNodeTable; } ///< @return The end-node table + + /** + * @brief Deletes all the nodes in the tree + * @details This method will recursively delete all the nodes, + * including the root nodes. Derived classes will call this method + * when the object is deleted. + */ void deleteRootNodes(); + /** + * @brief Recreate the endNodeTable + * @details the endNodeTable is first deleted and then rebuilt from + * scratch. It makes use of the TreeIterator to traverse the tree. + */ void resetEndNodeTable(); + /// @brief Clear the end-node table void clearEndNodeTable() { this->endNodeTable.clear(); } + //// @warning This method is currently not implemented. int countBranchNodes(int depth = -1); + //// @warning This method is currently not implemented. int countLeafNodes(int depth = -1); + //// @warning This method is currently not implemented. int countAllocNodes(int depth = -1); + //// @warning This method is currently not implemented. int countNodes(int depth = -1); - bool isLocal = false; // to know whether the tree coeffcients are stored in the Bank - int getIx(NodeIndex nIdx); // gives serialIx of a stored node from its NodeIndex if isLocal - - void makeMaxSquareNorms(); // sets values for maxSquareNorm and maxWSquareNorm in all nodes - - NodeAllocator &getNodeAllocator() { return *this->nodeAllocator_p; } - const NodeAllocator &getNodeAllocator() const { return *this->nodeAllocator_p; } - MWNodeVector endNodeTable; ///< Final projected nodes - - void getNodeCoeff(NodeIndex nIdx, T *data); // fetch coefficient from a specific node stored in Bank - bool conjugate() const { return this->conj; } - void setConjugate(bool conjug) { this->conj = conjug; } - + /// @brief Whether the tree coefficients are stored in the Bank + bool isLocal = false; + + /** + * @brief Gives serialIx of a node from its NodeIndex + * @param nIdx The NodeIndex of the node + * + * @details Gives a unique integer for each nodes corresponding to the position + * of the node in the serialized representation. Only works if isLocal == true. + * + * @return The serial index of the node. + */ + int getIx(NodeIndex nIdx); + + /** + * @brief Sets values for maxSquareNorm and maxWSquaredNorm in all nodes + * @details It defines the upper bound of the squared norm \f$ + * ||f||^2_{\ldots} \f$ in this node or its descendents. + */ + void makeMaxSquareNorms(); + + NodeAllocator &getNodeAllocator() { return *this->nodeAllocator_p; } ///< @return Reference to the node allocator. + const NodeAllocator &getNodeAllocator() const { return *this->nodeAllocator_p; } ///< @return Reference to the node allocator. + + MWNodeVector endNodeTable; ///< @brief Final projected nodes + + /** + * @brief Fetch coefficients of a specific node stored in Bank + * @param nIdx Node index + * @param[out] data The node coefficients are copied into this array + */ + void getNodeCoeff(NodeIndex nIdx, T *data); + + bool conjugate() const { return this->conj; } ///< @return Whether the tree is conjugated + void setConjugate(bool conjug) { this->conj = conjug; } ///< @param conjug Set whether the tree is conjugated + + /** + * @brief Print a formatted summary of the tree + * @param o The output stream + * @param tree The tree to print + * @return The output stream + */ friend std::ostream &operator<<(std::ostream &o, const MWTree &tree) { return tree.print(o); } + // Friend classes friend class MWNode; friend class FunctionNode; friend class OperatorNode; @@ -156,34 +396,76 @@ template class MWTree { protected: // Parameters that are set in construction and should never change - const MultiResolutionAnalysis MRA; + const MultiResolutionAnalysis MRA; ///< Domain and basis // Constant parameters that are derived internally - const int order; - const int kp1_d; + const int order; ///< Polynomial order k + const int kp1_d; ///< (k+1)^D - std::map, int> NodeIndex2serialIx; // to store nodes serialIx + std::map, int> NodeIndex2serialIx; ///< To store nodes serialIx // Parameters that are dynamic and can be set by user - std::string name; - - std::unique_ptr> nodeAllocator_p{nullptr}; + std::string name; ///< Name of this tree + std::unique_ptr> nodeAllocator_p{nullptr}; ///< Node allocator // Tree data - double squareNorm; - NodeBox rootBox; ///< The actual container of nodes - std::vector nodesAtDepth; ///< Node counter - std::vector nodesAtNegativeDepth; ///< Node counter - + double squareNorm; ///< Global squared L2 norm (-1 if undefined). + NodeBox rootBox; ///< The actual container of nodes + std::vector nodesAtDepth; ///< Node counter + std::vector nodesAtNegativeDepth; ///< Node counter + + /** + * @brief Regenerates all scaling coeffs by MW transformation of existing s/w-coeffs + * on coarser scales + * @param overwrite If true, the preexisting coefficients are overwritten + * + * @details The transformation starts at the rootNodes and proceeds + * recursively all the way to the leaf nodes. The existing scaling + * coefficeints will either be overwritten or added to. The latter + * operation is generally used after the operator application. + */ virtual void mwTransformDown(bool overwrite); + + /** + * @brief Regenerates all s/d-coeffs by backtransformation + * + * @details It starts at the bottom of the tree (scaling coefficients + * of the leaf nodes) and it generates the scaling and wavelet + * coefficients of the parent node. It then proceeds recursively all the + * way up to the root nodes. This is generally used after a function + * projection to purify the coefficients obtained by quadrature at + * coarser scales which are therefore not precise enough. + */ virtual void mwTransformUp(); + /** + * @brief Increments node counter by one for non-GenNodes + * @param scale Scale of the node + * @warning: This routine is not thread safe, and must NEVER be called + * outside a critical region in parallel. It's way, way too expensive to + * lock the tree, so don't even think about it. + */ void incrementNodeCount(int scale); + + /** + * @brief Decrements node counter by one for non-GenNodes + * @param scale Scale of the node + * @warning: This routine is not thread safe, and must NEVER be called + * outside a critical region in parallel. It's way, way too expensive to + * lock the tree, so don't even think about it. + */ void decrementNodeCount(int scale); - BankAccount *NodesCoeff = nullptr; - bool conj{false}; + BankAccount *NodesCoeff = nullptr; ///< Bank account for node coefficients + bool conj{false}; ///< Whether the tree is conjugated + + /** + * @brief Prints a summary of the tree structure on the output file + * @param o The output stream + * @return The formatted output stream + */ virtual std::ostream &print(std::ostream &o) const; }; -} // namespace mrcpp + +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/MultiResolutionAnalysis.cpp b/src/trees/MultiResolutionAnalysis.cpp index 43b39c32d..4378ce01d 100644 --- a/src/trees/MultiResolutionAnalysis.cpp +++ b/src/trees/MultiResolutionAnalysis.cpp @@ -32,22 +32,6 @@ namespace mrcpp { -/** @returns New MultiResolutionAnalysis (MRA) object - * - * @brief Constructs a MultiResolutionAnalysis object composed of computational domain (world) and a polynomial basis (Multiwavelets) - * - * @param[in] bb: 2-element integer array [Lower, Upper] defining the bounds for a BoundingBox object representing the computational domain - * @param[in] order: Maximum polynomial order of the multiwavelet basis, - * immediately used in the constructor of an InterPolatingBasis object which becomes an attribute of the MRA - * @param[in] maxDepth: Exponent of the node refinement in base 2, relative to root scale. - * In other words, it is the maximum amount of refinement that we allow in a node, in other to avoid overflow of values. - * - * @details Constructor of the MultiResolutionAnalysis class from scratch, without requiring any pre-existing complex structure. - * The constructor calls the InterpolatingBasis basis constructor to generate the MultiWavelets basis of functions, - * then the BoundingBox constructor to create the computational domain. The constructor then checks if the generated node depth, or - * node refinement is beyond the root scale or the maximum depth allowed, in which case it will abort the process. - * Otherwise, the process goes on to setup the filters with the class' setupFilter method. - */ template MultiResolutionAnalysis::MultiResolutionAnalysis(std::array bb, int order, int depth) : maxDepth(depth) @@ -58,18 +42,6 @@ MultiResolutionAnalysis::MultiResolutionAnalysis(std::array bb, int o setupFilter(); } -/** @returns New MultiResolutionAnalysis (MRA) object - * - * @brief Constructs a MultiResolutionAnalysis object composed of computational domain (world) and a polynomial basis (Multiwavelets) from a pre-existing BoundingBox object - * - * @param[in] bb: BoundingBox object representing the computational domain - * @param[in] order: (integer) Maximum polynomial order of the multiwavelet basis, - * immediately used in the constructor of an InterPolatingBasis object which becomes an attribute of the MRA - * @param[in] maxDepth: (integer) Exponent of the node refinement in base 2, relative to root scale. - * In other words, it is the maximum amount of refinement that we allow in a node, in other to avoid overflow of values. - * - * @details Constructor of the MultiResolutionAnalysis class from a BoundingBox object. For more details see the first constructor. - */ template MultiResolutionAnalysis::MultiResolutionAnalysis(const BoundingBox &bb, int order, int depth) : maxDepth(depth) @@ -80,14 +52,6 @@ MultiResolutionAnalysis::MultiResolutionAnalysis(const BoundingBox &bb, in setupFilter(); } -/** @returns New MultiResolutionAnalysis (MRA) object - * - * @brief Copy constructor for a MultiResolutionAnalysis object composed of computational domain (world) and a polynomial basis (Multiwavelets) - * - * @param[in] mra: Pre-existing MRA object - * - * @details Copy a MultiResolutionAnalysis object without modifying the original. For more details see the first constructor. - */ template MultiResolutionAnalysis::MultiResolutionAnalysis(const MultiResolutionAnalysis &mra) : maxDepth(mra.maxDepth) @@ -98,17 +62,6 @@ MultiResolutionAnalysis::MultiResolutionAnalysis(const MultiResolutionAnalysi setupFilter(); } -/** @returns New MultiResolutionAnalysis object - * - * @brief Constructor for a MultiResolutionAnalysis object from a pre-existing BoundingBox (computational domain) and a ScalingBasis (Multiwavelet basis) objects - * - * @param[in] bb: Computational domain as a BoundingBox object, taken by constant reference - * @param[in] sb: Polynomial basis (MW) as a ScalingBasis object - * @param[in] depth: Maximum allowed resolution depth, relative to root scale - * - * @details Creates a MRA object from pre-existing BoundingBox and ScalingBasis objects. These objects are taken as reference. For more details about the constructor itself, see the first - * constructor. - */ template MultiResolutionAnalysis::MultiResolutionAnalysis(const BoundingBox &bb, const ScalingBasis &sb, int depth) : maxDepth(depth) @@ -119,16 +72,6 @@ MultiResolutionAnalysis::MultiResolutionAnalysis(const BoundingBox &bb, co setupFilter(); } -/** @returns Whether the two MRA objects are equal. - * - * @brief Equality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis, computational domain and maximum depth, and false otherwise - * - * @param[in] mra: MRA object, taken by constant reference - * - * @details Equality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis represented by a BoundingBox object, computational domain (ScalingBasis - * object) and maximum depth (integer), and false otherwise. Computations on different MRA cannot be combined, this operator can be used to make sure that the multiple MRAs are compatible. For more - * information about the meaning of equality for BoundingBox and ScalingBasis objets, see their respective classes. - */ template bool MultiResolutionAnalysis::operator==(const MultiResolutionAnalysis &mra) const { if (this->basis != mra.basis) return false; if (this->world != mra.world) return false; @@ -136,16 +79,6 @@ template bool MultiResolutionAnalysis::operator==(const MultiResoluti return true; } -/** @returns Whether the two MRA objects are not equal. - * - * @brief Inequality operator for the MultiResolutionAnalysis class, returns false if both MRAs have the same polynomial basis, computational domain and maximum depth, and true otherwise - * - * @param[in] mra: MRA object, taken by constant reference - * - * @details Inequality operator for the MultiResolutionAnalysis class, returns true if both MRAs have the same polynomial basis represented by a BoundingBox object, computational domain (ScalingBasis - * object) and maximum depth (integer), and false otherwise. Opposite of the == operator. For more information about the meaning of equality for BoundingBox and ScalingBasis objets, see their - * respective classes. - */ template bool MultiResolutionAnalysis::operator!=(const MultiResolutionAnalysis &mra) const { if (this->basis != mra.basis) return true; if (this->world != mra.world) @@ -157,14 +90,6 @@ template bool MultiResolutionAnalysis::operator!=(const MultiResoluti return false; } -/** - * - * @brief Displays the MRA's attributes in the outstream defined in the Printer class - * - * @details This function displays the attributes of the MRA in the using the Printer class. - * By default, the Printer class writes all information in the output file, not the terminal. - * - */ template void MultiResolutionAnalysis::print() const { print::separator(0, ' '); print::header(0, "MultiResolution Analysis"); @@ -174,15 +99,6 @@ template void MultiResolutionAnalysis::print() const { print::separator(0, '=', 2); } -/** - * - * @brief Initializes the MW filters for the given MW basis. - * - * @details By calling the get() function for the appropriate MW basis, the global - * FilterCache Singleton object is initialized. Any subsequent reference to this - * particular filter will point to the same unique global object. - * - */ template void MultiResolutionAnalysis::setupFilter() { getLegendreFilterCache(lfilters); getInterpolatingFilterCache(ifilters); @@ -200,11 +116,6 @@ template void MultiResolutionAnalysis::setupFilter() { } } -/** @returns Maximum possible distance between two points in the MRA domain - * - * @brief Computes the difference between the lower and upper bounds of the computational domain - * - */ template double MultiResolutionAnalysis::calcMaxDistance() const { const Coord &lb = getWorldBox().getLowerBounds(); const Coord &ub = getWorldBox().getUpperBounds(); diff --git a/src/trees/MultiResolutionAnalysis.h b/src/trees/MultiResolutionAnalysis.h index 00135d822..c27dc8c42 100644 --- a/src/trees/MultiResolutionAnalysis.h +++ b/src/trees/MultiResolutionAnalysis.h @@ -33,52 +33,131 @@ namespace mrcpp { -/** @class MultiResolutionAnalysis +/** + * @class MultiResolutionAnalysis + * @tparam D Spatial dimension (1, 2, or 3) * - * @brief Class collecting computational domain and MW basis + * @brief Class for MultiResolutionAnalysis templates * - * @details In order to combine different functions and operators in - * mathematical operations, they need to be compatible. That is, they must - * be defined on the same computational domain and constructed using the same - * polynomial basis (order and type). This information constitutes an MRA, - * which needs to be defined and passed as argument to all function and - * operator constructors, and only functions and operators with compatible - * MRAs can be combined in subsequent calculations. + * @details + * The MultiResolutionAnalysis (MRA) objects bundles information that must be shared for + * compatible functions and operators: + * - Computational domain (see @ref BoundingBox) + * - MultiResolution scaling basis, as a polynomial order (see @ref ScalingBasis) + * - Maximum refinement depth, relative to the world’s root scale (= @ref maxDepth by default) + * + * Class also contains useful functions to compare MRA objects, + * find max and min box sizes and print a human readable diagnostic for the MRA. */ - -template class MultiResolutionAnalysis final { +template +class MultiResolutionAnalysis final { public: + /** + * @brief Construct from a symmetric domain and a basis order + * + * @param[in] bb 2-element integer array defining domain bounds + * @param[in] order Polynomial order of the multiwavelet basis + * @param[in] depth Maximum refinement depth (relative to root scale). Default is \ref MaxDepth + * + * @details + * Constructor of the MultiResolutionAnalysis class from scratch. + * The scaling basis type is chosen by MRCPP defaults for the given @p order. + * The root scale is inferred from @p bb to keep the per-dimension scaling factor in (1, 2). + */ MultiResolutionAnalysis(std::array bb, int order, int depth = MaxDepth); + + /** + * @brief Constructs MultiResolutionAnalysis object from a pre-existing @ref BoundingBox object + * + * @param[in] bb BoundingBox object representing the computational domain + * @param[in] order Polynomial order of the multiwavelet basis + * @param[in] depth Maximum refinement depth (relative to root scale). Default is \ref MaxDepth + * + * @details + * Creates a MRA object from pre-existing BoundingBox, @p bb, object with a polynomial, @ref p, order to set the basis + * and the maximum amount of allowed refinement in a node, @p depth. + */ MultiResolutionAnalysis(const BoundingBox &bb, int order, int depth = MaxDepth); + + /** + * @brief Construct from a @ref BoundingBox and a fully specified @ref ScalingBasis. + * + * @param[in] bb BoundingBox object representing the computational domain + * @param[in] sb Polynomial basis (MW) as a ScalingBasis object + * @param[in] depth Maximum refinement depth (relative to root scale). Default is \ref MaxDepth + + * @details + * Creates a MRA object from pre-existing BoundingBox, @p bb, and ScalingBasis, @p sb, objects + * and the maximum amount of allowed refinement in a node, @p depth. + */ MultiResolutionAnalysis(const BoundingBox &bb, const ScalingBasis &sb, int depth = MaxDepth); + + /** + * @brief Copy constructor for a MultiResolutionAnalysis object composed of computational domain (world) and a polynomial basis (Multiwavelets) + * @param[in] mra Pre-existing MRA object + * @details Copy a MultiResolutionAnalysis object without modifying the original + */ MultiResolutionAnalysis(const MultiResolutionAnalysis &mra); + + /** @brief Deleted assignment (MRAs are intended to be immutable after construction). */ MultiResolutionAnalysis &operator=(const MultiResolutionAnalysis &mra) = delete; - int getOrder() const { return this->basis.getScalingOrder(); } - int getMaxDepth() const { return this->maxDepth; } - int getMaxScale() const { return this->world.getScale() + this->maxDepth; } + /* + * Getters + */ - const MWFilter &getFilter() const { return *this->filter; } - const ScalingBasis &getScalingBasis() const { return this->basis; } - const BoundingBox &getWorldBox() const { return this->world; } + int getOrder() const { return this->basis.getScalingOrder(); } ///< @return Polynomial order of the scaling basis + int getMaxDepth() const { return this->maxDepth; } ///< @return Maximum refinement depth relative to the world’s root scale + int getMaxScale() const { return this->world.getScale() + this->maxDepth; } ///< @return Sum of world root scale and maximum refinement depth, @ref getMaxDepth + int getRootScale() const { return this->world.getScale(); } ///< @return World root scale + const MWFilter &getFilter() const { return *this->filter; } ///< @return Low-level filter associated with the current basis + const ScalingBasis &getScalingBasis() const { return this->basis; } ///< @return Scaling basis type and order + const BoundingBox &getWorldBox() const { return this->world; } ///< @return Computational domain (world box) + + /** + * @brief Convenience: compute a minimal length scale from a tolerance + * @param[in] epsilon Target tolerance + * @return A distance proportional to \f$\sqrt{\epsilon\,2^{-\mathrm{maxScale}}}\f$. + */ double calcMinDistance(double epsilon) const { return std::sqrt(epsilon * std::pow(2.0, -getMaxScale())); } - double calcMaxDistance() const; - int getRootScale() const { return this->world.getScale(); } + /** + * @brief Convenience: compute a maximal relevant distance + * @return Maximum distance of computational (world) domain + * @note The exact definition is basis-dependent + */ + double calcMaxDistance() const; + /** + * @brief Equality operator for the MultiResolutionAnalysis class (basis, domain, depth) + * + * @param[in] mra: MRA object, taken by constant reference + * @returns True if both MRAs have the same polynomial basis, computational domain and maximum depth + * + * @note Two MRAs must be equal to allow mixing functions/operators + */ bool operator==(const MultiResolutionAnalysis &mra) const; + + /** + * @brief Inequality operator for the MultiResolutionAnalysis class (basis, domain, depth) + * @param[in] mra: MRA object, taken by constant reference + * @returns True if MRAs have different polynomial basis, computational domain or maximum depth + */ bool operator!=(const MultiResolutionAnalysis &mra) const; - void print() const; + void print() const; ///< @brief Displays human-readable diagnostics of MRA to outputfile protected: - const int maxDepth; - const ScalingBasis basis; - const BoundingBox world; - MWFilter *filter; + const int maxDepth; ///< Maximum refinement depth permitted by this MRA + const ScalingBasis basis; ///< Scaling basis (type and polynomial order) + const BoundingBox world; ///< Computational domain description + MWFilter *filter; ///< Low-level filter derived from @ref basis + /** + * @brief Internal helper to instantiate @ref filter based on @ref basis + */ void setupFilter(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/NodeAllocator.cpp b/src/trees/NodeAllocator.cpp index cd3cf9fcc..7ba24fe17 100644 --- a/src/trees/NodeAllocator.cpp +++ b/src/trees/NodeAllocator.cpp @@ -215,7 +215,6 @@ template void NodeAllocator::appendChunk(bool coefs) { std::fill(this->stackStatus.begin() + oldsize, this->stackStatus.end(), 0); } -/** Fill all holes in the chunks with occupied nodes, then remove all empty chunks */ template int NodeAllocator::compress() { MRCPP_SET_OMP_LOCK(); int nNodes = (1 << D); @@ -358,7 +357,7 @@ template int NodeAllocator::findNextOccupied(int sIdx) return sIdx; } -/** Traverse tree and redefine pointer, counter and tables. */ + template void NodeAllocator::reassemble() { MRCPP_SET_OMP_LOCK(); this->nNodes = 0; diff --git a/src/trees/NodeAllocator.h b/src/trees/NodeAllocator.h index 7e33b7e21..a4b8f52b0 100644 --- a/src/trees/NodeAllocator.h +++ b/src/trees/NodeAllocator.h @@ -23,14 +23,6 @@ * */ -/** - * - * \date July, 2016 - * \author Peter Wind \n - * CTCC, University of Tromsø - * - */ - #pragma once #include @@ -40,71 +32,227 @@ namespace mrcpp { +/** + * @class NodeAllocator + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @brief Chunked memory manager for @ref MWNode objects and their coefficients + * + * @details + * Nodes and their coefficient arrays are organized in **chunks** to reduce + * allocation overhead and improve spatial locality. Indices into this storage + * are referred to as *serial indices* (`sIdx`), which are stable within a + * given tree instance until compaction or reassembly occurs. + * + * ### Thread-safety + * When MRCPP is built with OpenMP support (`MRCPP_HAS_OMP`), critical regions + * in the allocator use locks to avoid races during allocation and pointer + * retrieval. Callers are still responsible for higher-level synchronization + * of tree edits. + */ template class NodeAllocator final { public: - NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + /** + * @brief Construct an allocator bound to a function tree + * @param[in] tree Owning @ref FunctionTree instance + * @param[in] mem Optional shared-memory provider for coefficients (may be `nullptr`) + * @param[in] coefsPerNode Number of coefficients per node + * @param[in] nodesPerChunk Maximum number of nodes per chunk + * + * @details Reserves space for chunk pointers to avoid excessive reallocation, + * but does not allocate any chunks until needed. + */ NodeAllocator(FunctionTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + + + /** + * @brief Construct an allocator bound to an operator tree + * @param[in] tree Owning @ref OperatorTree instance + * @param[in] mem Optional shared-memory provider for coefficients (may be `nullptr`) + * @param[in] coefsPerNode Number of coefficients per node + * @param[in] nodesPerChunk Maximum number of nodes per chunk + * + * @details Reserves space for chunk pointers to avoid excessive reallocation, + * but does not allocate any chunks until needed. + */ + NodeAllocator(OperatorTree *tree, SharedMemory *mem, int coefsPerNode, int nodesPerChunk); + + + /// Non-copyable. NodeAllocator(const NodeAllocator &tree) = delete; + /// Non-assignable. NodeAllocator &operator=(const NodeAllocator &tree) = delete; + + /// Destructor; releases all owned chunks (nodes and coefficients). ~NodeAllocator(); + + /** + * @brief Get pointer to a node object by serial index + * @param[in] sIdx Serial index of the node + * @return Pointer to the @ref MWNode instance. + */ + MWNode *getNode_p(int sIdx); + + + /** + * @brief Get pointer to the coefficient array for a node + * @param[in] sIdx Serial index of the node + * @return Pointer to 'T[coefsPerNode]' or 'nullptr' if unavailable. + */ + T *getCoef_p(int sIdx); + + + + /** + * @brief Allocate a consecutive block of nodes + * @param[in] nNodes Number of nodes to allocate + * @param[in] coefs If 'true', also ensure coefficient storage is available + * @return Serial index ('sIdx') of the first newly allocated node (the top of the stack) + * + * @details Allocates a block of @p nNodes consecutive nodes, returning + * the serial index of the first node in the block. If `coefs` is true, + * coefficient arrays are also allocated for each node. If there is not + * enough space in existing chunks, new chunks are allocated to satisfy + * the request + * + * @warning Does not initialize the node objects; caller is responsible + * for placement-new or similar + * + * @warning If insufficient space is available, and allocation of new + * chunks fails, an exception is thrown and no nodes are allocated. + * + * @throw std::bad_alloc if memory allocation fails. + * @throw std::runtime_error if insufficient space is available after + * attempting to allocate new chunks. + * @note May grow the underlying chunk arrays if space is exhausted. + */ int alloc(int nNodes, bool coefs = true); + + /** + * @brief Deallocate a node at serial index + * @param[in] sIdx Serial index of the node to free + * @details Marks the node at serial index @p sIdx as free for future + * allocations. Does not destroy the node object or its coefficient array. + * It also updates the number of allocated nodes. + * + * @warning Does not shrink chunks; it only marks the slot as free. + * + * @throw std::out_of_range if @p sIdx is invalid. + */ void dealloc(int sIdx); + + /** + * @brief Deallocate coefficient arrays for all nodes + * @note Node objects remain allocated; only their coefficient buffers are freed. + */ void deallocAllCoeff(); + /** + * @brief Pre-allocate a number of chunks + * @param[in] nChunks Number of chunks to append + * @param[in] coefs If 'true', allocate coefficient chunks as well + * + * @details It reinitializes the allocator, allocating @p nChunks chunks + * (both nodes and coefficients, if @p coefs is true). It resized the + * stackStatus vectors with the new total capacity, and resets the + * allocation stack. + * + * @note This method clears any previously allocated nodes and + * their coefficient buffers. + * + * @throw If nChunks <= 0 + */ void init(int nChunks, bool coefs = true); + /** + * @brief Fill all holes in the chunks with occupied nodes, then remove all empty chunks + * @return Number of nodes deleted during compaction + * + * @details After compaction, serial indices may change internally; users + * should refresh any external mappings that depend on 'sIdx'. + */ int compress(); - void reassemble(); + + /** + * @brief Drop trailing unused chunks to release memory. + * @return Number of chunks deleted + * + * @details Scans chunks from the end towards the beginning, deleting any + * chunks that are completely unused. Stops when a chunk with at least + * one occupied node is found. + */ int deleteUnusedChunks(); - int getNNodes() const { return this->nNodes; } - int getNCoefs() const { return this->coefsPerNode; } - int getNChunks() const { return this->nodeChunks.size(); } - int getNChunksUsed() const { return (this->topStack + this->maxNodesPerChunk - 1) / this->maxNodesPerChunk; } - int getNodeChunkSize() const { return this->maxNodesPerChunk * this->sizeOfNode; } - int getCoefChunkSize() const { return this->maxNodesPerChunk * this->coefsPerNode * sizeof(T); } - int getMaxNodesPerChunk() const { return this->maxNodesPerChunk; } + /** + * @brief Traverse tree and redefine pointer, counter and tables + * @details Typically invoked after operations that reorder nodes without + * using @ref compress + */ + void reassemble(); + + + + + int getNNodes() const { return this->nNodes; } ///< @return Number of nodes currently in use (allocated and not freed). + int getNCoefs() const { return this->coefsPerNode; } ///< @return Number of coefficients per node. + int getNChunks() const { return this->nodeChunks.size(); } ///< @return Total number of allocated node chunks. + int getNChunksUsed() const { return (this->topStack + this->maxNodesPerChunk - 1) / this->maxNodesPerChunk; } ///< @return Number of chunks currently used by active nodes. + int getNodeChunkSize() const { return this->maxNodesPerChunk * this->sizeOfNode; } ///< @return Size in bytes of one node chunk (nodes only). + int getCoefChunkSize() const { return this->maxNodesPerChunk * this->coefsPerNode * sizeof(T); } ///< @return Size in bytes of one coefficient chunk. + int getMaxNodesPerChunk() const { return this->maxNodesPerChunk; } ///< @return Maximum number of nodes that fit in a single chunk. + - T *getCoef_p(int sIdx); - MWNode *getNode_p(int sIdx); + /// @return Pointer to the i-th coefficient chunk (contiguous block). T *getCoefChunk(int i) { return this->coefChunks[i]; } + /// @return Pointer to the i-th node chunk (contiguous block). MWNode *getNodeChunk(int i) { return this->nodeChunks[i]; } + /// Print allocator status (chunks, usage, sizes) to stdout. void print() const; protected: - int nNodes{0}; // number of nodes actually in use - int topStack{0}; // index of last node on stack - int sizeOfNode{0}; // sizeof(NodeType) - int coefsPerNode{0}; // number of coef for one node - int maxNodesPerChunk{0}; // max number of nodes per allocation + int nNodes{0}; ///< Number of nodes actually in use. + int topStack{0}; ///< Index of the next free slot (stack top). + int sizeOfNode{0}; ///< `sizeof(NodeType)` used in chunks. + int coefsPerNode{0}; ///< Number of coefficients per node. + int maxNodesPerChunk{0}; ///< Capacity (in nodes) of each chunk. - std::vector stackStatus{}; - std::vector coefChunks{}; - std::vector *> nodeChunks{}; + std::vector stackStatus{}; ///< Slot state (occupied/free). + std::vector coefChunks{}; ///< Coefficient chunk base pointers. + std::vector *> nodeChunks{};///< Node chunk base pointers. - char *cvptr{nullptr}; // pointer to virtual table - MWNode *last_p{nullptr}; // pointer just after the last active node, i.e. where to put next node - MWTree *tree_p{nullptr}; // pointer to external object - SharedMemory *shmem_p{nullptr}; // pointer to external object + char *cvptr{nullptr}; ///< Vtable cookie to initialize node objects. + MWNode *last_p{nullptr}; ///< Pointer just past the last active node. + MWTree *tree_p{nullptr}; ///< Back-pointer to owning tree. + SharedMemory *shmem_p{nullptr}; ///< Optional shared-memory backend. + /// @return Whether coefficients are backed by @ref SharedMemory. bool isShared() const { return (this->shmem_p != nullptr); } + /// @return Owning tree (non-const). MWTree &getTree() { return *this->tree_p; } + /// @return Shared-memory provider (non-const). SharedMemory &getMemory() { return *this->shmem_p; } + /// Internal: get coefficient pointer w/o locking (caller must synchronize). T *getCoefNoLock(int sIdx); + /// Internal: get node pointer w/o locking (caller must synchronize). MWNode *getNodeNoLock(int sIdx); + /// Move a block of nodes within chunks (used by @ref compress). void moveNodes(int nNodes, int srcIdx, int dstIdx); + /// Append a new chunk; if `coefs` is true, also append a coefficient chunk. void appendChunk(bool coefs); + /// Find next contiguous range of free slots starting at or after `sIdx`. int findNextAvailable(int sIdx, int nNodes) const; + /// Find next occupied slot at or after `sIdx`. int findNextOccupied(int sIdx) const; #ifdef MRCPP_HAS_OMP - omp_lock_t omp_lock; + omp_lock_t omp_lock; ///< OpenMP lock for critical sections. #endif }; diff --git a/src/trees/NodeBox.h b/src/trees/NodeBox.h index 7a7fc086e..ceb83853c 100644 --- a/src/trees/NodeBox.h +++ b/src/trees/NodeBox.h @@ -30,33 +30,101 @@ namespace mrcpp { +/** + * @class NodeBox + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @brief Bounding box with node-pointer storage + */ template class NodeBox final : public BoundingBox { public: + /** + * @brief Construct a NodeBox from a lower-corner index and number of boxes + * @param idx Lower-corner @ref NodeIndex at the world scale + * @param nb Number of boxes per dimension (defaults to all ones) + */ NodeBox(const NodeIndex &idx, const std::array &nb = {}); + + /** + * @brief Copy-construct from another NodeBox + * @param box Source NodeBox + */ NodeBox(const NodeBox &box); + + /** + * @brief Construct from a plain @ref BoundingBox + * @param box Geometric box to take as base + */ NodeBox(const BoundingBox &box); + NodeBox &operator=(const NodeBox &box) = delete; + + /// @brief Destructor, deletes all nodes ~NodeBox() override; + /** + * @brief Store a node pointer in index @p idx + * @param idx Linear box index in `[0, size())` + * @param node Address of the node pointer to store (double pointer) + */ void setNode(int idx, MWNode **node); + + /** + * @brief Clear the node pointer stored at index @p idx + * @param idx Linear box index in `[0, size())` + */ void clearNode(int idx) { this->nodes[idx] = nullptr; } + /** + * @brief Get the node stored at the given index @p idx + * @param idx Node index at the world scale + * @return Reference to the node + */ MWNode &getNode(NodeIndex idx); + /** + * @brief Get the node stored at the box containing coordinate @p r. + * @param r Coordinates of a point + * @return Reference to the node + */ MWNode &getNode(Coord r); + /** + * @brief Get the node stored at the given index @p i + * @param i Linear box index (default 0) + * @return Reference to the node + */ MWNode &getNode(int i = 0); + /** + * @brief Get the node stored at the given index @p idx + * @param idx Node index at the world scale + * @return Reference to the node + */ const MWNode &getNode(NodeIndex idx) const; + /** + * @brief Get the node stored at the box containing coordinate @p r. + * @param r Coordinates of a point + * @return Reference to the node + */ const MWNode &getNode(Coord r) const; + /** + * @brief Get the node stored at the given index @p i + * @param i Linear box index (default 0) + * @return Reference to the node + */ const MWNode &getNode(int i = 0) const; - int getNOccupied() const { return this->nOccupied; } - MWNode **getNodes() { return this->nodes; } + int getNOccupied() const { return this->nOccupied; } ///< @return The number of occupied node slots + MWNode **getNodes() { return this->nodes; } ///< @return The nodes stored in this box protected: - int nOccupied; ///< Number of non-zero pointers in box - MWNode **nodes; ///< Container of nodes + int nOccupied; ///< Number of non-null entries in @ref nodes. + MWNode **nodes; ///< Dense array of node pointers (size equals number of boxes). + /// @brief Allocate the node double pointers void allocNodePointers(); + + /// @brief Clear and delete all nodes void deleteNodes(); }; diff --git a/src/trees/NodeIndex.h b/src/trees/NodeIndex.h index f73ded001..7c7cf1081 100644 --- a/src/trees/NodeIndex.h +++ b/src/trees/NodeIndex.h @@ -23,46 +23,96 @@ * */ -/* - * \breif Simple storage class for scale and translation indexes. - * The usefulness of the class becomes evident when examining - * the parallel algorithms for projection & friends. - */ - #pragma once +#include #include #include namespace mrcpp { +/** + * @class NodeIndex + * @tparam D Spatial dimension (1, 2, or 3) + * @brief Storage class for scale and translation indexes + * + * @details + * A NodeIndex encodes the position of a node in a multiresolution tree by: + * - N: Scale of node stored as a short integer + * - L: D-dimensional translation vector of integers + * + * Provides helpers to obtain the parent/child indices, comparisons + * (including a strict weak ordering for associative containers), and utilities + * to test ancestry/sibling relations + * + * The usefulness of the class becomes evident when examining + * the parallel algorithms for projection & friends + */ template class NodeIndex final { public: - // regular constructors + /** + * @brief Regular constructor for NodeIndex + * @param[in] n Scale (defaults to zero) + * @param[in] l Translation vector with dimension D + * + * @details Casts n to a short int (N) and directly assigns L as l + */ NodeIndex(int n = 0, const std::array &l = {}) : N(static_cast(n)) , L(l) {} - - // relative constructors + /** + * @brief Relative constructor of the parent NodeIndex + * @return Parent NodeIndex + * + * @details Parents (N = N - 1) are obtained by floor rounding L/2 + */ NodeIndex parent() const { std::array l; for (int d = 0; d < D; d++) l[d] = (this->L[d] < 0) ? (this->L[d] - 1) / 2 : this->L[d] / 2; return NodeIndex(this->N - 1, l); } + /** + * @brief Relative constructor of child NodeIndex + * @param cIdx Child linear index + * @return Child NodeIndex + * + * @details Children (N = N + 1) are obtained by L = 2L + b, where @c b is given by the bits of @p cIdx + */ NodeIndex child(int cIdx) const { std::array l; for (int d = 0; d < D; d++) l[d] = (2 * this->L[d]) + ((cIdx >> d) & 1); return NodeIndex(this->N + 1, l); } - // comparisons + /** + * @brief Defines inequality operator + * @param[in] idx NodeIndex of the comparing node + * @return True if N and/or L are different + */ bool operator!=(const NodeIndex &idx) const { return not(*this == idx); } + /** + * @brief Defines equality operator + * @param[in] idx NodeIndex of comparing node + * @return True if both N and L are equal + */ bool operator==(const NodeIndex &idx) const { bool out = (this->N == idx.N); for (int d = 0; d < D; d++) out &= (this->L[d] == idx.L[d]); return out; } - // defines an order of the nodes (allows to use std::map) + /** + * @brief Defines comparison operator + * @param[in] idy NodeIndex of comparing node + * @return True if *this is smaller than idy + * + * @details + * Comparison rules (by order): + * 1. NodeIndex with smallest N is considered smallest + * 2. NodeIndex with the first component of L be smaller is considered smaller + * + * @note + * Strict weak ordering provides strict weak ordering to enables usage in std::map + */ bool operator<(const NodeIndex &idy) const { const NodeIndex &idx = *this; if (idx.N != idy.N) return idx.N < idy.N; @@ -71,19 +121,44 @@ template class NodeIndex final { return idx.L[2] < idy.L[2]; } - // setters - void setScale(int n) { this->N = static_cast(n); } - void setTranslation(const std::array &l) { this->L = l; } - - // value getters - int getScale() const { return this->N; } + /* + * Getters and setters + */ + + int getScale() const { return this->N; } ///< @return Scale of node + std::array getTranslation() const { return this->L; } ///< @return Full translation vector + void setScale(int n) { this->N = static_cast(n); } ///< @param n Scale of node + void setTranslation(const std::array &l) { this->L = l; } ///< @param l Translation vector of dimension D + + /** + * @brief Get a specific component of translation vector, L + * @param[in] d Index of wanted component + * @return Translation vector component @p d + */ int getTranslation(int d) const { return this->L[d]; } - std::array getTranslation() const { return this->L; } - // reference getters + /** + * @brief Define indexing operator of translation vector, L + * @param[in] d Index of wanted component + * @return Translation vector component @p d + */ int &operator[](int d) { return this->L[d]; } + + /** + * @brief Const version of @ref &operator[] + * @param[in] d Index of wanted component + * @return Translation vector component @p d + */ const int &operator[](int d) const { return this->L[d]; } + /** + * @brief Creates output stream of NodeIndex in readable format + * @param o Output stream + * @return A formatted version of @o + * + * @details + * Prints NodeIndex on the form "[ N | L0, L1, ... ]" + */ std::ostream &print(std::ostream &o) const { o << "[ " << std::setw(3) << this->N << " | "; for (int d = 0; d < D - 1; d++) o << std::setw(4) << this->L[d] << ", "; @@ -96,12 +171,25 @@ template class NodeIndex final { std::array L{}; ///< Translation index [x,y,z,...] }; -/** @brief ostream printer */ +/** + * @brief Defines operator for print of a @ref NodeIndex + * @param o Output stream + * @param[in] idx NodeIndex of wanted node + * @return Print stream of @ref NodeIndex + */ template std::ostream &operator<<(std::ostream &o, const NodeIndex &idx) { return idx.print(o); } -/** @brief Check whether indices are directly related (not sibling) */ +/** + * @brief Check whether two NodeIndices are directly related + * @param[in] a First NodeIndex + * @param[in] b Second NodeIndex + * @return True if related + * + * @details @p a and @p b are related if they follow the relation rules described + * in the @ref child() and @ref parent() constructors + */ template bool related(const NodeIndex &a, const NodeIndex &b) { const auto &sr = (a.getScale() < b.getScale()) ? a : b; const auto &jr = (a.getScale() >= b.getScale()) ? a : b; @@ -112,9 +200,14 @@ template bool related(const NodeIndex &a, const NodeIndex &b) { return related; } -/** @brief Check whether indices are siblings, i.e. same parent */ +/** + * @brief Check whether two NodeIndices are siblings, i.e. same parent + * @param[in] a First NodeIndex + * @param[in] b Second NodeIndex + * @return True if siblings + */ template bool siblings(const NodeIndex &a, const NodeIndex &b) { return (a.parent() == b.parent()); } -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/OperatorNode.cpp b/src/trees/OperatorNode.cpp index 37f576eac..adcb3e0d1 100644 --- a/src/trees/OperatorNode.cpp +++ b/src/trees/OperatorNode.cpp @@ -42,17 +42,6 @@ void OperatorNode::dealloc() { this->tree->getNodeAllocator().dealloc(sIdx); } -/** - * @brief Calculate one specific component norm of the OperatorNode (TODO: needs to be specified more). - * - * @param[in] i: TODO: deens to be specified - * - * @details OperatorNorms are defined as matrix 2-norms that are expensive to calculate. - * Thus we calculate some cheaper upper bounds for this norm for thresholding. - * First a simple vector norm, then a product of the 1- and infinity-norm. - * (TODO: needs to be more presiced). - * - */ double OperatorNode::calcComponentNorm(int i) const { int depth = getDepth(); double prec = getOperTree().getNormPrecision(); @@ -79,20 +68,6 @@ double OperatorNode::calcComponentNorm(int i) const { return norm; } -/** @brief Matrix elements of the non-standard form. - * - * @param[in] i: Index enumerating the matrix type in the non-standard form. - * @returns A submatrix of \f$ (k + 1) \times (k + 1) \f$-size from the non-standard form. - * - * @details OperatorNode is uniquely associted with a scale \f$ n \f$ and translation - * \f$ l = -2^n + 1, \ldots, 2^n = 1 \f$. - * The non-standard form \f$ T_n, B_n, C_n, A_n \f$ defines matrices - * \f$ \sigma_l^n, \beta_l^n, \gamma_l^n, \alpha_l^n \f$ for a given pair \f$ (n, l) \f$. - * One of these matrices is returned by the method according to the choice of the index parameter - * \f$ i = 0, 1, 2, 3 \f$, respectively. - * For example, \f$ \alpha_l^n = \text{getComponent}(3) \f$. - * - */ MatrixXd OperatorNode::getComponent(int i) { int depth = getDepth(); double prec = getOperTree().getNormPrecision(); diff --git a/src/trees/OperatorNode.h b/src/trees/OperatorNode.h index f7d313d9d..842f35432 100644 --- a/src/trees/OperatorNode.h +++ b/src/trees/OperatorNode.h @@ -25,41 +25,125 @@ #pragma once +#include // for Eigen::MatrixXd + #include "MWNode.h" #include "OperatorTree.h" namespace mrcpp { +/** + * @class OperatorNode + * + * @brief Node of an @ref OperatorTree. + * + * @details + * An operator node is formally a 2D node which stores the coefficients of an operator + * for a given scale and translation. The translation in this case corresponds to the difference + * in translation index betweeen input and output nodes of the function to which the operator is applied. + * The scaling and wavelet structure of the nodes encodes the which component of the operator the coefficients + * refer to (T scaling-scaling, C scaling-wavelet, B wavelet-scaling, A wavelet-wavelet) according to the non-standard form. + * + * The class offers typed accessors to the owning @ref mrcpp::OperatorTree and + * to parent/children nodes, and overrides a few hooks related to allocation, + * child generation, and norm computation that are specific to operator nodes. + * + * @note The spatial dimension is fixed to @c D=2 for operator trees. + */ class OperatorNode final : public MWNode<2> { public: - OperatorTree &getOperTree() { return static_cast(*this->tree); } - OperatorNode &getOperParent() { return static_cast(*this->parent); } - OperatorNode &getOperChild(int i) { return static_cast(*this->children[i]); } - - const OperatorTree &getOperTree() const { return static_cast(*this->tree); } - const OperatorNode &getOperParent() const { return static_cast(*this->parent); } - const OperatorNode &getOperChild(int i) const { return static_cast(*this->children[i]); } + OperatorTree &getOperTree() { return static_cast(*this->tree); } ///< @return Owning operator tree + OperatorNode &getOperParent() { return static_cast(*this->parent); } ///< @return Parent node + OperatorNode &getOperChild(int i) { return static_cast(*this->children[i]); } ///< @return Child @p i as @ref OperatorNode (non-const). + const OperatorTree &getOperTree() const { return static_cast(*this->tree); } ///< @return Owning operator tree + const OperatorNode &getOperParent() const { return static_cast(*this->parent); } ///< @return Parent node as @ref OperatorNode (const) + const OperatorNode &getOperChild(int i) const { return static_cast(*this->children[i]); } ///< @return Child @p i as @ref OperatorNode (const). + /** + * @brief Create child nodes + * + * @param coefs If @c true, also allocate coefficient storage for each child. + * + * @details Overrides @ref MWNode::createChildren to honor operator-specific + * allocation and bookkeeping. + */ void createChildren(bool coefs) override; + + /** + * @brief Generate child nodes and populates their coefficients. + * + * @details Overrides @ref MWNode::genChildren to implement the operator + */ void genChildren() override; + + /** + * @brief Delete all child nodes (and their coefficient storage) + * + * @details Overrides @ref MWNode::deleteChildren with operator-specific cleanup. + */ void deleteChildren() override; friend class OperatorTree; friend class NodeAllocator<2>; protected: + /** + * @brief Default constructor (used by allocators). + */ OperatorNode() : MWNode<2>(){}; + /** + * @brief Root node constructor (used by allocators). + */ OperatorNode(MWTree<2> *tree, int rIdx) : MWNode<2>(tree, rIdx){}; + /** + * @brief Child node constructor (used by allocators). + */ OperatorNode(MWNode<2> *parent, int cIdx) : MWNode<2>(parent, cIdx){}; + /** + * @brief Operator nodes cannot be copied. + */ OperatorNode(const OperatorNode &node) = delete; + /** + * @brief Operator nodes cannot be assigned. + */ OperatorNode &operator=(const OperatorNode &node) = delete; + /** + * @brief Default destructor. + */ ~OperatorNode() = default; - + /** + * @brief Release coefficient storage (if owned) and reset node state. + * @details Overrides @ref MWNode::dealloc to ensure operator-node invariants. + */ void dealloc() override; + + /** + * @brief Calculate the norm of a given component of the OperatorNode + * + * @param[in] i: component index in [0, 3] (2D operator node has 4 components) + * + * @details OperatorNorms are defined as matrix 2-norms that are expensive to calculate. + * Thus we calculate some cheaper upper bounds for this norm for thresholding. + * First a simple vector norm, then a product of the 1- and infinity-norm. + */ double calcComponentNorm(int i) const override; + + /** @brief Gets a given component of the non-standard form. + * + * @param[in] i: Index enumerating the non-standard form component (A, B, C, T). + * @returns The requested \f$ (k + 1) \times (k + 1) \f$-size matrix of the non-standard form. + * + * @details OperatorNode is uniquely associted with a scale \f$ n \f$ and translation + * \f$ l = -2^n + 1, \ldots, 2^n = 1 \f$. + * The non-standard form \f$ T_n, B_n, C_n, A_n \f$ defines matrices + * \f$ \sigma_l^n, \beta_l^n, \gamma_l^n, \alpha_l^n \f$ for a given pair \f$ (n, l) \f$. + * One of these matrices is returned by the method according to the choice of the index parameter + * \f$ i = 0, 1, 2, 3 \f$, respectively. + * For example, \f$ \alpha_l^n = \text{getComponent}(3) \f$. + */ Eigen::MatrixXd getComponent(int i); }; diff --git a/src/trees/OperatorTree.cpp b/src/trees/OperatorTree.cpp index 890f2677c..fb283a99a 100644 --- a/src/trees/OperatorTree.cpp +++ b/src/trees/OperatorTree.cpp @@ -98,14 +98,7 @@ void OperatorTree::clearBandWidth() { this->bandWidth = nullptr; } -/** @brief Calculates band widths of the non-standard form matrices. - * - * @param[in] prec: Precision used for thresholding - * - * @details It is starting from \f$ l = 0 \f$ and updating the band width value each time we encounter - * considerable value while keeping increasing \f$ l \f$, that stands for the distance to the diagonal. - * - */ + void OperatorTree::calcBandWidth(double prec) { if (this->bandWidth == nullptr) clearBandWidth(); this->bandWidth = new BandWidth(getDepth()); @@ -133,32 +126,12 @@ void OperatorTree::calcBandWidth(double prec) { println(100, "\nOperator BandWidth" << *this->bandWidth); } -/** @brief Checks if the distance to diagonal is bigger than the operator band width. - * - * @param[in] oTransl: distance to diagonal - * @param[in] o_depth: scaling order - * @param[in] idx: index corresponding to one of the matrices \f$ A, B, C \f$ or \f$ T \f$. - * - * @returns True if \b oTransl is outside of the band and False otherwise. - * - */ + bool OperatorTree::isOutsideBand(int oTransl, int o_depth, int idx) { return abs(oTransl) > this->bandWidth->getWidth(o_depth, idx); } -/** @brief Cleans up end nodes. - * - * @param[in] trust_scale: there is no cleaning down below \b trust_scale (it speeds up operator building). - * - * @details Traverses the tree and rewrites end nodes having branch node twins, - * i. e. identical with respect to scale and translation. - * This method is very handy, when an adaptive operator construction - * can make a significunt noise at low scaling depth. - * Its need comes from the fact that mwTransform up cannot override - * rubbish that can potentially stick to end nodes at a particular level, - * and as a result spread further up to the root with mwTransform. - * - */ + void OperatorTree::removeRoughScaleNoise(int trust_scale) { MWNode<2> *p_rubbish; // possibly inexact end node MWNode<2> *p_counterpart; // exact branch node @@ -191,12 +164,7 @@ void OperatorTree::getMaxTranslations(VectorXi &maxTransl) { } } -/** Make 1D lists, adressable from [-l, l] scale by scale, of operator node - * pointers for fast operator retrieval. This method is not thread safe, - * since it projects missing operator nodes on the fly. Hence, it must NEVER - * be called within a parallel region, or all hell will break loose. This is - * not really a problem, but you have been warned. - */ +// NOT THREAD SAFE void OperatorTree::setupOperNodeCache() { int nScales = this->nodesAtDepth.size(); int rootScale = this->getRootScale(); @@ -245,12 +213,7 @@ void OperatorTree::clearOperNodeCache() { } } -/** Regenerate all s/d-coeffs by backtransformation, starting at the bottom and - * thus purifying all coefficients. Option to overwrite or add up existing - * coefficients of BranchNodes (can be used after operator application). - * Reimplementation of MWTree::mwTransform() without OMP, as calculation - * of OperatorNorm is done using random vectors, which is non-deterministic - * in parallel. FunctionTrees should be fine. */ + void OperatorTree::mwTransformUp() { std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); @@ -264,12 +227,7 @@ void OperatorTree::mwTransformUp() { } } -/** Regenerate all scaling coeffs by MW transformation of existing s/w-coeffs - * on coarser scales, starting at the rootNodes. Option to overwrite or add up - * existing scaling coefficients (can be used after operator application). - * Reimplementation of MWTree::mwTransform() without OMP, as calculation - * of OperatorNorm is done using random vectors, which is non-deterministic - * in parallel. FunctionTrees should be fine. */ + void OperatorTree::mwTransformDown(bool overwrite) { std::vector> nodeTable; tree_utils::make_node_table(*this, nodeTable); diff --git a/src/trees/OperatorTree.h b/src/trees/OperatorTree.h index 83be4789a..de80ce347 100644 --- a/src/trees/OperatorTree.h +++ b/src/trees/OperatorTree.h @@ -25,53 +25,183 @@ #pragma once +#include // for Eigen::VectorXi + #include "MWTree.h" #include "NodeAllocator.h" namespace mrcpp { +// Forward declarations to avoid including the full headers here. +class BandWidth; +class OperatorNode; + +/** + * @class OperatorTree + * @brief Base class for 2D operator trees in non-standard form + * + * @details + * The tree is organized like any MW tree (roots/branches/leaves) but stores + * operator coefficients. A per-depth **band width** (distance from the main + * diagonal in translation space) can be estimated to prune negligible corner + * blocks during application. + */ class OperatorTree : public MWTree<2> { public: + /** + * @brief Construct an operator tree + * @param[in] mra Multi-resolution analysis (domain + basis) shared by the tree + * @param[in] np “Norm precision” used when estimating/screening norms + * @param[in] name Optional diagnostic name + */ OperatorTree(const MultiResolutionAnalysis<2> &mra, double np, const std::string &name = "nn"); + OperatorTree(const OperatorTree &tree) = delete; OperatorTree &operator=(const OperatorTree &tree) = delete; + + /// Virtual destructor virtual ~OperatorTree() override; + /// @return The precision value used for norm-based screening double getNormPrecision() const { return this->normPrec; } + /** + * @brief Release any existing @ref BandWidth object and set the pointer to null + * @details Call this if the operator has changed and band widths must be recomputed + */ void clearBandWidth(); + + /** @brief Calculates band widths of the non-standard form matrices + * + * @param[in] prec: Precision used for thresholding + * + * @details It is starting from \f$ l = 0 \f$ and updating the band width value each time we encounter + * considerable value while keeping increasing \f$ l \f$, that stands for the distance to the diagonal + */ virtual void calcBandWidth(double prec = -1.0); + + /** @brief Checks if the distance to diagonal is bigger than the operator band width. + * + * @param[in] oTransl: distance to diagonal + * @param[in] o_depth: scaling order + * @param[in] idx: index corresponding to one of the matrices \f$ A, B, C \f$ or \f$ T \f$.ì + * + * @returns True if \b oTransl is outside of the band and False otherwise + */ virtual bool isOutsideBand(int oTransl, int o_depth, int idx); + + /** @brief Cleans up end nodes. + * + * @param[in] trust_scale: there is no cleaning down below \b trust_scale (it speeds up operator building). + * + * @details Traverses the tree and rewrites end nodes having branch node twins, + * i. e. identical with respect to scale and translation. + * This method is very handy, when an adaptive operator construction + * can make a significunt noise at low scaling depth. + * Its need comes from the fact that mwTransform up cannot override + * rubbish that can potentially stick to end nodes at a particular level, + * and as a result spread further up to the root with mwTransform. + */ void removeRoughScaleNoise(int trust_scale = 10); + /** + * @brief Make 1D lists, adressable from [-l, l] scale by scale, of operator node pointers for fast operator retrieval + * + * @details Populates @ref nodePtrStore and @ref nodePtrAccess to avoid repeated lookups. + * + * @warning This method is not thread safe, + * since it projects missing operator nodes on the fly. Hence, it must NEVER + * be called within a parallel region, or all hell will break loose. This is + * not really a problem, but you have been warned. + */ void setupOperNodeCache(); + + /// @brief Clear the operator-node caches built by @ref setupOperNodeCache() void clearOperNodeCache(); + /// @return Mutable reference to the stored @ref BandWidth (must exist) BandWidth &getBandWidth() { return *this->bandWidth; } + /// @return Const reference to the stored @ref BandWidth (must exist) const BandWidth &getBandWidth() const { return *this->bandWidth; } + /** + * @brief Fast accessor to a node by indices (scale, diagonal distance) + * + * @param[in] n Scale (depth measured from the root scale). + * @param[in] l Distance to the diagonal (translation difference); l=0 hits the diagonal + * + * @return Reference to the requested @ref OperatorNode. + * @warning Valid only after calling @ref setupOperNodeCache() + */ OperatorNode &getNode(int n, int l) { return *nodePtrAccess[n][l]; - } ///< TODO: It has to be specified more. - ///< \b l is distance to the diagonal. + } + /// @overload const OperatorNode &getNode(int n, int l) const { return *nodePtrAccess[n][l]; } - void mwTransformDown(bool overwrite) override; + + + + + + + + /** + * @brief Regenerate all s/d-coeffs by backtransformation, starting at the bottom and thus purifying all coefficients + * + * @param overwrite If @c true, child coefficients may overwrite parent storage + * + * @details Option to overwrite or add up existing + * coefficients of BranchNodes (can be used after operator application). + * Reimplementation of MWTree::mwTransform() without OMP, as calculation + * of OperatorNorm is done using random vectors, which is non-deterministic + * in parallel. FunctionTrees should be fine. + */ void mwTransformUp() override; + + /** + * @brief Regenerate all scaling coeffs by MW transformation of existing s/w-coeffs on coarser scales, starting at the rootNodes + * + * @param overwrite If @c true, child coefficients may overwrite existing scaling coefficients + * + * @details Option to overwrite or add up existing + * coefficients of BranchNodes (can be used after operator application). + * Reimplementation of MWTree::mwTransform() without OMP, as calculation + * of OperatorNorm is done using random vectors, which is non-deterministic + * in parallel. FunctionTrees should be fine. + */ + void mwTransformDown(bool overwrite) override; + + + + + /// @overload using MWTree<2>::getNode; + /// @overload using MWTree<2>::findNode; protected: - const double normPrec; - BandWidth *bandWidth; - OperatorNode ***nodePtrStore; ///< Avoids tree lookups - OperatorNode ***nodePtrAccess; ///< Center (l=0) of node list + const double normPrec; ///< Default precision used in norm-based heuristics. + BandWidth *bandWidth; ///< Optional per-depth band-width description (owned). + /// @name Operator-node cache (built by @ref setupOperNodeCache()). + ///@{ + OperatorNode ***nodePtrStore; ///< Storage for contiguous (n,l) -> node pointers. + OperatorNode ***nodePtrAccess; ///< Centered view so that index l=0 addresses the diagonal. + ///@} + + /// @brief Allocate all root nodes required by the current world box. void allocRootNodes(); + + /** + * @brief Compute the maximum translation index at each depth. + * @param[out] maxTransl Vector whose @c d-th entry stores the maximum |l| at that depth. + */ void getMaxTranslations(Eigen::VectorXi &maxTransl); + /// @brief Human-readable dump of tree statistics. std::ostream &print(std::ostream &o) const override; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/trees/TreeIterator.h b/src/trees/TreeIterator.h index 82ea49eb9..33eee6647 100644 --- a/src/trees/TreeIterator.h +++ b/src/trees/TreeIterator.h @@ -23,6 +23,42 @@ * */ +/** + * @file TreeIterator.h + * @brief Iteration helpers for traversing multiwavelet trees. + * + * @details + * This header provides a depth-aware iterator over the nodes of a @ref MWTree. + * It supports different **traversal directions** and **node-ordering schemes**, + * selected via constants defined in @c MRCPP/constants.h: + * - Traversal mode: @c TopDown or @c BottomUp + * In the @c TopDown mode, one iterates from the first root node and recursively + * over the children + * In the @c BottomUp mode, one first traverses the tree all the way down to the + * leaves and then starts iteratig from there + * - Iterator type: @c Lebesgue (Z-order) or @c Hilbert + * + * The iterator yields @ref MWNode instances in the requested sequence determined by + * the parameters above + * + * The file contains two classes: @ref TreeIterator and @ref IteratorNode. + * The @ref TreeIterator is the main interface for users, while the @ref IteratorNode + * is mainly a placeholder for a few node-specific flags. + * + * @par Example + * @code{.cpp} + * using namespace mrcpp; + * TreeIterator<3,double> it(tree, TopDown, Lebesgue); + * it.setReturnGenNodes(true); // include generated nodes + * it.setMaxDepth(5); // restrict to depth <= 5 + * + * while (it.next()) { + * MWNode<3,double> &nd = it.getNode(); + * // ... inspect nd, read coefficients/norms, etc. + * } + * @endcode + */ + #pragma once #include "MRCPP/constants.h" @@ -30,56 +66,156 @@ namespace mrcpp { +/** + * @class TreeIterator + * + * @brief Iterator for traversing an @ref MWTree. + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @details + * The iterator traverses the tree starting the root node(s), producing nodes + * according to: + * - a **traversal direction** ( @c TopDown or @c BottomUp), and + * - an **ordering scheme** within siblings ( @c Lebesgue or @c Hilbert). + * + * The behavior can be refined with: + * - @ref setReturnGenNodes() to toggle inclusion of generated (non-leaf) nodes, + * - @ref setMaxDepth() to limit the traversal depth, + * - @ref setTraverse() / @ref setIterator() to change policies at runtime. + * + * The iteration state is represented by an internal linked stack of + * @ref IteratorNode instances. + */ template class TreeIterator { public: + /** + * @brief Construct a detached iterator (no tree bound yet). + * + * @param traverse Traversal mode (e.g., @c TopDown or @c BottomUp). + * @param iterator Node-ordering mode (e.g., @c Lebesgue or @c Hilbert). + * + * @note Call @ref init() before the first @ref next() if you use this ctor. + */ TreeIterator(int traverse = TopDown, int iterator = Lebesgue); + /** + * @brief Construct an iterator bound to a tree. + * + * @param tree Tree to traverse. + * @param traverse Traversal mode (e.g., @c TopDown or @c BottomUp). + * @param iterator Node-ordering mode (e.g., @c Lebesgue or @c Hilbert). + */ TreeIterator(MWTree &tree, int traverse = TopDown, int iterator = Lebesgue); + /// @brief Destructor (releases internal traversal state). virtual ~TreeIterator(); - - void setReturnGenNodes(bool i = true) { this->returnGenNodes = i; } - void setMaxDepth(int depth) { this->maxDepth = depth; } - void setTraverse(int traverse); - void setIterator(int iterator); - + void setReturnGenNodes(bool i = true) { this->returnGenNodes = i; } ///< @param i If true, generated nodes are included in the sequence. + void setMaxDepth(int depth) { this->maxDepth = depth; } ///< @param depth Non-negative maximum depth; if negative, no limit is applied. + void setTraverse(int traverse);///< @param traverse set Traversal mode (@c TopDown or @c BottomUp). + void setIterator(int iterator);///< @param iterator set Iterator type (@c Lebesgue or @c Hilbert). + MWNode &getNode() { return *this->state->node; } ///< @return Reference to the node yielded by the last successful @ref next() / @ref nextParent(). + /** + * @brief Bind the iterator to a tree and reset traversal state. + * + * @param tree Tree to traverse. + */ void init(MWTree &tree); + /** + * @brief Advance to the next node according to the current policy. + * + * @return @c true if a node is available (use @ref getNode()), @c false when finished. + * + * @details + * if the current @ref IteratorNode is null, return false. + * In @c TopDown mode, try to return the current node first. + * If successful, return true. + * If not, check if the current node has children, and try to return + * the next child node according to the ordering scheme. + * If successful, return true. + * If not, try to move to the next root node, and return its first node + * according to the ordering scheme. + * If successful, return true. + * If not, in @c BottomUp mode, try to return the current node. + * If successful, return true. + * If not, remove the current state and recur invoking a new @ref next(). + */ bool next(); + /** + * @brief Advance to the next parent node according to the current policy. + * + * @return @c true if the parent node is available, @c false when finished. + * + * @details + * Returns the current node or the parent of the current node. The logic makes sure the correct + * parent is returned according to the traversal mode and ordering scheme. In case of PBC calculations, + * the parent may be above the root nodes defining the unit cell. + */ bool nextParent(); - MWNode &getNode() { return *this->state->node; } friend class IteratorNode; protected: - int root; - int nRoots; - int mode; - int type; - int maxDepth; - bool returnGenNodes{true}; - IteratorNode *state; - IteratorNode *initialState; - - int getChildIndex(int i) const; + int root{0}; ///< Index of the current root box. + int nRoots{0}; ///< Number of root boxes in the tree. + int mode{TopDown}; ///< Traversal mode (@c TopDown or @c BottomUp). + int type{Lebesgue}; ///< Iterator type (@c Lebesgue or @c Hilbert). + int maxDepth{-1}; ///< Max depth limit; negative means unlimited. + bool returnGenNodes{true}; ///< If @c true, also return generated (non-leaf) nodes. + IteratorNode *state{nullptr}; ///< Current traversal frame. + IteratorNode *initialState{nullptr}; ///< Initial frame for the current root. - bool tryParent(); - bool tryChild(int i); - bool tryNode(); - bool tryNextRoot(); - bool tryNextRootParent(); - void removeState(); - bool checkDepth(const MWNode &node) const; - bool checkGenerated(const MWNode &node) const; + int getChildIndex(int i) const; ///< @brief Map logical child order [0..2^D) to actual child index based on @ref type. +/** + * @name try... methods + * @brief The following methods test if the node of a given type should be returned. + * @details In addition to returning @c true or @c false, these methods also update the internal + * traversal state accordingly. + * @{ + */ + bool tryParent(); ///< @return @c true if the parent node should be returned. + bool tryChild(int i);///< @return @c true if the child at index @p i should be returned. + bool tryNode(); ///< @return @c true if the current node shuld be returned. + bool tryNextRoot(); ///< @return @c true if the next root node should be returned. + bool tryNextRootParent(); ///< @return @c true if the parent of the next root node is available and should be returned. +/** @} */ + void removeState(); ///< @brief Remove the current traversal frame from the stack. + bool checkDepth(const MWNode &node) const; ///< @return @c true if the node is within the max depth limit. + bool checkGenerated(const MWNode &node) const; ///< @return @c true if the generated nodes should be included. }; +/** + * @class IteratorNode + * @brief Iterator representing a node in the traversal stack. + * + * @tparam D Spatial dimension (1, 2, or 3) + * @tparam T Coefficient type (e.g. double, ComplexDouble) + * + * @details + * This is an internal placeholder which contains both the pointer to the actual node to return and + * flags to determine if itself, its parent and its children have been already returned. + * It contains: + * - a pointer to the node, + * - a link to the next node in the stack + * - completion flags for the current node, its parent, and its children. + */ template class IteratorNode final { public: - MWNode *node; - IteratorNode *next; - bool doneNode; - bool doneParent; - bool doneChild[1 << D]; + MWNode *node; ///< Current node. + IteratorNode *next; ///< Next node in the stack. + bool doneNode; ///< Whether the node itself has been used. + bool doneParent; ///< Whether the parent node has been used. + bool doneChild[1 << D]; ///< Whether each child has been used. + /** + * @brief Construct a new iterator + * + * @param nd Pointer to the MW node represented by this frame. + * @param nx Link to the next iterator (can be @c nullptr). + */ IteratorNode(MWNode *nd, IteratorNode *nx = nullptr); + + /// @brief Recursively delete the linked iterators that follow this one. ~IteratorNode() { delete this->next; } }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/Bank.h b/src/utils/Bank.h index 69719c530..57ce028b3 100644 --- a/src/utils/Bank.h +++ b/src/utils/Bank.h @@ -1,4 +1,55 @@ +/* + * MRCPP, a numerical library based on multiresolution analysis and + * the multiwavelet basis which provide low-scaling algorithms as well as + * rigorous error control in numerical computations. + * Copyright (C) 2021 Stig Rune Jensen, Jonas Juselius, Luca Frediani and contributors. + * + * This file is part of MRCPP. + * + * MRCPP is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * MRCPP is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with MRCPP. If not, see . + * + * For information on the complete list of contributors to MRCPP, see: + * + */ + #pragma once +/** + * @file + * @brief Distributed “Bank” service for sharing functions and raw data across MPI ranks. + * + * This header declares a minimal runtime that lets multiple MPI ranks exchange: + * - Multiresolution functions (`CompFunction<3>`) and + * - Raw numeric buffers (`double` / `ComplexDouble`) + * + * The service is organized as a central **Bank** that maintains per-client + * accounts. Each rank interacts with the Bank through a lightweight RAII + * client, **BankAccount**. A simple **TaskManager** piggybacks on the same + * infrastructure to distribute integer-indexed tasks and collect “ready” + * notifications. + * + * @par High-level design + * - The **Bank** lives on one or more designated MPI ranks (see `mpi::is_bank` + * in the runtime). Non-bank ranks act as clients. + * - Clients open an **account** and then `put_*` or `get_*` by integer IDs or + * by `NodeIndex<3>` keys. The Bank tracks sizes in kB for accounting. + * - The **TaskManager** provides a tiny work-queue: clients request the next + * task, mark tasks as ready, and optionally consume ready items. + * + * @note The concrete message-passing, blocking semantics, and memory + * ownership rules are implemented in the Bank source (MPI-based). + * This header documents intent and call contracts at a high level. + */ #include "CompFunction.h" #include "parallel.h" @@ -8,128 +59,473 @@ namespace mrcpp { using namespace mpi; +/** + * @brief A deposited item stored by the Bank. + * + * A deposit can represent either a multiresolution function (`orb`) or a + * raw data buffer (`data`). Exactly one of them is expected to be active + * for a given deposit. + */ struct deposit { - CompFunction<3> *orb; - double *data; // for pure data arrays - bool hasdata; - int datasize; + /** Pointer to a deposited function (3D component function). */ + CompFunction<3> *orb = nullptr; + /** Pointer to a deposited plain data buffer (contiguous). */ + double *data = nullptr; // for pure data arrays + /** True if this deposit contains a raw data buffer in @ref data. */ + bool hasdata = false; + /** Size (number of elements) for @ref data when @ref hasdata is true. */ + int datasize = 0; + /** Application-defined identifier used to name and retrieve this deposit. */ int id = -1; // to identify what is deposited - int source; // mpi rank from the source of the data + /** MPI rank that originally deposited the item. */ + int source = 0; // mpi rank from the source of the data }; +/** + * @brief Queue bookkeeping for task-ready notifications. + * + * Associates a queue identifier with a list of client ranks that registered + * interest or contributed ready items. + */ struct queue_struct { - int id; + /** Queue identifier (application-defined). */ + int id = 0; + /** Ranks that have entries or are waiting on this queue. */ std::vector clients; }; +/** + * @brief Command codes exchanged between clients and the Bank/TaskManager. + * + * These enumerators are used as operation selectors in MPI messages. Listed + * values are stable and intentionally explicit to simplify debugging. + */ enum { - // (the values are used to interpret error messages) - CLOSE_BANK, // 0 - CLEAR_BANK, // 1 - NEW_ACCOUNT, // 2 - CLOSE_ACCOUNT, // 3 - GET_ORBITAL, // 4 - GET_FUNCTION_AND_WAIT, // 5 - GET_FUNCTION_AND_DELETE, // 6 - SAVE_ORBITAL, // 7 - GET_FUNCTION, // 8 - SAVE_FUNCTION, // 9 - GET_DATA, // 10 - SAVE_DATA, // 11 - SAVE_NODEDATA, // 12 - GET_NODEDATA, // 13 - GET_NODEBLOCK, // 14 - GET_ORBBLOCK, // 15 - CLEAR_BLOCKS, // 16 - GET_MAXTOTDATA, // 17 - GET_TOTDATA, // 18 - INIT_TASKS, // 19 - GET_NEXTTASK, // 20 - PUT_READYTASK, // 21 - DEL_READYTASK, // 22 - GET_READYTASK, // 23 - GET_READYTASK_DEL, // 24 + CLOSE_BANK, ///< 0 — Shut down the Bank service. + CLEAR_BANK, ///< 1 — Remove all accounts and deposits. + NEW_ACCOUNT, ///< 2 — Open a new client account. + CLOSE_ACCOUNT, ///< 3 — Close (delete) an existing account. + GET_ORBITAL, ///< 4 — Retrieve an orbital (internal legacy op). + GET_FUNCTION_AND_WAIT, ///< 5 — Blocking fetch of a function until available. + GET_FUNCTION_AND_DELETE, ///< 6 — Fetch and erase a function. + SAVE_ORBITAL, ///< 7 — Store an orbital (internal legacy op). + GET_FUNCTION, ///< 8 — Non-blocking fetch of a function if available. + SAVE_FUNCTION, ///< 9 — Store a function. + GET_DATA, ///< 10 — Fetch a raw data buffer. + SAVE_DATA, ///< 11 — Store a raw data buffer. + SAVE_NODEDATA, ///< 12 — Store node-scoped raw data. + GET_NODEDATA, ///< 13 — Fetch node-scoped raw data. + GET_NODEBLOCK, ///< 14 — Fetch a contiguous block for a node id across IDs. + GET_ORBBLOCK, ///< 15 — Fetch a contiguous block for an orbital id across nodes. + CLEAR_BLOCKS, ///< 16 — Clear block caches/aggregations. + GET_MAXTOTDATA, ///< 17 — Query max total stored size (kB). + GET_TOTDATA, ///< 18 — Query per-account total sizes (kB). + INIT_TASKS, ///< 19 — Initialize TaskManager with N tasks. + GET_NEXTTASK, ///< 20 — Acquire the next task index. + PUT_READYTASK, ///< 21 — Mark (i,j) as ready. + DEL_READYTASK, ///< 22 — Remove a ready marker (i,j). + GET_READYTASK, ///< 23 — Get ready list for i (keep). + GET_READYTASK_DEL, ///< 24 — Get ready list for i and consume. }; +/** + * @brief Central repository for distributed function/data sharing and task queues. + * + * The Bank owns per-client accounts, holds deposits, and tracks memory usage + * in kB. Only designated Bank ranks instantiate and `open()` the service; + * clients interact via @ref BankAccount and @ref TaskManager on worker ranks. + * + * @par Thread-safety + * Bank methods are orchestrated via MPI; within a single rank, methods are not + * inherently thread-safe unless otherwise guarded at the call site. + */ class Bank { public: + /** @brief Construct an unopened Bank instance on a Bank rank. */ Bank() = default; + + /** @brief Destructor. Ensures resources are released if not already closed. */ ~Bank(); + + /** + * @brief Start the Bank service (receive loop, state init). + * + * Must be called on the Bank rank(s). After `open()`, the service listens + * for client commands and manages account state. + */ void open(); + + /** + * @brief Stop the Bank service and release all resources. + * + * Closes all accounts and clears all deposits. + */ void close(); + + /** + * @brief Maximum total footprint observed so far, in kB. + * @return Peak cumulative size (across all accounts) since `open()`. + */ int get_maxtotalsize(); + + /** + * @brief Current total sizes per account, in kB. + * @return A vector of sizes aligned with internal account ordering. + */ std::vector get_totalsize(); private: friend class BankAccount; friend class TaskManager; - // used by BankAccount + // ---- Account control (called by clients through Bank's command loop) ---- + + /** + * @brief Create a new account for client rank @p iclient. + * @param iclient Rank creating the account (logical owner). + * @param comm Communicator the client uses to reach the Bank. + * @return Integer account identifier (>0 on success). + */ int openAccount(int iclient, MPI_Comm comm); - int clearAccount(int account, int iclient, MPI_Comm comm); // closes and open fresh account - void closeAccount(int account_id); // remove the account - // used by TaskManager; + /** + * @brief Clear and reinitialize an existing account. + * + * Equivalent to closing and reopening the account, preserving the account id. + * + * @param account Account id to clear. + * @param iclient Requesting client rank. + * @param comm Client communicator. + * @return 0 on success, negative on error. + */ + int clearAccount(int account, int iclient, MPI_Comm comm); + + /** + * @brief Permanently remove an account and all of its deposits. + * @param account_id Account identifier. + */ + void closeAccount(int account_id); + + // ---- Task manager control (internal) ---- + + /** + * @brief Initialize task bookkeeping for @p ntasks items. + * @param ntasks Number of tasks available (0..ntasks-1). + * @param iclient Requesting rank. + * @param comm Client communicator. + * @return Account id of the task manager context. + */ int openTaskManager(int ntasks, int iclient, MPI_Comm comm); + + /** + * @brief Close and remove a TaskManager context. + * @param account_id Associated account id. + */ void closeTaskManager(int account_id); - // used internally by Bank; + // ---- Internal utilities ---- + + /** @brief Remove all accounts and deposits (global reset). */ void clear_bank(); - void remove_account(int account); // remove the content and the account - long long totcurrentsize = 0ll; // number of kB used by all accounts - std::vector accounts; // open bank accounts - std::map *> get_deposits; // gives deposits of an account + /** + * @brief Remove a single account and its content. + * @param account Account id to erase. + */ + void remove_account(int account); + + // ---- Accounting & indices ---- + long long totcurrentsize = 0ll; ///< Sum of all account sizes (kB). + std::vector accounts; ///< Active account ids. + + /** Map: account id → vector of deposits. */ + std::map *> get_deposits; + + /** Map: account id → (item id → index in deposits vector). */ std::map *> get_id2ix; + + /** Map: account id → (queue id → index in queue vector). */ std::map *> get_id2qu; - std::map *> get_queue; // gives deposits of an account - std::map> *> get_readytasks; // used by task manager - std::map currentsize; // total deposited data size (without containers) - long long maxsize = 0; // max total deposited data size (without containers) + + /** Map: account id → queue collection (task-ready queues). */ + std::map *> get_queue; + + /** Map: account id → (i → vector of j ready items). */ + std::map> *> get_readytasks; + + /** Map: account id → current size in kB (without container overhead). */ + std::map currentsize; + + /** Peak total size (kB) observed since last reset. */ + long long maxsize = 0; }; +/** + * @brief RAII client-side view of a Bank account. + * + * A `BankAccount` encapsulates a live account and offers typed methods to + * deposit and retrieve functions or raw buffers. Most methods are thin + * request wrappers; the Bank performs the actual storage. + * + * @note By default, rank and communicator are taken from the MPI worker + * context (`mpi::wrk_rank`, `mpi::comm_wrk`). + * + * @par Ownership & lifetime + * Returned raw pointers (e.g., from `get_orbblock`) typically reference + * storage owned by the Bank. Callers should copy data if it must outlive + * subsequent Bank interactions. See the implementation for exact details. + */ class BankAccount { public: + /** + * @brief Open a new account for @p iclient on communicator @p comm. + * @param iclient Client rank that owns the account (default: @ref mpi::wrk_rank). + * @param comm Communicator used to contact the Bank (default: @ref mpi::comm_wrk). + */ BankAccount(int iclient = wrk_rank, MPI_Comm comm = comm_wrk); + + /** @brief Close the account and release any client-side resources. */ ~BankAccount(); + + /** @brief Bank-assigned account identifier (≥0 when open). */ int account_id = -1; + + /** + * @brief Clear and reinitialize this account. + * @param i Client rank issuing the request. + * @param comm Client communicator. + */ void clear(int i = wrk_rank, MPI_Comm comm = comm_wrk); - // int put_orb(int id, ComplexFunction &orb); - // int get_orb(int id, ComplexFunction &orb, int wait = 0); + + // --- Function storage/retrieval --- + + /** + * @brief Fetch a function by @p id and delete it on the Bank. + * @param id Application-level identifier of the function. + * @param orb Output destination; resized/assigned by the Bank. + * @return 0 on success, negative on error. + */ int get_func_del(int id, CompFunction<3> &orb); + + /** + * @brief Deposit a function under identifier @p id. + * @param id Application-level identifier. + * @param func Function object to store (copied/serialized by Bank). + * @return 0 on success, negative on error. + */ int put_func(int id, CompFunction<3> &func); + + /** + * @brief Fetch a function by @p id. + * @param id Application-level identifier. + * @param func Output destination; resized/assigned by the Bank. + * @param wait If nonzero, block until available; otherwise return immediately if missing. + * @return 0 on success; negative on error; positive (e.g. 1) if not found and @p wait==0. + */ int get_func(int id, CompFunction<3> &func, int wait = 0); + + // --- Raw data buffers by plain id --- + + /** + * @brief Deposit a real-valued buffer. + * @param id Application-level identifier. + * @param size Number of elements in @p data. + * @param data Pointer to contiguous buffer (copied by the Bank). + * @return 0 on success, negative on error. + */ int put_data(int id, int size, double *data); + + /** + * @brief Deposit a complex-valued buffer. + * @copydetails put_data(int,int,double*) + */ int put_data(int id, int size, ComplexDouble *data); + + /** + * @brief Retrieve a real-valued buffer by @p id. + * @param id Identifier previously used with @ref put_data. + * @param size Expected number of elements; used for validation. + * @param data Destination buffer; must have room for @p size elements. + * @return 0 on success; negative on error; positive if not found. + */ int get_data(int id, int size, double *data); + + /** + * @brief Retrieve a complex-valued buffer by @p id. + * @copydetails get_data(int,int,double*) + */ int get_data(int id, int size, ComplexDouble *data); + + // --- Raw data scoped by node index (spatial addressing) --- + + /** + * @brief Deposit real-valued data associated with a node index. + * @param nIdx Spatial node key. + * @param size Number of elements. + * @param data Buffer pointer (copied by the Bank). + * @return 0 on success, negative on error. + */ int put_data(NodeIndex<3> nIdx, int size, double *data); + + /** + * @brief Deposit complex-valued data for a node index. + * @copydetails put_data(NodeIndex<3>,int,double*) + */ int put_data(NodeIndex<3> nIdx, int size, ComplexDouble *data); + + /** + * @brief Retrieve real-valued data for a node index. + * @param nIdx Node key. + * @param size Expected number of elements. + * @param data Output buffer. + * @return 0 on success; negative on error; positive if not found. + */ int get_data(NodeIndex<3> nIdx, int size, double *data); + + /** + * @brief Retrieve complex-valued data for a node index. + * @copydetails get_data(NodeIndex<3>,int,double*) + */ int get_data(NodeIndex<3> nIdx, int size, ComplexDouble *data); + + // --- Node-scoped data grouped under an object id (e.g., orbital id) --- + + /** + * @brief Deposit real-valued data for a specific node @p nodeid within object @p id. + * @param id Object (e.g., orbital) identifier. + * @param nodeid Node identifier within the object. + * @param size Number of elements. + * @param data Buffer pointer (copied by the Bank). + * @return 0 on success, negative on error. + */ int put_nodedata(int id, int nodeid, int size, double *data); + + /** + * @brief Deposit complex-valued data for a node within object @p id. + * @copydetails put_nodedata(int,int,int,double*) + */ int put_nodedata(int id, int nodeid, int size, ComplexDouble *data); + + /** + * @brief Retrieve real-valued data for (@p id, @p nodeid). + * @param id Object identifier. + * @param nodeid Node identifier. + * @param size Expected element count. + * @param data Output buffer. + * @param idVec (Out) List of object ids actually present in the block, if aggregated. + * @return 0 on success; negative on error; positive if not found. + */ int get_nodedata(int id, int nodeid, int size, double *data, std::vector &idVec); + + /** + * @brief Retrieve complex-valued data for (@p id, @p nodeid). + * @copydetails get_nodedata(int,int,int,double*,std::vector&) + */ int get_nodedata(int id, int nodeid, int size, ComplexDouble *data, std::vector &idVec); + + // --- Block retrieval helpers --- + + /** + * @brief Retrieve a contiguous block of all real node data for @p nodeid across ids. + * @param nodeid Node identifier to gather. + * @param data (Out) Pointer to contiguous storage; copy data before next call. + * @param idVec (Out) List of ids participating in the block. + * @return Number of elements in @p data on success; negative on error. + */ int get_nodeblock(int nodeid, double *data, std::vector &idVec); + + /** + * @brief Retrieve a contiguous block of all complex node data for @p nodeid across ids. + * @copydetails get_nodeblock(int,double*,std::vector&) + */ int get_nodeblock(int nodeid, ComplexDouble *data, std::vector &idVec); + + /** + * @brief Retrieve all real-valued node data for an orbital id into a single contiguous block. + * @param orbid Orbital (object) id. + * @param data (Out) Pointer reference to contiguous storage. + * @param nodeidVec (Out) Node ids represented in the block. + * @param bankstart Starting index/offset within the Bank’s internal storage. + * @return Number of elements in the returned block; negative on error. + * + * @note Copy out the data if it must persist beyond this call or subsequent Bank calls. + */ int get_orbblock(int orbid, double *&data, std::vector &nodeidVec, int bankstart); + + /** + * @brief Retrieve all complex-valued node data for an orbital id into a contiguous block. + * @copydetails get_orbblock(int,double*&,std::vector&,int) + */ int get_orbblock(int orbid, ComplexDouble *&data, std::vector &nodeidVec, int bankstart); }; +/** + * @brief Minimal distributed task queue associated with a Bank account. + * + * The TaskManager assigns task indices in [0, @ref n_tasks). Clients can: + * - Request the next task (`next_task()`), + * - Mark specific items as ready (`put_readytask(i,j)` / `del_readytask(i,j)`), + * - Retrieve ready lists (`get_readytask(i, del)`). + * + * The actual synchronization and distribution are performed by the Bank. + */ class TaskManager { public: + /** + * @brief Construct and initialize a task context with @p ntasks tasks. + * @param ntasks Total number of tasks available (0..ntasks-1). + * @param iclient Client rank that opens the context. + * @param comm Communicator for Bank interaction. + */ TaskManager(int ntasks, int iclient = wrk_rank, MPI_Comm comm = comm_wrk); + + /** @brief Destructor; closes the TaskManager context. */ ~TaskManager(); + + /** + * @brief Obtain the next task index to process. + * @return Task index in [0, @ref n_tasks), or negative if none available. + */ int next_task(); + + /** + * @brief Mark item (@p i, @p j) as ready. + * @param i Primary key (e.g., task group/channel). + * @param j Secondary key (e.g., item id). + */ void put_readytask(int i, int j); + + /** + * @brief Remove ready marker (@p i, @p j). + * @param i Primary key. + * @param j Secondary key. + */ void del_readytask(int i, int j); + + /** + * @brief Retrieve the ready list for key @p i. + * @param i Primary key (queue id). + * @param del If nonzero, consume (erase) the ready list; otherwise keep it. + * @return Vector of secondary keys (j values) that are ready. + */ std::vector get_readytask(int i, int del); + + /** @brief Bank account id associated with this task context. */ int account_id = -1; - int task = 0; // used in serial case only - int n_tasks = 0; // used in serial case only + + /** @name Serial fallbacks + * These are used if the runtime is not using MPI distribution. + * @{ */ + int task = 0; ///< Current task pointer (serial mode only). + int n_tasks = 0; ///< Total tasks (serial mode only). + /** @} */ }; +/** + * @brief Fixed size of control messages exchanged with the Bank. + * + * @details This constant is used by the MPI layer to size control payloads. + */ int const message_size = 7; } // namespace mrcpp diff --git a/src/utils/CompFunction.h b/src/utils/CompFunction.h index 33011bd98..6d43217dd 100644 --- a/src/utils/CompFunction.h +++ b/src/utils/CompFunction.h @@ -1,4 +1,53 @@ +/* + * MRCPP, a numerical library based on multiresolution analysis and + * the multiwavelet basis which provide low-scaling algorithms as well as + * rigorous error control in numerical computations. + * Copyright (C) 2021 Stig Rune Jensen, Jonas Juselius, Luca Frediani and contributors. + * + * This file is part of MRCPP. + * + * MRCPP is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * MRCPP is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with MRCPP. If not, see . + * + * For information on the complete list of contributors to MRCPP, see: + * + */ + #pragma once +/** + * @file + * @brief Composite multicomponent function types (real/complex) on MRCPP multiresolution trees. + * + * This header defines: + * - @ref CompFunctionData : POD metadata describing a multicomponent function. + * - @ref TreePtr : Small owning handle to up to four component trees (real/complex), + * with optional MPI shared-memory backing. + * - @ref CompFunction : A high-level wrapper that owns/addresses component trees, + * provides algebra (add/multiply/dot), projection, scaling, + * norms, and utilities. + * - Helpers for deep copies, linear combinations, products, projections, and orthogonalization. + * - @ref CompFunctionVector : Convenience container for 3D functions with utilities for + * rotations and overlap matrices. + * + * Components are stored as MRCPP @ref FunctionTree "FunctionTree" instances. + * Both real (`double`) and complex (`ComplexDouble`) representations are supported. + * + * Parallel notes: + * - If built with MPI and `is_shared == true`, @ref TreePtr can allocate backing storage + * in an MPI shared-memory window (per @ref mpi::comm_share) to reduce duplication. + * - Distribution utilities (e.g., @ref CompFunctionVector::distribute) use the runtime in + * `mpi_utils.h`. + */ #include "mpi_utils.h" #include "trees/FunctionTreeVector.h" @@ -7,46 +56,83 @@ using namespace Eigen; namespace mrcpp { +/** + * @brief Lightweight, trivially-copiable metadata for a multicomponent function. + * + * This POD accompanies the trees comprising a @ref CompFunction. It holds flags about + * real/complex storage, conjugation, component counts, user-defined labels, and on-disk + * layout hints. Arrays have fixed size (4) to simplify MPI packing and shallow copies. + * + * @tparam D Spatial dimension of the function (1–3 supported by MRCPP). + */ template struct CompFunctionData { - // additional data that describe the overall multicomponent function (defined by user): - // occupancy, quantum number, norm, etc. - int Ncomp{0}; // number of components defined - int rank{-1}; // rank (index) if part of a vector - int conj{0}; // soft conjugate (all components) - int CompFn1{0}; - int CompFn2{0}; - int isreal{0}; // trees are defined for T=double - int iscomplex{0}; // trees are defined for T=DoubleComplex - double CompFd1{0.0}; - double CompFd2{0.0}; - double CompFd3{0.0}; - // additional data that describe each component (defined by user): - // occupancy, quantum number, norm, etc. - // Note: defined with fixed size to ease copying and MPI send - int n1[4]{0, 0, 0, 0}; // 0: neutral. otherwise different values are orthogonal to each other (product = 0) - int n2[4]{0, 0, 0, 0}; - int n3[4]{0, 0, 0, 0}; - int n4[4]{0, 0, 0, 0}; - // multiplicative scalar for the function. So far only actively used to take care of imag factor in momentum operator. + /** @name Global function descriptors (user-defined) */ + ///@{ + int Ncomp{0}; ///< Number of components actually defined/allocated (0–4). + int rank{-1}; ///< Rank (index) inside an external vector or basis set. + int conj{0}; ///< Soft-conjugate flag for algebra (applied to all components). + int CompFn1{0}; ///< Free integer tag (user purpose). + int CompFn2{0}; ///< Free integer tag (user purpose). + int isreal{0}; ///< 1 if component trees are real-valued (`T=double`). + int iscomplex{0}; ///< 1 if component trees are complex-valued (`T=ComplexDouble`). + double CompFd1{0.0};///< Free double tag (user purpose). + double CompFd2{0.0};///< Free double tag (user purpose). + double CompFd3{0.0};///< Free double tag (user purpose). + ///@} + + /** @name Per-component user metadata (fixed-size slots 0..3) */ + ///@{ + int n1[4]{0, 0, 0, 0}; ///< Integer label; unequal labels are treated orthogonal in some workflows. + int n2[4]{0, 0, 0, 0}; ///< Additional integer label (user purpose). + int n3[4]{0, 0, 0, 0}; ///< Additional integer label (user purpose). + int n4[4]{0, 0, 0, 0}; ///< Additional integer label (user purpose). + + /** + * @brief Per-component multiplicative factor. + * + * Often used to carry factors like *i* for momentum-like operators without + * explicitly modifying stored coefficients. + */ ComplexDouble c1[4]{{1.0, 0.0}, {1.0, 0.0}, {1.0, 0.0}, {1.0, 0.0}}; - double d1[4]{0.0, 0.0, 0.0, 0.0}; - double d2[4]{0.0, 0.0, 0.0, 0.0}; - double d3[4]{0.0, 0.0, 0.0, 0.0}; - // used for storage on disk - int type{0}; - int order{1}; - int scale{0}; - int depth{0}; - int boxes[3] = {0, 0, 0}; - int corner[3] = {0, 0, 0}; - - // used internally - int shared{0}; - int Nchunks[4]{0, 0, 0, 0}; // number of chunks of each component tree + + double d1[4]{0.0, 0.0, 0.0, 0.0}; ///< Free double tag (user purpose) per component. + double d2[4]{0.0, 0.0, 0.0, 0.0}; ///< Free double tag (user purpose) per component. + double d3[4]{0.0, 0.0, 0.0, 0.0}; ///< Free double tag (user purpose) per component. + ///@} + + /** @name On-disk/storage layout hints (optional) */ + ///@{ + int type{0}; ///< Serialization type code. + int order{1}; ///< Polynomial order or filter order hint. + int scale{0}; ///< Root scale / global scale offset. + int depth{0}; ///< Max depth. + int boxes[3] = {0, 0, 0}; ///< Root box tiling (D components used). + int corner[3] = {0, 0, 0}; ///< Root spatial corner (D components used). + ///@} + + /** @name Internal runtime fields */ + ///@{ + int shared{0}; ///< 1 if this function uses shared-memory trees. + int Nchunks[4]{0, 0, 0, 0}; ///< Chunk count for each component (used for MPI shipping). + ///@} }; +/** + * @brief Owning pointer wrapper for up to four component trees (real and/or complex). + * + * Optionally allocates per-communicator shared memory windows when constructed + * with @p share = true and MPI shared memory is available (see @ref mpi::comm_share + * and @ref mpi::shared_memory_size). + * + * @tparam D Spatial dimension (1–3). + */ template class TreePtr final { public: + /** + * @brief Construct an empty handle. + * @param share If true and MPI is enabled, create shared-memory windows + * for real and complex storage sized per @ref mpi::shared_memory_size (MB). + */ explicit TreePtr(bool share) : shared_mem_real(nullptr) , shared_mem_cplx(nullptr) { @@ -62,6 +148,7 @@ template class TreePtr final { } } + /// Destructor: frees shared windows and any allocated trees. ~TreePtr() { if (this->shared_mem_real != nullptr) delete this->shared_mem_real; if (this->shared_mem_cplx != nullptr) delete this->shared_mem_cplx; @@ -72,43 +159,104 @@ template class TreePtr final { this->cplx[i] = nullptr; } } - CompFunctionData data; - int &Ncomp = data.Ncomp; // number of components defined - int &rank = data.rank; // rank (index) if part of a vector - int &conj = data.conj; // soft conjugate - int &isreal = data.isreal; // T=double - int &iscomplex = data.iscomplex; // T=DoubleComplex - int &share = data.shared; - int *Nchunks = data.Nchunks; + /** @name Metadata forwarding (aliases into @ref data) */ + ///@{ + CompFunctionData data; ///< Attached function metadata. + int &Ncomp = data.Ncomp; ///< Number of active components. + int &rank = data.rank; ///< External rank/index tag. + int &conj = data.conj; ///< Soft conjugation flag. + int &isreal = data.isreal; ///< Real storage flag. + int &iscomplex = data.iscomplex; ///< Complex storage flag. + int &share = data.shared; ///< Shared-memory flag. + int *Nchunks = data.Nchunks; ///< Per-component chunk counts. + ///@} + + /** True if shared-memory windows were requested/allocated. */ bool is_shared = false; + friend class CompFunction; protected: - FunctionTree *real[4]; // Real function - FunctionTree *cplx[4]; // Complex function + /** Component trees (owned). Slots 0..3 are valid when @ref Ncomp > slot. */ + FunctionTree *real[4]; ///< Real components. + FunctionTree *cplx[4]; ///< Complex components. + + /** Optional backing shared-memory windows (one per value type). */ SharedMemory *shared_mem_real; SharedMemory *shared_mem_cplx; }; +/** + * @brief High-level multicomponent function wrapper on MRCPP trees. + * + * A @ref CompFunction manages up to four component trees, either real or complex, + * and exposes utilities such as allocation, projection, algebraic operations, + * normalization, conjugation, and data shipping. + * + * The class shares its internal state through a `std::shared_ptr>` + * to enable lightweight copies and move semantics, while retaining clear + * ownership of the underlying trees. + * + * @tparam D Spatial dimension (1–3). + */ template class CompFunction { public: + /** + * @name Construction + * Constructors optionally attach an @ref MultiResolutionAnalysis context, + * choose component count, and enable shared memory. + */ + ///@{ + /** @brief Construct empty function bound to @p mra (no components allocated). */ CompFunction(MultiResolutionAnalysis &mra); + /** @brief Construct unbound/empty function (MRA set later via allocation). */ CompFunction(); + /** @brief Construct with @p n1 components (0..4). */ CompFunction(int n1); + /** + * @brief Construct with @p n1 components and shared-memory preference. + * @param n1 Number of components to allocate (0..4). + * @param share If true, try to use MPI shared memory for tree storage. + */ CompFunction(int n1, bool share); + /** + * @brief Construct from metadata @p indata. + * @param indata Initial metadata (copied). + * @param alloc If true, allocate trees according to @p indata.Ncomp. + */ CompFunction(const CompFunctionData &indata, bool alloc = false); + /** @brief Copy constructor: shares underlying pointer (trees may be deep-copied by helpers). */ CompFunction(const CompFunction &compfunc); + /** @brief Move constructor. */ CompFunction(CompFunction &&compfunc); + /** @brief Copy assignment. */ CompFunction &operator=(const CompFunction &compfunc); + ///@} + + /** Virtual destructor. Trees are owned by the shared @ref TreePtr and freed accordingly. */ virtual ~CompFunction() = default; - FunctionTree **CompD; // = func_ptr->real so that we can use name CompD instead of func_ptr.real - FunctionTree **CompC; // = func_ptr->cplx + /** @name Raw component access (compatibility aliases) */ + ///@{ + /** + * @brief Pointer-to-array of real component trees (alias of internal storage). + * @warning Valid only when @ref isreal() is true. + */ + FunctionTree **CompD; + /** + * @brief Pointer-to-array of complex component trees (alias of internal storage). + * @warning Valid only when @ref iscomplex() is true. + */ + FunctionTree **CompC; + ///@} + /** Optional human-readable name. */ std::string name; - // additional data that describe each component (defined by user): + /** @name Metadata accessors */ + ///@{ + /** @brief Return a copy of the current metadata. */ CompFunctionData data() const { return func_ptr->data; } int Ncomp() const { return func_ptr->data.Ncomp; } // number of components defined int rank() const { return func_ptr->data.rank; } // rank (index) if part of a vector @@ -120,86 +268,341 @@ template class CompFunction { void defcomplex() { func_ptr->data.isreal = 0; // define as complex func_ptr->data.iscomplex = 1;} int share() const { return func_ptr->data.shared; } - int *Nchunks() const { return func_ptr->data.Nchunks; } // number of chunks of each component tree + /** @return Per-component chunk counts (used for MPI shipping). */ + int *Nchunks() const { return func_ptr->data.Nchunks; } + + /** + * @brief Copy metadata and optionally allocate components (without copying tree data). + * @param alloc If true, allocate tree containers for the copied component count. + * @return A new @ref CompFunction sharing no nodes/coefficients with the source. + */ CompFunction paramCopy(bool alloc = false) const; + + /** + * @brief Integrate the function over the domain. + * @return Complex integral (real-only functions return real part in `.real()`). + */ ComplexDouble integrate() const; + + /** + * @brief L2 norm of the function. + * @return \f$\|f\|_2\f$ as a double. + */ double norm() const; + + /** + * @brief Square L2 norm of the function. + * @return \f$\|f\|_2^2\f$ as a double. + */ double getSquareNorm() const; + + /** + * @brief Allocate @p nalloc component trees. + * @param nalloc Number of components (0..4). Existing components preserved if possible. + * @param zero If true, initialize coefficients to zero. + */ void alloc(int nalloc = 1, bool zero = true); - void alloc_comp(int i = 0); // allocate one specific component + + /** + * @brief Allocate a single component tree. + * @param i Component index (0..3). + */ + void alloc_comp(int i = 0); + + /** + * @brief Attach an externally created real tree as component @p i. + * @param tree Ownership is transferred to this object. + * @param i Component index (0..3). + */ void setReal(FunctionTree *tree, int i = 0); + + /** + * @brief Attach an externally created complex tree as component @p i. + * @copydetails setReal + */ void setCplx(FunctionTree *tree, int i = 0); + + /** @brief Set/get external rank/index label. */ void setRank(int i) { func_ptr->rank = i; }; const int getRank() const { return func_ptr->rank; }; + + /** + * @brief In-place linear update: @f$f \gets f + c \, g@f$. + * @param c Complex scalar. + * @param inp Addend function (components must be layout-compatible). + */ void add(ComplexDouble c, CompFunction inp); + /** + * @brief Remove coefficients/nodes below precision @p prec. + * @param prec Relative (or absolute) precision threshold. + * @return Number of removed nodes or a non-negative status. + */ int crop(double prec); + + /** + * @brief Multiply the entire function by a complex scalar in-place. + * @param c Complex factor. + */ void rescale(ComplexDouble c); + + /** + * @brief Release all component trees and reset to empty. + * + * Metadata is preserved unless tied to tree content. + */ void free(); + + /** @return Total memory footprint of nodes (bytes or implementation-defined units). */ int getSizeNodes() const; + + /** @return Total number of nodes across all component trees. */ int getNNodes() const; + + /** @brief Flush cached MRA-level data (filters, norms) from component trees. */ void flushMRAData(); + + /** @brief Flush cached function-level data (aux norms, temporaries). */ void flushFuncData(); + + /** @brief Snapshot of the current function metadata (same as @ref data()). */ CompFunctionData getFuncData() const; + + /** @name Component accessors (non-const/const). */ + ///@{ FunctionTree &real(int i = 0); FunctionTree &complex(int i = 0); const FunctionTree &real(int i = 0) const; const FunctionTree &complex(int i = 0) const; + ///@} - // NB: All below should be revised. Now only for backwards compatibility to ComplexFunction class - - void free(int type) { free(); } - bool hasReal() const { return isreal(); } - bool hasImag() const { return iscomplex(); } - bool isShared() const { return share(); } - bool conjugate() const { return conj(); } + /** @name Backwards-compatibility helpers (legacy ComplexFunction interface) */ + ///@{ + void free(int type) { free(); } ///< Ignored @p type; frees all. + bool hasReal() const { return isreal(); } ///< True if real storage is active. + bool hasImag() const { return iscomplex(); } ///< True if complex storage is active. + bool isShared() const { return share(); } ///< True if shared-memory is active. + bool conjugate() const { return conj(); } ///< True if conjugation is requested. + /** @brief Apply Hermitian adjoint (conjugation + operator-specific flips as implemented). */ void dagger(); - FunctionTree &imag(int i = 0); // does not make sense now - const FunctionTree &imag(int i = 0) const; // does not make sense now + /** @brief Imaginary component accessor (legacy; identical to @ref real()). */ + FunctionTree &imag(int i = 0); + /** @brief Const imaginary component accessor (legacy; identical to @ref real()). */ + const FunctionTree &imag(int i = 0) const; + ///@} + + /** @brief Shared state (trees + metadata). */ std::shared_ptr> func_ptr; }; +/** @name Helpers: copying and algebra on @ref CompFunction + * Functions operate componentwise on underlying trees and obey precision controls. + */ +///@{ +/** + * @brief Ensure @p out is complex-valued, copying/embedding a real @p inp if needed. + * @tparam D Dimension. + */ template void CopyToComplex(CompFunction &out, const CompFunction &inp); + +/** @brief Deep-copy @p inp into *@p out (allocate if needed). */ template void deep_copy(CompFunction *out, const CompFunction &inp); +/** @brief Deep-copy @p inp into @p out (allocate if needed). */ template void deep_copy(CompFunction &out, const CompFunction &inp); -template void add(CompFunction &out, ComplexDouble a, CompFunction inp_a, ComplexDouble b, CompFunction inp_b, double prec, bool conjugate = false); -template void linear_combination(CompFunction &out, const std::vector &c, std::vector> &inp, double prec, bool conjugate = false); -template void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, double prec, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); + +/** + * @brief Compute @f$out = a \, inp\_a + b \, inp\_b@f$ with adaptive precision. + * @param prec Target precision controlling refinement/cropping. + * @param conjugate If true, apply soft conjugation to inputs as required. + */ template -void multiply(double prec, CompFunction &out, double coef, CompFunction inp_a, CompFunction inp_b, int maxIter = -1, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); +void add(CompFunction &out, ComplexDouble a, CompFunction inp_a, + ComplexDouble b, CompFunction inp_b, double prec, bool conjugate = false); + +/** + * @brief Linear combination of many inputs: @f$out = \sum_k c_k \, inp_k@f$. + * @param c Coefficients (size must match @p inp). + * @param inp Input functions (modified only for temporary workspace). + * @param prec Target precision. + * @param conjugate Whether to conjugate inputs (soft). + */ +template +void linear_combination(CompFunction &out, const std::vector &c, + std::vector> &inp, double prec, bool conjugate = false); + +/** + * @brief Pointwise product: @f$out = inp\_a \cdot inp\_b@f$ with refinement control. + * @param prec Target precision (relative by default). + * @param absPrec If true, treat @p prec as absolute precision. + * @param useMaxNorms If true, use max norms in error control heuristics. + * @param conjugate If true, apply soft conjugation to first factor. + */ +template +void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, + double prec, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); + +/** + * @brief Scaled product with loop control: @f$out = coef \cdot inp\_a \cdot inp\_b@f$. + * @param maxIter Limit iterative refinement steps (-1 for default). + * @copydetails multiply(CompFunction&,CompFunction,CompFunction,double,bool,bool,bool) + */ +template +void multiply(double prec, CompFunction &out, double coef, + CompFunction inp_a, CompFunction inp_b, int maxIter = -1, + bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); + +/** @brief Density from a (possibly complex) function: @f$out = |inp|^2@f$. */ template void make_density(CompFunction &out, CompFunction inp, double prec); -template void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); -template void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); -template void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); -template void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); -template void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, double prec, int nrefine = 0, bool conjugate = false); + +/** @overload */ +template +void multiply(CompFunction &out, CompFunction inp_a, CompFunction inp_b, + bool absPrec = false, bool useMaxNorms = false, bool conjugate = false); + +/** @brief Multiply by an analytic representable real function @p f. */ +template +void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, + double prec, int nrefine = 0, bool conjugate = false); + +/** @brief Multiply by an analytic representable complex function @p f. */ +template +void multiply(CompFunction &out, CompFunction &inp_a, RepresentableFunction &f, + double prec, int nrefine = 0, bool conjugate = false); + +/** @brief Multiply a single tree by a representable real function. */ +template +void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, + double prec, int nrefine = 0, bool conjugate = false); + +/** @brief Multiply a single tree by a representable complex function. */ +template +void multiply(CompFunction &out, FunctionTree &inp_a, RepresentableFunction &f, + double prec, int nrefine = 0, bool conjugate = false); + +/** + * @brief Complex inner product @f$\langle bra \,|\, ket \rangle@f$. + * @return Complex inner product consistent with MRCPP normalization. + */ template ComplexDouble dot(CompFunction bra, CompFunction ket); + +/** + * @brief Node-wise norm dot helper (diagnostics / preconditioners). + * @return Real value summarizing node-level contributions. + */ template double node_norm_dot(CompFunction bra, CompFunction ket); +///@} + +/** @name Projection helpers (3D overloads and templated D) + * Project analytic functions onto the multiresolution basis. + */ +///@{ +/** + * @brief Project a real-valued lambda/function @p f onto @p out. + * @param prec Target precision. + */ void project(CompFunction<3> &out, std::function &r)> f, double prec); -void project_real(CompFunction<3> &out, std::function &r)> f, double prec); //overload of project is not always recognized by the compiler +/** @brief Real-valued projection (explicit name to avoid overload ambiguities on some compilers). */ +void project_real(CompFunction<3> &out, std::function &r)> f, double prec); +/** @brief Project a complex-valued lambda/function @p f onto @p out. */ void project(CompFunction<3> &out, std::function &r)> f, double prec); -void project_cplx(CompFunction<3> &out, std::function &r)> f, double prec); //overload of project is not always recognized by the compiler +/** @brief Complex-valued projection (explicit name to avoid overload ambiguities). */ +void project_cplx(CompFunction<3> &out, std::function &r)> f, double prec); + +/** @brief Project a representable real function onto @p out. */ template void project(CompFunction &out, RepresentableFunction &f, double prec); +/** @brief Project a representable complex function onto @p out. */ template void project(CompFunction &out, RepresentableFunction &f, double prec); +///@} + +/** + * @brief Orthogonalize @p Ket against @p Bra to precision @p prec (Gram–Schmidt-like). + * @param prec Target precision controlling projection refinement and cropping. + */ template void orthogonalize(double prec, CompFunction &Bra, CompFunction &Ket); +/** + * @brief Convenience container for 3D composite functions with shared MRA. + * + * Provides utilities for distribution and linear-algebra operations across + * the vector (e.g., rotations and overlap matrices). + */ class CompFunctionVector : public std::vector> { public: + /** @brief Construct a vector with @p N default-initialized functions. */ CompFunctionVector(int N = 0); + + /** @brief Common MRA pointer for all entries (optional but recommended). */ MultiResolutionAnalysis<3> *vecMRA; + + /** + * @brief Distribute internal storage across MPI workers (when enabled). + * + * Typically assigns component ownership/ranks and updates metadata so that + * subsequent parallel operations (add/multiply/dot) can proceed efficiently. + */ void distribute(); }; +/** @name Vector-level linear algebra and IO utilities */ +///@{ +/** + * @brief Apply a unitary (or general) complex rotation @p U in-place: @f$\Phi \gets \Phi U@f$. + * @param prec Optional precision for intermediate truncations (-1 to keep current). + */ void rotate(CompFunctionVector &Phi, const ComplexMatrix &U, double prec = -1.0); + +/** + * @brief Apply a rotation @p U producing @p Psi: @f$\Psi \gets \Phi U@f$. + * @param prec Optional precision for intermediate truncations (-1 to keep current). + */ void rotate(CompFunctionVector &Phi, const ComplexMatrix &U, CompFunctionVector &Psi, double prec = -1.0); -void save_nodes(CompFunctionVector &Phi, mrcpp::FunctionTree<3, double> &refTree, BankAccount &account, int sizes = -1); -CompFunctionVector multiply(CompFunctionVector &Phi, RepresentableFunction<3> &f, double prec = -1.0, CompFunction<3> *Func = nullptr, int nrefine = 1, bool all = false); + +/** + * @brief Store per-node coefficient blocks of @p Phi into @p account. + * @param refTree Reference tree defining the union grid/blocking. + * @param sizes Optional fixed block size; -1 to auto-size. + */ +void save_nodes(CompFunctionVector &Phi, mrcpp::FunctionTree<3, double> &refTree, + BankAccount &account, int sizes = -1); + +/** + * @brief Multiply a vector of functions by a representable function @p f. + * @param prec Target precision (-1 to inherit default). + * @param Func Optional workspace function (reused). + * @param nrefine Number of refinement passes (>=0). + * @param all If true, apply to all components; else honor component flags. + * @return Result vector (same size as @p Phi). + */ +CompFunctionVector multiply(CompFunctionVector &Phi, RepresentableFunction<3> &f, + double prec = -1.0, CompFunction<3> *Func = nullptr, + int nrefine = 1, bool all = false); + +/** + * @brief Set a library-global default MRA used by convenience constructors. + * @param MRA Non-owning pointer (caller keeps it alive). + */ void SetdefaultMRA(MultiResolutionAnalysis<3> *MRA); + +/** + * @brief Vectorized inner products: returns @f$\langle Bra_i \,|\, Ket_i \rangle@f$ for all i. + */ ComplexVector dot(CompFunctionVector &Bra, CompFunctionVector &Ket); + +/** @brief Compute the (symmetric) Löwdin overlap matrix @f$S@f$ for @p Phi. */ ComplexMatrix calc_lowdin_matrix(CompFunctionVector &Phi); + +/** @brief Overlap matrix of a single set against itself. */ ComplexMatrix calc_overlap_matrix(CompFunctionVector &BraKet); + +/** @brief Overlap matrix between two sets @p Bra and @p Ket. */ ComplexMatrix calc_overlap_matrix(CompFunctionVector &Bra, CompFunctionVector &Ket); + +/** + * @brief Pairwise orthogonalization of @p Ket against @p Bra to precision @p prec. + * @param prec Precision target (relative unless the implementation states otherwise). + */ void orthogonalize(double prec, CompFunctionVector &Bra, CompFunctionVector &Ket); +///@} -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/Plotter.h b/src/utils/Plotter.h index 9612dedec..b7cd635b8 100644 --- a/src/utils/Plotter.h +++ b/src/utils/Plotter.h @@ -24,6 +24,19 @@ */ #pragma once +/** + * @file + * @brief Plotting utilities for MRCPP functions and trees. + * + * This header declares a lightweight plotting helper that samples + * multivariate functions (or visualizes trees) on simple, equidistant + * grids derived from user-provided span vectors. It supports 1D (line), + * 2D (surface), and 3D (cube) outputs and can also dump tree grids. + * + * The plotting domain is parameterized by an origin @p O and up to three + * span vectors @p A, @p B, @p C (not required to be orthogonal). For an + * overview of the sampling conventions, see @ref mrcpp::Plotter. + */ #include @@ -36,68 +49,256 @@ namespace mrcpp { -/** @class Plotter +/** + * @class Plotter + * @tparam D Spatial dimension of the *function* being sampled (1–3). + * @tparam T Scalar type of the function values (e.g., double, ComplexDouble). * - * @brief Class for plotting multivariate functions + * @brief Sample multivariate functions on equidistant grids and write results. * - * This class will generate an equidistant grid in one (line), two (surf) - * or three (cube) dimensions, and subsequently evaluate the function on - * this grid. + * ### Domain definition + * The sampling region is specified by: + * - Origin **O** + * - Span vectors **A**, **B**, **C** * - * The grid is generated from the vectors A, B and C, relative to the origin O: - * - a linePlot will plot the line spanned by A, starting from O - * - a surfPlot will plot the area spanned by A and B, starting from O - * - a cubePlot will plot the volume spanned by A, B and C, starting from O + * The actual plot type determines how these are used: + * - **Line plot**: points along **A** starting at **O** + * - **Surface plot**: a 2D lattice in the parallelogram spanned by **A** and **B** + * - **Cube plot**: a 3D lattice in the parallelotope spanned by **A**, **B**, **C** * - * The vectors A, B and C do not necessarily have to be orthogonal. + * None of **A**, **B**, **C** need to be orthogonal. * - * The parameter `D` refers to the dimension of the _function_, not the - * dimension of the plot. + * ### Output + * This class writes simple text files (one value per line or a cube-like block) + * suitable for quick inspection or feeding into downstream visualization tools. + * Grid export for trees writes a mesh for node boxes (D=3). * + * @note The template parameter @p D reflects the *intrinsic* dimensionality of + * the function/tree. A 3D function can still be sampled along a 1D line using + * @ref linePlot by providing only **A** (and leaving **B**, **C** unused). */ - template class Plotter { public: + /** + * @brief Construct a plotter with a given origin. + * @param o Plot origin (defaults to the zero vector). + */ explicit Plotter(const Coord &o = {}); virtual ~Plotter() = default; + /** + * @brief Set the filename suffix for a plot type. + * + * @param t Plot type key (see @ref type). + * @param s Suffix including the dot (e.g., ".line", ".surf"). + * + * @details The suffix is appended to the base filename passed to the + * plotting routines. Defaults are set in the constructor. + */ void setSuffix(int t, const std::string &s); + + /** + * @brief Set the plot origin. + * @param o New origin **O**. + */ void setOrigin(const Coord &o); + + /** + * @brief Define (or update) the plot span vectors. + * + * @param a Vector **A** (required). + * @param b Vector **B** (optional; used for 2D/3D sampling). + * @param c Vector **C** (optional; used for 3D sampling). + * + * @note Vectors are not required to be orthogonal. The number of points + * per span is given at call time for each plot type. + */ void setRange(const Coord &a, const Coord &b = {}, const Coord &c = {}); + /** + * @brief Write a grid visualization of a function tree. + * + * @param tree Multiresolution tree to visualize. + * @param fname Base filename (suffix for @ref Grid is appended). + * + * @details Exports the end-node grid (and roots) of @p tree. The concrete + * output is implementation-dependent; for D=3 it is a geomview-friendly + * mesh (see the .grid writer in the implementation). + * + * @warning Meaningful only when the implementation supports the given @p D. + */ void gridPlot(const MWTree &tree, const std::string &fname); + + /** + * @brief Sample a function along a line @f$ O + s\,A @f$. + * + * @param npts Number of equidistant points along **A**; use @c {N}. + * @param func Function to evaluate. + * @param fname Base filename (suffix for @ref Line is appended). + * + * @details Generates @c npts[0] positions: + * @f$ r_i = O + \frac{i}{N-1} A,\ i=0,\dots,N-1 @f$ + * and writes coordinates and values in text form. + * + * @pre @ref setRange must have set a non-zero **A**; otherwise this call + * will fail validation. + */ void linePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + + /** + * @brief Sample a function on a surface spanned by **A**, **B**. + * + * @param npts Number of points along {**A**, **B**}; use @c {Na, Nb}. + * @param func Function to evaluate. + * @param fname Base filename (suffix for @ref Surface is appended). + * + * @details Generates positions + * @f$ r_{ij} = O + \frac{i}{N_a-1}A + \frac{j}{N_b-1}B @f$ and writes + * coordinates and values in text form. + * + * @pre @ref setRange must have set non-zero **A** and **B** when used in 2D/3D. + */ void surfPlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); + + /** + * @brief Sample a function in a 3D block spanned by **A**, **B**, **C**. + * + * @param npts Number of points along {**A**, **B**, **C**}; use @c {Na, Nb, Nc}. + * @param func Function to evaluate. + * @param fname Base filename (suffix for @ref Cube is appended). + * + * @details Generates positions + * @f$ r_{ijk} = O + \frac{i}{N_a-1}A + \frac{j}{N_b-1}B + \frac{k}{N_c-1}C @f$ + * and writes values in a simple cube-like format suitable for volumetric viewers. + * + * @pre @ref setRange must have set non-zero **A**, **B**, **C** when used in 3D. + */ void cubePlot(const std::array &npts, const RepresentableFunction &func, const std::string &fname); - enum type { Line, Surface, Cube, Grid }; + /** + * @brief Plot type selector used for file suffix mapping. + */ + enum type { Line, /**< 1D sampling along **A** */ + Surface, /**< 2D sampling on **A**–**B** lattice */ + Cube, /**< 3D sampling on **A**–**B**–**C** lattice */ + Grid /**< Grid/mesh export for trees */ + }; protected: - Coord O{}; // Plot origin - Coord A{}; // Vector for line plot - Coord B{}; // Vector for surf plot - Coord C{}; // Vector for cube plot - std::ofstream fstrm{}; - std::ofstream *fout{nullptr}; - std::map suffix{}; + /** @name Plot domain and output state */ + ///@{ + Coord O{}; ///< Plot origin. + Coord A{}; ///< Span vector for line plots and first lattice axis. + Coord B{}; ///< Span vector for surface/cube plots (second axis). + Coord C{}; ///< Span vector for cube plots (third axis). + std::ofstream fstrm{}; ///< Owned output stream storage. + std::ofstream *fout{nullptr}; ///< Active output stream (points to @ref fstrm). + std::map suffix{}; ///< Per-type filename suffix map. + ///@} + /** + * @brief Compute step size to place @p pts samples along a span. + * @param vec Span vector (**A**, **B**, or **C**). + * @param pts Number of points along the span (>= 1). + * @return Component-wise step equals @f$ \frac{\text{vec}}{\max(1, pts-1)} @f$. + * + * @note When @p pts == 1 the single sample is placed at the origin offset, + * and the step is unused (implementation guards against division by zero). + */ Coord calcStep(const Coord &vec, int pts) const; + + /** + * @brief Generate equidistant coordinates for a line plot. + * @param pts_a Points along **A**. + * @return Matrix of size (pts_a × D) with row-wise coordinates. + */ Eigen::MatrixXd calcLineCoordinates(int pts_a) const; + + /** + * @brief Generate equidistant coordinates for a surface plot. + * @param pts_a Points along **A**. + * @param pts_b Points along **B**. + * @return Matrix of size ((pts_a*pts_b) × D) with row-wise coordinates. + */ Eigen::MatrixXd calcSurfCoordinates(int pts_a, int pts_b) const; + + /** + * @brief Generate equidistant coordinates for a cube plot. + * @param pts_a Points along **A**. + * @param pts_b Points along **B**. + * @param pts_c Points along **C**. + * @return Matrix of size ((pts_a*pts_b*pts_c) × D) with row-wise coordinates. + */ Eigen::MatrixXd calcCubeCoordinates(int pts_a, int pts_b, int pts_c) const; + /** + * @brief Evaluate a representable function on given coordinates. + * @param func Function to sample. + * @param coords Row-major matrix of coordinates (N × D). + * @return Column vector of values (size N). + * + * @note The implementation may use parallel evaluation (e.g., OpenMP) + * when available at build time. + */ Eigen::Matrix evaluateFunction(const RepresentableFunction &func, const Eigen::MatrixXd &coords) const; + /** + * @brief Write coordinates and values as text rows. + * @param coords Row-major coordinates (N × D). + * @param values Column vector (size N). + * + * @details Each output line contains D coordinates followed by the + * function value. Floating-point formatting is implementation-defined. + */ void writeData(const Eigen::MatrixXd &coords, const Eigen::Matrix &values); + + /** + * @brief Write volumetric (cube) data. + * @param npts Lattice sizes {Na, Nb, Nc}. + * @param values Column vector of length Na*Nb*Nc. + * + * @details Default implementation may be a stub; specialized versions + * (e.g., D=3) provide actual volume exporters (cube/voxel formats). + */ virtual void writeCube(const std::array &npts, const Eigen::Matrix &values); + /** + * @brief Write a grid/mesh representation of a tree. + * @param tree Tree whose node boxes should be exported. + * + * @details Implementation targets D=3 (geomview-like mesh). Other + * dimensionalities may provide no-ops or alternative encodings. + */ void writeGrid(const MWTree &tree); + + /** + * @brief Emit a single node's box edges/faces to the active stream. + * @param node Node to visualize. + * @param color Renderer-dependent color string (implementation-defined). + */ void writeNodeGrid(const MWNode &node, const std::string &color); private: + /** + * @brief Validate that required span vectors are non-zero. + * @param dim Required plot dimensionality (1, 2, or 3). + * @return @c true if all needed spans (**A**, **B**, **C**) are non-zero. + */ bool verifyRange(int dim) const; + + /** + * @brief Open/prepare the output file stream. + * @param fname Base filename plus suffix (if non-empty). + * + * @details If @p fname is empty, reuses the current stream; otherwise + * closes any previous stream and opens the new one. + */ void openPlot(const std::string &fname); + + /** + * @brief Close the output stream if open and reset state. + */ void closePlot(); }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/Printer.h b/src/utils/Printer.h index c021155e8..32f5f63a1 100644 --- a/src/utils/Printer.h +++ b/src/utils/Printer.h @@ -23,12 +23,30 @@ * */ -/** \file Printer.h - * Collection of assertions and a standard error/warn/info/debug - * message interface. +#pragma once +/** + * @file + * @brief Lightweight, level-based printing, diagnostics, and assertion helpers. + * + * This header provides: + * - A singleton-style @ref mrcpp::Printer that controls global print level, + * numeric formatting (scientific vs. fixed), width and precision, and the + * active output stream (stdout or a per-rank file). + * - A small set of convenience functions in @ref mrcpp::print for consistent, + * nicely formatted environment, headers/footers, timers, memory usage, and + * tree statistics. + * - A family of macros (e.g. @ref println, @ref MSG_ABORT) to emit messages + * with source context and optional termination semantics. + * + * ### Print-level convention + * Every printing API takes an integer *level*. Output is produced iff + * `level <= Printer::getPrintLevel()`. Internal MRCPP prints use \>= 10, + * leaving levels 0–9 available for host/user code control. * + * @warning The facilities herein are process-local; when used in MPI programs, + * each rank will emit messages independently unless explicitly gated + * by rank logic in the caller. */ -#pragma once #include #include @@ -41,33 +59,56 @@ namespace mrcpp { class Timer; template class MWTree; -/** @class Printer +/** + * @class Printer + * @brief Process-local controller for formatted, level-gated output. * - * @brief Convenience class to handle printed output + * @details + * The Printer class is used in a singleton-like fashion via static methods. + * Call init once near program start (optionally per MPI rank) to set: + * - global print level (messages with a higher level are suppressed), + * - rank/size metadata (used to route output), + * - destination stream (stdout or a per-rank file). * - * @details The ``Printer`` singleton class holds the current state of the print - * environment. All ``mrcpp::print`` functions, as well as the ``println`` and - * ``printout`` macros, take an integer print level as first argument. When the - * global ``mrcpp::Printer`` is initialized with a given print level, only print - * statements with a *lower* print level will be displayed. All internal printing - * in MRCPP is at print level 10 or higher, so there is some flexibility left - * (levels 0 through 9) for adjusting the print volume within the host program. + * After initialization, helper functions/macros (see mrcpp::print and the + * macros below) consult the configured level and stream. * + * @par Example + * @code{.cpp} + * // Rank 0 prints to screen, others remain silent: + * Printer::init(5, rank, size); + * println(2, "Hello at level 2"); // shown when level >= 2 + * println(8, "Debug details..."); // suppressed when level < 8 + * @endcode */ - class Printer final { public: + /** + * @brief Initialize printing environment. + * + * @param level Maximum verbosity to emit (inclusive). + * @param rank MPI rank of this process (default 0). + * @param size MPI world size (default 1). + * @param file Optional base filename. If provided and @p size>1, + * output is written to "-.out"; otherwise to + * ".out". If null, output goes to stdout. When + * @p file is null and @p rank>0, printing is disabled + * by setting the print level to -1. + * + * @note Also sets scientific notation as the default numeric format. + */ static void init(int level = 0, int rank = 0, int size = 1, const char *file = nullptr); - /** @brief Use scientific floating point notation, e.g. 1.0e-2 */ + /** @brief Use scientific floating-point notation (e.g., 1.23e-2). */ static void setScientific() { *out << std::scientific; } - /** @brief Use fixed floating point notation, e.g. 0.01 */ + /** @brief Use fixed floating-point notation (e.g., 0.0123). */ static void setFixed() { *out << std::fixed; } - /** @brief Set new line width for printed output - * @param[in] i: New width (number of characters) - * @returns Old width (number of characters) + /** + * @brief Set line width for formatted helpers. + * @param i New width in characters. + * @return Previous width. */ static int setWidth(int i) { int oldWidth = printWidth; @@ -75,9 +116,10 @@ class Printer final { return oldWidth; } - /** @brief Set new precision for floating point output - * @param[in] i: New precision (digits after comma) - * @returns Old precision (digits after comma) + /** + * @brief Set precision for floating-point output. + * @param i Digits after the decimal point. + * @return Previous precision. */ static int setPrecision(int i) { int oldPrec = printPrec; @@ -86,9 +128,10 @@ class Printer final { return oldPrec; } - /** @brief Set new print level - * @param[in] i: New print level - * @returns Old print level + /** + * @brief Set the global print level threshold. + * @param i New level; only messages with level <= @p i are printed. + * @return Previous print level. */ static int setPrintLevel(int i) { int oldLevel = printLevel; @@ -96,116 +139,251 @@ class Printer final { return oldLevel; } - /** @returns Current line width (number of characters) */ + /** @return Current line width (characters). */ static int getWidth() { return printWidth; } - /** @returns Current precision for floating point output (digits after comma) */ + /** @return Current floating-point precision (digits after decimal). */ static int getPrecision() { return printPrec; } - /** @returns Current print level */ + /** @return Current global print level threshold. */ static int getPrintLevel() { return printLevel; } + /** + * @brief Active output stream (stdout or a file). + * @warning Pointer is owned externally; @ref init manages it appropriately. + */ static std::ostream *out; private: - static int printWidth; - static int printLevel; - static int printPrec; - static int printRank; - static int printSize; + static int printWidth; ///< Line width used by @ref mrcpp::print helpers. + static int printLevel; ///< Global verbosity threshold. + static int printPrec; ///< Floating-point precision (digits). + static int printRank; ///< MPI rank (for routing decisions). + static int printSize; ///< MPI world size. - Printer() = delete; // No instances of this class + Printer() = delete; ///< Non-instantiable utility. + /// @brief Redirect all output to @p o (used internally by @ref init). static void setOutputStream(std::ostream &o) { out = &o; } }; +/** + * @namespace mrcpp::print + * @brief Nicely formatted, level-aware printing helpers. + * + * These helpers produce standardized, aligned, and labeled output for common + * diagnostics: environment summaries, section headers/footers, timers, memory + * usage, and tree statistics. Each function is level-gated via + * @ref Printer::getPrintLevel(). + */ namespace print { + +/** + * @brief Print MRCPP and build environment information. + * @param level Activation level threshold. + * + * @details Includes library version, Git metadata, linear algebra backend, + * and parallelization mode (MPI/OpenMP). + */ void environment(int level); + +/** + * @brief Print a full separator line composed of @p c characters. + * @param level Activation level. + * @param c Filler character (e.g., '-', '='). + * @param newlines Number of extra trailing blank lines (default 0). + */ void separator(int level, const char &c, int newlines = 0); + +/** + * @brief Print a centered header with a framed title. + * @param level Activation level. + * @param txt Header text. + * @param newlines Extra trailing blank lines (default 0). + * @param c Filler character for the frame (default '='). + */ void header(int level, const std::string &txt, int newlines = 0, const char &c = '='); + +/** + * @brief Print a footer containing elapsed wall time and a closing frame. + * @param level Activation level. + * @param timer Timer whose @ref Timer::elapsed is shown. + * @param newlines Extra trailing blank lines (default 0). + * @param c Filler character for the closing line (default '='). + */ void footer(int level, const Timer &timer, int newlines = 0, const char &c = '='); + +/** + * @brief Print current process memory usage. + * @param level Activation level. + * @param txt Label to show before the value (aligned). + */ void memory(int level, const std::string &txt); + +/** + * @brief Print a labeled scalar with unit in aligned columns. + * @param level Activation level. + * @param txt Label. + * @param v Value. + * @param unit Unit string (optional). + * @param p Precision; if negative uses @ref Printer::getPrecision. + * @param sci Scientific formatting when true, fixed when false. + */ void value(int level, const std::string &txt, double v, const std::string &unit = "", int p = -1, bool sci = true); + +/** + * @brief Print an elapsed time value from a @ref Timer. + * @param level Activation level. + * @param txt Label. + * @param timer Timer whose @ref Timer::elapsed is shown. + */ void time(int level, const std::string &txt, const Timer &timer); + +/** + * @brief Print tree statistics (nodes, memory, wall time). + * @param level Activation level. + * @param txt Label/section title. + * @param n Number of nodes. + * @param m Memory usage in kB. + * @param t Elapsed wall time in seconds. + */ void tree(int level, const std::string &txt, int n, int m, double t); -template void tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer); + +/** + * @brief Print tree statistics extracted from an @ref MWTree and a @ref Timer. + * @tparam D Spatial dimension. + * @tparam T Coefficient scalar type. + * @param level Activation level. + * @param txt Label/section title. + * @param tree Tree whose node/memory info is reported. + * @param timer Timer whose elapsed time is reported. + */ +template +void tree(int level, const std::string &txt, const MWTree &tree, const Timer &timer); + } // namespace print -// clang-format off +// ============================================================================ +// Macros : level-aware printing and diagnostics +// ============================================================================ -/** @brief Print text at the given print level, with newline */ -#define println(level, STR) \ +/** + * @def println(level, STR) + * @brief Print text followed by newline if @p level is enabled. + */ +#define println(level, STR) \ { if (level <= mrcpp::Printer::getPrintLevel()) *mrcpp::Printer::out << STR << std::endl; } -/** @brief Print text at the given print level, without newline */ -#define printout(level, STR) \ +/** + * @def printout(level, STR) + * @brief Print text without newline if @p level is enabled. + */ +#define printout(level, STR) \ { if (level <= mrcpp::Printer::getPrintLevel()) *mrcpp::Printer::out << STR; } -/** @brief Print info message */ -#define MSG_INFO(STR) \ - { \ - *mrcpp::Printer::out << "Info: " << __FILE__ << ": " << __func__ << "(), line " << __LINE__ << ": " << STR << std::endl; \ +/** + * @def MSG_INFO(STR) + * @brief Emit an informational message with source location. + */ +#define MSG_INFO(STR) \ + { \ + *mrcpp::Printer::out << "Info: " << __FILE__ << ": " << __func__ \ + << "(), line " << __LINE__ << ": " << STR << std::endl; \ } -/** @brief Print warning message */ -#define MSG_WARN(STR) \ - { \ - *mrcpp::Printer::out << "Warning: " << __func__ << "(), line " << __LINE__ << ": " << STR << std::endl; \ +/** + * @def MSG_WARN(STR) + * @brief Emit a warning message with source location. + */ +#define MSG_WARN(STR) \ + { \ + *mrcpp::Printer::out << "Warning: " << __func__ << "(), line " << __LINE__ \ + << ": " << STR << std::endl; \ } -/** @brief Print error message, no abort*/ -#define MSG_ERROR(STR) \ - { \ - *mrcpp::Printer::out << "Error: " << __func__ << "(), line " << __LINE__ << ": " << STR << std::endl; \ +/** + * @def MSG_ERROR(STR) + * @brief Emit a non-fatal error message with source location. + */ +#define MSG_ERROR(STR) \ + { \ + *mrcpp::Printer::out << "Error: " << __func__ << "(), line " << __LINE__ \ + << ": " << STR << std::endl; \ } -/** @brief Print error message and abort */ -#define MSG_ABORT(STR) \ - { \ - *mrcpp::Printer::out << "Error: " << __FILE__ << ": " << __func__ << "(), line " << __LINE__ << ": " << STR << std::endl; \ - abort(); \ +/** + * @def MSG_ABORT(STR) + * @brief Emit an error message with source location and abort the process. + */ +#define MSG_ABORT(STR) \ + { \ + *mrcpp::Printer::out << "Error: " << __FILE__ << ": " << __func__ \ + << "(), line " << __LINE__ << ": " << STR << std::endl; \ + abort(); \ } -/** @brief You have passed an invalid argument to a function */ -#define INVALID_ARG_ABORT \ - { \ - *mrcpp::Printer::out << "Error, invalid argument passed: " << __func__ << "(), line " << __LINE__ << std::endl; \ - abort(); \ +/** + * @def INVALID_ARG_ABORT + * @brief Abort with a standardized message for invalid arguments. + */ +#define INVALID_ARG_ABORT \ + { \ + *mrcpp::Printer::out << "Error, invalid argument passed: " << __func__ \ + << "(), line " << __LINE__ << std::endl; \ + abort(); \ } -/** @brief You have reached a point in the code that is not yet implemented */ -#define NOT_IMPLEMENTED_ABORT \ - { \ - *mrcpp::Printer::out << "Error: Not implemented, " << __FILE__ ", " << __func__ << "(), line " << __LINE__ << std::endl; \ - abort(); \ +/** + * @def NOT_IMPLEMENTED_ABORT + * @brief Abort with a standardized message for unimplemented code paths. + */ +#define NOT_IMPLEMENTED_ABORT \ + { \ + *mrcpp::Printer::out << "Error: Not implemented, " << __FILE__ << ", " << __func__ \ + << "(), line " << __LINE__ << std::endl; \ + abort(); \ } -/** @brief You have reached a point that should not be reached, bug or inconsistency */ -#define NOT_REACHED_ABORT \ - { \ - *mrcpp::Printer::out << "Error, should not be reached: " << __func__ << "(), line " << __LINE__ << std::endl; \ - abort(); \ +/** + * @def NOT_REACHED_ABORT + * @brief Abort for code paths that should be logically unreachable. + */ +#define NOT_REACHED_ABORT \ + { \ + *mrcpp::Printer::out << "Error, should not be reached: " << __func__ \ + << "(), line " << __LINE__ << std::endl; \ + abort(); \ } -/** @brief You have reached an experimental part of the code, results cannot be trusted */ -#define NEEDS_TESTING \ - { \ - static bool __once = true; \ - if (__once) { \ - __once = false; \ - *mrcpp::Printer::out << "NEEDS TESTING: " << __FILE__ << ", " << __func__ << "(), line " << __LINE__ << std::endl; \ - } \ +/** + * @def NEEDS_TESTING + * @brief Emit a one-time notice that a code path is experimental. + * + * Prints exactly once per process at the first hit, then stays quiet. + */ +#define NEEDS_TESTING \ + { \ + static bool __once = true; \ + if (__once) { \ + __once = false; \ + *mrcpp::Printer::out << "NEEDS TESTING: " << __FILE__ << ", " << __func__ \ + << "(), line " << __LINE__ << std::endl; \ + } \ } -/** @brief You have hit a known bug that is yet to be fixed, results cannot be trusted */ -#define NEEDS_FIX(STR) \ - { \ - static bool __once = true; \ - if (__once) { \ - __once = false; \ - *mrcpp::Printer::out << "NEEDS FIX: " << __FILE__ << ", " << __func__ << "(), line " << __LINE__ << ": " << STR << std::endl; \ \ - } \ +/** + * @def NEEDS_FIX(STR) + * @brief Emit a one-time notice that a known issue affects this code path. + * @param STR Short description of the known issue. + */ +#define NEEDS_FIX(STR) \ + { \ + static bool __once = true; \ + if (__once) { \ + __once = false; \ + *mrcpp::Printer::out << "NEEDS FIX: " << __FILE__ << ", " << __func__ \ + << "(), line " << __LINE__ << ": " << STR << std::endl; \ + } \ } -// clang-format on -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/Timer.h b/src/utils/Timer.h index 4c02201a9..091812eca 100644 --- a/src/utils/Timer.h +++ b/src/utils/Timer.h @@ -29,33 +29,151 @@ namespace mrcpp { +/** + * @file + * @brief Wall-clock timing utilities. + */ + +/** + * @typedef timeT + * @brief Timestamp type used by Timer. + * + * @details + * Alias for `std::chrono::time_point`. + * The actual clock may map to a platform-specific high-resolution source. + * It typically offers sub-microsecond resolution but is not guaranteed to be + * steady on all standard libraries (i.e., it may jump if the underlying clock + * is adjusted). The @ref Timer class measures *wall* time, not CPU time. + */ using timeT = std::chrono::time_point; -/** @class Timer +/** + * @class Timer + * @brief Lightweight wall-time stopwatch with start/resume/stop semantics. + * + * @details + * `Timer` accumulates elapsed *wall* time across one or more running intervals. + * A newly constructed timer can optionally start immediately. + * + * ### State machine + * - **Stopped** (default when constructed with `start_timer == false`): + * - `elapsed()` returns the accumulated time. + * - `resume()` starts a new interval without clearing accumulated time. + * - `start()` clears accumulated time and starts fresh from zero. + * - **Running** (default when constructed with `start_timer == true`): + * - `elapsed()` returns the live time since the most recent start/resume, + * ignoring previously accumulated time until `stop()` is called. + * - `stop()` ends the current interval and adds it to the accumulation. + * + * ### Characteristics + * - Measures wall time (affected by system sleep/suspend). + * - Very low overhead; suitable for inner-loop timing in most cases. + * - Not thread-safe: do not share a single instance across threads without + * external synchronization. + * + * ### Example + * @code{.cpp} + * mrcpp::Timer t; // starts immediately by default + * // ... code section A ... + * t.stop(); + * double a = t.elapsed(); // seconds for section A * - * @brief Records wall time between the execution of two lines of source code + * t.resume(); + * // ... code section B ... + * t.stop(); + * double total = t.elapsed(); // seconds for A + B * + * t.start(); // reset and start from zero + * // ... code section C ... + * double live = t.elapsed(); // live time while running + * @endcode */ - class Timer final { public: + /** + * @brief Construct a timer. + * @param start_timer If true, the timer is started immediately with + * accumulated time cleared to zero. + * + * @note Default is `true` for convenience; pass `false` to construct + * a stopped timer and control the first start explicitly. + */ Timer(bool start_timer = true); + + /** + * @brief Copy constructor. + * @details Copies the running state, accumulated time, and the last start + * timestamp. If the source is running, the copy will also be + * running and will measure from the same start instant. + */ Timer(const Timer &timer); + + /** + * @brief Copy assignment. + * @details Assigns running state, accumulated time, and start timestamp. + * Self-assignment is a no-op. + * @return Reference to `*this`. + */ Timer &operator=(const Timer &timer); + /** + * @brief Start from zero. + * @details Resets the accumulated time to 0 and begins a new running + * interval starting "now". Use this to time a fresh region. + */ void start(); + + /** + * @brief Resume without clearing accumulated time. + * @details If the timer is stopped, begins a new running interval starting + * "now". If already running, the call has no effect besides + * potentially issuing a diagnostic in the implementation. + */ void resume(); + + /** + * @brief Stop and accumulate. + * @details Ends the current running interval and adds its duration to the + * accumulated time. If already stopped, the call has no effect + * besides potentially issuing a diagnostic in the implementation. + */ void stop(); + /** + * @brief Get elapsed time in seconds. + * @details + * - If the timer is **running**, returns the time since the most recent + * `start()` or `resume()` (not including previously accumulated time). + * - If the timer is **stopped**, returns the total accumulated time across + * all completed intervals since the last `start()`. + * + * @return Elapsed wall time in seconds as a `double`. + */ double elapsed() const; private: + /// @brief True if the timer is currently running. bool running{false}; + + /// @brief Accumulated time in seconds across completed intervals. double time_used{0.0}; + + /// @brief Timestamp when the current interval was started/resumed. timeT clock_start; + /** + * @brief Current timestamp helper. + * @return `high_resolution_clock::now()`. + */ timeT now() const; + + /** + * @brief Difference between two timestamps. + * @param t2 Later timestamp. + * @param t1 Earlier timestamp. + * @return `(t2 - t1)` expressed in seconds as a `double`. + */ double diffTime(timeT t2, timeT t1) const; }; -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/details.h b/src/utils/details.h index c88c16196..cb6e0b600 100644 --- a/src/utils/details.h +++ b/src/utils/details.h @@ -24,6 +24,18 @@ */ #pragma once +/** + * @file + * @brief Small cross-cutting utilities and helpers for MRCPP internals. + * + * This header declares: + * - Filesystem and environment helpers (e.g., locating MW filter folders). + * - Lightweight process information (Linux memory usage). + * - Tiny array algorithms (equality/any checks, C-array → std::array conversion). + * - Generic collection pretty-printer and a `std::array` stream operator. + * + * Most functions live in the `mrcpp::details` namespace to signal internal use. + */ #include #include @@ -31,16 +43,89 @@ #include namespace mrcpp { +/** + * @namespace mrcpp::details + * @brief Internal utilities; APIs may change without notice. + */ namespace details { + +/** + * @brief Check whether a path refers to an existing directory. + * @param path Path to check. + * @return `true` if the path exists and is a directory, otherwise `false`. + * @note Implementation typically uses `stat(2)` and is therefore + * POSIX-oriented. + */ bool directory_exists(std::string path); + +/** + * @brief Locate the directory containing multiresolution filter files. + * + * The search strategy prefers an explicit environment override and then + * compiled-in locations. + * + * @return Absolute/relative path to a directory with filter files. + * @throws (implementation-defined) if no suitable directory is found. + * + * @details + * The implementation checks (in order): + * 1. Environment variable `MWFILTERS_DIR`, if set and points to a directory. + * 2. Compiled-in source/install search paths (e.g., `mwfilters_source_dir()`, + * `mwfilters_install_dir()`). + */ std::string find_filters(); + +/** + * @brief Return the current process memory usage in kilobytes. + * @return Resident (or data+stack) usage in kB, or a negative value on error. + * @note Implemented via `/proc/self/statm`; available on Linux only. + */ int get_memory_usage(); -template bool are_all_equal(const std::array &exponent); -template bool are_any(const std::array &col, const T eq) { + +/** + * @brief Check if all elements of a fixed-size array of doubles are equal. + * @tparam D Array length. + * @param exponent Input array. + * @return `true` if all elements compare equal to the first element; otherwise `false`. + * @warning Equality is tested with `==` (no tolerance). + */ +template +bool are_all_equal(const std::array &exponent); + +/** + * @brief Test whether any element of an array equals a given value. + * @tparam T Element type (must be equality comparable). + * @tparam D Array length (compile-time). + * @param col Array to scan. + * @param eq Value to compare against. + * @return `true` if at least one element satisfies `element == eq`. + * @complexity O(D). + */ +template +bool are_any(const std::array &col, const T eq) { return std::any_of(col.cbegin(), col.cend(), [eq](const T &el) { return el == eq; }); }; -template std::array convert_to_std_array(T *arr); -template auto stream_collection(const T &coll) -> std::string { + +/** + * @brief Convert a C-style pointer to a fixed-size `std::array`. + * @tparam T Element type. + * @tparam D Number of elements to copy. + * @param arr Pointer to at least `D` contiguous elements of type `T`. + * @return `std::array` with a shallow copy of the `D` elements. + * @warning Caller is responsible for ensuring `arr` has at least `D` valid elements. + */ +template +std::array convert_to_std_array(T *arr); + +/** + * @brief Render a collection to a compact bracketed string. + * @tparam T A range/collection with range-for iteration and streamable elements. + * @param coll Collection to print. + * @return String like `"[e0, e1, ...]"`. + * @note This utility underpins the `operator<<` overload for `std::array`. + */ +template +auto stream_collection(const T &coll) -> std::string { std::ostringstream os; bool first = true; os << "["; @@ -52,9 +137,21 @@ template auto stream_collection(const T &coll) -> std::string { os << "]"; return os.str(); } + } // namespace details -template auto operator<<(std::ostream &os, const std::array &coll) -> std::ostream & { +/** + * @brief Stream insertion for `std::array`, producing a compact bracketed list. + * @tparam T Element type (must be stream-insertable). + * @tparam D Array length. + * @param os Output stream. + * @param coll Array to print. + * @return Reference to @p os. + * @sa mrcpp::details::stream_collection + */ +template +auto operator<<(std::ostream &os, const std::array &coll) -> std::ostream & { return (os << details::stream_collection(coll)); } -} // namespace mrcpp + +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/math_utils.h b/src/utils/math_utils.h index 3eacfa10b..38cbf7203 100644 --- a/src/utils/math_utils.h +++ b/src/utils/math_utils.h @@ -29,59 +29,263 @@ */ #pragma once +/** + * @file + * @brief Linear algebra and small numerical helpers built on top of Eigen. + * + * This header exposes a compact set of utilities frequently used across MRCPP: + * - Common Eigen-based type aliases (`Vector`, `Matrix`, complex scalars). + * - Small combinatorial helpers (factorial, binomial, integer powers). + * - Tensor products and self–outer-products. + * - Matrix norms (1, 2, and ∞). + * - Multi-index/tensor expansion helpers for separable bases. + * - Hermitian eigendecompositions and block diagonalization. + * - Cartesian products of small sets/vectors. + * - Euclidean distance for MRCPP coordinates. + */ #include #include #include "MRCPP/mrcpp_declarations.h" -using IntVector = Eigen::VectorXi; -using DoubleVector = Eigen::VectorXd; -using ComplexVector = Eigen::VectorXcd; +/** @name Eigen type aliases + * @brief Short, explicit aliases for Eigen vectors/matrices used in MRCPP. + * @{ */ +using IntVector = Eigen::VectorXi; ///< Column vector of integers. +using DoubleVector = Eigen::VectorXd; ///< Column vector of doubles. +using ComplexVector= Eigen::VectorXcd; ///< Column vector of complex doubles. -using IntMatrix = Eigen::MatrixXi; -using DoubleMatrix = Eigen::MatrixXd; -using ComplexMatrix = Eigen::MatrixXcd; +using IntMatrix = Eigen::MatrixXi; ///< Integer matrix. +using DoubleMatrix = Eigen::MatrixXd; ///< Double-precision matrix. +using ComplexMatrix= Eigen::MatrixXcd; ///< Complex double-precision matrix. -using ComplexDouble = std::complex; +using ComplexDouble= std::complex; ///< Convenience alias for complex. +/** @} */ namespace mrcpp { +/** + * @namespace mrcpp::math_utils + * @brief Numerical utilities layered on Eigen; header-only declarations. + */ namespace math_utils { +/** + * @brief Binomial coefficient \f$\binom{n}{j}\f$. + * @param n Non-negative integer. + * @param j Non-negative integer, \f$0 \le j \le n\f$. + * @return \f$\frac{n!}{(n-j)!\,j!}\f$ as a double. + * @note For out-of-domain inputs, the implementation may log an error. + */ double binomial_coeff(int n, int j); + +/** + * @brief Pascal row of binomial coefficients. + * @param order Row index \f$n\f$. + * @return Vector \f$[\binom{n}{0}, \ldots, \binom{n}{n}]\f$. + */ Eigen::VectorXd get_binomial_coefs(unsigned int order); +/** + * @brief Factorial for non-negative integers. + * @param n \f$n \ge 0\f$. + * @return \f$n!\f$ as a double. + */ double factorial(int n); + +/** + * @brief Integer power \f$m^e\f$ for \f$e \ge 0\f$ (loop-based; exact for small ranges). + * @param m Base (integer). + * @param e Exponent (integer, \f$e \ge 0\f$). + * @return \f$m^e\f$ as an int. + */ int ipow(int m, int e); +/** @name Tensor/Kronecker products + * @brief Kronecker products and outer products for vectors/matrices. + * @{ + */ + +/** + * @brief Kronecker product \f$A \otimes B\f$. + */ Eigen::MatrixXd tensor_product(const Eigen::MatrixXd &A, const Eigen::MatrixXd &B); + +/** + * @brief Kronecker product of a matrix and a column vector (treated as \f$B\f$). + */ Eigen::MatrixXd tensor_product(const Eigen::MatrixXd &A, const Eigen::VectorXd &B); + +/** + * @brief Kronecker product of a column vector and a matrix (treated as \f$B\f$). + */ Eigen::MatrixXd tensor_product(const Eigen::VectorXd &A, const Eigen::MatrixXd &B); + +/** + * @brief Outer product \f$A B^\top\f$ of two column vectors. + */ Eigen::MatrixXd tensor_product(const Eigen::VectorXd &A, const Eigen::VectorXd &B); +/** + * @brief Self outer-product \f$A \otimes A\f$ into a flat vector. + * @param A Input column vector. + * @param B Output vector (size must be \f$\mathrm{size}(A)^2\f$). + */ void tensor_self_product(const Eigen::VectorXd &A, Eigen::VectorXd &B); + +/** + * @brief Self outer-product \f$A A^\top\f$ into a matrix. + * @param A Input column vector. + * @param B Output matrix (square, same dimension as \f$A\f$). + */ void tensor_self_product(const Eigen::VectorXd &A, Eigen::MatrixXd &B); +/** @} */ +/** @name Matrix norms + * @brief Induced matrix norms consistent with Eigen semantics. + * @{ + */ +/** + * @brief Infinity norm \f$\|M\|_\infty\f$ (max row 1-norm). + */ double matrix_norm_inf(const Eigen::MatrixXd &M); + +/** + * @brief 1-norm \f$\|M\|_1\f$ (max column 1-norm). + */ double matrix_norm_1(const Eigen::MatrixXd &M); + +/** + * @brief Spectral norm \f$\|M\|_2\f$ (largest singular value). + */ double matrix_norm_2(const Eigen::MatrixXd &M); +/** @} */ -template void apply_filter(T *out, T *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); +/** + * @brief Apply a linear filter to a coefficient block (templated on scalar type). + * + * Conceptually computes and accumulates a matrix product of the form + * \f$G \leftarrow G + F^\top \cdot \mathrm{filter}\f$, where \f$F\f$ and \f$G\f$ + * are views over `in` and `out` with shapes derived from \p kp1 and \p kp1_dm1. + * + * @tparam T Scalar type (`double` or `ComplexDouble`). + * @param[out] out Output buffer (accumulation destination). + * @param[in] in Input buffer (interpreted as a matrix view). + * @param[in] filter Dense filter matrix to apply. + * @param[in] kp1 Leading polynomial order + 1 (per MR basis). + * @param[in] kp1_dm1 \f$\text{kp1}^{D-1}\f$ helper (stride in the mapped view). + * @param[in] fac If zero, overwrite the destination; otherwise accumulate. + * @warning Buffers must contain at least the required number of elements for the + * implicit matrix views. + */ +template +void apply_filter(T *out, T *in, const Eigen::MatrixXd &filter, int kp1, int kp1_dm1, double fac); +/** + * @brief Expand separable 1D coefficient blocks into an \f$ \text{dim}\f$-D tensor layout. + * + * Recursively multiplies along dimensions using the columns of \p primitive. + * + * @param dim Spatial dimensionality (1–3). + * @param dir Current recursion direction (starting at 0). + * @param kp1 Polynomial order + 1. + * @param kp1_d \f$\text{kp1}^d\f$ where \f$d = \text{dim}\f$ (total coefficients). + * @param primitive Matrix with primitive 1D basis values per dimension. + * @param[in,out] expanded Buffer holding intermediate input and final expanded output. + */ void tensor_expand_coefs(int dim, int dir, int kp1, int kp1_d, const Eigen::MatrixXd &primitive, Eigen::VectorXd &expanded); +/** + * @brief Generate 2D sampling coordinates on a tensor grid spanned by primitive 1D points. + * @param kp1 Points per axis. + * @param primitive Matrix whose rows are the per-axis primitive coordinates. + * @param[out] expanded Output matrix of size \f$(\text{kp1}^2) \times 2\f$. + */ void tensor_expand_coords_2D(int kp1, const Eigen::MatrixXd &primitive, Eigen::MatrixXd &expanded); + +/** + * @brief Generate 3D sampling coordinates on a tensor grid spanned by primitive 1D points. + * @param kp1 Points per axis. + * @param primitive Matrix whose rows are the per-axis primitive coordinates. + * @param[out] expanded Output matrix of size \f$(\text{kp1}^3) \times 3\f$. + */ void tensor_expand_coords_3D(int kp1, const Eigen::MatrixXd &primitive, Eigen::MatrixXd &expanded); +/** + * @brief Hermitian matrix power \f$A^b\f$ via eigendecomposition. + * @param A Hermitian (self-adjoint) complex matrix. + * @param b Real exponent. + * @return \f$U\,\mathrm{diag}(\lambda_i^b)\,U^\dagger\f$ where \f$A = U\,\mathrm{diag}(\lambda_i)\,U^\dagger\f$. + * @note Eigenvalues with magnitude near zero are guarded to avoid blow-ups for negative \p b. + */ ComplexMatrix hermitian_matrix_pow(const ComplexMatrix &A, double b); + +/** + * @brief Diagonalize a Hermitian matrix. + * @param A Input Hermitian matrix (not modified). + * @param[out] diag Real vector of eigenvalues (ascending). + * @return Matrix of eigenvectors as columns (unitary). + */ ComplexMatrix diagonalize_hermitian_matrix(const ComplexMatrix &A, DoubleVector &diag); + +/** + * @brief In-place diagonalization of a Hermitian sub-block. + * + * Replaces the \f$n_\text{size}\times n_\text{size}\f$ block of @p M at + * \f$(n_\text{start}, n_\text{start})\f$ with its eigenvalues on the diagonal + * and writes the corresponding eigenvectors into the same block of @p U. + * + * @param[in,out] M Matrix containing the Hermitian sub-block to diagonalize. + * @param[out] U Matrix receiving the block eigenvectors. + * @param nstart Upper-left index of the block. + * @param nsize Size of the (square) block. + */ void diagonalize_block(ComplexMatrix &M, ComplexMatrix &U, int nstart, int nsize); -template std::vector> cartesian_product(std::vector A, std::vector B); -template std::vector> cartesian_product(std::vector> l_A, std::vector B); -template std::vector> cartesian_product(std::vector a, int dim); +/** @name Cartesian products + * @brief Simple, small-container cartesian products (for enumeration tasks). + * @{ + */ +/** + * @brief Cartesian product \f$A \times B\f$. + * @tparam T Element type. + * @param A First list. + * @param B Second list. + * @return Vector of pairs `[a, b]`. + */ +template +std::vector> cartesian_product(std::vector A, std::vector B); + +/** + * @brief Cartesian product \f{(l\_A) \times B\f}, where each element of @p l_A is itself a tuple. + * @tparam T Element type. + * @param l_A List of partial tuples. + * @param B Second list. + * @return Concatenated tuples. + */ +template +std::vector> cartesian_product(std::vector> l_A, std::vector B); + +/** + * @brief Repeated cartesian power \f$A^{\times \text{dim}}\f$. + * @tparam T Element type. + * @param a Base list. + * @param dim Number of repeats (\f$\ge 1\f$). + * @return All length-\p dim tuples with elements from \p a. + */ +template +std::vector> cartesian_product(std::vector a, int dim); +/** @} */ -template double calc_distance(const Coord &a, const Coord &b); +/** + * @brief Euclidean distance between two D-dimensional coordinates. + * @tparam D Dimension (compile-time). + * @param a First point. + * @param b Second point. + * @return \f$\sqrt{\sum_{i=1}^D (a_i-b_i)^2}\f$. + */ +template +double calc_distance(const Coord &a, const Coord &b); } // namespace math_utils -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/mpi_utils.h b/src/utils/mpi_utils.h index 062d1affa..46c6b42ac 100644 --- a/src/utils/mpi_utils.h +++ b/src/utils/mpi_utils.h @@ -24,43 +24,106 @@ */ #pragma once +/** + * @file + * @brief MPI-facing declarations and a lightweight shared-memory helper for MRCPP. + * + * This header provides: + * - Portable aliases for MPI types (working in non-MPI builds as no-ops). + * - Public globals describing the current MPI topology/roles used by MRCPP. + * - A templated @ref mrcpp::SharedMemory class to allocate a shared-memory + * window among ranks that share a physical node (MPI-3 RMA). + * - Prototypes for shipping trees between ranks: @ref send_tree, @ref recv_tree, @ref share_tree. + * + * @note All MPI symbols are guarded by `MRCPP_HAS_MPI`. In non-MPI builds, dummy + * typedefs are supplied so that client code can still compile. + */ #ifdef MRCPP_HAS_MPI -#include + #include #else -using MPI_Comm = int; -using MPI_Win = int; -using MPI_Request = int; + /// Fallback alias so non-MPI builds can compile client code. + using MPI_Comm = int; + /// Fallback alias so non-MPI builds can compile client code. + using MPI_Win = int; + /// Fallback alias so non-MPI builds can compile client code. + using MPI_Request = int; #endif namespace mrcpp { -using mpi_comm = MPI_Comm; -using mpi_win = MPI_Win; + +/// Alias for MPI communicator (portable across MPI/non-MPI builds). +using mpi_comm = MPI_Comm; +/// Alias for MPI window used by RMA (portable across MPI/non-MPI builds). +using mpi_win = MPI_Win; +/// Alias for MPI request handle (portable across MPI/non-MPI builds). using mpi_request = MPI_Request; + +/** + * @namespace mrcpp::mpi + * @brief Runtime MPI topology, role flags, and communicators used internally by MRCPP. + * + * These externs are set during MRCPP's MPI initialization (see implementation) + * and describe how the current process participates in computation and data + * distribution. They are intentionally kept as simple PODs for easy broadcasting + * and logging. + */ namespace mpi { + +/// If true, the code may choose numerically exact variants of some algorithms. extern bool numerically_exact; -extern int shared_memory_size; +/// Requested per-node shared-memory window size (in MB) for shared allocations. +extern int shared_memory_size; +/// Rank of this process in `MPI_COMM_WORLD`. extern int world_rank; +/// Size of `MPI_COMM_WORLD`. extern int world_size; + +/// Rank within the MRCPP "worker" communicator. extern int wrk_rank; +/// Size of the MRCPP "worker" communicator. extern int wrk_size; + +/// Rank within the node-local shared-memory communicator. extern int share_rank; +/// Size of the node-local shared-memory communicator. extern int share_size; + +/// Rank inside the group communicator that clusters ranks by shared-memory groups. extern int sh_group_rank; + +/// True iff this rank belongs to the bank (data service) group. extern int is_bank; +/// True iff this rank is a worker (i.e., bank client). extern int is_bankclient; + +/// Number of ranks dedicated to the bank (data service). extern int bank_size; +/// Desired number of bank ranks per node (if configured). extern int bank_per_node; + +/// User/auto-configured OpenMP thread count hint for workers. extern int omp_threads; +/// If non-zero, honor the environment's OMP thread count for sizing decisions. extern int use_omp_num_threads; + +/// Total number of bank ranks (including any special managers). extern int tot_bank_size; + +/// Upper bound for usable MPI tags (implementation specific). extern int max_tag; + +/// World-rank of the special task-manager bank (if any). extern int task_bank; +/// Communicator for workers (orbital/function computations). extern MPI_Comm comm_wrk; +/// Communicator that groups ranks which share physical memory on the same node. extern MPI_Comm comm_share; +/// Communicator that orders ranks within a shared-memory group. extern MPI_Comm comm_sh_group; +/// Communicator that includes all bank ranks (and possibly clients for RPC). extern MPI_Comm comm_bank; } // namespace mpi @@ -68,34 +131,126 @@ extern MPI_Comm comm_bank; namespace mrcpp { -/** @class SharedMemory +/** + * @class SharedMemory + * @brief Thin RAII wrapper around an MPI-3 shared-memory window (per node). + * + * A `SharedMemory` instance allocates a node-local window using + * `MPI_Win_allocate_shared` (only when compiled with MPI). The window can be used + * to place data structures (e.g., coefficient chunks of a @ref FunctionTree) + * accessible by all ranks on the same physical node without explicit messaging. * - * @brief Shared memory block within a compute node + * @tparam T Element type of the memory window. * - * @details This class defines a shared memory window in a shared MPI - * communicator. In order to allocate a FunctionTree in shared memory, - * simply pass a SharedMemory object to the FunctionTree constructor. + * @par Usage + * - Construct on one or more ranks of `mpi::comm_share` to allocate a window. + * - Use `sh_start_ptr`/`sh_end_ptr`/`sh_max_ptr` to manage a simple bump allocator. + * - Call @ref clear to reset the bump pointer without freeing the window. + * + * @note In non-MPI builds, this class becomes a trivial holder and does not + * allocate any real shared memory. */ -template class SharedMemory { +template +class SharedMemory { public: + /** + * @brief Create (or attach to) a node-local shared-memory window. + * @param comm Node-local communicator (typically @ref mpi::comm_share). + * @param sh_size Window size in megabytes (MB). Only rank 0 in @p comm + * dictates the size; other ranks attach to it. + * + * @details + * When `MRCPP_HAS_MPI` is enabled, this constructor calls: + * - `MPI_Win_allocate_shared` on @p comm with the requested size on rank 0 + * (and size 0 on others), + * - `MPI_Win_shared_query` so that every rank obtains a base pointer + * into the same window, + * - Initializes `sh_start_ptr`, `sh_end_ptr`, and `sh_max_ptr`. + */ SharedMemory(mrcpp::mpi_comm comm, int sh_size); + + /// Deleted copy constructor to avoid double-free of the MPI window. SharedMemory(const SharedMemory &mem) = delete; + /// Deleted copy assignment. SharedMemory &operator=(const SharedMemory &mem) = delete; + + /** + * @brief Destroy the shared window and release resources. + * Calls `MPI_Win_free` when built with MPI. + */ ~SharedMemory(); - void clear(); // show shared memory as entirely available + /** + * @brief Reset the bump pointer so the whole window appears free. + * Does not deallocate or shrink the MPI window. + */ + void clear(); - T *sh_start_ptr; // start of shared block - T *sh_end_ptr; // end of used part - T *sh_max_ptr; // end of shared block - mrcpp::mpi_win sh_win; // MPI window object - int rank; // rank among shared group + /// Pointer to the beginning of the shared window. + T *sh_start_ptr{nullptr}; + /// Pointer to one past the last used element (bump pointer). + T *sh_end_ptr{nullptr}; + /// Pointer to one past the last available element (capacity end). + T *sh_max_ptr{nullptr}; + /// Underlying MPI window handle. + mrcpp::mpi_win sh_win{}; + /// Rank of this process within the shared-memory communicator. + int rank{0}; }; template class FunctionTree; -template void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); -template void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); -template void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm); +/** + * @brief Send a @ref FunctionTree to another rank (blocking). + * + * Transfers node/structure (and optionally coefficient) chunks to @p dst + * using point-to-point MPI. If @p nChunks is negative, a small header with the + * number of chunks is sent first. + * + * @tparam D Dimensionality of the function. + * @tparam T Scalar type (`double` or @ref ComplexDouble). + * @param tree FunctionTree to send. + * @param dst Destination rank (in @p comm). + * @param tag Base MPI tag (chunk indices are offset from this). + * @param comm Communicator over which to send. + * @param nChunks Number of chunks to send; if `<0`, the count is sent first. + * @param coeff If true, also send coefficient chunks; otherwise only structure. + */ +template +void send_tree(FunctionTree &tree, int dst, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); + +/** + * @brief Receive a @ref FunctionTree from another rank (blocking). + * + * Reconstructs tree structure (and optionally coefficients) by receiving the + * same chunk layout produced by @ref send_tree. + * + * @tparam D Dimensionality of the function. + * @tparam T Scalar type (`double` or @ref ComplexDouble). + * @param tree Destination FunctionTree (reinitialized internally). + * @param src Source rank (in @p comm). + * @param tag Base MPI tag (must match sender). + * @param comm Communicator over which to receive. + * @param nChunks Number of chunks to receive; if `<0`, read the header first. + * @param coeff If true, receive coefficient chunks; otherwise only structure. + */ +template +void recv_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm, int nChunks = -1, bool coeff = true); + +/** + * @brief Share a @ref FunctionTree to all ranks in a node-local communicator. + * + * Used to mirror the latest version of a shared function across ranks that + * participate in a shared-memory group, without reconstructing the tree from scratch. + * + * @tparam D Dimensionality of the function. + * @tparam T Scalar type (`double` or @ref ComplexDouble). + * @param tree FunctionTree to disseminate. + * @param src Rank that owns the up-to-date copy (in @p comm). + * @param tag Base tag used to coordinate transfers. + * @param comm Communicator comprising the sharing ranks (e.g., @ref mpi::comm_share). + */ +template +void share_tree(FunctionTree &tree, int src, int tag, mrcpp::mpi_comm comm); } // namespace mrcpp diff --git a/src/utils/omp_utils.h b/src/utils/omp_utils.h index 083387c27..9fc87a8af 100644 --- a/src/utils/omp_utils.h +++ b/src/utils/omp_utils.h @@ -24,31 +24,101 @@ */ #pragma once +/** + * @file + * @brief OpenMP utilities and portability shims for MRCPP. + * + * This header centralizes MRCPP's interaction with OpenMP so client code can + * compile and run both with and without OpenMP support. It provides: + * - A consistent way to query the number of threads/rank of a thread. + * - Lightweight lock helpers that compile away in non-OpenMP builds. + * - A global cap on MRCPP-managed threads to avoid oversubscription with Eigen. + * + * @note Eigen is explicitly forced to single-threaded mode via + * `EIGEN_DONT_PARALLELIZE` to prevent nested parallelism when MRCPP + * runs OpenMP regions. + */ +/// Disable Eigen's internal multi-threading to avoid oversubscription. #define EIGEN_DONT_PARALLELIZE #ifdef MRCPP_HAS_OMP -#include -#define mrcpp_get_max_threads() omp_get_max_threads() -#define mrcpp_get_num_threads() mrcpp::max_threads -#define mrcpp_get_thread_num() omp_get_thread_num() -#define MRCPP_INIT_OMP_LOCK() omp_init_lock(&this->omp_lock) -#define MRCPP_DESTROY_OMP_LOCK() omp_destroy_lock(&this->omp_lock) -#define MRCPP_SET_OMP_LOCK() omp_set_lock(&this->omp_lock) -#define MRCPP_UNSET_OMP_LOCK() omp_unset_lock(&this->omp_lock) -#define MRCPP_TEST_OMP_LOCK() omp_test_lock(&this->omp_lock) + #include + + /** + * @name Thread query helpers (OpenMP build) + * @{ + */ + + /// Maximum number of threads OpenMP may use for parallel regions. + #define mrcpp_get_max_threads() omp_get_max_threads() + + /** + * Number of threads MRCPP intends to use in parallel regions. + * + * @details This is capped by user/runtime policy via @ref mrcpp::set_max_threads + * and may be lower than @c omp_get_max_threads() to respect node-level limits. + */ + #define mrcpp_get_num_threads() mrcpp::max_threads + + /// Zero-based thread id within a running OpenMP parallel region. + #define mrcpp_get_thread_num() omp_get_thread_num() + /** @} */ + + /** + * @name Lightweight lock helpers (OpenMP build) + * @brief Macros that operate on a member `omp_lock_t omp_lock;`. + * @{ + */ + #define MRCPP_INIT_OMP_LOCK() omp_init_lock(&this->omp_lock) + #define MRCPP_DESTROY_OMP_LOCK() omp_destroy_lock(&this->omp_lock) + #define MRCPP_SET_OMP_LOCK() omp_set_lock(&this->omp_lock) + #define MRCPP_UNSET_OMP_LOCK() omp_unset_lock(&this->omp_lock) + #define MRCPP_TEST_OMP_LOCK() omp_test_lock(&this->omp_lock) + /** @} */ + #else -#define mrcpp_get_max_threads() 1 -#define mrcpp_get_num_threads() 1 -#define mrcpp_get_thread_num() 0 -#define MRCPP_INIT_OMP_LOCK() -#define MRCPP_DESTROY_OMP_LOCK() -#define MRCPP_SET_OMP_LOCK() -#define MRCPP_UNSET_OMP_LOCK() -#define MRCPP_TEST_OMP_LOCK() + /** + * @name Thread/query helpers (non-OpenMP build) + * @brief Serial fallbacks so code compiles and runs without OpenMP. + * @{ + */ + #define mrcpp_get_max_threads() 1 ///< Always 1 in serial builds. + #define mrcpp_get_num_threads() 1 ///< Always 1 in serial builds. + #define mrcpp_get_thread_num() 0 ///< Single thread has id 0. + /** @} */ + + /** + * @name Lock helpers (non-OpenMP build) + * @brief No-ops in serial builds. + * @{ + */ + #define MRCPP_INIT_OMP_LOCK() + #define MRCPP_DESTROY_OMP_LOCK() + #define MRCPP_SET_OMP_LOCK() + #define MRCPP_UNSET_OMP_LOCK() + #define MRCPP_TEST_OMP_LOCK() + /** @} */ #endif namespace mrcpp { + +/** + * @brief Upper bound on threads MRCPP will request for OpenMP regions. + * + * @details This value is used by @c mrcpp_get_num_threads() and allows MRCPP + * to honor node-level thread budgeting (e.g., when co-scheduled with MPI or + * other threaded libraries). In non-OpenMP builds this remains 1. + */ extern int max_threads; + +/** + * @brief Set the global thread cap used by MRCPP parallel regions. + * @param threads Desired number of threads (clamped to at least 1). + * + * @note This does not change system-wide OpenMP settings; it only influences + * MRCPP's internal use (e.g., via @c mrcpp_get_num_threads()). + */ void set_max_threads(int threads); -} // namespace mrcpp + +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/parallel.h b/src/utils/parallel.h index 395cc1174..417b1d852 100644 --- a/src/utils/parallel.h +++ b/src/utils/parallel.h @@ -1,4 +1,46 @@ +/* + * MRCPP, a numerical library based on multiresolution analysis and + * the multiwavelet basis which provide low-scaling algorithms as well as + * rigorous error control in numerical computations. + * Copyright (C) 2021 Stig Rune Jensen, Jonas Juselius, Luca Frediani and contributors. + * + * This file is part of MRCPP. + * + * MRCPP is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * MRCPP is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with MRCPP. If not, see . + * + * For information on the complete list of contributors to MRCPP, see: + * + */ + #pragma once +/** + * @file + * @brief MPI/OpenMP orchestration and collectives for MRCPP. + * + * This header declares the process/thread orchestration utilities and the + * common collective/point-to-point helpers used by MRCPP to distribute + * multiresolution data structures across MPI ranks (and optionally coordinate + * with OpenMP threads). It provides: + * + * - Initialization/finalization of the MRCPP MPI environment. + * - Rank/topology helpers (e.g., “grand master”, ownership checks). + * - Typed send/recv/broadcast for @ref mrcpp::CompFunction and trees. + * - Element-wise allreduce helpers for Eigen vectors/matrices. + * + * All MPI symbols are no-ops in non-MPI builds (compiled without + * `MRCPP_HAS_MPI`), allowing the same interface to work in serial. + */ #include @@ -12,61 +54,241 @@ using namespace Eigen; -using IntVector = Eigen::VectorXi; +using IntVector = Eigen::VectorXi; using DoubleVector = Eigen::VectorXd; -using ComplexVector = Eigen::VectorXcd; +using ComplexVector= Eigen::VectorXcd; -using IntMatrix = Eigen::MatrixXi; +using IntMatrix = Eigen::MatrixXi; using DoubleMatrix = Eigen::MatrixXd; -using ComplexMatrix = Eigen::MatrixXcd; +using ComplexMatrix= Eigen::MatrixXcd; namespace mrcpp { +/** + * @namespace mrcpp::omp + * @brief OpenMP runtime hints used by the parallel layer. + */ namespace omp { -extern int n_threads; +extern int n_threads; ///< Number of OpenMP threads MRCPP intends to use. } // namespace omp -class Bank; -extern Bank dataBank; +class Bank; ///< Forward declaration of the in-memory data bank. +extern Bank dataBank; ///< Global bank instance used by bank ranks. +/** + * @namespace mrcpp::mpi + * @brief MPI utilities, communicators, and collectives. + * + * Functions in this namespace act as thin wrappers around MPI and encode + * MRCPP’s distribution policy for component functions and trees. + */ namespace mpi { +/** @brief World ranks assigned to bank masters (control/data services). */ extern std::vector bankmaster; +/** + * @brief Initialize MRCPP’s MPI environment and process topology. + * + * Sets up communicators (workers, shared-memory groups, bank group), + * partitions ranks into worker/bank roles, and configures OpenMP thread + * counts per rank. Safe to call exactly once at program start. + */ void initialize(); + +/** + * @brief Finalize MRCPP’s MPI environment. + * + * Performs a global barrier, closes the global data bank (if present), and + * calls `MPI_Finalize()` in MPI builds. Safe to call once at program exit. + */ void finalize(); + +/** + * @brief Rank barrier on a given communicator. + * @param comm MPI communicator to synchronize. + * + * In non-MPI builds this is a no-op. + */ void barrier(MPI_Comm comm); +/** + * @brief Whether this rank is the global worker “grand master”. + * @return @c true iff world rank is 0 and the rank is a worker (not a bank). + */ bool grand_master(); + +/** + * @brief Whether this rank is the master of its shared-memory group. + * @return @c true iff rank is 0 within @ref mpi::comm_share. + */ bool share_master(); +/** + * @name Ownership helpers + * @brief Determine whether an object/function is owned by this rank. + * @{ + */ + +/** + * @brief Ownership test for an index. + * @param j Global function index. + * @return @c true if @c j maps to this rank under MRCPP’s block-cyclic policy. + */ bool my_func(int j); + +/** + * @brief Ownership test for a component function (const ref). + * @param func Component function to test. + * @return @c true if @p func belongs to this rank (by @c func.rank()). + */ bool my_func(const CompFunction<3> &func); + +/** + * @brief Ownership test for a component function (pointer). + * @param func Pointer to component function. + * @return @c true if @p func belongs to this rank (by @c func->rank()). + */ bool my_func(CompFunction<3> *func); +/** @} */ -// bool my_unique_orb(const Orbital &orb); +/** + * @brief Free memory held by functions not owned by this rank. + * @param Phi Vector of component functions; foreign entries are freed in place. + */ void free_foreign(CompFunctionVector &Phi); +/** + * @name Point-to-point transfers for component functions + * @brief Send/receive/share a @ref mrcpp::CompFunction across ranks. + * @{ + */ + +/** + * @brief Send a component function to a destination rank. + * @param func Function to send. + * @param dst Destination world rank. + * @param tag Message tag base (submessages will offset from this). + * @param comm Communicator (default: worker communicator). + * + * Sends the function header followed by its component trees. Assumes the + * receiver uses @ref recv_function with the same @p tag and @p comm. + */ void send_function(const CompFunction<3> &func, int dst, int tag, MPI_Comm comm = mpi::comm_wrk); + +/** + * @brief Receive a component function from a source rank. + * @param func Function to receive into (resized as needed). + * @param src Source world rank. + * @param tag Message tag base (must match sender). + * @param comm Communicator (default: worker communicator). + */ void recv_function(CompFunction<3> &func, int src, int tag, MPI_Comm comm = mpi::comm_wrk); + +/** + * @brief Update shared-memory replicas of a function after modification. + * @param func Function to share (must be marked shared). + * @param src Rank that produced the update. + * @param tag Base tag for the transfer. + * @param comm Communicator that defines the sharing group. + * + * Only has effect if the function was allocated in shared memory. + */ void share_function(CompFunction<3> &func, int src, int tag, MPI_Comm comm); +/** @} */ +/** + * @brief Reduce (sum/accumulate) a function onto rank 0 of @p comm. + * @param prec Cropping precision applied after each accumulation. + * @param func Function buffer holding the local contribution; on rank 0 it + * becomes the global sum; on other ranks it may be left unchanged. + * @param comm Communicator over which to reduce. + * + * Uses a binary-tree pattern to send odd ranks to preceding even ranks; the + * receiver adds and crops to control growth. + */ void reduce_function(double prec, CompFunction<3> &func, MPI_Comm comm); + +/** + * @brief Broadcast a function from rank 0 to all ranks in @p comm. + * @param func Buffer to receive (or hold, on root) the broadcasted function. + * @param comm Communicator to broadcast over. + * + * Implements a reverse of the binary-tree pattern used by + * @ref reduce_function. + */ void broadcast_function(CompFunction<3> &func, MPI_Comm comm); -template void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm); -template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, std::vector> &Phi, MPI_Comm comm); -template void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm); +/** + * @name Tree collectives (no coefficient payload) + * @brief Perform collectives on @ref mrcpp::FunctionTree without coefficients. + * @{ + */ + +/** + * @brief Reduce (union) grids from all ranks to rank 0, excluding coeffs. + * @tparam T Coefficient scalar type of the tree. + * @param tree Output/input tree on each rank; on rank 0 it becomes the union. + * @param comm Communicator. + */ +template +void reduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm); -template void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, std::vector> &Phi, MPI_Comm comm); +/** + * @brief Build local union grid, reduce to rank 0, then broadcast to all. + * @tparam T Coefficient scalar type. + * @param tree Target tree to hold the global union grid (no coeffs). + * @param Phi Vector of trees whose grids contribute to the union. + * @param comm Communicator. + */ +template +void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, + std::vector> &Phi, + MPI_Comm comm); +/** + * @brief Broadcast a no-coeff tree from rank 0 to all ranks. + * @tparam T Coefficient scalar type. + * @param tree Tree to broadcast/receive. + * @param comm Communicator. + */ +template +void broadcast_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, MPI_Comm comm); + +/** + * @brief Build union grid from owned components in @p Phi, allreduce to all. + * @tparam T Coefficient scalar type. + * @param tree Output tree receiving the global union grid. + * @param Phi Vector of component functions contributing their grids. + * @param comm Communicator. + */ +template +void allreduce_Tree_noCoeff(mrcpp::FunctionTree<3, T> &tree, + std::vector> &Phi, + MPI_Comm comm); +/** @} */ + +/** + * @name Element-wise allreduce (sum) helpers + * @brief Sum across ranks into every rank for Eigen containers. + * @{ + */ + +/** @brief In-place element-wise sum allreduce for integer vectors. */ void allreduce_vector(IntVector &vec, MPI_Comm comm); +/** @brief In-place element-wise sum allreduce for double vectors. */ void allreduce_vector(DoubleVector &vec, MPI_Comm comm); +/** @brief In-place element-wise sum allreduce for complex vectors. */ void allreduce_vector(ComplexVector &vec, MPI_Comm comm); + +/** @brief In-place element-wise sum allreduce for integer matrices. */ void allreduce_matrix(IntMatrix &vec, MPI_Comm comm); +/** @brief In-place element-wise sum allreduce for double matrices. */ void allreduce_matrix(DoubleMatrix &mat, MPI_Comm comm); +/** @brief In-place element-wise sum allreduce for complex matrices. */ void allreduce_matrix(ComplexMatrix &mat, MPI_Comm comm); +/** @} */ } // namespace mpi -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/periodic_utils.h b/src/utils/periodic_utils.h index f79a893d3..1ea95aaf2 100644 --- a/src/utils/periodic_utils.h +++ b/src/utils/periodic_utils.h @@ -24,12 +24,98 @@ */ #pragma once +/** + * @file + * @brief Periodic boundary utilities for node indices and real-space coordinates. + * + * This header declares helpers for enforcing periodic boundary conditions (PBC) + * on both discrete tree indices (@ref mrcpp::NodeIndex) and continuous + * coordinates (@ref mrcpp::Coord). The utilities are templated on dimension + * @p D and support selectively periodic directions via a boolean mask. + * + * Typical use cases: + * - Normalizing a node index to the canonical unit cell before lookup. + * - Wrapping real-space coordinates into the primary cell when sampling or + * exporting data. + * - Applying periodicity per axis (e.g., 2D slab periodic in x/y but not z). + */ #include "MRCPP/mrcpp_declarations.h" + namespace mrcpp { +/** + * @namespace mrcpp::periodic + * @brief Helpers for periodic index/coordinate manipulation. + * + * The functions here assume MRCPP’s convention where the canonical cell is the + * unit hypercube, and indices/coordinates are normalized accordingly: + * - Discrete indices: @ref NodeIndex logically cover tiles of the unit cell + * at a given resolution (scale). These helpers re-map out-of-range indices + * back into the unit cell modulo the periodic axes. + * - Continuous coordinates: @ref Coord (double-valued) are wrapped by + * subtracting/adding integer lattice vectors along periodic axes so that + * the result lies in the half-open interval [0, 1) per periodic dimension. + */ namespace periodic { -template bool in_unit_cell(NodeIndex idx); -template void index_manipulation(NodeIndex &idx, const std::array &periodic); -template void coord_manipulation(Coord &r, const std::array &periodic); + +/** + * @brief Check whether a node index lies inside the unit cell. + * + * @tparam D Spatial dimension. + * @param idx Node index to test (scale and per-dimension integer indices). + * @return @c true if @p idx is within the canonical unit cell bounds in all + * dimensions; @c false if any component is outside. + * + * @details “Inside” means the discrete index components fall in the valid range + * for the node’s scale with no modular wrap required. This does not modify + * @p idx and performs a pure check. Use @ref index_manipulation to fold an + * index back into the unit cell when periodicity is intended. + */ +template +bool in_unit_cell(NodeIndex idx); + +/** + * @brief Fold a node index into the unit cell under per-axis periodicity. + * + * @tparam D Spatial dimension. + * @param[in,out] idx Node index to normalize; on return, the per-axis integer + * index components are mapped into the unit-cell range for + * the node’s scale when the corresponding axis is periodic. + * @param periodic Boolean mask of length @p D; @c true marks an axis as + * periodic, @c false leaves that axis unchanged (no wrapping). + * + * @details + * For each periodic axis, the index component is reduced modulo the extent at + * the node’s scale so that the resulting index is in-range. For non-periodic + * axes, the index is left as-is (and may remain out-of-bounds if provided so). + * + * @note This function is idempotent for already in-range indices on periodic + * axes and is a no-op for non-periodic axes. + */ +template +void index_manipulation(NodeIndex &idx, const std::array &periodic); + +/** + * @brief Wrap a coordinate into the unit cell under per-axis periodicity. + * + * @tparam D Spatial dimension. + * @param[in,out] r Coordinate to normalize; each periodic component is wrapped + * into the half-open interval [0, 1). + * @param periodic Boolean mask of length @p D; @c true marks an axis as + * periodic, @c false leaves that axis unchanged (no wrapping). + * + * @details + * For each periodic axis, the component @f$r_d@f$ is replaced by + * @f$r_d - \lfloor r_d \rfloor@f$, producing a value in [0, 1). Non-periodic + * axes are not modified. This is equivalent to applying @c std::floor-based + * fractional reduction to each periodic component. + * + * @warning If your simulation cell is scaled or shifted relative to the unit + * cube, convert to reduced coordinates before calling this function, or adjust + * the values accordingly (e.g., divide by box length) and convert back. + */ +template +void coord_manipulation(Coord &r, const std::array &periodic); + } // namespace periodic -} // namespace mrcpp +} // namespace mrcpp \ No newline at end of file diff --git a/src/utils/tree_utils.h b/src/utils/tree_utils.h index 56c8c7d79..52d7a564b 100644 --- a/src/utils/tree_utils.h +++ b/src/utils/tree_utils.h @@ -29,16 +29,149 @@ #include "utils/math_utils.h" namespace mrcpp { +/** + * @file + * @brief Utilities for inspecting and transforming Multiwavelet (MW) trees. + * + * @details + * This header declares helper routines that operate on MRCPP tree structures: + * - adaptive refinement decisions based on wavelet norms, + * - creation of per-scale or flat node tables (Hilbert-ordered), + * - forward and backward multiwavelet transforms between parent/children + * scaling coefficients. + * + * Unless otherwise stated, functions are **not** thread-safe; synchronize at + * a higher level if multiple threads may act on the same tree or buffers. + */ namespace tree_utils { -template bool split_check(const MWNode &node, double prec, double split_fac, bool abs_prec); +/** + * @brief Decide whether a node should be split (refined) based on its wavelet norm. + * + * @tparam D Spatial dimension of the MW tree. + * @tparam T Coefficient type (`double` or `ComplexDouble`). + * @param node Node to be tested. + * @param prec Target accuracy (relative by default). Non-positive disables splitting. + * @param split_fac Scale-dependent factor. If `> MachineZero`, the threshold is + * scaled by \f$2^{-0.5 \cdot \text{split\_fac} \cdot (s+1)}\f$ + * where `s` is the node scale; this makes refinement stricter + * at finer scales. + * @param abs_prec When `true`, interpret `prec` as an **absolute** tolerance. + * When `false`, use a **relative** tolerance multiplied by + * \f$\|f\|\f$ (square-norm taken from the owning tree). + * + * @return `true` if the node’s wavelet norm exceeds the computed threshold and + * the node should be refined; `false` otherwise. + * + * @details + * The decision compares \f$\|\mathbf{w}\|\f$ (node wavelet norm) to a threshold: + * \f[ + * \tau = \max(2\,\text{MachinePrec},\; + * \text{prec} \times (\text{abs\_prec} ? 1 : \|f\|) \times \text{scale\_fac}) + * \f] + * where \f$\text{scale\_fac}\f$ is determined by `split_fac` as described above. + * If the owning tree’s square norm is zero and `abs_prec == false`, a fallback + * of \f$\|f\|=1\f$ is used. + */ +template +bool split_check(const MWNode &node, double prec, double split_fac, bool abs_prec); + +/** + * @brief Build a flat, Hilbert-ordered table of all non-root nodes in a tree. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param tree Input MW tree. + * @param table Output vector receiving pointers to all internal and leaf nodes + * (root depth 0 is skipped). Nodes are traversed in a Hilbert + * space-filling curve order; generator nodes are excluded. + * + * @details + * Useful for linear passes (e.g., I/O, diagnostics, custom sweeps) where a + * contiguous list of nodes is required. + */ +template +void make_node_table(MWTree &tree, MWNodeVector &table); + +/** + * @brief Build per-scale Hilbert-ordered node tables. + * + * @tparam D Spatial dimension. + * @tparam T Coefficient type. + * @param tree Input MW tree. + * @param table Output vector of vectors. Index `d` stores node pointers whose + * depth corresponds to `d - tree.getNNegScales()`. Each inner + * vector is Hilbert-ordered; generator nodes are excluded. + * + * @details + * This form is convenient for level-wise processing such as multigrid cycles, + * visualization, or per-scale statistics. + */ +template +void make_node_table(MWTree &tree, std::vector> &table); -template void make_node_table(MWTree &tree, MWNodeVector &table); -template void make_node_table(MWTree &tree, std::vector> &table); +/** + * @brief Forward MW transform: build children scaling coefficients from a parent block. + * + * @tparam D Spatial dimension (implemented for 1, 2, 3). + * @tparam T Coefficient type (`double` or `ComplexDouble`). + * @param tree Tree providing filter and arity/meta information. + * @param coeff_in Pointer to the parent block (size = `kp1^D` entries), + * laid out in standard MRCPP order. + * @param coeff_out Pointer to the destination buffer for **children** blocks. + * This routine writes (or accumulates) into `2^D` child + * blocks separated by `stride` elements each. + * @param readOnlyScaling If `true`, operate as if only scaling components are + * present (skips mixing with wavelets internally). + * @param stride Stride, in elements, between consecutive child blocks + * inside `coeff_out`. Must be at least `kp1^D`. + * @param overwrite When `true` (default), assign into `coeff_out`. + * When `false`, accumulate (add) into existing values. + * + * @pre + * - `coeff_out` points to sufficient writable storage: + * at least `2^D * stride` elements of type `T`. + * - `coeff_in` points to at least `kp1^D` elements. + * + * @post + * - The `2^D` children scaling blocks are produced in-place in `coeff_out`. + * + * @note + * Complexity is \f$O(2^D \cdot k^{D+1})\f$ for polynomial order `k` (where `kp1 = k+1`). + * For `D > 3` the routine is not implemented. + */ +template +void mw_transform(const MWTree &tree, + T *coeff_in, + T *coeff_out, + bool readOnlyScaling, + int stride, + bool overwrite = true); -template void mw_transform(const MWTree &tree, T *coeff_in, T *coeff_out, bool readOnlyScaling, int stride, bool overwrite = true); // template void mw_transform_back(MWTree &tree, T *coeff_in, T *coeff_out, int stride); -template void mw_transform_back(MWTree<3, T> &tree, T *coeff_in, T *coeff_out, int stride); + +/** + * @brief Backward MW transform (3D specialization): build the parent block from children. + * + * @tparam T Coefficient type (`double` or `ComplexDouble`). + * @param tree Tree providing filter and arity/meta information. + * @param coeff_in Pointer to the concatenated **children** blocks (8 blocks in 3D), + * each of size `kp1^3`, separated by `stride` elements. + * @param coeff_out Pointer to the **parent** block storage (size `kp1^3`). + * @param stride Stride, in elements, between consecutive children blocks. + * + * @pre + * - `coeff_in` provides at least `8 * stride` elements. + * - `coeff_out` provides at least `kp1^3` writable elements. + * + * @post + * - The parent scaling block is reconstructed into `coeff_out`. + * + * @note + * Only the \f$D=3\f$ variant is provided. Use @ref mw_transform for the forward direction. + */ +template +void mw_transform_back(MWTree<3, T> &tree, T *coeff_in, T *coeff_out, int stride); } // namespace tree_utils } // namespace mrcpp