From 1f7386af30b3c617d9cd9db9f25427788ede951e Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Wed, 14 May 2025 09:10:55 -0500 Subject: [PATCH 01/14] initialize --- src/Solvers/Loggers/moving_average_logger.jl | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/Solvers/Loggers/moving_average_logger.jl diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl new file mode 100644 index 00000000..e69de29b From 1c31e8464d4b29c8bcc2e4de63d88869d610c6ca Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Fri, 16 May 2025 11:15:47 -0500 Subject: [PATCH 02/14] helpers --- src/Solvers/Loggers/ma_helpers/dist_info.jl | 56 +++++++++ src/Solvers/Loggers/ma_helpers/ma_info.jl | 130 ++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 src/Solvers/Loggers/ma_helpers/dist_info.jl create mode 100644 src/Solvers/Loggers/ma_helpers/ma_info.jl diff --git a/src/Solvers/Loggers/ma_helpers/dist_info.jl b/src/Solvers/Loggers/ma_helpers/dist_info.jl new file mode 100644 index 00000000..7d5d13ac --- /dev/null +++ b/src/Solvers/Loggers/ma_helpers/dist_info.jl @@ -0,0 +1,56 @@ +# This file contains the components that are needed for storing +# and using distribution infromation for moving average method: +# Structs: SEDistInfo +# Functions: get_uncertainty, get_SE_constants! + +######################################### +# Structs +######################################### +""" + SEDistInfo + +A mutable structure that stores information about a distribution (i.e., sampling method) +in the sub-Exponential family. + +# Fields +- `sampler::Union{DataType, Nothing}`, the type of sampling method. +- `dimension::Int64`, the dimension that of the space that is being sampled. +- `block_dimension::Int64`, the dimension of the sample. +- `sigma2::Union{Float64, Nothing}`, the variance parameter in the sub-Exponential family. + If not specified by the user, a value is selected from a table based on the `sampler`. + If the `sampler` is not in the table, then `sigma2` is set to `1`. +- `omega::Union{Float64, Nothing}`, the exponential distrbution parameter. If not specified + by the user, a value is selected from a table based on the `sampler`. + If the `sampler` is not in the table, then `omega` is set to `1`. +- `eta::Float64`, a parameter for adjusting the conservativeness of the distribution, higher + value means a less conservative estimate. A recommended value is `1`. +- `scaling::Float64`, a scaling parameter for the norm-squared of the sketched residual to + ensure its expectation is the norm-squared of the residual. + +For more information see: +- Pritchard, Nathaniel, and Vivak Patel. "Solving, tracking and stopping streaming linear + inverse problems." Inverse Problems (2024). doi:10.1088/1361-6420/ad5583. +- Pritchard, Nathaniel, and Vivak Patel. “Towards Practical Large-Scale Randomized Iterative + Least Squares Solvers through Uncertainty Quantification.” SIAM/ASA J. Uncertainty + Quantification 11 (2022): 996-1024. doi.org/10.1137/22M1515057 +""" +mutable struct SEDistInfo + sampler::Union{DataType, Nothing} + dimension::Int64 + block_dimension::Int64 + sigma2::Union{Float64, Nothing} + omega::Union{Float64, Nothing} + eta::Float64 + scaling::Float64 +end + +function SEDistInfo(; sampler=nothing, dimension=0, block_dimension=0, sigma2=nothing, omega=nothing, eta=1.0, scaling=0.0) + eta > 0 || throw(ArgumentError("eta must be positive")) + new(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) +end + +######################################### +# Functions +######################################### +#Function that will return rho and its uncertainty from a LSLogMA type +""" \ No newline at end of file diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl new file mode 100644 index 00000000..89e92638 --- /dev/null +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -0,0 +1,130 @@ +# This file contains the components that are needed for storing +# and update moving average information for moving average method: +# Structs: MAInfo +# Functions: update_ma! + +######################################### +# Structs +######################################### +""" + MAInfo + +A mutable structure that stores information relevant to the moving average of the + progress estimator. + +# Fields +- `lambda1::Int64`, the width of the moving average during the fast convergence phase of the algorithm. + During this fast convergence phase, the majority of variation of the sketched estimator comes from + improvement in the solution and thus wide moving average windows inaccurately represent progress. +- `lambda2::Int64`, the width of the moving average in the slower convergence phase. In the slow convergence + phase, each iterate differs from the previous one by a small amount and thus most of the observed variation + arises from the randomness of the sketched progress estimator, which is best smoothed by a wide moving + average width. +- `lambda::Int64`, the width of the moving average at the current iteration. This value is not controlled by + the user. +- `flag::Bool`, a boolean indicating which phase we are in, a value of `true` indicates slow convergence phase. +- `idx::Int64`, the index indcating what value should be replaced in the moving average buffer. +- `res_window::Vector{Float64}`, the moving average buffer. + +For more information see: +- Pritchard, Nathaniel, and Vivak Patel. "Solving, tracking and stopping streaming linear + inverse problems." Inverse Problems (2024). doi:10.1088/1361-6420/ad5583. +- Pritchard, Nathaniel, and Vivak Patel. “Towards Practical Large-Scale Randomized Iterative + Least Squares Solvers through Uncertainty Quantification.” SIAM/ASA J. Uncertainty + Quantification 11 (2022): 996-1024. doi.org/10.1137/22M1515057 +""" +mutable struct MAInfo + lambda1::Int64 + lambda2::Int64 + lambda::Int64 + flag::Bool + idx::Int64 + res_window::Vector{Float64} +end + +######################################### +# Functions +######################################### +""" + update_ma!( + log::LoggerRecipe, + res::Union{AbstractVector, Real}, + lambda_base::Int64, + iter::Int64 + ) + +Function that updates the moving average tracking statistic. + +# Arguments +- `log::LoggerRecipe`, the parent of moving average log structure. +- `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration. +- `lambda_base::Int64`, which lambda, between lambda1 and lambda2, is currently being used. +- `iter::Int64`, the current iteration. + +# Returns +- Updates the log datatype and does not explicitly return anything. +""" +function update_ma!( + log::LoggerRecipe, # log::L where L <: LoggerRecipe + res::Union{AbstractVector, Real}, + lambda_base::Int64, + iter::Int64, +) + # Variable to store the sum of the terms for rho + accum = 0 + # Variable to store the sum of the terms for iota + accum2 = 0 + ma_info = log.ma_info + ma_info.idx = ma_info.idx < ma_info.lambda2 && iter != 0 ? ma_info.idx + 1 : 1 + ma_info.res_window[ma_info.idx] = res + #Check if entire storage buffer can be used + if ma_info.lambda == ma_info.lambda2 + # Compute the moving average + for i in 1:ma_info.lambda2 + accum += ma_info.res_window[i] + accum2 += ma_info.res_window[i]^2 + end + + if mod(iter, log.collection_rate) == 0 || iter == 0 + push!(log.lambda_hist, ma_info.lambda) + push!(log.resid_hist, accum / ma_info.lambda) + (:iota_hist in fieldnames(typeof(log))) && push!(log.iota_hist, accum2 / ma_info.lambda) + end + + else + # Consider the case when lambda <= lambda1 or lambda1 < lambda < lambda2 + diff = ma_info.idx - ma_info.lambda + # Because the storage of the residual is based dependent on lambda2 and + # we want to sum only the previous lamdda terms we could have a situation + # where we want the first `idx` terms of the buffer and the last `diff` + # terms of the buffer. Doing this requires two loops + # If `diff` is negative there idx is not far enough into the buffer and + # two sums will be needed + startp1 = diff < 0 ? 1 : (diff + 1) + + # Assuming that the width of the buffer is lambda2 + startp2 = diff < 0 ? ma_info.lambda2 + diff + 1 : 2 + endp2 = diff < 0 ? ma_info.lambda2 : 1 + + # Compute the moving average two loop setup required when lambda < lambda2 + for i in startp1:ma_info.idx + accum += ma_info.res_window[i] + accum2 += ma_info.res_window[i]^2 + end + + for i in startp2:endp2 + accum += ma_info.res_window[i] + accum2 += ma_info.res_window[i]^2 + end + + #Update the log variable with the information for this update + if mod(iter, log.collection_rate) == 0 || iter == 0 + push!(log.lambda_hist, ma_info.lambda) + push!(log.resid_hist, accum / ma_info.lambda) + (:iota_hist in fieldnames(typeof(log))) && push!(log.iota_hist, accum2 / ma_info.lambda) + end + + ma_info.lambda += ma_info.lambda < lambda_base ? 1 : 0 + end + +end \ No newline at end of file From 4e69156743deef269c962ae81be2b11225aff7e0 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Sat, 17 May 2025 21:44:37 -0500 Subject: [PATCH 03/14] helpers --- src/Solvers/Loggers/ma_helpers/dist_info.jl | 157 +++++++++++++++++--- src/Solvers/Loggers/ma_helpers/ma_info.jl | 32 ++-- 2 files changed, 155 insertions(+), 34 deletions(-) diff --git a/src/Solvers/Loggers/ma_helpers/dist_info.jl b/src/Solvers/Loggers/ma_helpers/dist_info.jl index 7d5d13ac..89b07f78 100644 --- a/src/Solvers/Loggers/ma_helpers/dist_info.jl +++ b/src/Solvers/Loggers/ma_helpers/dist_info.jl @@ -14,17 +14,17 @@ in the sub-Exponential family. # Fields - `sampler::Union{DataType, Nothing}`, the type of sampling method. -- `dimension::Int64`, the dimension that of the space that is being sampled. -- `block_dimension::Int64`, the dimension of the sample. -- `sigma2::Union{Float64, Nothing}`, the variance parameter in the sub-Exponential family. +- `dimension::Integer`, the dimension that of the space that is being sampled. +- `block_dimension::Integer`, the dimension of the sample. +- `sigma2::Union{AbstractFloat, Nothing}`, the variance parameter in the sub-Exponential family. If not specified by the user, a value is selected from a table based on the `sampler`. If the `sampler` is not in the table, then `sigma2` is set to `1`. -- `omega::Union{Float64, Nothing}`, the exponential distrbution parameter. If not specified +- `omega::Union{AbstractFloat, Nothing}`, the exponential distrbution parameter. If not specified by the user, a value is selected from a table based on the `sampler`. If the `sampler` is not in the table, then `omega` is set to `1`. -- `eta::Float64`, a parameter for adjusting the conservativeness of the distribution, higher +- `eta::AbstractFloat`, a parameter for adjusting the conservativeness of the distribution, higher value means a less conservative estimate. A recommended value is `1`. -- `scaling::Float64`, a scaling parameter for the norm-squared of the sketched residual to +- `scaling::AbstractFloat`, a scaling parameter for the norm-squared of the sketched residual to ensure its expectation is the norm-squared of the residual. For more information see: @@ -36,21 +36,142 @@ For more information see: """ mutable struct SEDistInfo sampler::Union{DataType, Nothing} - dimension::Int64 - block_dimension::Int64 - sigma2::Union{Float64, Nothing} - omega::Union{Float64, Nothing} - eta::Float64 - scaling::Float64 + dimension::Integer + block_dimension::Integer + sigma2::Union{AbstractFloat, Nothing} + omega::Union{AbstractFloat, Nothing} + eta::AbstractFloat + scaling::AbstractFloat + function SEDistInfo(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) + eta > 0 || throw(ArgumentError("eta must be positive")) + return new(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) + end end -function SEDistInfo(; sampler=nothing, dimension=0, block_dimension=0, sigma2=nothing, omega=nothing, eta=1.0, scaling=0.0) - eta > 0 || throw(ArgumentError("eta must be positive")) - new(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) -end +SEDistInfo(; + sampler=nothing, + dimension=0, + block_dimension=0, + sigma2=nothing, + omega=nothing, + eta=1.0, + scaling=0.0 + ) = SEDistInfo(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) ######################################### # Functions ######################################### -#Function that will return rho and its uncertainty from a LSLogMA type -""" \ No newline at end of file +#Function that will return rho and its uncertainty from a LoggerRecipe type +""" + get_uncertainty(log::LoggerRecipe; alpha::AbstractFloat = 0.05) + +A function that gets the uncertainty from LoggerRecipe or LSLogFullMA type. + +# Arguments +- `hist::LoggerRecipe`, the parent structure of moving average log structure, + i.e. LoggerRecipe and LSLogFullMA types. Specifically, the information of + distribution (`dist_info`), and all histories stored in the structure. +- `alpha::AbstractFloat`, the confidence level. + +# Returns +- A `(1-alpha)`-credible intervals for every `rho` in the `log`, specifically + it returns a tuple with (rho, Upper bound, Lower bound). +""" +function get_uncertainty(hist::LoggerRecipe; alpha::AbstractFloat = 0.05) + l = length(hist.iota_hist) + upper = zeros(l) + lower = zeros(l) + # If the constants for the sub-Exponential distribution are not defined then define them + if typeof(hist.dist_info.sigma2) <: Nothing + throw(ArgumentError("The SE constants are empty, please set them in dist_info field of LoggerRecipe first.")) + end + + for i in 1:l + width = hist.lambda_hist[i] + iota = hist.iota_hist[i] + rho = hist.resid_hist[i] + #Define the variance term for the Gaussian part + cG = hist.dist_info.sigma2 * (1 + log(width)) * iota / (hist.dist_info.eta * width) + #If there is an omega in the sub-Exponential distribution then skip that calculation + if typeof(hist.dist_info.omega) <: Nothing + # Compute the threshold bound in the case where there is no omega + diffG = sqrt(cG * 2 * log(2/(alpha))) + upper[i] = rho + diffG + lower[i] = rho - diffG + else + #compute error bound when there is an omega + diffG = sqrt(cG * 2 * log(2/(alpha))) + diffO = sqrt(iota) * 2 * log(2/(alpha)) * hist.dist_info.omega / (hist.dist_info.eta * width) + diffM = max(diffG, diffO) + upper[i] = rho + diffM + lower[i] = rho - diffM + end + + end + + return (hist.resid_hist, upper, lower) +end + +""" + get_SE_constants!(log::LoggerRecipe, sampler::Type{T<:CompressorRecipe}) + +A function that returns a default set of sub-Exponential constants for each sampling method. + This function is not exported and thus the user does not have direct access to it. + +# Arguments +- `log::LoggerRecipe`, the log containing all the tracking information. Specifically, + the information of distribution (`dist_info`). +- `sampler::Type{CompressorRecipe}`, the type of sampler being used. + +# Returns +- Performs an inplace update of the sub-Exponential constants for the log. Additionally, + updates the scaling constant to ensure expectation of block norms is equal to true norm. + If default is not a defined a warning is returned that sigma2 is set 1 and scaling + is set to 1. +""" +function get_SE_constants!(log::LoggerRecipe, sampler::Type{T}) where T<:CompressorRecipe + @warn "No constants defined for method of type $sampler. By default we set sigma2 to 1 and scaling to 1." + log.dist_info.sigma2 = 1 + log.dist_info.scaling = 1 +end + +for type in (LinSysVecRowDetermCyclic,LinSysVecRowHopRandCyclic, + LinSysVecRowOneRandCyclic, LinSysVecRowSVSampler, + LinSysVecRowRandCyclic, LinSysVecRowUnidSampler, + LinSysVecRowDistCyclic, LinSysVecRowResidCyclic, + LinSysVecRowMaxResidual, LinSysVecRowMaxDistance,) + @eval begin + function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) + log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) + log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension + end + + end + +end + + +# Column subsetting methods have same constants as in row case +for type in (LinSysVecColOneRandCyclic, LinSysVecColDetermCyclic) + @eval begin + function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) + log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) + log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension + end + + end + +end + +# For row samplers with gaussian sampling we have sigma2 = 1/.2345 and omega = .1127 +for type in (LinSysVecRowGaussSampler, LinSysVecRowSparseGaussSampler) + @eval begin + function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) + log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta) + log.dist_info.omega = .1127 + log.dist_info.scaling = 1. + end + + end + +end \ No newline at end of file diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index 89e92638..aa992941 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -13,18 +13,18 @@ A mutable structure that stores information relevant to the moving average of th progress estimator. # Fields -- `lambda1::Int64`, the width of the moving average during the fast convergence phase of the algorithm. +- `lambda1::Integer`, the width of the moving average during the fast convergence phase of the algorithm. During this fast convergence phase, the majority of variation of the sketched estimator comes from improvement in the solution and thus wide moving average windows inaccurately represent progress. -- `lambda2::Int64`, the width of the moving average in the slower convergence phase. In the slow convergence +- `lambda2::Integer`, the width of the moving average in the slower convergence phase. In the slow convergence phase, each iterate differs from the previous one by a small amount and thus most of the observed variation arises from the randomness of the sketched progress estimator, which is best smoothed by a wide moving average width. -- `lambda::Int64`, the width of the moving average at the current iteration. This value is not controlled by +- `lambda::Integer`, the width of the moving average at the current iteration. This value is not controlled by the user. - `flag::Bool`, a boolean indicating which phase we are in, a value of `true` indicates slow convergence phase. -- `idx::Int64`, the index indcating what value should be replaced in the moving average buffer. -- `res_window::Vector{Float64}`, the moving average buffer. +- `idx::Integer`, the index indcating what value should be replaced in the moving average buffer. +- `res_window::Vector{<:AbstractFloat}`, the moving average buffer. For more information see: - Pritchard, Nathaniel, and Vivak Patel. "Solving, tracking and stopping streaming linear @@ -34,12 +34,12 @@ For more information see: Quantification 11 (2022): 996-1024. doi.org/10.1137/22M1515057 """ mutable struct MAInfo - lambda1::Int64 - lambda2::Int64 - lambda::Int64 + lambda1::Integer + lambda2::Integer + lambda::Integer flag::Bool - idx::Int64 - res_window::Vector{Float64} + idx::Integer + res_window::Vector{<:AbstractFloat} end ######################################### @@ -49,8 +49,8 @@ end update_ma!( log::LoggerRecipe, res::Union{AbstractVector, Real}, - lambda_base::Int64, - iter::Int64 + lambda_base::Integer, + iter::Integer ) Function that updates the moving average tracking statistic. @@ -58,8 +58,8 @@ Function that updates the moving average tracking statistic. # Arguments - `log::LoggerRecipe`, the parent of moving average log structure. - `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration. -- `lambda_base::Int64`, which lambda, between lambda1 and lambda2, is currently being used. -- `iter::Int64`, the current iteration. +- `lambda_base::Integer`, which lambda, between lambda1 and lambda2, is currently being used. +- `iter::Integer`, the current iteration. # Returns - Updates the log datatype and does not explicitly return anything. @@ -67,8 +67,8 @@ Function that updates the moving average tracking statistic. function update_ma!( log::LoggerRecipe, # log::L where L <: LoggerRecipe res::Union{AbstractVector, Real}, - lambda_base::Int64, - iter::Int64, + lambda_base::Integer, + iter::Integer, ) # Variable to store the sum of the terms for rho accum = 0 From b77ab9d35e1b5a81fb294815dcb2dfdf8e2579b1 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Tue, 20 May 2025 00:56:03 -0500 Subject: [PATCH 04/14] basic_logger --- src/Solvers/Loggers/basic_logger.jl | 2 +- src/Solvers/Loggers/ma_helpers/dist_info.jl | 83 ++++++----- src/Solvers/Loggers/moving_average_logger.jl | 142 +++++++++++++++++++ 3 files changed, 193 insertions(+), 34 deletions(-) diff --git a/src/Solvers/Loggers/basic_logger.jl b/src/Solvers/Loggers/basic_logger.jl index e1e906d3..d1cfc5e8 100644 --- a/src/Solvers/Loggers/basic_logger.jl +++ b/src/Solvers/Loggers/basic_logger.jl @@ -1,7 +1,7 @@ """ BasicLogger <: Logger -This is a mutable struct that contains the `max_it` parameter and stores the error metric +This is a struct that contains the `max_it` parameter and stores the error metric in a vector. Checks convergence of the solver based on the log information. # Fields diff --git a/src/Solvers/Loggers/ma_helpers/dist_info.jl b/src/Solvers/Loggers/ma_helpers/dist_info.jl index 89b07f78..422d9cf2 100644 --- a/src/Solvers/Loggers/ma_helpers/dist_info.jl +++ b/src/Solvers/Loggers/ma_helpers/dist_info.jl @@ -69,7 +69,8 @@ A function that gets the uncertainty from LoggerRecipe or LSLogFullMA type. # Arguments - `hist::LoggerRecipe`, the parent structure of moving average log structure, - i.e. LoggerRecipe and LSLogFullMA types. Specifically, the information of + TODO: check the types + i.e. MALogger and FullMALogger types. Specifically, the information of distribution (`dist_info`), and all histories stored in the structure. - `alpha::AbstractFloat`, the confidence level. @@ -90,16 +91,16 @@ function get_uncertainty(hist::LoggerRecipe; alpha::AbstractFloat = 0.05) width = hist.lambda_hist[i] iota = hist.iota_hist[i] rho = hist.resid_hist[i] - #Define the variance term for the Gaussian part + # Define the variance term for the Gaussian part cG = hist.dist_info.sigma2 * (1 + log(width)) * iota / (hist.dist_info.eta * width) - #If there is an omega in the sub-Exponential distribution then skip that calculation + # If there is an omega in the sub-Exponential distribution then skip that calculation if typeof(hist.dist_info.omega) <: Nothing # Compute the threshold bound in the case where there is no omega diffG = sqrt(cG * 2 * log(2/(alpha))) upper[i] = rho + diffG lower[i] = rho - diffG else - #compute error bound when there is an omega + # Compute error bound when there is an omega diffG = sqrt(cG * 2 * log(2/(alpha))) diffO = sqrt(iota) * 2 * log(2/(alpha)) * hist.dist_info.omega / (hist.dist_info.eta * width) diffM = max(diffG, diffO) @@ -113,7 +114,7 @@ function get_uncertainty(hist::LoggerRecipe; alpha::AbstractFloat = 0.05) end """ - get_SE_constants!(log::LoggerRecipe, sampler::Type{T<:CompressorRecipe}) + get_SE_constants!(log::LoggerRecipe, sampler::Type{T<:Compressor}) A function that returns a default set of sub-Exponential constants for each sampling method. This function is not exported and thus the user does not have direct access to it. @@ -121,7 +122,7 @@ A function that returns a default set of sub-Exponential constants for each samp # Arguments - `log::LoggerRecipe`, the log containing all the tracking information. Specifically, the information of distribution (`dist_info`). -- `sampler::Type{CompressorRecipe}`, the type of sampler being used. +- `sampler::Type{Compressor}`, the type of sampler being used. # Returns - Performs an inplace update of the sub-Exponential constants for the log. Additionally, @@ -129,47 +130,63 @@ A function that returns a default set of sub-Exponential constants for each samp If default is not a defined a warning is returned that sigma2 is set 1 and scaling is set to 1. """ -function get_SE_constants!(log::LoggerRecipe, sampler::Type{T}) where T<:CompressorRecipe +function get_SE_constants!(log::LoggerRecipe, sampler::Type{T}) where T<:Compressor @warn "No constants defined for method of type $sampler. By default we set sigma2 to 1 and scaling to 1." log.dist_info.sigma2 = 1 log.dist_info.scaling = 1 end -for type in (LinSysVecRowDetermCyclic,LinSysVecRowHopRandCyclic, - LinSysVecRowOneRandCyclic, LinSysVecRowSVSampler, - LinSysVecRowRandCyclic, LinSysVecRowUnidSampler, - LinSysVecRowDistCyclic, LinSysVecRowResidCyclic, - LinSysVecRowMaxResidual, LinSysVecRowMaxDistance,) - @eval begin - function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) - log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) - log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension - end - end +# for type in (LinSysVecRowDetermCyclic,LinSysVecRowHopRandCyclic, +# LinSysVecRowOneRandCyclic, LinSysVecRowSVSampler, +# LinSysVecRowRandCyclic, LinSysVecRowUnidSampler, +# LinSysVecRowDistCyclic, LinSysVecRowResidCyclic, +# LinSysVecRowMaxResidual, LinSysVecRowMaxDistance,) +# @eval begin +# function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) +# log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) +# log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension +# end -end +# end +# end -# Column subsetting methods have same constants as in row case -for type in (LinSysVecColOneRandCyclic, LinSysVecColDetermCyclic) - @eval begin - function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) - log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) - log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension - end - end +# # Column subsetting methods have same constants as in row case +# for type in (LinSysVecColOneRandCyclic, LinSysVecColDetermCyclic) +# @eval begin +# function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) +# log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) +# log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension +# end -end +# end + +# end + +# # For row samplers with gaussian sampling we have sigma2 = 1/.2345 and omega = .1127 +# for type in (LinSysVecRowGaussSampler, LinSysVecRowSparseGaussSampler) +# @eval begin +# function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) +# log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta) +# log.dist_info.omega = .1127 +# log.dist_info.scaling = 1. +# end + +# end + +# end -# For row samplers with gaussian sampling we have sigma2 = 1/.2345 and omega = .1127 -for type in (LinSysVecRowGaussSampler, LinSysVecRowSparseGaussSampler) +for type in (Gaussian) @eval begin function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) - log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta) - log.dist_info.omega = .1127 - log.dist_info.scaling = 1. + if sampler.compression_dim == 1 + log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta) + log.dist_info.omega = .1127 + log.dist_info.scaling = 1. + end + end end diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index e69de29b..acac1efe 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -0,0 +1,142 @@ +""" + FullMALogger <: Logger + +A structure that stores information of specification about a randomized linear solver's + behavior. The log assumes that the full linear system is available for processing. + The goal of this log is usually for research, development or testing as it is unlikely + that the entire residual vector is readily available. + +# Fields +- `collection_rate::Integer`, the frequency with which to record information about progress + to append to the remaining fields, starting with the initialization + (i.e., iteration `0`). For example, `collection_rate` = 3 means the iteration + difference between each records is 3, i.e. recording information at + iteration `0`, `3`, `6`, `9`, .... +- `ma_info::MAInfo`, [`MAInfo`](@ref) +- `resid_hist::Vector{AbstractFloat}`, retains a vector of numbers corresponding to the residual + (uses the whole system to compute the residual). These values are stored at iterates + specified by `collection_rate`. +- `lambda_hist::Vector{Integer}`, contains the widths of the moving average. + These values are stored at iterates specified by `collection_rate`. +- `resid_norm::Function`, a function that accepts a single vector argument and returns a + scalar. Used to compute the residual size. +- `iterations::Integer`, the number of iterations of the solver. +- `converged::Bool`, a flag to indicate whether the system has converged by some measure. + By default this is `false`. + +# Constructors + + FullMALogger(;collection_rate=1, lambda1=1, lambda2=30, resid_norm=norm) + +## Keywords +- `collection_rate::Integer`, the frequency with which to record information about progress + to append to the remaining fields, starting with the initialization + (i.e., iteration `0`). For example, `collection_rate` = 3 means the iteration + difference between each records is 3, i.e. recording information at + iteration `0`, `3`, `6`, `9`, .... By default, this is set to `1`. +- `lambda1::Integer`, the TODO. By default, this is set to `1`. +- `lambda2::Integer`, the TODO. By default, this is set to `30`. +- `resid_norm::Function`, a function that accepts a single vector argument and returns a + scalar. Used to compute the residual size. By default, `norm`, which is Euclidean + norm, is set. + +## Returns +- A `FullMALogger` object. + +## Throws TODO +- `ArgumentError` if `compression_dim` is non-positive, if `nnz` is exceeds + `compression_dim`, or if `nnz` is non-positive. +""" +struct FullMALogger <: Logger + collection_rate::Integer + ma_info::MAInfo + resid_hist::Vector{AbstractFloat} + lambda_hist::Vector{Integer} + resid_norm::Function + iterations::Integer + converged::Bool +end + +FullMALogger(; + collection_rate::Integer=1, + lambda1::Integer=1, + lambda2::Integer=30, + resid_norm::Function=norm, + ) = LSLogFullMA(collection_rate, + MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), + AbstractFloat[], + Int64[], + resid_norm, + -1, + false + ) + + +""" + FullMALoggerRecipe <: LoggerRecipe + + TODO +This is a mutable struct that contains the `max_it` parameter and stores the error metric + in a vector. Checks convergence of the solver based on the log information. + +# Fields + +""" + + + + + + + + + + + + + + + + + +# Common interface for update +function update_logger!( + log::FullMALogger, + sampler::LinSysSampler, + x::AbstractVector, + samp::Tuple, + iter::Int64, + A::AbstractArray, + b::AbstractVector, +) + # Update iteration counter + log.iterations = iter + + ############################### + # Implement moving average (MA) + ############################### + ma_info = log.ma_info + # Compute the current residual to second power to align with theory + res_norm_iter = log.resid_norm(A * x - b) + res::Float64 = res_norm_iter^2 + + # Check if MA is in lambda1 or lambda2 regime + if ma_info.flag + update_ma!(log, res, ma_info.lambda2, iter) + else + # Check if we can switch between lambda1 and lambda2 regime + # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime + # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag + # because update_ma changes res_window and ma_info.idx we must check condition first + flag_cond = iter == 0 || res <= ma_info.res_window[ma_info.idx] + update_ma!(log, res, ma_info.lambda1, iter) + ma_info.flag = !flag_cond + end + +end + + + + + + From 28fdb0af0a5d49eaf041e330b036034632befc9f Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Tue, 20 May 2025 17:38:42 -0500 Subject: [PATCH 05/14] changes --- src/Solvers/Loggers.jl | 12 ++- src/Solvers/Loggers/ma_helpers/ma_info.jl | 29 +++--- src/Solvers/Loggers/ma_helpers/ma_stop.jl | 97 ++++++++++++++++++++ src/Solvers/Loggers/moving_average_logger.jl | 77 +++++++++++++--- 4 files changed, 184 insertions(+), 31 deletions(-) create mode 100644 src/Solvers/Loggers/ma_helpers/ma_stop.jl diff --git a/src/Solvers/Loggers.jl b/src/Solvers/Loggers.jl index 44bc24b4..e0273996 100644 --- a/src/Solvers/Loggers.jl +++ b/src/Solvers/Loggers.jl @@ -89,7 +89,15 @@ function reset_logger!(logger::LoggerRecipe) return nothing end -############################## -# Include Logger Files +############################### +# Include Logger Methods Files ############################### include("Loggers/basic_logger.jl") +include("Loggers/moving_average_logger.jl") + +################################################ +# Include Logger Moving Average helpers Files +################################################ +include("Loggers/ma_helpers/ma_info.jl") +include("Loggers/ma_helpers/dist_info.jl") +include("Loggers/ma_helpers/ma_stop.jl") diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index aa992941..320a4eea 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -66,8 +66,8 @@ Function that updates the moving average tracking statistic. """ function update_ma!( log::LoggerRecipe, # log::L where L <: LoggerRecipe - res::Union{AbstractVector, Real}, - lambda_base::Integer, + res::Union{AbstractVector,Real}, + lambda_base::Integer, iter::Integer, ) # Variable to store the sum of the terms for rho @@ -78,17 +78,18 @@ function update_ma!( ma_info.idx = ma_info.idx < ma_info.lambda2 && iter != 0 ? ma_info.idx + 1 : 1 ma_info.res_window[ma_info.idx] = res #Check if entire storage buffer can be used - if ma_info.lambda == ma_info.lambda2 + if ma_info.lambda == ma_info.lambda2 # Compute the moving average - for i in 1:ma_info.lambda2 + for i in 1:(ma_info.lambda2) accum += ma_info.res_window[i] accum2 += ma_info.res_window[i]^2 end - + if mod(iter, log.collection_rate) == 0 || iter == 0 push!(log.lambda_hist, ma_info.lambda) - push!(log.resid_hist, accum / ma_info.lambda) - (:iota_hist in fieldnames(typeof(log))) && push!(log.iota_hist, accum2 / ma_info.lambda) + push!(log.resid_hist, accum / ma_info.lambda) + (:iota_hist in fieldnames(typeof(log))) && + push!(log.iota_hist, accum2 / ma_info.lambda) end else @@ -101,13 +102,13 @@ function update_ma!( # If `diff` is negative there idx is not far enough into the buffer and # two sums will be needed startp1 = diff < 0 ? 1 : (diff + 1) - + # Assuming that the width of the buffer is lambda2 - startp2 = diff < 0 ? ma_info.lambda2 + diff + 1 : 2 + startp2 = diff < 0 ? ma_info.lambda2 + diff + 1 : 2 endp2 = diff < 0 ? ma_info.lambda2 : 1 # Compute the moving average two loop setup required when lambda < lambda2 - for i in startp1:ma_info.idx + for i in startp1:(ma_info.idx) accum += ma_info.res_window[i] accum2 += ma_info.res_window[i]^2 end @@ -120,11 +121,11 @@ function update_ma!( #Update the log variable with the information for this update if mod(iter, log.collection_rate) == 0 || iter == 0 push!(log.lambda_hist, ma_info.lambda) - push!(log.resid_hist, accum / ma_info.lambda) - (:iota_hist in fieldnames(typeof(log))) && push!(log.iota_hist, accum2 / ma_info.lambda) + push!(log.resid_hist, accum / ma_info.lambda) + (:iota_hist in fieldnames(typeof(log))) && + push!(log.iota_hist, accum2 / ma_info.lambda) end - + ma_info.lambda += ma_info.lambda < lambda_base ? 1 : 0 end - end \ No newline at end of file diff --git a/src/Solvers/Loggers/ma_helpers/ma_stop.jl b/src/Solvers/Loggers/ma_helpers/ma_stop.jl new file mode 100644 index 00000000..631ca727 --- /dev/null +++ b/src/Solvers/Loggers/ma_helpers/ma_stop.jl @@ -0,0 +1,97 @@ +""" + MAStop + +A structure that specifies a stopping criterion that incoroporates the randomness of the moving average estimator. That is, once a method + achieves a certain number of iterations, it stops. + +# Fields +- `max_iter::Integer`, the maximum number of iterations. +- `threshold::AbstractFloat`, the value of the estimator that is sufficient progress. +- `delta1::AbstractFloat`, the percent below the threshold does the true value of the progress estimator need to be for not stopping to be a mistake. This is equivalent to stopping too late. +- `delta2::AbstractFloat`, the percent above the threshold does the true value of the progress estimator need to be for stopping to be a + mistake. This is equivalent to stopping too early. +- `chi1::AbstractFloat`, the probability that the stopping too late action occurs. +- `chi2::AbstractFloat`, the probability that the stopping too early action occurs. +# Constructors +- Calling MAStop(iter) will specify the users desired maximum number of iterations, threshold = 1e-10, delta1 = .9, delta2 = 1.1, chi1 = 0.01, and chi2 = 0.01. +""" +struct MAStop + max_iter::Integer + threshold::AbstractFloat + delta1::AbstractFloat + delta2::AbstractFloat + chi1::AbstractFloat + chi2::AbstractFloat +end + +function MAStop(iter; threshold=1e-10, delta1=0.9, delta2=1.1, chi1=0.01, chi2=0.01) + return MAStop(iter, threshold, delta1, delta2, chi1, chi2) +end + +# Common interface for stopping criteria +function check_stop_criterion(log::MALogger, stop::MAStop) + its = log.iterations + if its > 0 + I_threshold = iota_threshold(log, stop) + thresholdChecks = + sqrt(log.iota_hist[its]) <= I_threshold && log.resid_hist[its] <= stop.threshold + else + thresholdChecks = false + end + return (thresholdChecks || its == stop.max_iter ? true : false) +end + +""" + iota_threshold(hist::LSLogMA, stop::LSStopMA) + +Function that computes the stopping criterion using the sub-Exponential distribution from the `LSLogMA`, and the stopping criterion information in `LSSopMA`. This function is not exported and thus not directly callable by the user. + +# Inputs +- `hist::LSLogMA`, the log information for the moving average tracking. +- `stop::LSStopMA`, the stopping information for the stopping criterion. + +# Ouputs +Returns the stoppping criterion value. + +Pritchard, Nathaniel, and Vivak Patel. "Solving, Tracking and Stopping Streaming Linear Inverse Problems." arXiv preprint arXiv:2201.05741 (2024). +""" +function iota_threshold(hist::LSLogMA, stop::MAStop) + delta1 = stop.delta1 + delta2 = stop.delta2 + chi1 = stop.chi1 + chi2 = stop.chi2 + threshold = stop.threshold + lambda = hist.ma_info.lambda + # If the constants for the sub-Exponential distribution are not defined then define them + + if typeof(hist.dist_info.sigma2) <: Nothing + get_SE_constants!(hist, hist.dist_info.sampler) + end + #If there is an omega in the sub-Exponential distribution then skip that calculation + if typeof(hist.dist_info.omega) <: Nothing + # Compute the threshold bound in the case where there is no omega + c = min( + (1 - delta1)^2 * threshold^2 / (2 * log(1 / chi1)), + (delta2 - 1)^2 * threshold^2 / (2 * log(1 / chi2)), + ) + c /= + (hist.dist_info.sigma2 * sqrt(hist.iota_hist[hist.iterations])) * + (1 + log(lambda)) / lambda + else + #compute error bound when there is an omega + siota = + (hist.dist_info.sigma2 * sqrt(hist.iota_hist[hist.iterations])) * + (1 + log(lambda)) / lambda + min1 = min( + (1 - delta1)^2 * threshold^2 / (2 * log(1 / chi1) * siota), + lambda * (1 - delta1) * threshold / (2 * log(1 / chi1) * hist.dist_info.omega), + ) + min2 = min( + (delta2 - 1)^2 * threshold^2 / (2 * log(1 / chi2) * siota), + lambda * (delta2 - 1) * threshold / (2 * log(1 / chi2) * hist.dist_info.omega), + ) + c = min(min1, min2) + end + + return c +end \ No newline at end of file diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index acac1efe..1db4efb1 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -48,27 +48,42 @@ A structure that stores information of specification about a randomized linear s `compression_dim`, or if `nnz` is non-positive. """ struct FullMALogger <: Logger + max_it::Integer collection_rate::Integer ma_info::MAInfo - resid_hist::Vector{AbstractFloat} - lambda_hist::Vector{Integer} - resid_norm::Function - iterations::Integer - converged::Bool + threshold_info::MAStop + stopping_criterion::Function + # resid_hist::Vector{AbstractFloat} + # lambda_hist::Vector{Integer} + # resid_norm::Function + # iterations::Integer + # converged::Bool + function BasicLogger(max_it, collection_rate, threshold_info, stopping_criterion) + if max_it < 0 + throw(ArgumentError("Field `max_it` must be positive or 0.")) + elseif collection_rate < 1 + throw(ArgumentError("Field `colection_rate` must be positive.")) + elseif collection_rate > max_it && max_it > 0 + throw(ArgumentError("Field `colection_rate` must be smaller than `max_it`.")) + end + + return new(max_it, collection_rate, threshold_info, stopping_criterion) + end + end FullMALogger(; collection_rate::Integer=1, lambda1::Integer=1, lambda2::Integer=30, - resid_norm::Function=norm, + # resid_norm::Function=norm, ) = LSLogFullMA(collection_rate, MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), - AbstractFloat[], - Int64[], - resid_norm, - -1, - false + # AbstractFloat[], + # Int64[], + # resid_norm, + # -1, + # false ) @@ -76,12 +91,46 @@ FullMALogger(; FullMALoggerRecipe <: LoggerRecipe TODO -This is a mutable struct that contains the `max_it` parameter and stores the error metric +The recipe contains the information of `FullMALogger`, stores the error metric in a vector. Checks convergence of the solver based on the log information. # Fields """ +mutable struct FullMALoggerRecipe{F<:Function} <: LoggerRecipe + max_it::Integer + error::AbstractFloat + iteration::Integer + record_location::Integer + collection_rate::Integer + converged::Bool + resid_hist::Vector{AbstractFloat} + lambda_hist::Vector{Integer} + threshold_info::MAStop + stopping_criterion::F +end + +function complete_logger(logger::FullMALogger) + # By using ceil if we divide exactly we always have space to record last value, if it + # does not divide exactly we have one more than required and thus enough space to record + # the last value + max_collection = Int(ceil(logger.max_it / logger.collection_rate)) + # Use one more than max_it to collect + res_hist = zeros(max_collection + 1) + lambda_hist = zeros(max_collection + 1) + return FullMALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, + 0.0, + logger.threshold_info, + 1, + 1, + logger.collection_rate, + false, + res_hist, + lambda_hist, + MAStop, + logger.stopping_criterion + ) +end @@ -102,10 +151,9 @@ This is a mutable struct that contains the `max_it` parameter and stores the err # Common interface for update function update_logger!( log::FullMALogger, - sampler::LinSysSampler, x::AbstractVector, samp::Tuple, - iter::Int64, + iter::Integer, A::AbstractArray, b::AbstractVector, ) @@ -139,4 +187,3 @@ end - From a647c9c363f166f65b782d6e15a52944e9e97888 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Wed, 21 May 2025 00:12:45 -0500 Subject: [PATCH 06/14] updatees --- src/Solvers/Loggers/ma_helpers/ma_stop.jl | 6 +- src/Solvers/Loggers/moving_average_logger.jl | 75 ++++++++++++-------- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/Solvers/Loggers/ma_helpers/ma_stop.jl b/src/Solvers/Loggers/ma_helpers/ma_stop.jl index 631ca727..29c8a3a3 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_stop.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_stop.jl @@ -5,7 +5,6 @@ A structure that specifies a stopping criterion that incoroporates the randomnes achieves a certain number of iterations, it stops. # Fields -- `max_iter::Integer`, the maximum number of iterations. - `threshold::AbstractFloat`, the value of the estimator that is sufficient progress. - `delta1::AbstractFloat`, the percent below the threshold does the true value of the progress estimator need to be for not stopping to be a mistake. This is equivalent to stopping too late. - `delta2::AbstractFloat`, the percent above the threshold does the true value of the progress estimator need to be for stopping to be a @@ -16,7 +15,6 @@ A structure that specifies a stopping criterion that incoroporates the randomnes - Calling MAStop(iter) will specify the users desired maximum number of iterations, threshold = 1e-10, delta1 = .9, delta2 = 1.1, chi1 = 0.01, and chi2 = 0.01. """ struct MAStop - max_iter::Integer threshold::AbstractFloat delta1::AbstractFloat delta2::AbstractFloat @@ -24,8 +22,8 @@ struct MAStop chi2::AbstractFloat end -function MAStop(iter; threshold=1e-10, delta1=0.9, delta2=1.1, chi1=0.01, chi2=0.01) - return MAStop(iter, threshold, delta1, delta2, chi1, chi2) +function MAStop(;threshold=1e-10, delta1=0.9, delta2=1.1, chi1=0.01, chi2=0.01) + return MAStop(threshold, delta1, delta2, chi1, chi2) end # Common interface for stopping criteria diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index 1db4efb1..9a42e6e8 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -58,7 +58,7 @@ struct FullMALogger <: Logger # resid_norm::Function # iterations::Integer # converged::Bool - function BasicLogger(max_it, collection_rate, threshold_info, stopping_criterion) + function BasicLogger(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) if max_it < 0 throw(ArgumentError("Field `max_it` must be positive or 0.")) elseif collection_rate < 1 @@ -67,18 +67,23 @@ struct FullMALogger <: Logger throw(ArgumentError("Field `colection_rate` must be smaller than `max_it`.")) end - return new(max_it, collection_rate, threshold_info, stopping_criterion) + return new(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) end end FullMALogger(; + max_it::Integer=0, collection_rate::Integer=1, lambda1::Integer=1, lambda2::Integer=30, - # resid_norm::Function=norm, - ) = LSLogFullMA(collection_rate, + threshold_info(), + stopping_criterion=check_stop_criterion + ) = LSLogFullMA(max_it, + collection_rate, MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), + threshold_info(), + stopping_criterion # AbstractFloat[], # Int64[], # resid_norm, @@ -119,17 +124,17 @@ function complete_logger(logger::FullMALogger) res_hist = zeros(max_collection + 1) lambda_hist = zeros(max_collection + 1) return FullMALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, - 0.0, - logger.threshold_info, - 1, - 1, - logger.collection_rate, - false, - res_hist, - lambda_hist, - MAStop, - logger.stopping_criterion - ) + 0.0, + logger.threshold_info, + 1, + 1, + logger.collection_rate, + false, + res_hist, + lambda_hist, + MAStop(), + logger.stopping_criterion + ) end @@ -150,37 +155,51 @@ end # Common interface for update function update_logger!( - log::FullMALogger, - x::AbstractVector, - samp::Tuple, - iter::Integer, - A::AbstractArray, - b::AbstractVector, + logger::FullMALoggerRecipe, + error::AbstractFloat, + iteration::Integer ) # Update iteration counter - log.iterations = iter + logger.iterations = iteration + logger.error = error ############################### # Implement moving average (MA) ############################### - ma_info = log.ma_info + ma_info = logger.ma_info # Compute the current residual to second power to align with theory - res_norm_iter = log.resid_norm(A * x - b) - res::Float64 = res_norm_iter^2 + res::AbstractFloat = error^2 # Check if MA is in lambda1 or lambda2 regime if ma_info.flag - update_ma!(log, res, ma_info.lambda2, iter) + update_ma!(logger, res, ma_info.lambda2, iteration) else # Check if we can switch between lambda1 and lambda2 regime # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag # because update_ma changes res_window and ma_info.idx we must check condition first - flag_cond = iter == 0 || res <= ma_info.res_window[ma_info.idx] - update_ma!(log, res, ma_info.lambda1, iter) + flag_cond = iteration == 0 || res <= ma_info.res_window[ma_info.idx] + update_ma!(logger, res, ma_info.lambda1, iteration) ma_info.flag = !flag_cond end + ############################### + # Stop the algorithm + ############################### + # Always check max_it stopping criterion + # Compute in this way to avoid bounds error from searching in the max_it + 1 location + logger.converged = iteration <= logger.max_it ? + logger.stopping_criterion(logger, logger.threshold_info) : + true + + # log according to collection rate or if we have converged + if rem(iteration, logger.collection_rate) == 0 || logger.converged + logger.hist[logger.record_location] = error + logger.record_location += 1 + end + + return nothing + end From 3f8a450f4c3433ee0bdd553d77067f74364c4b0b Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Wed, 21 May 2025 10:25:54 -0500 Subject: [PATCH 07/14] updates --- src/Solvers/Loggers/ma_helpers/dist_info.jl | 21 +++--- src/Solvers/Loggers/ma_helpers/ma_info.jl | 6 +- src/Solvers/Loggers/ma_helpers/ma_stop.jl | 68 ++++++++++++-------- src/Solvers/Loggers/moving_average_logger.jl | 39 +++-------- 4 files changed, 63 insertions(+), 71 deletions(-) diff --git a/src/Solvers/Loggers/ma_helpers/dist_info.jl b/src/Solvers/Loggers/ma_helpers/dist_info.jl index 422d9cf2..4040a849 100644 --- a/src/Solvers/Loggers/ma_helpers/dist_info.jl +++ b/src/Solvers/Loggers/ma_helpers/dist_info.jl @@ -63,22 +63,21 @@ SEDistInfo(; ######################################### #Function that will return rho and its uncertainty from a LoggerRecipe type """ - get_uncertainty(log::LoggerRecipe; alpha::AbstractFloat = 0.05) + get_uncertainty(log::MALoggerRecipe; alpha::AbstractFloat = 0.05) -A function that gets the uncertainty from LoggerRecipe or LSLogFullMA type. +A function that gets the uncertainty from moving average logger recipe types. # Arguments -- `hist::LoggerRecipe`, the parent structure of moving average log structure, - TODO: check the types - i.e. MALogger and FullMALogger types. Specifically, the information of - distribution (`dist_info`), and all histories stored in the structure. +- `hist::MALoggerRecipe`, the moving average logger recipe structure, + i.e. MALogger type. Specifically, the information of distribution (`dist_info`), + and all histories stored in the structure. - `alpha::AbstractFloat`, the confidence level. # Returns - A `(1-alpha)`-credible intervals for every `rho` in the `log`, specifically it returns a tuple with (rho, Upper bound, Lower bound). """ -function get_uncertainty(hist::LoggerRecipe; alpha::AbstractFloat = 0.05) +function get_uncertainty(hist::MALoggerRecipe; alpha::AbstractFloat = 0.05) l = length(hist.iota_hist) upper = zeros(l) lower = zeros(l) @@ -114,14 +113,14 @@ function get_uncertainty(hist::LoggerRecipe; alpha::AbstractFloat = 0.05) end """ - get_SE_constants!(log::LoggerRecipe, sampler::Type{T<:Compressor}) + get_SE_constants!(log::MALoggerRecipe, sampler::Type{T<:Compressor}) A function that returns a default set of sub-Exponential constants for each sampling method. This function is not exported and thus the user does not have direct access to it. # Arguments -- `log::LoggerRecipe`, the log containing all the tracking information. Specifically, - the information of distribution (`dist_info`). +- `log::MALoggerRecipe`, the log containing all the tracking + information. Specifically, the information of distribution (`dist_info`). - `sampler::Type{Compressor}`, the type of sampler being used. # Returns @@ -130,7 +129,7 @@ A function that returns a default set of sub-Exponential constants for each samp If default is not a defined a warning is returned that sigma2 is set 1 and scaling is set to 1. """ -function get_SE_constants!(log::LoggerRecipe, sampler::Type{T}) where T<:Compressor +function get_SE_constants!(log::MALoggerRecipe, sampler::Type{T}) where T<:Compressor @warn "No constants defined for method of type $sampler. By default we set sigma2 to 1 and scaling to 1." log.dist_info.sigma2 = 1 log.dist_info.scaling = 1 diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index 320a4eea..1fefdebc 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -47,7 +47,7 @@ end ######################################### """ update_ma!( - log::LoggerRecipe, + log::Union{MALoggerRecipe, FullMALoggerRecipe}, res::Union{AbstractVector, Real}, lambda_base::Integer, iter::Integer @@ -56,7 +56,7 @@ end Function that updates the moving average tracking statistic. # Arguments -- `log::LoggerRecipe`, the parent of moving average log structure. +- `log::Union{MALoggerRecipe, FullMALoggerRecipe}`, the moving average logger recipe structure. - `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration. - `lambda_base::Integer`, which lambda, between lambda1 and lambda2, is currently being used. - `iter::Integer`, the current iteration. @@ -65,7 +65,7 @@ Function that updates the moving average tracking statistic. - Updates the log datatype and does not explicitly return anything. """ function update_ma!( - log::LoggerRecipe, # log::L where L <: LoggerRecipe + log::Union{MALoggerRecipe, FullMALoggerRecipe}, res::Union{AbstractVector,Real}, lambda_base::Integer, iter::Integer, diff --git a/src/Solvers/Loggers/ma_helpers/ma_stop.jl b/src/Solvers/Loggers/ma_helpers/ma_stop.jl index 29c8a3a3..bad39f5d 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_stop.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_stop.jl @@ -5,7 +5,7 @@ A structure that specifies a stopping criterion that incoroporates the randomnes achieves a certain number of iterations, it stops. # Fields -- `threshold::AbstractFloat`, the value of the estimator that is sufficient progress. +- `threshold::Union{AbstractFloat, Tuple}`, the value of the estimator that is sufficient progress. - `delta1::AbstractFloat`, the percent below the threshold does the true value of the progress estimator need to be for not stopping to be a mistake. This is equivalent to stopping too late. - `delta2::AbstractFloat`, the percent above the threshold does the true value of the progress estimator need to be for stopping to be a mistake. This is equivalent to stopping too early. @@ -15,7 +15,7 @@ A structure that specifies a stopping criterion that incoroporates the randomnes - Calling MAStop(iter) will specify the users desired maximum number of iterations, threshold = 1e-10, delta1 = .9, delta2 = 1.1, chi1 = 0.01, and chi2 = 0.01. """ struct MAStop - threshold::AbstractFloat + threshold::Union{AbstractFloat, Tuple} delta1::AbstractFloat delta2::AbstractFloat chi1::AbstractFloat @@ -26,67 +26,83 @@ function MAStop(;threshold=1e-10, delta1=0.9, delta2=1.1, chi1=0.01, chi2=0.01) return MAStop(threshold, delta1, delta2, chi1, chi2) end + +""" + threshold_stop(log::MALoggerRecipe) + +Function that takes an input threshold and stops when the most recent entry in the history + vector is less than the threshold. + +# Arguments + - `log::MALoggerRecipe`, a structure containing the moving average logger information. + +# Bool + - Returns a Bool indicating if the stopping threshold is satisfied. +""" # Common interface for stopping criteria -function check_stop_criterion(log::MALogger, stop::MAStop) +function check_stop_criterion(log::MALoggerRecipe) its = log.iterations if its > 0 - I_threshold = iota_threshold(log, stop) + I_threshold = iota_threshold(log) thresholdChecks = - sqrt(log.iota_hist[its]) <= I_threshold && log.resid_hist[its] <= stop.threshold + sqrt(log.iota_hist[its]) <= I_threshold && log.resid_hist[its] <= log.threshold_info.threshold else thresholdChecks = false end - return (thresholdChecks || its == stop.max_iter ? true : false) + return thresholdChecks end + """ - iota_threshold(hist::LSLogMA, stop::LSStopMA) + iota_threshold(log::MALoggerRecipe) -Function that computes the stopping criterion using the sub-Exponential distribution from the `LSLogMA`, and the stopping criterion information in `LSSopMA`. This function is not exported and thus not directly callable by the user. +Function that computes the stopping criterion using the sub-Exponential distribution + from the `MALoggerRecipe`, and the stopping criterion information in TODO: use this stucture? + `MAStop`. This + function is not exported and thus not directly callable by the user. -# Inputs -- `hist::LSLogMA`, the log information for the moving average tracking. -- `stop::LSStopMA`, the stopping information for the stopping criterion. +# Arguments +- `log::MALoggerRecipe`, the log information for the moving average tracking. # Ouputs -Returns the stoppping criterion value. +- Returns the stoppping criterion value. Pritchard, Nathaniel, and Vivak Patel. "Solving, Tracking and Stopping Streaming Linear Inverse Problems." arXiv preprint arXiv:2201.05741 (2024). """ -function iota_threshold(hist::LSLogMA, stop::MAStop) - delta1 = stop.delta1 - delta2 = stop.delta2 - chi1 = stop.chi1 - chi2 = stop.chi2 - threshold = stop.threshold - lambda = hist.ma_info.lambda +function iota_threshold(log::MALoggerRecipe) + delta1 = log.threshold_info.delta1 + delta2 = log.threshold_info.delta2 + chi1 = log.threshold_info.chi1 + chi2 = log.threshold_info.chi2 + threshold = log.threshold_info.threshold + lambda = log.ma_info.lambda # If the constants for the sub-Exponential distribution are not defined then define them - if typeof(hist.dist_info.sigma2) <: Nothing - get_SE_constants!(hist, hist.dist_info.sampler) + if typeof(log.dist_info.sigma2) <: Nothing + get_SE_constants!(log, log.dist_info.sampler) end #If there is an omega in the sub-Exponential distribution then skip that calculation - if typeof(hist.dist_info.omega) <: Nothing + if typeof(log.dist_info.omega) <: Nothing # Compute the threshold bound in the case where there is no omega c = min( (1 - delta1)^2 * threshold^2 / (2 * log(1 / chi1)), (delta2 - 1)^2 * threshold^2 / (2 * log(1 / chi2)), ) c /= - (hist.dist_info.sigma2 * sqrt(hist.iota_hist[hist.iterations])) * + (log.dist_info.sigma2 * sqrt(log.iota_log[log.iterations])) * (1 + log(lambda)) / lambda else #compute error bound when there is an omega siota = - (hist.dist_info.sigma2 * sqrt(hist.iota_hist[hist.iterations])) * + (log.dist_info.sigma2 * sqrt(log.iota_log[log.iterations])) * (1 + log(lambda)) / lambda min1 = min( (1 - delta1)^2 * threshold^2 / (2 * log(1 / chi1) * siota), - lambda * (1 - delta1) * threshold / (2 * log(1 / chi1) * hist.dist_info.omega), + lambda * (1 - delta1) * threshold / (2 * log(1 / chi1) * log.dist_info.omega), ) min2 = min( (delta2 - 1)^2 * threshold^2 / (2 * log(1 / chi2) * siota), - lambda * (delta2 - 1) * threshold / (2 * log(1 / chi2) * hist.dist_info.omega), + lambda * (delta2 - 1) * threshold / (2 * log(1 / chi2) * log.dist_info.omega), ) c = min(min1, min2) end diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index 9a42e6e8..c0dd7a66 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -53,12 +53,7 @@ struct FullMALogger <: Logger ma_info::MAInfo threshold_info::MAStop stopping_criterion::Function - # resid_hist::Vector{AbstractFloat} - # lambda_hist::Vector{Integer} - # resid_norm::Function - # iterations::Integer - # converged::Bool - function BasicLogger(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) + function FullMALogger(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) if max_it < 0 throw(ArgumentError("Field `max_it` must be positive or 0.")) elseif collection_rate < 1 @@ -73,22 +68,17 @@ struct FullMALogger <: Logger end FullMALogger(; - max_it::Integer=0, - collection_rate::Integer=1, - lambda1::Integer=1, - lambda2::Integer=30, - threshold_info(), + max_it=0, + collection_rate=1, + lambda1=1, + lambda2=30, + threshold_info=MAStop(), stopping_criterion=check_stop_criterion ) = LSLogFullMA(max_it, collection_rate, MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), - threshold_info(), + threshold_info, stopping_criterion - # AbstractFloat[], - # Int64[], - # resid_norm, - # -1, - # false ) @@ -132,7 +122,7 @@ function complete_logger(logger::FullMALogger) false, res_hist, lambda_hist, - MAStop(), + logger.threshold_info, logger.stopping_criterion ) end @@ -140,19 +130,6 @@ end - - - - - - - - - - - - - # Common interface for update function update_logger!( logger::FullMALoggerRecipe, From 99f27216c3e8e345d037ddeedaa9412e9cce3119 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Fri, 23 May 2025 11:11:40 -0500 Subject: [PATCH 08/14] ma_updates --- docs/src/api/loggers.md | 12 ++ src/Solvers/Loggers.jl | 2 - src/Solvers/Loggers/ma_helpers/dist_info.jl | 193 ------------------ src/Solvers/Loggers/ma_helpers/ma_info.jl | 6 +- src/Solvers/Loggers/ma_helpers/ma_stop.jl | 111 ---------- src/Solvers/Loggers/moving_average_logger.jl | 117 ++++++----- test/Solvers/Loggers/moving_average_logger.jl | 21 ++ 7 files changed, 104 insertions(+), 358 deletions(-) delete mode 100644 src/Solvers/Loggers/ma_helpers/dist_info.jl delete mode 100644 src/Solvers/Loggers/ma_helpers/ma_stop.jl create mode 100644 test/Solvers/Loggers/moving_average_logger.jl diff --git a/docs/src/api/loggers.md b/docs/src/api/loggers.md index 8b7275a4..8f02c1aa 100644 --- a/docs/src/api/loggers.md +++ b/docs/src/api/loggers.md @@ -16,6 +16,16 @@ BasicLogger BasicLoggerRecipe +MALogger + +MALoggerRecipe + +``` + +### Moving average logger structures +```@docs +MAInfo + ``` ## Exported Functions @@ -27,4 +37,6 @@ update_logger! reset_logger! threshold_stop + +update_ma! ``` diff --git a/src/Solvers/Loggers.jl b/src/Solvers/Loggers.jl index e0273996..0284270b 100644 --- a/src/Solvers/Loggers.jl +++ b/src/Solvers/Loggers.jl @@ -99,5 +99,3 @@ include("Loggers/moving_average_logger.jl") # Include Logger Moving Average helpers Files ################################################ include("Loggers/ma_helpers/ma_info.jl") -include("Loggers/ma_helpers/dist_info.jl") -include("Loggers/ma_helpers/ma_stop.jl") diff --git a/src/Solvers/Loggers/ma_helpers/dist_info.jl b/src/Solvers/Loggers/ma_helpers/dist_info.jl deleted file mode 100644 index 4040a849..00000000 --- a/src/Solvers/Loggers/ma_helpers/dist_info.jl +++ /dev/null @@ -1,193 +0,0 @@ -# This file contains the components that are needed for storing -# and using distribution infromation for moving average method: -# Structs: SEDistInfo -# Functions: get_uncertainty, get_SE_constants! - -######################################### -# Structs -######################################### -""" - SEDistInfo - -A mutable structure that stores information about a distribution (i.e., sampling method) -in the sub-Exponential family. - -# Fields -- `sampler::Union{DataType, Nothing}`, the type of sampling method. -- `dimension::Integer`, the dimension that of the space that is being sampled. -- `block_dimension::Integer`, the dimension of the sample. -- `sigma2::Union{AbstractFloat, Nothing}`, the variance parameter in the sub-Exponential family. - If not specified by the user, a value is selected from a table based on the `sampler`. - If the `sampler` is not in the table, then `sigma2` is set to `1`. -- `omega::Union{AbstractFloat, Nothing}`, the exponential distrbution parameter. If not specified - by the user, a value is selected from a table based on the `sampler`. - If the `sampler` is not in the table, then `omega` is set to `1`. -- `eta::AbstractFloat`, a parameter for adjusting the conservativeness of the distribution, higher - value means a less conservative estimate. A recommended value is `1`. -- `scaling::AbstractFloat`, a scaling parameter for the norm-squared of the sketched residual to - ensure its expectation is the norm-squared of the residual. - -For more information see: -- Pritchard, Nathaniel, and Vivak Patel. "Solving, tracking and stopping streaming linear - inverse problems." Inverse Problems (2024). doi:10.1088/1361-6420/ad5583. -- Pritchard, Nathaniel, and Vivak Patel. “Towards Practical Large-Scale Randomized Iterative - Least Squares Solvers through Uncertainty Quantification.” SIAM/ASA J. Uncertainty - Quantification 11 (2022): 996-1024. doi.org/10.1137/22M1515057 -""" -mutable struct SEDistInfo - sampler::Union{DataType, Nothing} - dimension::Integer - block_dimension::Integer - sigma2::Union{AbstractFloat, Nothing} - omega::Union{AbstractFloat, Nothing} - eta::AbstractFloat - scaling::AbstractFloat - function SEDistInfo(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) - eta > 0 || throw(ArgumentError("eta must be positive")) - return new(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) - end -end - -SEDistInfo(; - sampler=nothing, - dimension=0, - block_dimension=0, - sigma2=nothing, - omega=nothing, - eta=1.0, - scaling=0.0 - ) = SEDistInfo(sampler, dimension, block_dimension, sigma2, omega, eta, scaling) - -######################################### -# Functions -######################################### -#Function that will return rho and its uncertainty from a LoggerRecipe type -""" - get_uncertainty(log::MALoggerRecipe; alpha::AbstractFloat = 0.05) - -A function that gets the uncertainty from moving average logger recipe types. - -# Arguments -- `hist::MALoggerRecipe`, the moving average logger recipe structure, - i.e. MALogger type. Specifically, the information of distribution (`dist_info`), - and all histories stored in the structure. -- `alpha::AbstractFloat`, the confidence level. - -# Returns -- A `(1-alpha)`-credible intervals for every `rho` in the `log`, specifically - it returns a tuple with (rho, Upper bound, Lower bound). -""" -function get_uncertainty(hist::MALoggerRecipe; alpha::AbstractFloat = 0.05) - l = length(hist.iota_hist) - upper = zeros(l) - lower = zeros(l) - # If the constants for the sub-Exponential distribution are not defined then define them - if typeof(hist.dist_info.sigma2) <: Nothing - throw(ArgumentError("The SE constants are empty, please set them in dist_info field of LoggerRecipe first.")) - end - - for i in 1:l - width = hist.lambda_hist[i] - iota = hist.iota_hist[i] - rho = hist.resid_hist[i] - # Define the variance term for the Gaussian part - cG = hist.dist_info.sigma2 * (1 + log(width)) * iota / (hist.dist_info.eta * width) - # If there is an omega in the sub-Exponential distribution then skip that calculation - if typeof(hist.dist_info.omega) <: Nothing - # Compute the threshold bound in the case where there is no omega - diffG = sqrt(cG * 2 * log(2/(alpha))) - upper[i] = rho + diffG - lower[i] = rho - diffG - else - # Compute error bound when there is an omega - diffG = sqrt(cG * 2 * log(2/(alpha))) - diffO = sqrt(iota) * 2 * log(2/(alpha)) * hist.dist_info.omega / (hist.dist_info.eta * width) - diffM = max(diffG, diffO) - upper[i] = rho + diffM - lower[i] = rho - diffM - end - - end - - return (hist.resid_hist, upper, lower) -end - -""" - get_SE_constants!(log::MALoggerRecipe, sampler::Type{T<:Compressor}) - -A function that returns a default set of sub-Exponential constants for each sampling method. - This function is not exported and thus the user does not have direct access to it. - -# Arguments -- `log::MALoggerRecipe`, the log containing all the tracking - information. Specifically, the information of distribution (`dist_info`). -- `sampler::Type{Compressor}`, the type of sampler being used. - -# Returns -- Performs an inplace update of the sub-Exponential constants for the log. Additionally, - updates the scaling constant to ensure expectation of block norms is equal to true norm. - If default is not a defined a warning is returned that sigma2 is set 1 and scaling - is set to 1. -""" -function get_SE_constants!(log::MALoggerRecipe, sampler::Type{T}) where T<:Compressor - @warn "No constants defined for method of type $sampler. By default we set sigma2 to 1 and scaling to 1." - log.dist_info.sigma2 = 1 - log.dist_info.scaling = 1 -end - - -# for type in (LinSysVecRowDetermCyclic,LinSysVecRowHopRandCyclic, -# LinSysVecRowOneRandCyclic, LinSysVecRowSVSampler, -# LinSysVecRowRandCyclic, LinSysVecRowUnidSampler, -# LinSysVecRowDistCyclic, LinSysVecRowResidCyclic, -# LinSysVecRowMaxResidual, LinSysVecRowMaxDistance,) -# @eval begin -# function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) -# log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) -# log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension -# end - -# end - -# end - - -# # Column subsetting methods have same constants as in row case -# for type in (LinSysVecColOneRandCyclic, LinSysVecColDetermCyclic) -# @eval begin -# function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) -# log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta) -# log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension -# end - -# end - -# end - -# # For row samplers with gaussian sampling we have sigma2 = 1/.2345 and omega = .1127 -# for type in (LinSysVecRowGaussSampler, LinSysVecRowSparseGaussSampler) -# @eval begin -# function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) -# log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta) -# log.dist_info.omega = .1127 -# log.dist_info.scaling = 1. -# end - -# end - -# end - -for type in (Gaussian) - @eval begin - function get_SE_constants!(log::LoggerRecipe, sampler::Type{$type}) - if sampler.compression_dim == 1 - log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta) - log.dist_info.omega = .1127 - log.dist_info.scaling = 1. - end - - end - - end - -end \ No newline at end of file diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index 1fefdebc..cbd4f550 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -47,7 +47,7 @@ end ######################################### """ update_ma!( - log::Union{MALoggerRecipe, FullMALoggerRecipe}, + log::MALoggerRecipe, res::Union{AbstractVector, Real}, lambda_base::Integer, iter::Integer @@ -56,7 +56,7 @@ end Function that updates the moving average tracking statistic. # Arguments -- `log::Union{MALoggerRecipe, FullMALoggerRecipe}`, the moving average logger recipe structure. +- `log::MALoggerRecipe`, the moving average logger recipe structure. - `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration. - `lambda_base::Integer`, which lambda, between lambda1 and lambda2, is currently being used. - `iter::Integer`, the current iteration. @@ -65,7 +65,7 @@ Function that updates the moving average tracking statistic. - Updates the log datatype and does not explicitly return anything. """ function update_ma!( - log::Union{MALoggerRecipe, FullMALoggerRecipe}, + log::MALoggerRecipe, res::Union{AbstractVector,Real}, lambda_base::Integer, iter::Integer, diff --git a/src/Solvers/Loggers/ma_helpers/ma_stop.jl b/src/Solvers/Loggers/ma_helpers/ma_stop.jl deleted file mode 100644 index bad39f5d..00000000 --- a/src/Solvers/Loggers/ma_helpers/ma_stop.jl +++ /dev/null @@ -1,111 +0,0 @@ -""" - MAStop - -A structure that specifies a stopping criterion that incoroporates the randomness of the moving average estimator. That is, once a method - achieves a certain number of iterations, it stops. - -# Fields -- `threshold::Union{AbstractFloat, Tuple}`, the value of the estimator that is sufficient progress. -- `delta1::AbstractFloat`, the percent below the threshold does the true value of the progress estimator need to be for not stopping to be a mistake. This is equivalent to stopping too late. -- `delta2::AbstractFloat`, the percent above the threshold does the true value of the progress estimator need to be for stopping to be a - mistake. This is equivalent to stopping too early. -- `chi1::AbstractFloat`, the probability that the stopping too late action occurs. -- `chi2::AbstractFloat`, the probability that the stopping too early action occurs. -# Constructors -- Calling MAStop(iter) will specify the users desired maximum number of iterations, threshold = 1e-10, delta1 = .9, delta2 = 1.1, chi1 = 0.01, and chi2 = 0.01. -""" -struct MAStop - threshold::Union{AbstractFloat, Tuple} - delta1::AbstractFloat - delta2::AbstractFloat - chi1::AbstractFloat - chi2::AbstractFloat -end - -function MAStop(;threshold=1e-10, delta1=0.9, delta2=1.1, chi1=0.01, chi2=0.01) - return MAStop(threshold, delta1, delta2, chi1, chi2) -end - - -""" - threshold_stop(log::MALoggerRecipe) - -Function that takes an input threshold and stops when the most recent entry in the history - vector is less than the threshold. - -# Arguments - - `log::MALoggerRecipe`, a structure containing the moving average logger information. - -# Bool - - Returns a Bool indicating if the stopping threshold is satisfied. -""" -# Common interface for stopping criteria -function check_stop_criterion(log::MALoggerRecipe) - its = log.iterations - if its > 0 - I_threshold = iota_threshold(log) - thresholdChecks = - sqrt(log.iota_hist[its]) <= I_threshold && log.resid_hist[its] <= log.threshold_info.threshold - else - thresholdChecks = false - end - return thresholdChecks -end - - -""" - iota_threshold(log::MALoggerRecipe) - -Function that computes the stopping criterion using the sub-Exponential distribution - from the `MALoggerRecipe`, and the stopping criterion information in TODO: use this stucture? - `MAStop`. This - function is not exported and thus not directly callable by the user. - -# Arguments -- `log::MALoggerRecipe`, the log information for the moving average tracking. - -# Ouputs -- Returns the stoppping criterion value. - -Pritchard, Nathaniel, and Vivak Patel. "Solving, Tracking and Stopping Streaming Linear Inverse Problems." arXiv preprint arXiv:2201.05741 (2024). -""" -function iota_threshold(log::MALoggerRecipe) - delta1 = log.threshold_info.delta1 - delta2 = log.threshold_info.delta2 - chi1 = log.threshold_info.chi1 - chi2 = log.threshold_info.chi2 - threshold = log.threshold_info.threshold - lambda = log.ma_info.lambda - # If the constants for the sub-Exponential distribution are not defined then define them - - if typeof(log.dist_info.sigma2) <: Nothing - get_SE_constants!(log, log.dist_info.sampler) - end - #If there is an omega in the sub-Exponential distribution then skip that calculation - if typeof(log.dist_info.omega) <: Nothing - # Compute the threshold bound in the case where there is no omega - c = min( - (1 - delta1)^2 * threshold^2 / (2 * log(1 / chi1)), - (delta2 - 1)^2 * threshold^2 / (2 * log(1 / chi2)), - ) - c /= - (log.dist_info.sigma2 * sqrt(log.iota_log[log.iterations])) * - (1 + log(lambda)) / lambda - else - #compute error bound when there is an omega - siota = - (log.dist_info.sigma2 * sqrt(log.iota_log[log.iterations])) * - (1 + log(lambda)) / lambda - min1 = min( - (1 - delta1)^2 * threshold^2 / (2 * log(1 / chi1) * siota), - lambda * (1 - delta1) * threshold / (2 * log(1 / chi1) * log.dist_info.omega), - ) - min2 = min( - (delta2 - 1)^2 * threshold^2 / (2 * log(1 / chi2) * siota), - lambda * (delta2 - 1) * threshold / (2 * log(1 / chi2) * log.dist_info.omega), - ) - c = min(min1, min2) - end - - return c -end \ No newline at end of file diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index c0dd7a66..8612920f 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -1,13 +1,13 @@ """ - FullMALogger <: Logger + MALogger <: Logger A structure that stores information of specification about a randomized linear solver's behavior. The log assumes that the full linear system is available for processing. The goal of this log is usually for research, development or testing as it is unlikely that the entire residual vector is readily available. -# Fields -- `collection_rate::Integer`, the frequency with which to record information about progress +# Fields TODO +- `collection_rate::Int64`, the frequency with which to record information about progress to append to the remaining fields, starting with the initialization (i.e., iteration `0`). For example, `collection_rate` = 3 means the iteration difference between each records is 3, i.e. recording information at @@ -26,7 +26,7 @@ A structure that stores information of specification about a randomized linear s # Constructors - FullMALogger(;collection_rate=1, lambda1=1, lambda2=30, resid_norm=norm) + MALogger(;collection_rate=1, lambda1=1, lambda2=30) ## Keywords - `collection_rate::Integer`, the frequency with which to record information about progress @@ -36,24 +36,21 @@ A structure that stores information of specification about a randomized linear s iteration `0`, `3`, `6`, `9`, .... By default, this is set to `1`. - `lambda1::Integer`, the TODO. By default, this is set to `1`. - `lambda2::Integer`, the TODO. By default, this is set to `30`. -- `resid_norm::Function`, a function that accepts a single vector argument and returns a - scalar. Used to compute the residual size. By default, `norm`, which is Euclidean - norm, is set. ## Returns -- A `FullMALogger` object. +- A `MALogger` object. ## Throws TODO - `ArgumentError` if `compression_dim` is non-positive, if `nnz` is exceeds `compression_dim`, or if `nnz` is non-positive. """ -struct FullMALogger <: Logger - max_it::Integer +struct MALogger <: Logger + max_it::Int64 collection_rate::Integer ma_info::MAInfo - threshold_info::MAStop + threshold_info::Union{Float64, Tuple} stopping_criterion::Function - function FullMALogger(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) + function MALogger(max_it, collection_rate, ma_info, threshold_info, stopping_criterion) if max_it < 0 throw(ArgumentError("Field `max_it` must be positive or 0.")) elseif collection_rate < 1 @@ -67,45 +64,45 @@ struct FullMALogger <: Logger end -FullMALogger(; - max_it=0, - collection_rate=1, - lambda1=1, - lambda2=30, - threshold_info=MAStop(), - stopping_criterion=check_stop_criterion - ) = LSLogFullMA(max_it, - collection_rate, - MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), - threshold_info, - stopping_criterion - ) +MALogger(; + max_it=0, + collection_rate=1, + lambda1=1, + lambda2=30, + threshold_info=1e-10, + stopping_criterion=check_stop_criterion + ) = LSLogFullMA(max_it, + collection_rate, + MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), + threshold_info, + stopping_criterion + ) """ - FullMALoggerRecipe <: LoggerRecipe + MALoggerRecipe <: LoggerRecipe TODO -The recipe contains the information of `FullMALogger`, stores the error metric +The recipe contains the information of `MALogger`, stores the error metric in a vector. Checks convergence of the solver based on the log information. # Fields """ -mutable struct FullMALoggerRecipe{F<:Function} <: LoggerRecipe - max_it::Integer +mutable struct MALoggerRecipe{F<:Function} <: LoggerRecipe + max_it::Int64 error::AbstractFloat - iteration::Integer - record_location::Integer + iteration::Int64 + record_location::Int64 collection_rate::Integer converged::Bool resid_hist::Vector{AbstractFloat} lambda_hist::Vector{Integer} - threshold_info::MAStop + threshold_info::Union{Float64, Tuple} stopping_criterion::F end -function complete_logger(logger::FullMALogger) +function complete_logger(logger::MALogger) # By using ceil if we divide exactly we always have space to record last value, if it # does not divide exactly we have one more than required and thus enough space to record # the last value @@ -113,28 +110,26 @@ function complete_logger(logger::FullMALogger) # Use one more than max_it to collect res_hist = zeros(max_collection + 1) lambda_hist = zeros(max_collection + 1) - return FullMALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, - 0.0, - logger.threshold_info, - 1, - 1, - logger.collection_rate, - false, - res_hist, - lambda_hist, - logger.threshold_info, - logger.stopping_criterion - ) + return MALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, + 0.0, + logger.threshold_info, + 1, + 1, + logger.collection_rate, + false, + res_hist, + lambda_hist, + logger.threshold_info, + logger.stopping_criterion + ) end - - # Common interface for update function update_logger!( - logger::FullMALoggerRecipe, + logger::MALoggerRecipe, error::AbstractFloat, - iteration::Integer + iteration::Int64 ) # Update iteration counter logger.iterations = iteration @@ -181,5 +176,29 @@ end +function reset_logger!(logger::MALoggerRecipe) + logger.error = 0.0 + logger.iteration = 1 + logger.record_location = 1 + logger.converged = false + fill!(logger.hist, 0.0) + return nothing +end + + + +""" + threshold_stop(log::MALoggerRecipe) +Function that takes an input threshold and stops when the most recent entry in the history +vector is less than the threshold. +# Arguments + - `log::MALoggerRecipe`, a structure containing the logger information + +# Bool + - Returns a Bool indicating if the stopping threshold is satisfied. +""" +function threshold_stop(log::MALoggerRecipe) + return log.error < log.threshold_info +end diff --git a/test/Solvers/Loggers/moving_average_logger.jl b/test/Solvers/Loggers/moving_average_logger.jl new file mode 100644 index 00000000..0c2a1a40 --- /dev/null +++ b/test/Solvers/Loggers/moving_average_logger.jl @@ -0,0 +1,21 @@ +module moving_average_logger + using Test, RLinearAlgebra, Random + include("../../test_helpers/field_test_macros.jl") + include("../../test_helpers/approx_tol.jl") + using .FieldTest + using .ApproxTol + @testset "Logger MALogger" begin + Random.seed!(21321) + n_rows = 4 + n_cols = 2 + A = rand(n_rows, n_cols) + b = rand(n_rows) + + + + + + + end + +end From e97ab9ae6e686274b5dd7746a2082491db9bc685 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Fri, 23 May 2025 11:24:11 -0500 Subject: [PATCH 09/14] rlinearalgebra.jl --- src/RLinearAlgebra.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/RLinearAlgebra.jl b/src/RLinearAlgebra.jl index 046c02d1..27548945 100644 --- a/src/RLinearAlgebra.jl +++ b/src/RLinearAlgebra.jl @@ -28,8 +28,9 @@ export complete_solver, update_solver!, rsolve, rsolve! # Export Logger types and functions export Logger, LoggerRecipe -export BasicLogger, BasicLoggerRecipe -export complete_logger, update_logger!, reset_logger! +export MAInfo +export BasicLogger, BasicLoggerRecipe, MALogger, MALoggerRecipe +export complete_logger, update_logger!, reset_logger!, update_ma! export threshold_stop # Export SubSolver types From bfef75e7ca3fae10ed824e7fb93ce8811d86a428 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Fri, 23 May 2025 11:39:54 -0500 Subject: [PATCH 10/14] Solve dependence error --- src/Solvers/Loggers.jl | 9 +++++---- src/Solvers/Loggers/ma_helpers/ma_info.jl | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/Solvers/Loggers.jl b/src/Solvers/Loggers.jl index 0284270b..3bf30925 100644 --- a/src/Solvers/Loggers.jl +++ b/src/Solvers/Loggers.jl @@ -89,13 +89,14 @@ function reset_logger!(logger::LoggerRecipe) return nothing end +################################################ +# Include Logger Moving Average helpers Files +################################################ +include("Loggers/ma_helpers/ma_info.jl") + ############################### # Include Logger Methods Files ############################### include("Loggers/basic_logger.jl") include("Loggers/moving_average_logger.jl") -################################################ -# Include Logger Moving Average helpers Files -################################################ -include("Loggers/ma_helpers/ma_info.jl") diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index cbd4f550..e2777c42 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -47,7 +47,7 @@ end ######################################### """ update_ma!( - log::MALoggerRecipe, + log::LoggerRecipe, res::Union{AbstractVector, Real}, lambda_base::Integer, iter::Integer @@ -56,7 +56,7 @@ end Function that updates the moving average tracking statistic. # Arguments -- `log::MALoggerRecipe`, the moving average logger recipe structure. +- `log::LoggerRecipe`, the parent of moving average logger recipe structure. - `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration. - `lambda_base::Integer`, which lambda, between lambda1 and lambda2, is currently being used. - `iter::Integer`, the current iteration. @@ -65,7 +65,7 @@ Function that updates the moving average tracking statistic. - Updates the log datatype and does not explicitly return anything. """ function update_ma!( - log::MALoggerRecipe, + log::LoggerRecipe, res::Union{AbstractVector,Real}, lambda_base::Integer, iter::Integer, From 4fab32eeee5fe43248431bc7071898c5a35a811c Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Fri, 23 May 2025 14:29:06 -0500 Subject: [PATCH 11/14] update logger --- src/Solvers/Loggers/moving_average_logger.jl | 65 +++++++++---------- test/Solvers/Loggers/moving_average_logger.jl | 6 ++ 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index 8612920f..f6801e27 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -70,13 +70,13 @@ MALogger(; lambda1=1, lambda2=30, threshold_info=1e-10, - stopping_criterion=check_stop_criterion - ) = LSLogFullMA(max_it, - collection_rate, - MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), - threshold_info, - stopping_criterion - ) + stopping_criterion=threshold_stop + ) = MALogger(max_it, + collection_rate, + MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)), + threshold_info, + stopping_criterion + ) """ @@ -96,6 +96,7 @@ mutable struct MALoggerRecipe{F<:Function} <: LoggerRecipe record_location::Int64 collection_rate::Integer converged::Bool + ma_info::MAInfo resid_hist::Vector{AbstractFloat} lambda_hist::Vector{Integer} threshold_info::Union{Float64, Tuple} @@ -112,11 +113,11 @@ function complete_logger(logger::MALogger) lambda_hist = zeros(max_collection + 1) return MALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, 0.0, - logger.threshold_info, 1, 1, logger.collection_rate, false, + logger.ma_info, res_hist, lambda_hist, logger.threshold_info, @@ -132,32 +133,9 @@ function update_logger!( iteration::Int64 ) # Update iteration counter - logger.iterations = iteration + logger.iteration = iteration logger.error = error - ############################### - # Implement moving average (MA) - ############################### - ma_info = logger.ma_info - # Compute the current residual to second power to align with theory - res::AbstractFloat = error^2 - - # Check if MA is in lambda1 or lambda2 regime - if ma_info.flag - update_ma!(logger, res, ma_info.lambda2, iteration) - else - # Check if we can switch between lambda1 and lambda2 regime - # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime - # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag - # because update_ma changes res_window and ma_info.idx we must check condition first - flag_cond = iteration == 0 || res <= ma_info.res_window[ma_info.idx] - update_ma!(logger, res, ma_info.lambda1, iteration) - ma_info.flag = !flag_cond - end - - ############################### - # Stop the algorithm - ############################### # Always check max_it stopping criterion # Compute in this way to avoid bounds error from searching in the max_it + 1 location logger.converged = iteration <= logger.max_it ? @@ -166,7 +144,25 @@ function update_logger!( # log according to collection rate or if we have converged if rem(iteration, logger.collection_rate) == 0 || logger.converged - logger.hist[logger.record_location] = error + ############################### + # Implement moving average (MA) + ############################### + ma_info = logger.ma_info + # Compute the current residual to second power to align with theory + res::AbstractFloat = error^2 + + # Check if MA is in lambda1 or lambda2 regime + if ma_info.flag + update_ma!(logger, res, ma_info.lambda2, iteration) + else + # Check if we can switch between lambda1 and lambda2 regime + # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime + # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag + # because update_ma changes res_window and ma_info.idx we must check condition first + flag_cond = iteration == 0 || res <= ma_info.res_window[ma_info.idx] + update_ma!(logger, res, ma_info.lambda1, iteration) + ma_info.flag = !flag_cond + end logger.record_location += 1 end @@ -181,7 +177,8 @@ function reset_logger!(logger::MALoggerRecipe) logger.iteration = 1 logger.record_location = 1 logger.converged = false - fill!(logger.hist, 0.0) + fill!(logger.resid_hist, 0.0) + fill!(logger.lambda_hist, 0.0) return nothing end diff --git a/test/Solvers/Loggers/moving_average_logger.jl b/test/Solvers/Loggers/moving_average_logger.jl index 0c2a1a40..65bf7a3c 100644 --- a/test/Solvers/Loggers/moving_average_logger.jl +++ b/test/Solvers/Loggers/moving_average_logger.jl @@ -11,6 +11,12 @@ module moving_average_logger A = rand(n_rows, n_cols) b = rand(n_rows) + # How to use the logger, the error for the update + a = MALogger() + b = complete_logger(a) + reset_logger!(b) + update_logger!(b, 0.5, 1) + From c0052afd9a5f6c3a7c1375592374f3be1011e354 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Sat, 24 May 2025 00:52:17 -0500 Subject: [PATCH 12/14] change ma_update to make it update outside --- src/Solvers/Loggers/ma_helpers/ma_info.jl | 34 +++++++----- src/Solvers/Loggers/moving_average_logger.jl | 56 ++++++++++++-------- 2 files changed, 55 insertions(+), 35 deletions(-) diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index e2777c42..bcbb550e 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -85,12 +85,17 @@ function update_ma!( accum2 += ma_info.res_window[i]^2 end - if mod(iter, log.collection_rate) == 0 || iter == 0 - push!(log.lambda_hist, ma_info.lambda) - push!(log.resid_hist, accum / ma_info.lambda) - (:iota_hist in fieldnames(typeof(log))) && - push!(log.iota_hist, accum2 / ma_info.lambda) - end + # Record the moving average error for stopping + log.lambda_origin = ma_info.lambda + log.ma_error = accum / ma_info.lambda + log.iota_error = accum2 / ma_info.lambda + + # if mod(iter, log.collection_rate) == 0 || iter == 0 + # push!(log.lambda_hist, ma_info.lambda) + # push!(log.resid_hist, accum / ma_info.lambda) + # (:iota_hist in fieldnames(typeof(log))) && + # push!(log.iota_hist, accum2 / ma_info.lambda) + # end else # Consider the case when lambda <= lambda1 or lambda1 < lambda < lambda2 @@ -118,13 +123,18 @@ function update_ma!( accum2 += ma_info.res_window[i]^2 end + # Record the moving average error for stopping + log.lambda_origin = ma_info.lambda + log.ma_error = accum / ma_info.lambda + log.iota_error = accum2 / ma_info.lambda + #Update the log variable with the information for this update - if mod(iter, log.collection_rate) == 0 || iter == 0 - push!(log.lambda_hist, ma_info.lambda) - push!(log.resid_hist, accum / ma_info.lambda) - (:iota_hist in fieldnames(typeof(log))) && - push!(log.iota_hist, accum2 / ma_info.lambda) - end + # if mod(iter, log.collection_rate) == 0 || iter == 0 + # push!(log.lambda_hist, ma_info.lambda) + # push!(log.resid_hist, accum / ma_info.lambda) + # (:iota_hist in fieldnames(typeof(log))) && + # push!(log.iota_hist, accum2 / ma_info.lambda) + # end ma_info.lambda += ma_info.lambda < lambda_base ? 1 : 0 end diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index f6801e27..4deffcc1 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -91,13 +91,15 @@ The recipe contains the information of `MALogger`, stores the error metric """ mutable struct MALoggerRecipe{F<:Function} <: LoggerRecipe max_it::Int64 - error::AbstractFloat + ma_error::AbstractFloat + iota_error::AbstractFloat iteration::Int64 record_location::Int64 collection_rate::Integer converged::Bool ma_info::MAInfo resid_hist::Vector{AbstractFloat} + lambda_origin::Integer lambda_hist::Vector{Integer} threshold_info::Union{Float64, Tuple} stopping_criterion::F @@ -112,6 +114,7 @@ function complete_logger(logger::MALogger) res_hist = zeros(max_collection + 1) lambda_hist = zeros(max_collection + 1) return MALoggerRecipe{typeof(logger.stopping_criterion)}(logger.max_it, + 0.0, 0.0, 1, 1, @@ -119,6 +122,7 @@ function complete_logger(logger::MALogger) false, logger.ma_info, res_hist, + 0, lambda_hist, logger.threshold_info, logger.stopping_criterion @@ -134,7 +138,26 @@ function update_logger!( ) # Update iteration counter logger.iteration = iteration - logger.error = error + + ############################### + # Implement moving average (MA) + ############################### + ma_info = logger.ma_info + # Compute the current residual to second power to align with theory + res::AbstractFloat = error^2 + + # Check if MA is in lambda1 or lambda2 regime + if ma_info.flag + update_ma!(logger, res, ma_info.lambda2, iteration) + else + # Check if we can switch between lambda1 and lambda2 regime + # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime + # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag + # because update_ma changes res_window and ma_info.idx we must check condition first + flag_cond = iteration == 0 || res <= ma_info.res_window[ma_info.idx] + update_ma!(logger, res, ma_info.lambda1, iteration) + ma_info.flag = !flag_cond + end # Always check max_it stopping criterion # Compute in this way to avoid bounds error from searching in the max_it + 1 location @@ -144,25 +167,10 @@ function update_logger!( # log according to collection rate or if we have converged if rem(iteration, logger.collection_rate) == 0 || logger.converged - ############################### - # Implement moving average (MA) - ############################### - ma_info = logger.ma_info - # Compute the current residual to second power to align with theory - res::AbstractFloat = error^2 - - # Check if MA is in lambda1 or lambda2 regime - if ma_info.flag - update_ma!(logger, res, ma_info.lambda2, iteration) - else - # Check if we can switch between lambda1 and lambda2 regime - # If it is in the monotonic decreasing of the sketched residual then we are in a lambda1 regime - # otherwise we switch to the lambda2 regime which is indicated by the changing of the flag - # because update_ma changes res_window and ma_info.idx we must check condition first - flag_cond = iteration == 0 || res <= ma_info.res_window[ma_info.idx] - update_ma!(logger, res, ma_info.lambda1, iteration) - ma_info.flag = !flag_cond - end + + logger.lambda_hist[logger.record_location] = logger.lambda_origin + logger.resid_hist[logger.record_location] = logger.ma_error + logger.record_location += 1 end @@ -173,7 +181,9 @@ end function reset_logger!(logger::MALoggerRecipe) - logger.error = 0.0 + logger.ma_error = 0.0 + logger.iota_error = 0.0 + logger.lambda_origin = 0 logger.iteration = 1 logger.record_location = 1 logger.converged = false @@ -197,5 +207,5 @@ vector is less than the threshold. - Returns a Bool indicating if the stopping threshold is satisfied. """ function threshold_stop(log::MALoggerRecipe) - return log.error < log.threshold_info + return log.ma_error < log.threshold_info end From 920a13e39433c7e7ded58fae6ed70589db5d638a Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Tue, 27 May 2025 10:22:16 -0500 Subject: [PATCH 13/14] change the fields name --- src/Solvers/Loggers/ma_helpers/ma_info.jl | 8 ++++---- src/Solvers/Loggers/moving_average_logger.jl | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index bcbb550e..ca31b653 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -87,12 +87,12 @@ function update_ma!( # Record the moving average error for stopping log.lambda_origin = ma_info.lambda - log.ma_error = accum / ma_info.lambda + log.error = accum / ma_info.lambda log.iota_error = accum2 / ma_info.lambda # if mod(iter, log.collection_rate) == 0 || iter == 0 # push!(log.lambda_hist, ma_info.lambda) - # push!(log.resid_hist, accum / ma_info.lambda) + # push!(log.hist, accum / ma_info.lambda) # (:iota_hist in fieldnames(typeof(log))) && # push!(log.iota_hist, accum2 / ma_info.lambda) # end @@ -125,13 +125,13 @@ function update_ma!( # Record the moving average error for stopping log.lambda_origin = ma_info.lambda - log.ma_error = accum / ma_info.lambda + log.error = accum / ma_info.lambda log.iota_error = accum2 / ma_info.lambda #Update the log variable with the information for this update # if mod(iter, log.collection_rate) == 0 || iter == 0 # push!(log.lambda_hist, ma_info.lambda) - # push!(log.resid_hist, accum / ma_info.lambda) + # push!(log.hist, accum / ma_info.lambda) # (:iota_hist in fieldnames(typeof(log))) && # push!(log.iota_hist, accum2 / ma_info.lambda) # end diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index 4deffcc1..967d906a 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -13,7 +13,7 @@ A structure that stores information of specification about a randomized linear s difference between each records is 3, i.e. recording information at iteration `0`, `3`, `6`, `9`, .... - `ma_info::MAInfo`, [`MAInfo`](@ref) -- `resid_hist::Vector{AbstractFloat}`, retains a vector of numbers corresponding to the residual +- `hist::Vector{AbstractFloat}`, retains a vector of numbers corresponding to the residual (uses the whole system to compute the residual). These values are stored at iterates specified by `collection_rate`. - `lambda_hist::Vector{Integer}`, contains the widths of the moving average. @@ -91,14 +91,14 @@ The recipe contains the information of `MALogger`, stores the error metric """ mutable struct MALoggerRecipe{F<:Function} <: LoggerRecipe max_it::Int64 - ma_error::AbstractFloat + error::AbstractFloat # Moving average error, named for field check iota_error::AbstractFloat iteration::Int64 record_location::Int64 collection_rate::Integer converged::Bool ma_info::MAInfo - resid_hist::Vector{AbstractFloat} + hist::Vector{AbstractFloat} # Residual history, named for field check lambda_origin::Integer lambda_hist::Vector{Integer} threshold_info::Union{Float64, Tuple} @@ -169,7 +169,7 @@ function update_logger!( if rem(iteration, logger.collection_rate) == 0 || logger.converged logger.lambda_hist[logger.record_location] = logger.lambda_origin - logger.resid_hist[logger.record_location] = logger.ma_error + logger.hist[logger.record_location] = logger.error logger.record_location += 1 end @@ -181,13 +181,13 @@ end function reset_logger!(logger::MALoggerRecipe) - logger.ma_error = 0.0 + logger.error = 0.0 logger.iota_error = 0.0 logger.lambda_origin = 0 logger.iteration = 1 logger.record_location = 1 logger.converged = false - fill!(logger.resid_hist, 0.0) + fill!(logger.hist, 0.0) fill!(logger.lambda_hist, 0.0) return nothing end @@ -207,5 +207,5 @@ vector is less than the threshold. - Returns a Bool indicating if the stopping threshold is satisfied. """ function threshold_stop(log::MALoggerRecipe) - return log.ma_error < log.threshold_info + return log.error < log.threshold_info end From 6df932ddbd09d12f9468b13226ad3b58cfa04be9 Mon Sep 17 00:00:00 2001 From: Tunan Wang Date: Wed, 28 May 2025 09:10:43 -0500 Subject: [PATCH 14/14] documents --- docs/src/refs.bib | 2 + src/Solvers/Loggers/ma_helpers/ma_info.jl | 39 ++++--- src/Solvers/Loggers/moving_average_logger.jl | 101 ++++++++++++------- 3 files changed, 88 insertions(+), 54 deletions(-) diff --git a/docs/src/refs.bib b/docs/src/refs.bib index c6869065..b78a0428 100644 --- a/docs/src/refs.bib +++ b/docs/src/refs.bib @@ -56,3 +56,5 @@ @article{pritchard2024solving urldate = {2025-01-24}, abstract = {Abstract In large-scale applications including medical imaging, collocation differential equation solvers, and estimation with differential privacy, the underlying linear inverse problem can be reformulated as a streaming problem. In theory, the streaming problem can be effectively solved using memory-efficient, exponentially-converging streaming solvers. In special cases when the underlying linear inverse problem is finite-dimensional, streaming solvers can periodically evaluate the residual norm at a substantial computational cost. When the underlying system is infinite dimensional, streaming solver can only access noisy estimates of the residual. While such noisy estimates are computationally efficient, they are useful only when their accuracy is known. In this work, we rigorously develop a general family of computationally-practical residual estimators and their uncertainty sets for streaming solvers, and we demonstrate the accuracy of our methods on a number of large-scale linear problems. Thus, we further enable the practical use of streaming solvers for important classes of linear inverse problems.} } + + diff --git a/src/Solvers/Loggers/ma_helpers/ma_info.jl b/src/Solvers/Loggers/ma_helpers/ma_info.jl index ca31b653..062521b5 100644 --- a/src/Solvers/Loggers/ma_helpers/ma_info.jl +++ b/src/Solvers/Loggers/ma_helpers/ma_info.jl @@ -9,8 +9,13 @@ """ MAInfo -A mutable structure that stores information relevant to the moving average of the - progress estimator. +A mutable structure that stores information relevant to the moving average (MA) of a + progress estimator, such as a residual. It manages different MA window widths + (`lambda1`, `lambda2`) for different convergence phases and tracks the current + MA window (`res_window`). + +See [pritchard2024solving](@cite) for more information on the underlying MA methods. + # Fields - `lambda1::Integer`, the width of the moving average during the fast convergence phase of the algorithm. @@ -20,18 +25,14 @@ A mutable structure that stores information relevant to the moving average of th phase, each iterate differs from the previous one by a small amount and thus most of the observed variation arises from the randomness of the sketched progress estimator, which is best smoothed by a wide moving average width. -- `lambda::Integer`, the width of the moving average at the current iteration. This value is not controlled by - the user. -- `flag::Bool`, a boolean indicating which phase we are in, a value of `true` indicates slow convergence phase. -- `idx::Integer`, the index indcating what value should be replaced in the moving average buffer. -- `res_window::Vector{<:AbstractFloat}`, the moving average buffer. - -For more information see: -- Pritchard, Nathaniel, and Vivak Patel. "Solving, tracking and stopping streaming linear - inverse problems." Inverse Problems (2024). doi:10.1088/1361-6420/ad5583. -- Pritchard, Nathaniel, and Vivak Patel. “Towards Practical Large-Scale Randomized Iterative - Least Squares Solvers through Uncertainty Quantification.” SIAM/ASA J. Uncertainty - Quantification 11 (2022): 996-1024. doi.org/10.1137/22M1515057 +- `lambda::Integer`, the actual width of the moving average being used at the current iteration. + This field is updated internally and is not set directly by the user. +- `flag::Bool`, a boolean flag indicating the current convergence phase. A value of `true` + typically indicates the "slow" convergence phase (using `lambda2`). +- `idx::Integer`, the current index within the `res_window` buffer where the next residual value + will be stored, implementing a circular buffer. +- `res_window::Vector{<:AbstractFloat}`, the buffer storing the recent residual values + used to compute the moving average. """ mutable struct MAInfo lambda1::Integer @@ -53,7 +54,10 @@ end iter::Integer ) -Function that updates the moving average tracking statistic. +Updates the moving average statistics stored within the `log.ma_info` field of a + `MALoggerRecipe`. It computes the moving average and second moment (iota) of the + provided residual `res` and updates `log.error`, `log.iota_error`, and + `log.lambda_origin` fields of the `MALoggerRecipe`. # Arguments - `log::LoggerRecipe`, the parent of moving average logger recipe structure. @@ -62,7 +66,8 @@ Function that updates the moving average tracking statistic. - `iter::Integer`, the current iteration. # Returns -- Updates the log datatype and does not explicitly return anything. +- `nothing` (The `log` object, specifically its `ma_info`, `error`, `iota_error`, and + `lambda_origin` fields, is modified in-place). """ function update_ma!( log::LoggerRecipe, @@ -138,4 +143,6 @@ function update_ma!( ma_info.lambda += ma_info.lambda < lambda_base ? 1 : 0 end + + return nothing end \ No newline at end of file diff --git a/src/Solvers/Loggers/moving_average_logger.jl b/src/Solvers/Loggers/moving_average_logger.jl index 967d906a..95256cb8 100644 --- a/src/Solvers/Loggers/moving_average_logger.jl +++ b/src/Solvers/Loggers/moving_average_logger.jl @@ -2,47 +2,52 @@ MALogger <: Logger A structure that stores information of specification about a randomized linear solver's - behavior. The log assumes that the full linear system is available for processing. - The goal of this log is usually for research, development or testing as it is unlikely - that the entire residual vector is readily available. + behavior. The log assumes that the full linear system is available for processing. + The goal of this log is usually for research, development or testing as it is unlikely + that the entire residual vector is readily available. -# Fields TODO +# Fields +- `max_it::Int64`, the maximum number of iterations for the solver. - `collection_rate::Int64`, the frequency with which to record information about progress - to append to the remaining fields, starting with the initialization - (i.e., iteration `0`). For example, `collection_rate` = 3 means the iteration - difference between each records is 3, i.e. recording information at - iteration `0`, `3`, `6`, `9`, .... + to append to the remaining fields, starting with the initialization + (i.e., iteration `0`). For example, `collection_rate` = `3` means the iteration + difference between each records is `3`, i.e. recording information at + iteration `0`, `3`, `6`, `9`, .... - `ma_info::MAInfo`, [`MAInfo`](@ref) -- `hist::Vector{AbstractFloat}`, retains a vector of numbers corresponding to the residual - (uses the whole system to compute the residual). These values are stored at iterates - specified by `collection_rate`. -- `lambda_hist::Vector{Integer}`, contains the widths of the moving average. - These values are stored at iterates specified by `collection_rate`. -- `resid_norm::Function`, a function that accepts a single vector argument and returns a - scalar. Used to compute the residual size. -- `iterations::Integer`, the number of iterations of the solver. -- `converged::Bool`, a flag to indicate whether the system has converged by some measure. - By default this is `false`. +- `threshold_info::Union{Float64, Tuple}`, the parameters used for stopping the algorithm. +- `stopping_criterion::Function`, function that evaluates the stopping criterion. # Constructors - MALogger(;collection_rate=1, lambda1=1, lambda2=30) + MALogger(; + max_it=0, + collection_rate=1, + lambda1=1, + lambda2=30, + threshold_info=1e-10, + stopping_criterion=threshold_stop + ) ## Keywords -- `collection_rate::Integer`, the frequency with which to record information about progress - to append to the remaining fields, starting with the initialization - (i.e., iteration `0`). For example, `collection_rate` = 3 means the iteration - difference between each records is 3, i.e. recording information at - iteration `0`, `3`, `6`, `9`, .... By default, this is set to `1`. -- `lambda1::Integer`, the TODO. By default, this is set to `1`. -- `lambda2::Integer`, the TODO. By default, this is set to `30`. +- `max_it::Int64`, the maximum number of iterations for the solver. Default: `0`. +- `collection_rate::Integer`, the frequency for recording progress. Default: `1`. +- `lambda1::Integer`, the width of the moving average during the initial "fast" + convergence phase. Default: `1`. +- `lambda2::Integer`, the width of the moving average during the later "slow" + convergence phase. Default: `30`. +- `threshold_info::Union{Float64, Tuple}`, parameters for the stopping criterion. + Default: `1e-10`. +- `stopping_criterion::Function`, the function used to check for convergence. + Default: `threshold_stop`. ## Returns - A `MALogger` object. -## Throws TODO -- `ArgumentError` if `compression_dim` is non-positive, if `nnz` is exceeds - `compression_dim`, or if `nnz` is non-positive. +## Throws +- `ArgumentError` if `max_it` is negative. +- `ArgumentError` if `collection_rate` is less than `1`. +- `ArgumentError` if `max_it` is positive and `collection_rate` is greater + than `max_it`. """ struct MALogger <: Logger max_it::Int64 @@ -80,14 +85,33 @@ MALogger(; """ - MALoggerRecipe <: LoggerRecipe + MALoggerRecipe{F<:Function} <: LoggerRecipe - TODO -The recipe contains the information of `MALogger`, stores the error metric - in a vector. Checks convergence of the solver based on the log information. +A mutable structure that contains the fully initialized state and pre-allocated memory + for the `MALogger`. It stores the error metric history and moving average information, + and checks for convergence based on this log. # Fields - +- `max_it::Int64`, the maximum number of iterations for the solver. +- `error::AbstractFloat`, the current error metric (often a moving average of residuals). + Named `error` for compatibility with generic field checks. +- `iota_error::AbstractFloat`, an auxiliary error metric, potentially related to variance + or another aspect of the moving average. +- `iteration::Int64`, the current iteration number of the solver. +- `record_location::Int64`, the index in `hist` and `lambda_hist` to store data. +- `collection_rate::Integer`, the frequency with which progress information is recorded. +- `converged::Bool`, a flag indicating whether the stopping criterion has been met. + Default is `false` upon initialization. +- `ma_info::MAInfo`, [`MAInfo`](@ref) +- `hist::Vector{AbstractFloat}`, a vector storing the history of the primary error metric + (e.g., moving average of squared residuals) at specified `collection_rate` intervals. +- `lambda_origin::Integer`, stores the current lambda value (lambda1 or lambda2) being used by the + moving average calculation at the time of recording. +- `lambda_hist::Vector{Integer}`, a vector storing the history of the `lambda_origin` values + at specified `collection_rate` intervals. +- `threshold_info::Union{Float64, Tuple}`, parameters used by the `stopping_criterion`. +- `stopping_criterion::F` (where `F<:Function`), the function used to evaluate + the stopping criterion. """ mutable struct MALoggerRecipe{F<:Function} <: LoggerRecipe max_it::Int64 @@ -197,14 +221,15 @@ end """ threshold_stop(log::MALoggerRecipe) -Function that takes an input threshold and stops when the most recent entry in the history -vector is less than the threshold. +Default stopping criterion that checks if the current error metric in the logger + is below a specified threshold. # Arguments - - `log::MALoggerRecipe`, a structure containing the logger information +- `log::MALoggerRecipe`, a structure containing the logger information. # Bool - - Returns a Bool indicating if the stopping threshold is satisfied. +- `Bool`, Returns `true` if `log.error` is less than `log.threshold_info`, + indicating the stopping threshold is satisfied. Otherwise, returns `false`. """ function threshold_stop(log::MALoggerRecipe) return log.error < log.threshold_info