diff --git a/docs/src/api/loggers.md b/docs/src/api/loggers.md index 8b7275a4..8f02c1aa 100644 --- a/docs/src/api/loggers.md +++ b/docs/src/api/loggers.md @@ -16,6 +16,16 @@ BasicLogger BasicLoggerRecipe +MALogger + +MALoggerRecipe + +``` + +### Moving average logger structures +```@docs +MAInfo + ``` ## Exported Functions @@ -27,4 +37,6 @@ update_logger! reset_logger! threshold_stop + +update_ma! ``` diff --git a/docs/src/refs.bib b/docs/src/refs.bib index c6869065..b78a0428 100644 --- a/docs/src/refs.bib +++ b/docs/src/refs.bib @@ -56,3 +56,5 @@ @article{pritchard2024solving urldate = {2025-01-24}, abstract = {Abstract In large-scale applications including medical imaging, collocation differential equation solvers, and estimation with differential privacy, the underlying linear inverse problem can be reformulated as a streaming problem. In theory, the streaming problem can be effectively solved using memory-efficient, exponentially-converging streaming solvers. In special cases when the underlying linear inverse problem is finite-dimensional, streaming solvers can periodically evaluate the residual norm at a substantial computational cost. When the underlying system is infinite dimensional, streaming solver can only access noisy estimates of the residual. While such noisy estimates are computationally efficient, they are useful only when their accuracy is known. In this work, we rigorously develop a general family of computationally-practical residual estimators and their uncertainty sets for streaming solvers, and we demonstrate the accuracy of our methods on a number of large-scale linear problems. Thus, we further enable the practical use of streaming solvers for important classes of linear inverse problems.} } + + diff --git a/src/RLinearAlgebra.jl b/src/RLinearAlgebra.jl index 046c02d1..27548945 100644 --- a/src/RLinearAlgebra.jl +++ b/src/RLinearAlgebra.jl @@ -28,8 +28,9 @@ export complete_solver, update_solver!, rsolve, rsolve! # Export Logger types and functions export Logger, LoggerRecipe -export BasicLogger, BasicLoggerRecipe -export complete_logger, update_logger!, reset_logger! +export MAInfo +export BasicLogger, BasicLoggerRecipe, MALogger, MALoggerRecipe +export complete_logger, update_logger!, reset_logger!, update_ma! export threshold_stop # Export SubSolver types diff --git a/src/Solvers/Loggers.jl b/src/Solvers/Loggers.jl index 44bc24b4..3bf30925 100644 --- a/src/Solvers/Loggers.jl +++ b/src/Solvers/Loggers.jl @@ -89,7 +89,14 @@ function reset_logger!(logger::LoggerRecipe) return nothing end -############################## -# Include Logger Files +################################################ +# Include Logger Moving Average helpers Files +################################################ +include("Loggers/ma_helpers/ma_info.jl") + +############################### +# Include Logger Methods Files ############################### include("Loggers/basic_logger.jl") +include("Loggers/moving_average_logger.jl") + diff --git a/src/Solvers/Loggers/basic_logger.jl b/src/Solvers/Loggers/basic_logger.jl index e1e906d3..d1cfc5e8 100644 --- a/src/Solvers/Loggers/basic_logger.jl +++ b/src/Solvers/Loggers/basic_logger.jl @@ -1,7 +1,7 @@ """ BasicLogger <: Logger -This is a mutable struct that contains the `max_it` parameter and stores the error metric +This is a struct that contains the `max_it` parameter and stores the error metric in a vector. Checks convergence of the solver based on the log information. # Fields diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl new file mode 100644 index 00000000..062521b5 --- /dev/null +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -0,0 +1,148 @@ +# This file contains the components that are needed for storing +# and update moving average information for moving average method: +# Structs: MAInfo +# Functions: update_ma! + +######################################### +# Structs +######################################### +""" + MAInfo + +A mutable structure that stores information relevant to the moving average (MA) of a + progress estimator, such as a residual. It manages different MA window widths + (`lambda1`, `lambda2`) for different convergence phases and tracks the current + MA window (`res_window`). + +See [pritchard2024solving](@cite) for more information on the underlying MA methods. + + +# Fields +- `lambda1::Integer`, the width of the moving average during the fast convergence phase of the algorithm. + During this fast convergence phase, the majority of variation of the sketched estimator comes from + improvement in the solution and thus wide moving average windows inaccurately represent progress. +- `lambda2::Integer`, the width of the moving average in the slower convergence phase. In the slow convergence + phase, each iterate differs from the previous one by a small amount and thus most of the observed variation + arises from the randomness of the sketched progress estimator, which is best smoothed by a wide moving + average width. +- `lambda::Integer`, the actual width of the moving average being used at the current iteration. + This field is updated internally and is not set directly by the user. +- `flag::Bool`, a boolean flag indicating the current convergence phase. A value of `true` + typically indicates the "slow" convergence phase (using `lambda2`). +- `idx::Integer`, the current index within the `res_window` buffer where the next residual value + will be stored, implementing a circular buffer. +- `res_window::Vector{<:AbstractFloat}`, the buffer storing the recent residual values + used to compute the moving average. +""" +mutable struct MAInfo + lambda1::Integer + lambda2::Integer + lambda::Integer + flag::Bool + idx::Integer + res_window::Vector{<:AbstractFloat} +end + +######################################### +# Functions +######################################### +""" + update_ma!( + log::LoggerRecipe, + res::Union{AbstractVector, Real}, + lambda_base::Integer, + iter::Integer + ) + +Updates the moving average statistics stored within the `log.ma_info` field of a + `MALoggerRecipe`. It computes the moving average and second moment (iota) of the + provided residual `res` and updates `log.error`, `log.iota_error`, and + `log.lambda_origin` fields of the `MALoggerRecipe`. + +# Arguments +- `log::LoggerRecipe`, the parent of moving average logger recipe structure. +- `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration. +- `lambda_base::Integer`, which lambda, between lambda1 and lambda2, is currently being used. +- `iter::Integer`, the current iteration. + +# Returns +- `nothing` (The `log` object, specifically its `ma_info`, `error`, `iota_error`, and + `lambda_origin` fields, is modified in-place). +""" +function update_ma!( + log::LoggerRecipe, + res::Union{AbstractVector,Real}, + lambda_base::Integer, + iter::Integer, +) + # Variable to store the sum of the terms for rho + accum = 0 + # Variable to store the sum of the terms for iota + accum2 = 0 + ma_info = log.ma_info + ma_info.idx = ma_info.idx < ma_info.lambda2 && iter != 0 ? ma_info.idx + 1 : 1 + ma_info.res_window[ma_info.idx] = res + #Check if entire storage buffer can be used + if ma_info.lambda == ma_info.lambda2 + # Compute the moving average + for i in 1:(ma_info.lambda2) + accum += ma_info.res_window[i] + accum2 += ma_info.res_window[i]^2 + end + + # Record the moving average error for stopping + log.lambda_origin = ma_info.lambda + log.error = accum / ma_info.lambda + log.iota_error = accum2 / ma_info.lambda + + # if mod(iter, log.collection_rate) == 0 || iter == 0 + # push!(log.lambda_hist, ma_info.lambda) + # push!(log.hist, accum / ma_info.lambda) + # (:iota_hist in fieldnames(typeof(log))) && + # push!(log.iota_hist, accum2 / ma_info.lambda) + # end + + else + # Consider the case when lambda <= lambda1 or lambda1 < lambda < lambda2 + diff = ma_info.idx - ma_info.lambda + # Because the storage of the residual is based dependent on lambda2 and + # we want to sum only the previous lamdda terms we could have a situation + # where we want the first `idx` terms of the buffer and the last `diff` + # terms of the buffer. Doing this requires two loops + # If `diff` is negative there idx is not far enough into the buffer and + # two sums will be needed + startp1 = diff < 0 ? 1 : (diff + 1) + + # Assuming that the width of the buffer is lambda2 + startp2 = diff < 0 ? ma_info.lambda2 + diff + 1 : 2 + endp2 = diff < 0 ? ma_info.lambda2 : 1 + + # Compute the moving average two loop setup required when lambda < lambda2 + for i in startp1:(ma_info.idx) + accum += ma_info.res_window[i] + accum2 += ma_info.res_window[i]^2 + end + + for i in startp2:endp2 + accum += ma_info.res_window[i] + accum2 += ma_info.res_window[i]^2 + end + + # Record the moving average error for stopping + log.lambda_origin = ma_info.lambda + log.error = accum / ma_info.lambda + log.iota_error = accum2 / ma_info.lambda + + #Update the log variable with the information for this update + # if mod(iter, log.collection_rate) == 0 || iter == 0 + # push!(log.lambda_hist, ma_info.lambda) + # push!(log.hist, accum / ma_info.lambda) + # (:iota_hist in fieldnames(typeof(log))) && + # push!(log.iota_hist, accum2 / ma_info.lambda) + # end + + ma_info.lambda += ma_info.lambda < lambda_base ? 1 : 0 + end + + return nothing +end \ No newline at end of file diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl new file mode 100644 index 00000000..95256cb8 --- /dev/null +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -0,0 +1,236 @@ +""" + MALogger <: Logger + +A structure that stores information of specification about a randomized linear solver's + behavior. The log assumes that the full linear system is available for processing. + The goal of this log is usually for research, development or testing as it is unlikely + that the entire residual vector is readily available. + +# Fields +- `max_it::Int64`, the maximum number of iterations for the solver. +- `collection_rate::Int64`, the frequency with which to record information about progress + to append to the remaining fields, starting with the initialization + (i.e., iteration `0`). For example, `collection_rate` = `3` means the iteration + difference between each records is `3`, i.e. recording information at + iteration `0`, `3`, `6`, `9`, .... +- `ma_info::MAInfo`, [`MAInfo`](@ref) +- `threshold_info::Union{Float64, Tuple}`, the parameters used for stopping the algorithm. +- `stopping_criterion::Function`, function that evaluates the stopping criterion. + +# Constructors + + MALogger(; + max_it=0, + collection_rate=1, + lambda1=1, + lambda2=30, + threshold_info=1e-10, + stopping_criterion=threshold_stop + ) + +## Keywords +- `max_it::Int64`, the maximum number of iterations for the solver. Default: `0`. +- `collection_rate::Integer`, the frequency for recording progress. Default: `1`. +- `lambda1::Integer`, the width of the moving average during the initial "fast" + convergence phase. Default: `1`. +- `lambda2::Integer`, the width of the moving average during the later "slow" + convergence phase. Default: `30`. +- `threshold_info::Union{Float64, Tuple}`, parameters for the stopping criterion. + Default: `1e-10`. +- `stopping_criterion::Function`, the function used to check for convergence. + Default: `threshold_stop`. + +## Returns +- A `MALogger` object. + +## Throws +- `ArgumentError` if `max_it` is negative. +- `ArgumentError` if `collection_rate` is less than `1`. +- `ArgumentError` if `max_it` is positive and `collection_rate` is greater + than `max_it`. +""" +struct MALogger <: Logger + max_it::Int64 + collection_rate::Integer + ma_info::MAInfo + threshold_info::Union{Float64, Tuple} + stopping_criterion::Function + function MALogger(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) + if max_it < 0 + throw(ArgumentError("Field `max_it` must be positive or 0.")) + elseif collection_rate < 1 + throw(ArgumentError("Field `colection_rate` must be positive.")) + elseif collection_rate > max_it && max_it > 0 + throw(ArgumentError("Field `colection_rate` must be smaller than `max_it`.")) + end + + return new(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) + end + +end + +MALogger(; + max_it=0, + collection_rate=1, + lambda1=1, + lambda2=30, + threshold_info=1e-10, + stopping_criterion=threshold_stop + ) = MALogger(max_it, + collection_rate, + MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), + threshold_info, + stopping_criterion + ) + + +""" + MALoggerRecipe{F<:Function} <: LoggerRecipe + +A mutable structure that contains the fully initialized state and pre-allocated memory + for the `MALogger`. It stores the error metric history and moving average information, + and checks for convergence based on this log. + +# Fields +- `max_it::Int64`, the maximum number of iterations for the solver. +- `error::AbstractFloat`, the current error metric (often a moving average of residuals). + Named `error` for compatibility with generic field checks. +- `iota_error::AbstractFloat`, an auxiliary error metric, potentially related to variance + or another aspect of the moving average. +- `iteration::Int64`, the current iteration number of the solver. +- `record_location::Int64`, the index in `hist` and `lambda_hist` to store data. +- `collection_rate::Integer`, the frequency with which progress information is recorded. +- `converged::Bool`, a flag indicating whether the stopping criterion has been met. + Default is `false` upon initialization. +- `ma_info::MAInfo`, [`MAInfo`](@ref) +- `hist::Vector{AbstractFloat}`, a vector storing the history of the primary error metric + (e.g., moving average of squared residuals) at specified `collection_rate` intervals. +- `lambda_origin::Integer`, stores the current lambda value (lambda1 or lambda2) being used by the + moving average calculation at the time of recording. +- `lambda_hist::Vector{Integer}`, a vector storing the history of the `lambda_origin` values + at specified `collection_rate` intervals. +- `threshold_info::Union{Float64, Tuple}`, parameters used by the `stopping_criterion`. +- `stopping_criterion::F` (where `F<:Function`), the function used to evaluate + the stopping criterion. +""" +mutable struct MALoggerRecipe{F<:Function} <: LoggerRecipe + max_it::Int64 + error::AbstractFloat # Moving average error, named for field check + iota_error::AbstractFloat + iteration::Int64 + record_location::Int64 + collection_rate::Integer + converged::Bool + ma_info::MAInfo + hist::Vector{AbstractFloat} # Residual history, named for field check + lambda_origin::Integer + lambda_hist::Vector{Integer} + threshold_info::Union{Float64, Tuple} + stopping_criterion::F +end + +function complete_logger(logger::MALogger) + # By using ceil if we divide exactly we always have space to record last value, if it + # does not divide exactly we have one more than required and thus enough space to record + # the last value + max_collection = Int(ceil(logger.max_it / logger.collection_rate)) + # Use one more than max_it to collect + res_hist = zeros(max_collection + 1) + lambda_hist = zeros(max_collection + 1) + return MALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, + 0.0, + 0.0, + 1, + 1, + logger.collection_rate, + false, + logger.ma_info, + res_hist, + 0, + lambda_hist, + logger.threshold_info, + logger.stopping_criterion + ) +end + + +# Common interface for update +function update_logger!( + logger::MALoggerRecipe, + error::AbstractFloat, + iteration::Int64 +) + # Update iteration counter + logger.iteration = iteration + + ############################### + # Implement moving average (MA) + ############################### + ma_info = logger.ma_info + # Compute the current residual to second power to align with theory + res::AbstractFloat = error^2 + + # Check if MA is in lambda1 or lambda2 regime + if ma_info.flag + update_ma!(logger, res, ma_info.lambda2, iteration) + else + # Check if we can switch between lambda1 and lambda2 regime + # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime + # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag + # because update_ma changes res_window and ma_info.idx we must check condition first + flag_cond = iteration == 0 || res <= ma_info.res_window[ma_info.idx] + update_ma!(logger, res, ma_info.lambda1, iteration) + ma_info.flag = !flag_cond + end + + # Always check max_it stopping criterion + # Compute in this way to avoid bounds error from searching in the max_it + 1 location + logger.converged = iteration <= logger.max_it ? + logger.stopping_criterion(logger, logger.threshold_info) : + true + + # log according to collection rate or if we have converged + if rem(iteration, logger.collection_rate) == 0 || logger.converged + + logger.lambda_hist[logger.record_location] = logger.lambda_origin + logger.hist[logger.record_location] = logger.error + + logger.record_location += 1 + end + + return nothing + +end + + + +function reset_logger!(logger::MALoggerRecipe) + logger.error = 0.0 + logger.iota_error = 0.0 + logger.lambda_origin = 0 + logger.iteration = 1 + logger.record_location = 1 + logger.converged = false + fill!(logger.hist, 0.0) + fill!(logger.lambda_hist, 0.0) + return nothing +end + + + +""" + threshold_stop(log::MALoggerRecipe) + +Default stopping criterion that checks if the current error metric in the logger + is below a specified threshold. + +# Arguments +- `log::MALoggerRecipe`, a structure containing the logger information. + +# Bool +- `Bool`, Returns `true` if `log.error` is less than `log.threshold_info`, + indicating the stopping threshold is satisfied. Otherwise, returns `false`. +""" +function threshold_stop(log::MALoggerRecipe) + return log.error < log.threshold_info +end diff --git a/test/Solvers/Loggers/moving_average_logger.jl b/test/Solvers/Loggers/moving_average_logger.jl new file mode 100644 index 00000000..65bf7a3c --- /dev/null +++ b/test/Solvers/Loggers/moving_average_logger.jl @@ -0,0 +1,27 @@ +module moving_average_logger + using Test, RLinearAlgebra, Random + include("../../test_helpers/field_test_macros.jl") + include("../../test_helpers/approx_tol.jl") + using .FieldTest + using .ApproxTol + @testset "Logger MALogger" begin + Random.seed!(21321) + n_rows = 4 + n_cols = 2 + A = rand(n_rows, n_cols) + b = rand(n_rows) + + # How to use the logger, the error for the update + a = MALogger() + b = complete_logger(a) + reset_logger!(b) + update_logger!(b, 0.5, 1) + + + + + + + end + +end