From f9a73506307a60a99630d5b42f9670b737aeaf82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20L=C3=B3pez-Ib=C3=A1=C3=B1ez?= <2620021+MLopez-Ibanez@users.noreply.github.com> Date: Thu, 18 May 2023 16:40:37 +0100 Subject: [PATCH] Add fast-path to rtnorm --- R/utils.R | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/R/utils.R b/R/utils.R index ab471fb..8ba2907 100755 --- a/R/utils.R +++ b/R/utils.R @@ -126,6 +126,69 @@ qtnorm <- function(p, mean=0, sd=1, lower=-Inf, upper=Inf, lower.tail=TRUE, log. rtnorm <- function (n, mean = 0, sd = 1, lower = -Inf, upper = Inf) { if (length(n) > 1) n <- length(n) + # Fast-path for frequent case. + if (length(mean) == 1L && length(sd) == 1L && length(lower) == 1L && length(upper) == 1L) { + lower <- (lower - mean) / sd ## Algorithm works on mean 0, sd 1 scale + upper <- (upper - mean) / sd + nas <- is.na(mean) | is.na(sd) | is.na(lower) | is.na(upper) + if (any(nas)) warning("NAs produced") + alg <- if ((lower > upper) && nas) -1L # return NaN + else if ((lower < 0 && upper == Inf) || + (lower == -Inf && upper > 0) || + (is.finite(lower) && is.finite(upper) && (lower < 0) && (upper > 0) && (upper - lower > sqrt(2*pi)))) + 0L # standard "simulate from normal and reject if outside limits" method. Use if bounds are wide. + else if (lower >= 0 && (upper > lower + 2*sqrt(exp(1)) / + (lower + sqrt(lower^2 + 4)) * exp((lower*2 - lower*sqrt(lower^2 + 4)) / 4))) + 1L # rejection sampling with exponential proposal. Use if lower >> mean + else if (upper <= 0 && (-lower > -upper + 2*sqrt(exp(1)) / + (-upper + sqrt(upper^2 + 4)) * exp((upper*2 - -upper*sqrt(upper^2 + 4)) / 4))) + 2L # rejection sampling with exponential proposal. Use if upper << mean. + else 3L # rejection sampling with uniform proposal. Use if bounds are narrow and central. + + ret <- rep_len(NaN, n) + if (alg == -1L) { + return(ret) + } else if (alg == 0L) { + ind.no <- seq_len(n) + while (length(ind.no) > 0) { + y <- rnorm(length(ind.no)) + done <- which(y >= lower & y <= upper) + ret[ind.no[done]] <- y[done] + ind.no <- setdiff(ind.no, ind.no[done]) + } + } else if (alg == 1L) { + ind.expl <- seq_len(n) + a <- (lower + sqrt(lower^2 + 4)) / 2 + while (length(ind.expl) > 0) { + z <- rexp(length(ind.expl), a) + lower + u <- runif(length(ind.expl)) + done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= upper)) + ret[ind.expl[done]] <- z[done] + ind.expl <- setdiff(ind.expl, ind.expl[done]) + } + } else if (alg == 2L) { + ind.expu <- seq_len(n) + a <- (-upper + sqrt(upper^2 +4)) / 2 + while (length(ind.expu) > 0) { + z <- rexp(length(ind.expu), a) - upper + u <- runif(length(ind.expu)) + done <- which((u <= exp(-(z - a)^2 / 2)) & (z <= -lower)) + ret[ind.expu[done]] <- -z[done] + ind.expu <- setdiff(ind.expu, ind.expu[done]) + } + } else { + ind.u <- seq_len(n) + K <- if (lower > 0) lower^2 else if (upper < 0) upper^2 else 0 + while (length(ind.u) > 0) { + z <- runif(length(ind.u), lower, upper) + rho <- exp((K - z^2) / 2) + u <- runif(length(ind.u)) + done <- which(u <= rho) + ret[ind.u[done]] <- z[done] + ind.u <- setdiff(ind.u, ind.u[done]) + } + } + } else { mean <- rep(mean, length=n) sd <- rep(sd, length=n) lower <- rep(lower, length=n) @@ -194,6 +257,7 @@ rtnorm <- function (n, mean = 0, sd = 1, lower = -Inf, upper = Inf) { ind.u <- setdiff(ind.u, ind.u[done]) } stopifnot(length(ind.u) == 0) + } ret*sd + mean }