Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 66 additions & 34 deletions omega/src/estimator/omega.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,79 @@ def get_synthetic_data(l, true_omega, true_overdispersion):
return neg_binom.sample()


def dichotomous_search(point_estimate, func, bound, step=5, tol=1e-3):
def get_bisection_root(f, a, b, tol=1e-6, max_iter=1000):
"""
Find the root of the function f(x) using the bisection method.

Parameters:
f (function): The function for which the root is to be found.
a (float): The left endpoint of the interval.
b (float): The right endpoint of the interval.
tol (float): The tolerance, which determines when to stop searching.
max_iter (int): The maximum number of iterations.

Returns:
float: The estimated root of the function.
"""

# Check if the function changes signs over the interval
if f(a) * f(b) >= 0:
raise ValueError("The function must have different signs at the endpoints a and b.")

# step 1: looking for the upper limit
a = point_estimate
b = a + step
diff = bound - func(b)
while diff > 0:
a = b
b += step
diff = bound - func(b)
eps = abs(a - b)
# Initialize the iteration counter
iteration = 0

# step 2: refine upper limit with dichotomous search
while eps > tol:
middle = (a + b) / 2
if (bound - func(middle)) * (bound - func(b)) > 0:
b = middle
while (b - a) / 2 > tol and iteration < max_iter:
# Find the midpoint
c = (a + b) / 2

# Check if the midpoint is a root or if the interval is sufficiently small
if f(c) == 0:
return c

# Narrow the search interval
if f(c) * f(a) < 0:
b = c
else:
a = middle
eps = abs(a - b)
a = c

# Increment the iteration counter
iteration += 1

upper_limit = a
# Return the midpoint as the best estimate for the root
return (a + b) / 2

# step 3: set a lower limit
a = point_estimate
b = 0.01
eps = abs(a - b)

# step 4: refine lower limit with dichotomous search
while eps > tol:
middle = (a + b) / 2
if (bound - func(middle)) * (bound - func(b)) > 0:
b = middle
else:
a = middle
eps = abs(a - b)

def get_bounds(f, mu_hat, threshold=0):

"""
Find the bounds where the input function f crosses the threshold
with respect to the value at mu_hat.

Parameters:
f (function): function
mu_hat (float): value in f input space
threshold (float): threshold

Returns:
(float, float): (lower, upper) values at which f crosses the theshold
"""

g = lambda x: f(x) - threshold

lower_limit = a
b = mu_hat
while g(mu_hat) * g(b) >= 0:
b += 1
upper = get_bisection_root(g, mu_hat, b)

return lower_limit, upper_limit
a = mu_hat
while g(a) * g(mu_hat) >= 0:
a -= 1

lower = get_bisection_root(g, a, mu_hat)

return lower, upper


@tf.function(autograph=False, experimental_compile=True)
def sampler(num_results, num_burnin_steps, log_prob_func):
Expand Down Expand Up @@ -189,7 +221,7 @@ def twice_llr(w):
alpha = 0.05
chi2 = tfp.distributions.Chi2(1)
llr_boundary = chi2.quantile(1-alpha).numpy()
lower, upper = dichotomous_search(omega_hat, twice_llr, llr_boundary)
lower, upper = get_bounds(twice_llr, omega_hat, threshold=llr_boundary)

return omega_hat.numpy(), lower.numpy(), upper.numpy(), pvalue.numpy(), self.res.numpy()

Expand Down