-
Notifications
You must be signed in to change notification settings - Fork 0
Use different weight strategies (like in Nova) #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f01f254
8d8f5c4
62e1dcb
9a2ead8
691cf05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| #' Calculate time-based sample weights for MAP Bayesian fitting | ||
| #' | ||
| #' Downweights older observations relative to more recent ones during the | ||
| #' iterative MAP Bayesian fitting step. Can be passed as the `weights` | ||
| #' argument to [run_eval()]. | ||
| #' | ||
| #' Available schemes: | ||
| #' - `"weight_all"`: all samples weighted equally (weight = 1). | ||
| #' - `"weight_last_only"`: only the most recent sample is used (weight = 1), | ||
| #' all others are excluded (weight = 0). | ||
| #' - `"weight_last_two_only"`: only the two most recent samples are used. | ||
| #' - `"weight_gradient_linear"`: weights increase linearly from a minimum | ||
| #' (`w1`) for samples older than `t1` days to a maximum (`w2`) for samples | ||
| #' more recent than `t2` days. Accepts optional list element `gradient` with | ||
| #' named elements `t1`, `w1`, `t2`, `w2`. Default: | ||
| #' `list(t1 = 7, w1 = 0, t2 = 2, w2 = 1)`. | ||
| #' - `"weight_gradient_exponential"`: weights decay exponentially with the age | ||
| #' of the sample. Accepts optional list elements `t12_decay` (half-life of | ||
| #' decay in hours, default 48) and `t_start` (delay in hours before decay | ||
| #' starts, default 0). | ||
| #' | ||
| #' For schemes with additional parameters, pass `weights` as a named list with | ||
| #' a `scheme` element plus any scheme-specific elements, e.g.: | ||
| #' ```r | ||
| #' list(scheme = "weight_gradient_exponential", t12_decay = 72) | ||
| #' list(scheme = "weight_gradient_linear", gradient = list(t1 = 5, w1 = 0.1, t2 = 1, w2 = 1)) | ||
| #' ``` | ||
| #' | ||
| #' @param weights weighting scheme: a string with the scheme name, or a named | ||
| #' list with a `scheme` element plus optional scheme-specific parameters. | ||
| #' @param t numeric vector of observation times (in hours) | ||
| #' | ||
| #' @returns numeric vector of weights the same length as `t`, or `NULL` if | ||
| #' `weights` is `NULL`. | ||
| #' @export | ||
| calculate_fit_weights <- function(weights = NULL, t = NULL) { | ||
| if (is.null(weights) || is.null(t)) return(NULL) | ||
|
|
||
| scheme <- if (is.list(weights)) weights$scheme else weights | ||
|
|
||
| valid_schemes <- c( | ||
| "weight_gradient_linear", | ||
| "weight_gradient_exponential", | ||
| "weight_last_only", | ||
| "weight_last_two_only", | ||
| "weight_all" | ||
| ) | ||
|
|
||
| if (!scheme %in% valid_schemes) { | ||
| warning("Weighting scheme not recognized, ignoring weights.") | ||
| return(NULL) | ||
| } | ||
|
|
||
| weight_vec <- NULL | ||
|
|
||
| if (scheme == "weight_gradient_linear") { | ||
| gradient <- list(t1 = 7, w1 = 0, t2 = 2, w2 = 1) | ||
| if (is.list(weights) && !is.null(weights$gradient)) { | ||
| gradient[names(weights$gradient)] <- weights$gradient | ||
| } | ||
| if (gradient$t2 > gradient$t1) { | ||
| warning( | ||
| "weight_gradient_linear: t2 (", gradient$t2, ") > t1 (", gradient$t1, | ||
| "). t1 should be the older threshold and t2 the more recent one." | ||
| ) | ||
| } | ||
| t_start <- max(c(0, max(t) - gradient$t1 * 24)) | ||
| t_end <- max(c(0, max(t) - gradient$t2 * 24)) | ||
| if (t_end <= t_start) { | ||
| weight_vec <- ifelse(t >= t_end, gradient$w2, gradient$w1) | ||
| } else { | ||
| weight_vec <- ifelse( | ||
| t <= t_start, gradient$w1, | ||
| ifelse( | ||
| t >= t_end, gradient$w2, | ||
| gradient$w1 + (gradient$w2 - gradient$w1) * (t - t_start) / (t_end - t_start) | ||
| ) | ||
| ) | ||
| } | ||
| } | ||
|
|
||
| if (scheme == "weight_gradient_exponential") { | ||
| t12_decay <- if (is.list(weights) && !is.null(weights$t12_decay)) weights$t12_decay else 48 | ||
| k_decay <- log(2) / t12_decay | ||
| t_diff <- max(t) - t | ||
| if (is.list(weights) && !is.null(weights$t_start)) { | ||
| t_diff <- t_diff - weights$t_start | ||
| t_diff <- ifelse(t_diff < 0, 0, t_diff) | ||
| } | ||
| weight_vec <- exp(-k_decay * t_diff) | ||
roninsightrx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| if (scheme == "weight_last_only") { | ||
| weight_vec <- rep(0, length(t)) | ||
| weight_vec[which.max(t)] <- 1 | ||
| } | ||
|
|
||
| if (scheme == "weight_last_two_only") { | ||
| weight_vec <- rep(0, length(t)) | ||
| ranked <- order(t, decreasing = TRUE) | ||
| weight_vec[ranked[1]] <- 1 | ||
| if (length(t) > 1) weight_vec[ranked[2]] <- 1 | ||
| } | ||
|
|
||
| if (scheme == "weight_all") { | ||
| weight_vec <- rep(1, length(t)) | ||
| } | ||
|
|
||
| if (!is.null(weight_vec)) { | ||
| weight_vec[t < 0] <- 0 | ||
| } | ||
|
|
||
| weight_vec | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.