From 7695f7a4cade5801f330c3e5cf9e2ab260400b0f Mon Sep 17 00:00:00 2001 From: davidwalter2 Date: Sun, 11 Jan 2026 19:08:59 -0500 Subject: [PATCH] Implement BB full with gamma via numerically solving for nbeta --- rabbit/fitter.py | 166 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 148 insertions(+), 18 deletions(-) diff --git a/rabbit/fitter.py b/rabbit/fitter.py index faba62a..2f90680 100644 --- a/rabbit/fitter.py +++ b/rabbit/fitter.py @@ -81,15 +81,6 @@ def __init__( else: self.binByBinStatType = options.binByBinStatType - if ( - self.binByBinStat - and self.binByBinStatMode == "full" - and not self.binByBinStatType.startswith("normal") - ): - raise Exception( - 'bin-by-bin stat only for option "--binByBinStatMode full" with "--binByBinStatType normal"' - ) - if ( options.covarianceFit and self.binByBinStat @@ -231,9 +222,14 @@ def __init__( self.sumw = self.indata.sumw if self.binByBinStatType in ["gamma", "normal-multiplicative"]: - self.kstat = self.sumw**2 / self.varbeta - self.betamask = (self.varbeta == 0.0) | (self.kstat == 0.0) - self.kstat = tf.where(self.betamask, 1.0, self.kstat) + self.betamask = (self.varbeta == 0.0) | (self.sumw == 0.0) + self.kstat = tf.where(self.betamask, 1.0, self.sumw**2 / self.varbeta) + + if self.binByBinStatType == "gamma" and self.binByBinStatMode == "full": + self.nbeta = tf.Variable( + tf.ones_like(self.nobs), trainable=True, name="nbeta" + ) + elif self.binByBinStatType == "normal-additive": # precompute decomposition of composite matrix to speed up # calculation of profiled beta values @@ -1405,13 +1401,74 @@ def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True) if self.chisqFit: if self.binByBinStatType == "gamma": kstat = self.kstat[: self.indata.nbins] + betamask = self.betamask[: self.indata.nbins] - abeta = nexp_profile**2 - bbeta = kstat * self.varnobs - nexp_profile * self.nobs - cbeta = -kstat * self.varnobs * beta0 - beta = solve_quad_eq(abeta, bbeta, cbeta) + if self.binByBinStatMode == "lite": + abeta = nexp_profile**2 + bbeta = kstat * self.varnobs - nexp_profile * self.nobs + cbeta = -kstat * self.varnobs * beta0 + beta = solve_quad_eq(abeta, bbeta, cbeta) + elif self.binByBinStatMode == "full": + norm_profile = norm[: self.indata.nbins] + + # solving nbeta numerically using newtons method (does not work with forward differentiation i.e. use --globalImpacts with --globalImpactsDisableJVP) + def fnll_nbeta(x): + beta = ( + kstat + * beta0 + / ( + kstat + + ((x - self.nobs) / self.varnobs)[..., None] + * norm_profile + ) + ) + beta = tf.where(betamask, beta0, beta) + new_nexp = tf.reduce_sum(beta * norm_profile, axis=-1) + ln = 0.5 * (new_nexp - self.nobs) ** 2 / self.varnobs + lbeta = tf.reduce_sum( + kstat * (beta - beta0) + - kstat + * beta0 + * (tf.math.log(beta) - tf.math.log(beta0)), + axis=-1, + ) + return ln + lbeta + + def body(i, edm): + with tf.GradientTape() as t2: + with tf.GradientTape() as t1: + nll = fnll_nbeta(nexp_profile * self.nbeta) + grad = t1.gradient(nll, self.nbeta) + hess = t2.gradient(grad, self.nbeta) + + eps = 1e-8 + safe_hess = tf.where(hess > 0, hess, tf.ones_like(hess)) + step = grad / (safe_hess + eps) + + self.nbeta.assign_sub(step) + + return i + 1, tf.reduce_max(0.5 * grad * step) + + def cond(i, edm): + return tf.logical_and(i < 50, edm > 1e-10) + + i0 = tf.constant(0) + edm0 = tf.constant(tf.float64.max) + tf.while_loop(cond, body, loop_vars=(i0, edm0)) + + beta = ( + kstat + * beta0 + / ( + kstat + + ( + (nexp_profile * self.nbeta - self.nobs) + / self.varnobs + )[..., None] + * norm_profile + ) + ) - betamask = self.betamask[: self.indata.nbins] beta = tf.where(betamask, beta0, beta) elif self.binByBinStatType == "normal-multiplicative": kstat = self.kstat[: self.indata.nbins] @@ -1560,7 +1617,72 @@ def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True) kstat = self.kstat[: self.indata.nbins] betamask = self.betamask[: self.indata.nbins] - beta = (self.nobs + kstat * beta0) / (nexp_profile + kstat) + if self.binByBinStatMode == "lite": + beta = (self.nobs + kstat * beta0) / (nexp_profile + kstat) + elif self.binByBinStatMode == "full": + norm_profile = norm[: self.indata.nbins] + + # solving nbeta numerically using newtons method (does not work with forward differentiation i.e. use --globalImpacts with --globalImpactsDisableJVP) + def fnll_nbeta(x): + beta = ( + kstat + * beta0 + / ( + norm_profile + + kstat + - (self.nobs / x)[..., None] * norm_profile + ) + ) + beta = tf.where(betamask, beta0, beta) + new_nexp = tf.reduce_sum(beta * norm_profile, axis=-1) + ln = ( + new_nexp + - self.nobs + - self.nobs + * (tf.math.log(new_nexp) - tf.math.log(self.nobs)) + ) + lbeta = tf.reduce_sum( + kstat * (beta - beta0) + - kstat + * beta0 + * (tf.math.log(beta) - tf.math.log(beta0)), + axis=-1, + ) + return ln + lbeta + + def body(i, edm): + with tf.GradientTape() as t2: + with tf.GradientTape() as t1: + nll = fnll_nbeta(nexp_profile * self.nbeta) + grad = t1.gradient(nll, self.nbeta) + hess = t2.gradient(grad, self.nbeta) + + eps = 1e-8 + safe_hess = tf.where(hess > 0, hess, tf.ones_like(hess)) + step = grad / (safe_hess + eps) + self.nbeta.assign_sub(step) + return i + 1, tf.reduce_max(0.5 * grad * step) + + def cond(i, edm): + return tf.logical_and(i < 50, edm > 1e-10) + + i0 = tf.constant(0) + edm0 = tf.constant(tf.float64.max) + tf.while_loop(cond, body, loop_vars=(i0, edm0)) + + beta = ( + kstat + * beta0 + / ( + norm_profile + - (self.nobs / (nexp_profile * self.nbeta))[ + ..., None + ] + * norm_profile + + kstat + ) + ) + beta = tf.where(betamask, beta0, beta) elif self.binByBinStatType == "normal-multiplicative": kstat = self.kstat[: self.indata.nbins] @@ -2063,6 +2185,14 @@ def loss_val_grad_hess_beta(self, profile=True): grad = t1.gradient(val, self.ubeta) hess = t2.jacobian(grad, self.ubeta) + grad = tf.reshape(grad, [-1]) + hess = tf.reshape(hess, [grad.shape[0], grad.shape[0]]) + + betamask = ~tf.reshape(self.betamask, [-1]) + grad = grad[betamask] + hess = tf.boolean_mask(hess, betamask, axis=0) + hess = tf.boolean_mask(hess, betamask, axis=1) + return val, grad, hess def minimize(self):