Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/api/loggers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ BasicLogger

BasicLoggerRecipe

MALogger

MALoggerRecipe

```

### Moving average logger structures
```@docs
MAInfo

```

## Exported Functions
Expand All @@ -27,4 +37,6 @@ update_logger!
reset_logger!

threshold_stop

update_ma!
```
2 changes: 2 additions & 0 deletions docs/src/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,5 @@ @article{pritchard2024solving
urldate = {2025-01-24},
abstract = {Abstract In large-scale applications including medical imaging, collocation differential equation solvers, and estimation with differential privacy, the underlying linear inverse problem can be reformulated as a streaming problem. In theory, the streaming problem can be effectively solved using memory-efficient, exponentially-converging streaming solvers. In special cases when the underlying linear inverse problem is finite-dimensional, streaming solvers can periodically evaluate the residual norm at a substantial computational cost. When the underlying system is infinite dimensional, streaming solver can only access noisy estimates of the residual. While such noisy estimates are computationally efficient, they are useful only when their accuracy is known. In this work, we rigorously develop a general family of computationally-practical residual estimators and their uncertainty sets for streaming solvers, and we demonstrate the accuracy of our methods on a number of large-scale linear problems. Thus, we further enable the practical use of streaming solvers for important classes of linear inverse problems.}
}


5 changes: 3 additions & 2 deletions src/RLinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ export complete_solver, update_solver!, rsolve, rsolve!

# Export Logger types and functions
export Logger, LoggerRecipe
export BasicLogger, BasicLoggerRecipe
export complete_logger, update_logger!, reset_logger!
export MAInfo
export BasicLogger, BasicLoggerRecipe, MALogger, MALoggerRecipe
export complete_logger, update_logger!, reset_logger!, update_ma!
export threshold_stop

# Export SubSolver types
Expand Down
11 changes: 9 additions & 2 deletions src/Solvers/Loggers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ function reset_logger!(logger::LoggerRecipe)
return nothing
end

##############################
# Include Logger Files
################################################
# Include Logger Moving Average helpers Files
################################################
include("Loggers/ma_helpers/ma_info.jl")

###############################
# Include Logger Methods Files
###############################
include("Loggers/basic_logger.jl")
include("Loggers/moving_average_logger.jl")

2 changes: 1 addition & 1 deletion src/Solvers/Loggers/basic_logger.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
BasicLogger <: Logger

This is a mutable struct that contains the `max_it` parameter and stores the error metric
This is a struct that contains the `max_it` parameter and stores the error metric
in a vector. Checks convergence of the solver based on the log information.

# Fields
Expand Down
148 changes: 148 additions & 0 deletions src/Solvers/Loggers/ma_helpers/ma_info.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# This file contains the components that are needed for storing
# and update moving average information for moving average method:
# Structs: MAInfo
# Functions: update_ma!

#########################################
# Structs
#########################################
"""
MAInfo

A mutable structure that stores information relevant to the moving average (MA) of a
progress estimator, such as a residual. It manages different MA window widths
(`lambda1`, `lambda2`) for different convergence phases and tracks the current
MA window (`res_window`).

See [pritchard2024solving](@cite) for more information on the underlying MA methods.


# Fields
- `lambda1::Integer`, the width of the moving average during the fast convergence phase of the algorithm.
During this fast convergence phase, the majority of variation of the sketched estimator comes from
improvement in the solution and thus wide moving average windows inaccurately represent progress.
- `lambda2::Integer`, the width of the moving average in the slower convergence phase. In the slow convergence
phase, each iterate differs from the previous one by a small amount and thus most of the observed variation
arises from the randomness of the sketched progress estimator, which is best smoothed by a wide moving
average width.
- `lambda::Integer`, the actual width of the moving average being used at the current iteration.
This field is updated internally and is not set directly by the user.
- `flag::Bool`, a boolean flag indicating the current convergence phase. A value of `true`
typically indicates the "slow" convergence phase (using `lambda2`).
- `idx::Integer`, the current index within the `res_window` buffer where the next residual value
will be stored, implementing a circular buffer.
- `res_window::Vector{<:AbstractFloat}`, the buffer storing the recent residual values
used to compute the moving average.
"""
mutable struct MAInfo
lambda1::Integer
lambda2::Integer
lambda::Integer
flag::Bool
idx::Integer
res_window::Vector{<:AbstractFloat}
end

#########################################
# Functions
#########################################
"""
update_ma!(
log::LoggerRecipe,
res::Union{AbstractVector, Real},
lambda_base::Integer,
iter::Integer
)

Updates the moving average statistics stored within the `log.ma_info` field of a
`MALoggerRecipe`. It computes the moving average and second moment (iota) of the
provided residual `res` and updates `log.error`, `log.iota_error`, and
`log.lambda_origin` fields of the `MALoggerRecipe`.

# Arguments
- `log::LoggerRecipe`, the parent of moving average logger recipe structure.
- `res::Union{AbstractVector, Real}`, the sketched residual for the current iteration.
- `lambda_base::Integer`, which lambda, between lambda1 and lambda2, is currently being used.
- `iter::Integer`, the current iteration.

# Returns
- `nothing` (The `log` object, specifically its `ma_info`, `error`, `iota_error`, and
`lambda_origin` fields, is modified in-place).
"""
function update_ma!(
log::LoggerRecipe,
res::Union{AbstractVector,Real},
lambda_base::Integer,
iter::Integer,
)
# Variable to store the sum of the terms for rho
accum = 0
# Variable to store the sum of the terms for iota
accum2 = 0
ma_info = log.ma_info
ma_info.idx = ma_info.idx < ma_info.lambda2 && iter != 0 ? ma_info.idx + 1 : 1
ma_info.res_window[ma_info.idx] = res
#Check if entire storage buffer can be used
if ma_info.lambda == ma_info.lambda2
# Compute the moving average
for i in 1:(ma_info.lambda2)
accum += ma_info.res_window[i]
accum2 += ma_info.res_window[i]^2
end

Check warning on line 91 in src/Solvers/Loggers/ma_helpers/ma_info.jl

View check run for this annotation

Codecov / codecov/patch

src/Solvers/Loggers/ma_helpers/ma_info.jl#L88-L91

Added lines #L88 - L91 were not covered by tests

# Record the moving average error for stopping
log.lambda_origin = ma_info.lambda
log.error = accum / ma_info.lambda
log.iota_error = accum2 / ma_info.lambda

Check warning on line 96 in src/Solvers/Loggers/ma_helpers/ma_info.jl

View check run for this annotation

Codecov / codecov/patch

src/Solvers/Loggers/ma_helpers/ma_info.jl#L94-L96

Added lines #L94 - L96 were not covered by tests

# if mod(iter, log.collection_rate) == 0 || iter == 0
# push!(log.lambda_hist, ma_info.lambda)
# push!(log.hist, accum / ma_info.lambda)
# (:iota_hist in fieldnames(typeof(log))) &&
# push!(log.iota_hist, accum2 / ma_info.lambda)
# end

else
# Consider the case when lambda <= lambda1 or lambda1 < lambda < lambda2
diff = ma_info.idx - ma_info.lambda
# Because the storage of the residual is based dependent on lambda2 and
# we want to sum only the previous lamdda terms we could have a situation
# where we want the first `idx` terms of the buffer and the last `diff`
# terms of the buffer. Doing this requires two loops
# If `diff` is negative there idx is not far enough into the buffer and
# two sums will be needed
startp1 = diff < 0 ? 1 : (diff + 1)

# Assuming that the width of the buffer is lambda2
startp2 = diff < 0 ? ma_info.lambda2 + diff + 1 : 2
endp2 = diff < 0 ? ma_info.lambda2 : 1

# Compute the moving average two loop setup required when lambda < lambda2
for i in startp1:(ma_info.idx)
accum += ma_info.res_window[i]
accum2 += ma_info.res_window[i]^2
end

for i in startp2:endp2
accum += ma_info.res_window[i]
accum2 += ma_info.res_window[i]^2
end

Check warning on line 129 in src/Solvers/Loggers/ma_helpers/ma_info.jl

View check run for this annotation

Codecov / codecov/patch

src/Solvers/Loggers/ma_helpers/ma_info.jl#L127-L129

Added lines #L127 - L129 were not covered by tests

# Record the moving average error for stopping
log.lambda_origin = ma_info.lambda
log.error = accum / ma_info.lambda
log.iota_error = accum2 / ma_info.lambda

#Update the log variable with the information for this update
# if mod(iter, log.collection_rate) == 0 || iter == 0
# push!(log.lambda_hist, ma_info.lambda)
# push!(log.hist, accum / ma_info.lambda)
# (:iota_hist in fieldnames(typeof(log))) &&
# push!(log.iota_hist, accum2 / ma_info.lambda)
# end

ma_info.lambda += ma_info.lambda < lambda_base ? 1 : 0
end

return nothing
end
Loading
Loading