Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 148 additions & 18 deletions rabbit/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,6 @@ def __init__(
else:
self.binByBinStatType = options.binByBinStatType

if (
self.binByBinStat
and self.binByBinStatMode == "full"
and not self.binByBinStatType.startswith("normal")
):
raise Exception(
'bin-by-bin stat only for option "--binByBinStatMode full" with "--binByBinStatType normal"'
)

if (
options.covarianceFit
and self.binByBinStat
Expand Down Expand Up @@ -231,9 +222,14 @@ def __init__(
self.sumw = self.indata.sumw

if self.binByBinStatType in ["gamma", "normal-multiplicative"]:
self.kstat = self.sumw**2 / self.varbeta
self.betamask = (self.varbeta == 0.0) | (self.kstat == 0.0)
self.kstat = tf.where(self.betamask, 1.0, self.kstat)
self.betamask = (self.varbeta == 0.0) | (self.sumw == 0.0)
self.kstat = tf.where(self.betamask, 1.0, self.sumw**2 / self.varbeta)

if self.binByBinStatType == "gamma" and self.binByBinStatMode == "full":
self.nbeta = tf.Variable(
tf.ones_like(self.nobs), trainable=True, name="nbeta"
)

elif self.binByBinStatType == "normal-additive":
# precompute decomposition of composite matrix to speed up
# calculation of profiled beta values
Expand Down Expand Up @@ -1405,13 +1401,74 @@ def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True)
if self.chisqFit:
if self.binByBinStatType == "gamma":
kstat = self.kstat[: self.indata.nbins]
betamask = self.betamask[: self.indata.nbins]

abeta = nexp_profile**2
bbeta = kstat * self.varnobs - nexp_profile * self.nobs
cbeta = -kstat * self.varnobs * beta0
beta = solve_quad_eq(abeta, bbeta, cbeta)
if self.binByBinStatMode == "lite":
abeta = nexp_profile**2
bbeta = kstat * self.varnobs - nexp_profile * self.nobs
cbeta = -kstat * self.varnobs * beta0
beta = solve_quad_eq(abeta, bbeta, cbeta)
elif self.binByBinStatMode == "full":
norm_profile = norm[: self.indata.nbins]

# solving nbeta numerically using newtons method (does not work with forward differentiation i.e. use --globalImpacts with --globalImpactsDisableJVP)
def fnll_nbeta(x):
beta = (
kstat
* beta0
/ (
kstat
+ ((x - self.nobs) / self.varnobs)[..., None]
* norm_profile
)
)
beta = tf.where(betamask, beta0, beta)
new_nexp = tf.reduce_sum(beta * norm_profile, axis=-1)
ln = 0.5 * (new_nexp - self.nobs) ** 2 / self.varnobs
lbeta = tf.reduce_sum(
kstat * (beta - beta0)
- kstat
* beta0
* (tf.math.log(beta) - tf.math.log(beta0)),
axis=-1,
)
return ln + lbeta

def body(i, edm):
with tf.GradientTape() as t2:
with tf.GradientTape() as t1:
nll = fnll_nbeta(nexp_profile * self.nbeta)
grad = t1.gradient(nll, self.nbeta)
hess = t2.gradient(grad, self.nbeta)

eps = 1e-8
safe_hess = tf.where(hess > 0, hess, tf.ones_like(hess))
step = grad / (safe_hess + eps)

self.nbeta.assign_sub(step)

return i + 1, tf.reduce_max(0.5 * grad * step)

def cond(i, edm):
return tf.logical_and(i < 50, edm > 1e-10)

i0 = tf.constant(0)
edm0 = tf.constant(tf.float64.max)
tf.while_loop(cond, body, loop_vars=(i0, edm0))

beta = (
kstat
* beta0
/ (
kstat
+ (
(nexp_profile * self.nbeta - self.nobs)
/ self.varnobs
)[..., None]
* norm_profile
)
)

betamask = self.betamask[: self.indata.nbins]
beta = tf.where(betamask, beta0, beta)
elif self.binByBinStatType == "normal-multiplicative":
kstat = self.kstat[: self.indata.nbins]
Expand Down Expand Up @@ -1560,7 +1617,72 @@ def _compute_yields_with_beta(self, profile=True, compute_norm=False, full=True)
kstat = self.kstat[: self.indata.nbins]
betamask = self.betamask[: self.indata.nbins]

beta = (self.nobs + kstat * beta0) / (nexp_profile + kstat)
if self.binByBinStatMode == "lite":
beta = (self.nobs + kstat * beta0) / (nexp_profile + kstat)
elif self.binByBinStatMode == "full":
norm_profile = norm[: self.indata.nbins]

# solving nbeta numerically using newtons method (does not work with forward differentiation i.e. use --globalImpacts with --globalImpactsDisableJVP)
def fnll_nbeta(x):
beta = (
kstat
* beta0
/ (
norm_profile
+ kstat
- (self.nobs / x)[..., None] * norm_profile
)
)
beta = tf.where(betamask, beta0, beta)
new_nexp = tf.reduce_sum(beta * norm_profile, axis=-1)
ln = (
new_nexp
- self.nobs
- self.nobs
* (tf.math.log(new_nexp) - tf.math.log(self.nobs))
)
lbeta = tf.reduce_sum(
kstat * (beta - beta0)
- kstat
* beta0
* (tf.math.log(beta) - tf.math.log(beta0)),
axis=-1,
)
return ln + lbeta

def body(i, edm):
with tf.GradientTape() as t2:
with tf.GradientTape() as t1:
nll = fnll_nbeta(nexp_profile * self.nbeta)
grad = t1.gradient(nll, self.nbeta)
hess = t2.gradient(grad, self.nbeta)

eps = 1e-8
safe_hess = tf.where(hess > 0, hess, tf.ones_like(hess))
step = grad / (safe_hess + eps)
self.nbeta.assign_sub(step)
return i + 1, tf.reduce_max(0.5 * grad * step)

def cond(i, edm):
return tf.logical_and(i < 50, edm > 1e-10)

i0 = tf.constant(0)
edm0 = tf.constant(tf.float64.max)
tf.while_loop(cond, body, loop_vars=(i0, edm0))

beta = (
kstat
* beta0
/ (
norm_profile
- (self.nobs / (nexp_profile * self.nbeta))[
..., None
]
* norm_profile
+ kstat
)
)

beta = tf.where(betamask, beta0, beta)
elif self.binByBinStatType == "normal-multiplicative":
kstat = self.kstat[: self.indata.nbins]
Expand Down Expand Up @@ -2063,6 +2185,14 @@ def loss_val_grad_hess_beta(self, profile=True):
grad = t1.gradient(val, self.ubeta)
hess = t2.jacobian(grad, self.ubeta)

grad = tf.reshape(grad, [-1])
hess = tf.reshape(hess, [grad.shape[0], grad.shape[0]])

betamask = ~tf.reshape(self.betamask, [-1])
grad = grad[betamask]
hess = tf.boolean_mask(hess, betamask, axis=0)
hess = tf.boolean_mask(hess, betamask, axis=1)

return val, grad, hess

def minimize(self):
Expand Down