diff --git a/omega/src/estimator/omega.py b/omega/src/estimator/omega.py index 944beb3..c7c6cf8 100644 --- a/omega/src/estimator/omega.py +++ b/omega/src/estimator/omega.py @@ -32,47 +32,79 @@ def get_synthetic_data(l, true_omega, true_overdispersion): return neg_binom.sample() -def dichotomous_search(point_estimate, func, bound, step=5, tol=1e-3): +def get_bisection_root(f, a, b, tol=1e-6, max_iter=1000): + """ + Find the root of the function f(x) using the bisection method. + + Parameters: + f (function): The function for which the root is to be found. + a (float): The left endpoint of the interval. + b (float): The right endpoint of the interval. + tol (float): The tolerance, which determines when to stop searching. + max_iter (int): The maximum number of iterations. + + Returns: + float: The estimated root of the function. + """ + + # Check if the function changes signs over the interval + if f(a) * f(b) >= 0: + raise ValueError("The function must have different signs at the endpoints a and b.") - # step 1: looking for the upper limit - a = point_estimate - b = a + step - diff = bound - func(b) - while diff > 0: - a = b - b += step - diff = bound - func(b) - eps = abs(a - b) + # Initialize the iteration counter + iteration = 0 - # step 2: refine upper limit with dichotomous search - while eps > tol: - middle = (a + b) / 2 - if (bound - func(middle)) * (bound - func(b)) > 0: - b = middle + while (b - a) / 2 > tol and iteration < max_iter: + # Find the midpoint + c = (a + b) / 2 + + # Check if the midpoint is a root or if the interval is sufficiently small + if f(c) == 0: + return c + + # Narrow the search interval + if f(c) * f(a) < 0: + b = c else: - a = middle - eps = abs(a - b) + a = c + + # Increment the iteration counter + iteration += 1 - upper_limit = a + # Return the midpoint as the best estimate for the root + return (a + b) / 2 - # step 3: set a lower limit - a = point_estimate - b = 0.01 - eps = abs(a - b) - - # step 4: refine lower limit with dichotomous search - while eps > tol: - middle = (a + b) / 2 - if (bound - func(middle)) * (bound - func(b)) > 0: - b = middle - else: - a = middle - eps = abs(a - b) + +def get_bounds(f, mu_hat, threshold=0): + + """ + Find the bounds where the input function f crosses the threshold + with respect to the value at mu_hat. + + Parameters: + f (function): function + mu_hat (float): value in f input space + threshold (float): threshold + + Returns: + (float, float): (lower, upper) values at which f crosses the theshold + """ + + g = lambda x: f(x) - threshold - lower_limit = a + b = mu_hat + while g(mu_hat) * g(b) >= 0: + b += 1 + upper = get_bisection_root(g, mu_hat, b) - return lower_limit, upper_limit + a = mu_hat + while g(a) * g(mu_hat) >= 0: + a -= 1 + lower = get_bisection_root(g, a, mu_hat) + + return lower, upper + @tf.function(autograph=False, experimental_compile=True) def sampler(num_results, num_burnin_steps, log_prob_func): @@ -189,7 +221,7 @@ def twice_llr(w): alpha = 0.05 chi2 = tfp.distributions.Chi2(1) llr_boundary = chi2.quantile(1-alpha).numpy() - lower, upper = dichotomous_search(omega_hat, twice_llr, llr_boundary) + lower, upper = get_bounds(twice_llr, omega_hat, threshold=llr_boundary) return omega_hat.numpy(), lower.numpy(), upper.numpy(), pvalue.numpy(), self.res.numpy()