diff --git a/aepsych/__init__.py b/aepsych/__init__.py index ee68af6d5..1d00e0c69 100644 --- a/aepsych/__init__.py +++ b/aepsych/__init__.py @@ -11,6 +11,7 @@ from . import acquisition, config, factory, generators, models, strategy, utils from .config import Config +from .distributions import RTDistWithUniformLapseRate, LogNormalDDMDistribution, ShiftedGammaDDMDistribution, ShiftedInverseGammaDDMDistribution, ShiftedLogNormalDDMDistribution from .likelihoods import BernoulliObjectiveLikelihood from .models import GPClassificationModel from .strategy import SequentialStrategy, Strategy @@ -31,6 +32,11 @@ "BernoulliObjectiveLikelihood", "BernoulliLikelihood", "GaussianLikelihood", + "RTDistWithUniformLapseRate", + "LogNormalDDMDistribution", + "ShiftedGammaDDMDistribution", + "ShiftedInverseGammaDDMDistribution", + "ShiftedLogNormalDDMDistribution", ] try: diff --git a/aepsych/distributions.py b/aepsych/distributions.py new file mode 100644 index 000000000..f97da32ab --- /dev/null +++ b/aepsych/distributions.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import warnings + + +from logging import getLogger +from numbers import Number + +import torch +from torch.distributions import Bernoulli, Exponential, LogNormal, Normal +from gpytorch import constraints +from gpytorch.distributions import Distribution +from torch.distributions import ( + Gamma, + TransformedDistribution, + Uniform, +) +from torch.distributions.transforms import AffineTransform, ExpTransform, PowerTransform +from torch.distributions.utils import broadcast_all + +logger = getLogger() + + +class RTDistWithUniformLapseRate(Distribution): + arg_constraints = { + "lapse_rate": constraints.Interval(1e-5, 0.2), + "max_rt": constraints.Positive(), + } + + def __init__( + self, lapse_rate, max_rt, base_dist, validate_args=False, **kwargs + ): + self.lapse_rate = lapse_rate + self.base_dist = base_dist + self.max_rt = max_rt + self.lapse_dist = Uniform(-self.max_rt, self.max_rt) + self.p_lapse_dist = Bernoulli(self.lapse_rate) + super().__init__(**kwargs, validate_args=validate_args) + + @property + def mean(self): + return ( + self.lapse_rate * self.lapse_dist.mean + + (1 - self.lapse_rate) * self.base_dist.mean + ) + + def log_prob(self, rts): + # rt whose p=0 will have logp=nan, replace with -1000 which will exp() to 0 anyway + # in logsumexp + rt_logp = torch.nan_to_num(self.base_dist.log_prob(rts), nan=-1000) + lapse_logp = self.lapse_dist.log_prob(rts) + + [*batch_shape, rt_shape] = rt_logp.shape + assert rt_shape == lapse_logp.shape[0] + lapse_logp = lapse_logp.expand(*batch_shape, -1) + + mix_logps = torch.stack( + ( + lapse_logp + torch.log(self.lapse_rate), + rt_logp + torch.log(1 - self.lapse_rate), + ), + dim=-1, + ) + return torch.logsumexp(mix_logps, dim=-1) + + def sample(self, sample_shape=torch.Size([])): # noqa B008 + rt_samps = self.base_dist.sample(sample_shape=sample_shape) + unif_samps = self.lapse_dist.rsample(sample_shape=rt_samps.shape) + coinflips = self.p_lapse_dist.sample(sample_shape=rt_samps.shape).int()[..., 0] + return torch.where(coinflips == 1, unif_samps, rt_samps) + + +class ExGaussian(Distribution): + def __init__(self, mean, stddev, lam, validate_args=False, *args, **kwargs): + self.mean = mean + self.stddev = stddev + self.lam = lam + + super().__init__(**kwargs, validate_args=validate_args) + + def log_prob(self, x): + """ + Same as PyMC + """ + res = torch.where( + self.lam > 0.05 * self.stddev, + -torch.log(self.lam) + + (self.mean - x) / self.lam + + 0.5 * (self.stddev / self.lam) ** 2 + + torch.log( + Normal(loc=self.mean + (self.stddev**2) / self.lam, scale=self.stddev**2).cdf(x) + ), + LogNormal(loc=self.mean, scale=self.stddev**2).log_prob(x), + ) + return res + + def rsample(self, sample_shape=torch.Size()): # noqa B008 + return Normal(loc=self.mean, scale=self.stddev).rsample( + sample_shape=sample_shape + ) + Exponential(rate=self.lam).rsample(sample_shape=sample_shape) + + +class ShiftedGamma(TransformedDistribution): + r""" + Creates a shifted log-normal distribution parameterized by + :attr:`shift` and, :attr:`loc`, and :attr:`scale` where:: + + """ + arg_constraints = { + "concentration": constraints.Positive(), + "rate": constraints.Positive(), + } + support = constraints.Positive() + has_rsample = True + + def __init__( + self, shift, concentration, rate, validate_args=False, **kwargs + ): + base_dist = Gamma(concentration, rate, validate_args=validate_args) + self.shift = shift + super().__init__( + base_dist, + [AffineTransform(loc=shift, scale=torch.tensor(1.0))], + validate_args=validate_args, + **kwargs, + ) + + @property + def concentration(self): + return self.base_dist.concentration + + @property + def rate(self): + return self.base_dist.rate + + def log_prob(self, X): + return torch.where(X < self.shift, torch.nan, super().log_prob(X)) + + +class ShiftedInverseGamma(TransformedDistribution): + r""" + Creates a shifted log-normal distribution parameterized by + :attr:`shift` and, :attr:`loc`, and :attr:`scale` where:: + + """ + arg_constraints = { + "concentration": constraints.Positive(), + "rate": constraints.Positive(), + } + support = constraints.Positive() + has_rsample = True + + def __init__( + self, shift, concentration, rate, validate_args=False, **kwargs + ): + base_dist = Gamma(concentration, rate, validate_args=validate_args) + self.shift = shift + super(ShiftedInverseGamma, self).__init__( + base_dist, + [ + PowerTransform(exponent=torch.tensor(-1.0)), + AffineTransform(loc=shift, scale=torch.tensor(1.0)), + ], + validate_args=validate_args, + **kwargs, + ) + + @property + def concentration(self): + return self.base_dist.concentration + + @property + def rate(self): + return self.base_dist.rate + + +class ShiftedLognormal(TransformedDistribution): + r""" + Creates a shifted log-normal distribution parameterized by + :attr:`shift` and, :attr:`loc`, and :attr:`scale` where:: + + """ + arg_constraints = { + "scale": constraints.Positive(), + } + support = constraints.Positive() + has_rsample = True + + def __init__(self, shift, loc, scale, validate_args=False): + base_dist = Normal(loc, scale, validate_args=validate_args) + self.shift = shift + super(ShiftedLognormal, self).__init__( + base_dist, + [ExpTransform(), AffineTransform(loc=shift, scale=torch.tensor(1.0))], + validate_args=validate_args, + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ShiftedLognormal, _instance) + return super(ShiftedLognormal, self).expand(batch_shape, _instance=new) + + @property + def loc(self): + return self.base_dist.loc + + @property + def scale(self): + return self.base_dist.scale + + @property + def mean(self): + return (self.loc + self.scale.pow(2) / 2).exp() + self.shift + + @property + def variance(self): + return (self.scale.pow(2).exp() - 1) * (2 * self.loc + self.scale.pow(2)).exp() + + +def coth(x): + # probably not numerically terrific + return torch.cosh(x) / torch.sinh(x) + + +def csch(x): + # probably not numerically terrific + return 1 / torch.sinh(x) + +class DDMMomentMatchDistribution(Distribution): + """ + Distribution over [choices, rts]. However, since rts are always positive, we use the sign of the RTs + to track choice information (so rt>0 means yes choice, rt<0 means no choice) which means as far as gpytorch knows, + we still have a univariate outcome. + There's basically 2 steps here: + 1. Use the DDM parameters to compute moments of the conditional RT distributions + and choice probabilities using the expressions in https://mae.princeton.edu/sites/default/files/SrivastHSimen-JMatPsy16.pdf. + This is what the base class does. + 2. Moment-match the moments to some nicer distribution, and pretend that's our likelihood. That + is what subclasses do. + """ + + SMALL_DRIFT_CUTOFF = ( + 1e-2 # use this as cutoff to use asymptotic drift -> 0 expressions + ) + arg_constraints = { + "threshold": constraints.Positive(), + "relative_x0": constraints.Interval(0.2, 0.8), + "t0": constraints.Positive(), + } + support = constraints.Positive() + + def __init__( + self, drift, threshold, relative_x0, t0, restrict_skew=False, max_shift=None + ): + + self.drift = drift + self.threshold = threshold + self.max_shift = max_shift + + # relative x0 is scaled 0 to 1, x0 is -thresh to thresh + # boundarySep = thresh * 2 + # relativeInitCond = (x0+z) / boundarySep + # boundarySep * relativeInitCond = x0+z + + self.x0 = threshold * (2 * relative_x0 - 1) + + kz = drift * threshold + kx = drift * self.x0 + + near_zero_drift = drift.abs() < self.SMALL_DRIFT_CUTOFF + + # as abs(drift) -> 0, use different expressions (expr 30 and 32) + rt_mean_yes0 = (4 * threshold**2 - (threshold + self.x0) ** 2) / 3 + rt_mean_no0 = (4 * threshold**2 - (threshold - self.x0) ** 2) / 3 + rt_var_yes0 = (32 * threshold**4 - 2 * (threshold + self.x0) ** 4) / 45 + rt_var_no0 = (32 * threshold**4 - 2 * (threshold - self.x0) ** 4) / 45 + + # for nonzero drift, expr 29 and 31 + self.rt_mean_yes = ( + torch.where( + near_zero_drift, + rt_mean_yes0, + drift ** (-2) * ((2 * kz * coth(2 * kz)) - (kx + kz) * coth(kx + kz)), + ) + + t0 + ) + self.rt_mean_no = ( + torch.where( + near_zero_drift, + rt_mean_no0, + drift ** (-2) * ((2 * kz * coth(2 * kz)) - (-kx + kz) * coth(-kx + kz)), + ) + + t0 + ) + self.rt_var_yes = torch.where( + near_zero_drift, + rt_var_yes0, + drift ** (-4) + * ( + 4 * kz**2 * csch(2 * kz) ** 2 + + 2 * kz * coth(2 * kz) + - (kx + kz) ** 2 * csch(kx + kz) ** 2 + - (kx + kz) * coth(kx + kz) + ), + ) + self.rt_var_no = torch.where( + near_zero_drift, + rt_var_no0, + drift ** (-4) + * ( + 4 * kz**2 * csch(2 * kz) ** 2 + + 2 * kz * coth(2 * kz) + - (-kx + kz) ** 2 * csch(-kx + kz) ** 2 + - (-kx + kz) * coth(-kx + kz) + ), + ) + + # expr 36 + rt_3rd_moment_yes = drift ** (-6) * ( + 12 * kz**2 * csch(2 * kz) ** 2 + + 16 * kz**3 * coth(2 * kz) * csch(2 * kz) ** 2 + + 6 * kz * coth(2 * kz) + - 3 * (kz + kx) ** 2 * csch(kx + kz) ** 2 + - 2 * (kx + kz) ** 3 * coth(kz + kx) * csch(kz + kx) ** 2 + - 3 * (kx + kz) * coth(kx + kz) + ) + rt_3rd_moment_no = drift ** (-6) * ( + 12 * kz**2 * csch(2 * kz) ** 2 + + 16 * kz**3 * coth(2 * kz) * csch(2 * kz) ** 2 + + 6 * kz * coth(2 * kz) + - 3 * (kz - kx) ** 2 * csch(kz - kx) ** 2 + - 2 * (-kx + kz) ** 3 * coth(kz - kx) * csch(kz - kx) ** 2 + - 3 * (-kx + kz) * coth(-kx + kz) + ) + rt_skew_yes = rt_3rd_moment_yes / self.rt_var_yes ** (3 / 2) + rt_skew_no = rt_3rd_moment_no / self.rt_var_no ** (3 / 2) + + # expr 37 + # np.sqrt(45/2) = 4.743416490252569 + SQRT45_2 = 4.743416490252569 + rt_skew_yes0 = SQRT45_2 * ( + (8 * (64 * threshold**6 - (threshold + self.x0) ** 6)) + / (21 * (16 * threshold**4 - (threshold + self.x0) ** 4) ** (3 / 2)) + ) + rt_skew_no0 = SQRT45_2 * ( + (8 * (64 * threshold**6 - (threshold - self.x0) ** 6)) + / (21 * (16 * threshold**4 - (threshold - self.x0) ** 4) ** (3 / 2)) + ) + + self.rt_skew_yes = torch.where(near_zero_drift, rt_skew_yes0, rt_skew_yes) + self.rt_skew_no = torch.where(near_zero_drift, rt_skew_no0, rt_skew_no) + + # expr 6 and 9 + self.response_prob = torch.where( + near_zero_drift, + (threshold - self.x0) / (2 * threshold), + 1 + - (torch.exp(-2 * kx) - torch.exp(-2 * kz)) + / (torch.exp(2 * kz) - torch.exp(-2 * kz)), + ) + + # these will fail if numerical stability is bad, clamp them + self.response_prob = self.response_prob.clamp(min=1e-5, max=1 - 1e-5) + self.rt_var_yes = self.rt_var_yes.clamp(min=1e-5) + self.rt_var_no = self.rt_var_no.clamp(min=1e-5) + self.rt_mean_yes = self.rt_mean_yes.clamp(min=1e-5) + self.rt_mean_no = self.rt_mean_no.clamp(min=1e-5) + if restrict_skew: + self.rt_skew_yes = self.rt_skew_yes.clamp(min=0.01, max=10) + self.rt_skew_no = self.rt_skew_no.clamp(min=0.01, max=10) + + self._make_moment_matched_likelihood() + + def _make_moment_matched_likelihood(self): + raise NotImplementedError + + @property + def mean(self): + return ( + self.response_prob * self.rt_mean_yes + + (1 - self.response_prob) * self.rt_mean_no + ) + + def rsample(self, sample_shape=torch.Size()): # noqa B008 + choices = self.choice_dist.sample(sample_shape=sample_shape) + rt_yes = self.rt_yes_dist.rsample(sample_shape=sample_shape) + rt_no = self.rt_no_dist.rsample(sample_shape=sample_shape) + return torch.where(choices > 0, rt_yes, -rt_no) + + def log_prob(self, signed_rts): + # log p(rt, choice | theta) =log p(rt|choice, theta) + log p(choice | theta) + # p(rt|choice) is our conditional lognormal, p(choice) is bernoulli. + choices = signed_rts > 0 + yes_log_probs = self.rt_yes_dist.log_prob(torch.abs(signed_rts)) + no_log_probs = self.rt_no_dist.log_prob(torch.abs(signed_rts)) + rt_log_probs = torch.where(choices, yes_log_probs, no_log_probs) + + return rt_log_probs + self.choice_dist.log_prob(choices.float()) + + +class LogNormalDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + + # moment match to lognormal (from https://en.wikipedia.org/wiki/Log-normal_distribution) + lognormal_mu_yes = torch.log( + self.rt_mean_yes / torch.sqrt(self.rt_var_yes / self.rt_mean_yes**2 + 1) + ) + lognormal_sigma_yes = torch.sqrt( + torch.log(self.rt_var_yes / self.rt_mean_yes**2 + 1) + ) + lognormal_mu_no = torch.log( + self.rt_mean_no / torch.sqrt(self.rt_var_no / self.rt_mean_no**2 + 1) + ) + lognormal_sigma_no = torch.sqrt( + torch.log(self.rt_var_no / self.rt_mean_no**2 + 1) + ) + + assert (lognormal_sigma_yes > 0.0).all(), lognormal_sigma_yes.min() + assert (lognormal_sigma_no > 0.0).all(), lognormal_sigma_no.min() + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = torch.distributions.LogNormal( + loc=lognormal_mu_yes, scale=lognormal_sigma_yes + ) + self.rt_no_dist = torch.distributions.LogNormal( + loc=lognormal_mu_no, scale=lognormal_sigma_no + ) + + +class ShiftedLogNormalDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + + # moment match to shifted lognormal (from https://jod.pm-research.com/content/21/4/103, + # doi:10.3905/jod.2014.21.4.103, lemma 8 + + B_yes = 0.5 * ( + self.rt_skew_yes.square() + + 2 + - torch.sqrt(self.rt_skew_yes**4 + 4 * self.rt_skew_yes.square()) + ) + + shifted_lognormal_shift_yes = self.rt_mean_yes - ( + self.rt_var_yes.sqrt() / self.rt_skew_yes + ) * (1 + B_yes ** (1 / 3) + B_yes ** (-1 / 3)) + + shifted_lognormal_var_yes = torch.log( + 1 + + self.rt_var_yes / ((self.rt_mean_yes - shifted_lognormal_shift_yes) ** 2) + ) + shifted_lognormal_mean_yes = ( + torch.log(self.rt_mean_yes - shifted_lognormal_shift_yes) + - shifted_lognormal_var_yes**2 / 2 + ) + + B_no = 0.5 * ( + self.rt_skew_no.square() + + 2 + - torch.sqrt(self.rt_skew_no**4 + 4 * self.rt_skew_no.square()) + ) + + shifted_lognormal_shift_no = self.rt_mean_no - ( + self.rt_var_no.sqrt() / self.rt_skew_no + ) * (1 + B_no ** (1 / 3) + B_no ** (-1 / 3)) + + shifted_lognormal_var_no = torch.log( + 1 + self.rt_var_no / ((self.rt_mean_no - shifted_lognormal_shift_no) ** 2) + ) + shifted_lognormal_mean_no = ( + torch.log(self.rt_mean_no - shifted_lognormal_shift_no) + - shifted_lognormal_var_no**2 / 2 + ) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + + shifted_lognormal_shift_yes = shifted_lognormal_shift_yes.clamp( + min=0, max=self.max_shift + ) + shifted_lognormal_shift_no = shifted_lognormal_shift_no.clamp( + min=0, max=self.max_shift + ) + + self.rt_yes_dist = ShiftedLognormal( + shift=shifted_lognormal_shift_yes, + loc=shifted_lognormal_mean_yes, + scale=shifted_lognormal_var_yes.sqrt(), + ) + self.rt_no_dist = ShiftedLognormal( + shift=shifted_lognormal_shift_no, + loc=shifted_lognormal_mean_no, + scale=shifted_lognormal_var_no.sqrt(), + ) + + +class ExGaussianDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + # moment match to exgaussian (from https://en.wikipedia.org/wiki/Exponentially_modified_Gaussian_distribution#Parameter_estimation) + + # exgaussian is restricted to skew <= 2, so we clamp + + clamped_yes_skew = torch.clamp(self.rt_skew_yes, max=torch.tensor(2.0)) + clamped_no_skew = torch.clamp(self.rt_skew_no, max=torch.tensor(2.0)) + tau_yes = self.rt_var_yes.sqrt() * (clamped_yes_skew / 2) ** (1 / 3) + mu_yes = self.rt_mean_yes - tau_yes + var_yes = self.rt_var_yes * (1 - (clamped_yes_skew / 2) ** (2 / 3)) + + tau_no = self.rt_var_no.sqrt() * (clamped_no_skew / 2) ** (1 / 3) + mu_no = self.rt_mean_no - tau_no + var_no = self.rt_var_no * (1 - (clamped_no_skew / 2) ** (2 / 3)) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = ExGaussian(m=mu_yes, s=var_yes.sqrt(), l=1 / tau_yes) + self.rt_no_dist = ExGaussian(m=mu_no, s=var_no.sqrt(), l=1 / tau_no) + + +class ShiftedGammaDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + # doi:10.1016/j.insmatheco.2020.12.002, + # section 4.2. + + a_yes = 4 / self.rt_skew_yes**2 + scale_yes = (self.rt_var_yes / a_yes).sqrt() + shift_yes = self.rt_mean_yes - a_yes * scale_yes + + a_no = 4 / self.rt_skew_no**2 + scale_no = (self.rt_var_no / a_no).sqrt() + shift_no = self.rt_mean_no - a_no * scale_no + + shift_yes = shift_yes.clamp(min=0, max=self.max_shift) + shift_no = shift_no.clamp(min=0, max=self.max_shift) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = ShiftedGamma( + shift=shift_yes, concentration=a_yes, rate=1 / scale_yes + ) + self.rt_no_dist = ShiftedGamma( + shift=shift_no, concentration=a_no, rate=1 / scale_no + ) + + +class ShiftedInverseGammaDDMDistribution(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + # doi:10.1016/j.insmatheco.2020.12.002, + # section 4.3. + + shift_yes = self.rt_mean_yes - self.rt_var_yes.sqrt() / self.rt_skew_yes * ( + 2 + (4 + self.rt_skew_yes.square()).sqrt() + ) + a_yes = 2 + (self.rt_mean_yes - shift_yes).square() / self.rt_var_yes + b_yes = (self.rt_mean_yes - shift_yes) * (a_yes - 1) + + shift_no = self.rt_mean_no - self.rt_var_no.sqrt() / self.rt_skew_no * ( + 2 + (4 + self.rt_skew_no.square()).sqrt() + ) + a_no = 2 + (self.rt_mean_no - shift_no).square() / self.rt_var_no + b_no = (self.rt_mean_no - shift_no) * (a_no - 1) + + shift_yes = shift_yes.clamp(min=0, max=self.max_shift) + shift_no = shift_no.clamp(min=0, max=self.max_shift) + + self.choice_dist = torch.distributions.Bernoulli(probs=self.response_prob) + self.rt_yes_dist = ShiftedInverseGamma( + shift=shift_yes, concentration=a_yes, rate=b_yes + ) + self.rt_no_dist = ShiftedInverseGamma( + shift=shift_no, concentration=a_no, rate=b_no + ) + + +class DDMDistribution(Distribution): + + arg_constraints = { + "z": constraints.Positive(), + "relative_x0": constraints.Interval(0.0, 1.0), + "t0": constraints.Positive(), + } + def __init__(self, a, z, relative_x0, t0, eps=1e-10, validate_args=True): + + self.a, self.z, self.relative_x0, self.t0 = broadcast_all(a, z, relative_x0, t0) + + if ( + isinstance(a, Number) + and isinstance(z, Number) + and isinstance(relative_x0, Number) + and isinstance(t0, Number) + ): + batch_shape = torch.Size() + else: + batch_shape = self.a.size() + + self.eps = eps + super().__init__(batch_shape=batch_shape, validate_args=validate_args) + + def _standardized_WFPT_large_time(self, t, w, nterms): + # large time expansion from navarro & fuss + + piSqOv2 = 4.93480220054 + # use nterms that's max over the batch. This guarantees + # we'll hit our target precision and enable batched + # computation, but will incur extra cost for the extra + # terms if not needed. + k = torch.arange(1, nterms + 1) + k = k.expand(*w.shape, *k.shape) # match batch shape to params + w = w[:, None] # broadcast an extra dim for w we can reduce sum over + + terms = ( + torch.pi + * k + * torch.exp(-(k**2) * t * piSqOv2) + * torch.sin(k * torch.pi * w) + ) + assert terms.shape == (*t.shape[:-1], *self.batch_shape, nterms) + return terms.sum(-1) + + def _standardized_WFPT_small_time(self, t, w, nterms): + # small time expansion navarro & fuss + + fr = math.floor(-(nterms - 1) / 2) + to = math.ceil((nterms - 2) // 2) + k = torch.arange(fr, to + 1) + k = k.expand(*w.shape, *k.shape) + w = w[:, None] # broadcast an extra dim for w we can reduce sum over + + terms = ( + 1 + / torch.sqrt(2 * torch.pi * t**3) + * (w + 2 * k) + * torch.exp(-((w + 2 * k) ** 2) / (2 * t)) + ) + assert terms.shape == (*t.shape[:-1], *self.batch_shape, nterms) + return terms.sum(0) + + def log_prob(self, signed_rt): + """ + Log probability of first passage time of double-threshold wiener process + (aka "pure DDM" of Bogacz et al.). Uses series truncation of Navarro & Fuss 2009 + """ + + shifted_t = signed_rt.abs() - self.t0 # correct for the shift + # normalize time (this also implicitly broadcasts) + normT = shifted_t / (self.relative_x0**2) + + # if t is below NDT, return -inf + t_below_ndt = normT <= 0 + + # by default return hit of lower bound, so if resp is correct flip + # signflip based on choice as needed + driftsign = torch.where(signed_rt > 0, -1, 1) + a = self.a * driftsign + relative_x0 = torch.where(signed_rt > 0, 1 - self.relative_x0, self.relative_x0) + + largeK = torch.ceil( + torch.sqrt( + (-2 * torch.log(torch.pi * normT * self.eps)) / (torch.pi**2 * normT) + ) + ) + smallK = torch.ceil( + 2 + + torch.sqrt( + -2 * normT * torch.log(2 * self.eps * torch.sqrt(2 * torch.pi * normT)) + ) + ) + + # if eps is too big for bound to be valid, adjust + smallK[self.eps > (1 / (2 * torch.sqrt(2 * torch.pi * normT)))] = 2 + bound_invalid = self.eps > (1 / (torch.pi * torch.sqrt(normT))) + largeK[bound_invalid] = torch.ceil( + (1 / (torch.pi * torch.sqrt(normT[bound_invalid]))) + ) + + # pick the smaller of large and small k options, then + # take the max so we can batch properly without needing ragged arrays + nterms = torch.min(largeK, smallK)[torch.logical_not(t_below_ndt)] + if nterms.max() - nterms.min() > 100: + warnings.warn( + "Number of series terms over a batch varies by more than 100, compute costs may be increased", + RuntimeWarning, + stacklevel=2 + ) + + nterms = nterms.max() + + use_large_time = largeK >= smallK + + prob = torch.zeros_like(normT) + prob[t_below_ndt] = -torch.inf + + large_time_approx = self._standardized_WFPT_large_time(normT, relative_x0, nterms) + small_time_approx = self._standardized_WFPT_small_time(normT, relative_x0, nterms) + prob[use_large_time] = large_time_approx[use_large_time.squeeze()] + prob[torch.logical_not(use_large_time)] = small_time_approx[ + torch.logical_not(use_large_time).squeeze() + ] + + boundarySep = 2 * self.z + + # scale from the std case to whatever is our actual + scaler = (1 / relative_x0**2) * torch.exp( + -a * boundarySep * relative_x0 - (a**2 * shifted_t / 2) + ) + + return torch.log(scaler * prob) diff --git a/aepsych/kernels/__init__.py b/aepsych/kernels/__init__.py index 8b2df349c..6c59dc1c7 100644 --- a/aepsych/kernels/__init__.py +++ b/aepsych/kernels/__init__.py @@ -4,3 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from .pairwisekernel import PairwiseKernel +from .rbf_partial_grad import RBFKernelPartialObsGrad + +__all__ = ["PairwiseKernel", "RBFKernelPartialObsGrad"] diff --git a/aepsych/kernels/pairwisekernel.py b/aepsych/kernels/pairwisekernel.py new file mode 100644 index 000000000..a76b8c75c --- /dev/null +++ b/aepsych/kernels/pairwisekernel.py @@ -0,0 +1,85 @@ +import torch +from gpytorch.kernels import Kernel +from gpytorch.lazy import lazify + + +class PairwiseKernel(Kernel): + """ + Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling + functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K). + + Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K') + where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d). + + """ + + def __init__(self, latent_kernel, is_partial_obs=False, **kwargs): + super(PairwiseKernel, self).__init__(**kwargs) + + self.latent_kernel = latent_kernel + self.is_partial_obs = is_partial_obs + + def forward(self, x1, x2, diag=False, **params): + r""" + TODO: make last_batch_dim work properly + + d must be 2*k for integer k, k is the dimension of the latent space + Args: + :attr:`x1` (Tensor `n x d` or `b x n x d`): + First set of data + :attr:`x2` (Tensor `m x d` or `b x m x d`): + Second set of data + :attr:`diag` (bool): + Should the Kernel compute the whole kernel, or just the diag? + + Returns: + :class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`. + The exact size depends on the kernel's evaluation mode: + + * `full_covar`: `n x m` or `b x n x m` + * `diag`: `n` or `b x n` + """ + if self.is_partial_obs: + d = x1.shape[-1] - 1 + assert d == x2.shape[-1] - 1, "tensors not the same dimension" + assert d % 2 == 0, "dimension must be even" + + k = int(d / 2) + + # special handling for kernels that (also) do funky + # things with the input dimension + deriv_idx_1 = x1[..., -1][:, None] + deriv_idx_2 = x2[..., -1][:, None] + + a = torch.cat((x1[..., :k], deriv_idx_1), dim=1) + b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1) + c = torch.cat((x2[..., :k], deriv_idx_2), dim=1) + d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1) + + else: + d = x1.shape[-1] + + assert d == x2.shape[-1], "tensors not the same dimension" + assert d % 2 == 0, "dimension must be even" + + k = int(d / 2) + + a = x1[..., :k] + b = x1[..., k:] + c = x2[..., :k] + d = x2[..., k:] + + if not diag: + return ( + lazify(self.latent_kernel(a, c, diag=diag, **params)) + + lazify(self.latent_kernel(b, d, diag=diag, **params)) + - lazify(self.latent_kernel(b, c, diag=diag, **params)) + - lazify(self.latent_kernel(a, d, diag=diag, **params)) + ) + else: + return ( + self.latent_kernel(a, c, diag=diag, **params) + + self.latent_kernel(b, d, diag=diag, **params) + - self.latent_kernel(b, c, diag=diag, **params) + - self.latent_kernel(a, d, diag=diag, **params) + ) diff --git a/aepsych/likelihoods/__init__.py b/aepsych/likelihoods/__init__.py index dfe839d4a..824d78d35 100644 --- a/aepsych/likelihoods/__init__.py +++ b/aepsych/likelihoods/__init__.py @@ -9,13 +9,17 @@ from ..config import Config from .bernoulli import BernoulliObjectiveLikelihood +from .ddm import DDMLikelihood, LapseRateRTLikelihood from .ordinal import OrdinalLikelihood from .semi_p import LinearBernoulliLikelihood + __all__ = [ "BernoulliObjectiveLikelihood", "OrdinalLikelihood", "LinearBernoulliLikelihood", + "DDMLikelihood", + "LapseRateRTLikelihood" ] Config.register_module(sys.modules[__name__]) diff --git a/aepsych/likelihoods/ddm.py b/aepsych/likelihoods/ddm.py new file mode 100644 index 000000000..851142037 --- /dev/null +++ b/aepsych/likelihoods/ddm.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import gpytorch + +import torch +from gpytorch.likelihoods import _OneDimensionalLikelihood + +from aepsych.distributions import RTDistWithUniformLapseRate + +class DDMLikelihood(_OneDimensionalLikelihood): + """ """ + + def __init__(self, distribution, max_shift = None, restrict_skew = False): + super().__init__() + self.distribution = distribution + self.register_parameter( + name="raw_relative_x0", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint("raw_relative_x0", gpytorch.constraints.Interval(0, 1)) + + self.register_parameter( + name="raw_t0", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint("raw_t0", gpytorch.constraints.Interval(0., 1.0)) + + self.register_parameter( + name="raw_threshold", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint("raw_threshold", gpytorch.constraints.Positive()) + + self.max_shift = max_shift + self.restrict_skew = restrict_skew + + def _set_relative_x0(self, value): + value = self.raw_relative_x0_constraint.inverse_transform(value) + self.initialize(raw_relative_x0=value) + + def _set_threshold(self, value): + value = self.raw_threshold_constraint.inverse_transform(value) + self.initialize(raw_threshold=value) + + def _set_t0(self, value): + value = self.raw_t0_constraint.inverse_transform(value) + self.initialize(raw_t0=value) + + @property + def relative_x0(self): + return self.raw_relative_x0_constraint.transform(self.raw_relative_x0) + + @relative_x0.setter + def relative_x0(self, value): + self._set_relative_x0(value) + + @property + def x0(self): + return self.threshold * (2*self.relative_x0 - 1) + + @property + def t0(self): + return self.raw_t0_constraint.transform(self.raw_t0) + + @t0.setter + def t0(self, value): + self._set_t0(value) + + @property + def threshold(self): + return self.raw_threshold_constraint.transform(self.raw_threshold) + + @threshold.setter + def threshold(self, value): + self._set_threshold(value) + + def forward(self, function_samples, *params, **kwargs): + return self.distribution( + drift=function_samples, threshold=self.threshold, relative_x0=self.relative_x0, t0=self.t0, max_shift = self.max_shift, restrict_skew = self.restrict_skew + ) + + @classmethod + def from_config(cls, config): + classname = cls.__name__ + max_shift = config.getfloat(classname, "max_shift", fallback=None) + restrict_skew = config.getboolean(classname, "restrict_skew", fallback=None) + + distribution = config.getobj(classname, "distribution") + + + return cls(distribution=distribution, max_shift=max_shift, restrict_skew=restrict_skew) + + # def log_marginal(self, observations, function_dist, *args, **kwargs): + # """ + # here we need the expectation of logp(r,c|f) w.r.t f + # p(r, c|f) = p(r|c,f)p(c|f), so we can factorize + # the log marginal as E_f log p(r|c,f) + E_f log p(c|f). + # and to the integrals separately + # """ + # choices = observations > 0 + # # rt_log_probs = torch.where(choices, yes_log_probs, no_log_probs) + + # def choice_prob_sampler(function_samples): + # ddmdist = self.forward(function_samples) + # return ddmdist.choice_dist.log_prob(choices.float()).exp() + + # choice_marginal = self.quadrature(choice_prob_sampler, function_dist) + + # def rt_prob_sampler(function_samples): + # ddmdist = self.forward(function_samples) + # yes_probs = ddmdist.rt_yes_dist.log_prob(torch.abs(observations)).exp() + # no_probs = ddmdist.rt_no_dist.log_prob(torch.abs(observations)).exp() + # return torch.where(choices, yes_probs, no_probs) + + # rt_marginal = self.quadrature(rt_prob_sampler, function_dist) + + # return choice_marginal.log() + rt_marginal.log() + + +class LapseRateRTLikelihood(_OneDimensionalLikelihood): + def __init__(self, base_likelihood, max_rt=10.0): + super().__init__() + self.max_rt = max_rt + self.base_likelihood = base_likelihood + self.register_parameter( + name="raw_lapse_rate", parameter=torch.nn.Parameter(torch.randn(1)) + ) + self.register_constraint( + "raw_lapse_rate", gpytorch.constraints.Interval(1e-5, 0.2) + ) # any greater than that and the model is really bad anyway + + @property + def lapse_rate(self): + return self.raw_lapse_rate_constraint.transform(self.raw_lapse_rate) + + def forward(self, function_samples, *args, **kwargs): + base_dist = self.base_likelihood(function_samples, *args, **kwargs) + return RTDistWithUniformLapseRate( + lapse_rate=self.lapse_rate, base_dist=base_dist, max_rt=self.max_rt + ) + + @classmethod + def from_config(cls, config): + classname = cls.__name__ + max_rt = config.getfloat(classname, "max_rt", fallback=10.) + + base_lik_class = config.getobj(classname, "base_likelihood") + + base_lik = base_lik_class.from_config(config) + return cls(base_likelihood = base_lik, max_rt = max_rt) diff --git a/tests/test_ddm_distr.py b/tests/test_ddm_distr.py new file mode 100644 index 000000000..a2e27ea22 --- /dev/null +++ b/tests/test_ddm_distr.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +import unittest + +import numpy as np + +import torch +from functools import partial +from torch.func import grad +from torch import vmap +from numbers import Number + +from aepsych.distributions import DDMMomentMatchDistribution + + +class TestDDMDistr(DDMMomentMatchDistribution): + def _make_moment_matched_likelihood(self): + pass + +global_atol = 1e-3 + +def ddm_mgf(alpha, drift, x0, threshold, response=1): + """ + Moment-generating function of the Wiener First Passage time distribution (DDM) + """ + if response == 0: + drift = -drift.clone() + threshold = -threshold.clone() + return torch.exp(drift * (threshold - x0)) * ( + torch.sinh((threshold + x0) * torch.sqrt(drift**2 - 2 * alpha)) + / torch.sinh(2 * threshold * torch.sqrt(drift**2 - 2 * alpha)) + ) + + +def ddm_cgf(alpha, drift, x0, threshold, response=1): + """ + Cumulant-generating function of the Wiener First Passage time distribution (DDM) + """ + + return torch.log(ddm_mgf(alpha, drift, x0, threshold, response=response)) + + +def ddm_moment_cumulant(n, drift, x0, threshold, fun="cumulant", response=1): + """ + Function to generate arbitrary moments or cumulants of DDM by autodiff, + vectorized over drift (but not other arguments currently, TODO). + """ + assert fun in ("moment", "cumulant") + if isinstance(drift, Number): + drift = torch.Tensor([drift]) + if fun == "moment": + deriv_fun = ddm_mgf + elif fun == "cumulant": + deriv_fun = ddm_cgf + else: + raise RuntimeError(f"fun should be moment or cumulant, got {fun}") + for _ in range(n): + deriv_fun = grad(deriv_fun) + moment_fun = partial(deriv_fun, torch.tensor(0.0), response=response) + + moment_fun_vmap = vmap(moment_fun, in_dims=(0, None, None)) + + return moment_fun_vmap(drift, x0, threshold) + + + +class DDMMomemtnMatchTest(unittest.TestCase): + def setUp(self): + np.random.seed(1) + torch.manual_seed(1) + self.f = torch.randn(100) + # things are numerically unstable as drift -> 0. + # in the momentmatch expressions we use limiting expressions + # if drift is too small but we don't have that in moment/cumulant + # so exclude from tests. TODO: can probably improve numerical stability. + self.f = self.f + torch.sign(self.f) * 0.05 + self.relative_x0 = torch.tensor(0.1) + self.t0 = torch.tensor(0.15) + self.rt_dist = TestDDMDistr( + drift=self.f, + threshold=torch.tensor(0.5), + relative_x0=self.relative_x0, + t0=self.t0, + ) + self.x0 = torch.tensor(0.5 * (2 * self.relative_x0- 1)) + + def test_mean(self): + # sanity check mean + expected_yes_mean = ddm_moment_cumulant( + n=1, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) + expected_no_mean = ddm_moment_cumulant( + n=1, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) + self.assertTrue( + torch.allclose( + self.rt_dist.rt_mean_yes, + expected_yes_mean + self.t0, + atol=global_atol, + ) + ) + self.assertTrue( + torch.allclose( + self.rt_dist.rt_mean_no, + expected_no_mean + self.t0, + atol=global_atol, + ) + ) + + def test_var(self): + # sanity check var + expected_yes_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) + expected_no_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) + self.assertTrue( + torch.allclose(self.rt_dist.rt_var_yes, expected_yes_var, atol=global_atol) + ) + self.assertTrue( + torch.allclose(self.rt_dist.rt_var_no, expected_no_var, atol=global_atol) + ) + + def test_skew(self): + # sanity check skew + expected_yes_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) + expected_no_var = ddm_moment_cumulant( + n=2, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) + expected_yes_skew = ddm_moment_cumulant( + n=3, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=1, + ) / (expected_yes_var ** (3 / 2)) + expected_no_skew = ddm_moment_cumulant( + n=3, + drift=self.rt_dist.drift, + x0=self.x0, + threshold=self.rt_dist.threshold, + response=0, + ) / (expected_no_var ** (3 / 2)) + self.assertTrue( + torch.allclose( + self.rt_dist.rt_skew_yes, expected_yes_skew, atol=global_atol + ) + ) + self.assertTrue( + torch.allclose(self.rt_dist.rt_skew_no, expected_no_skew, atol=global_atol) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pairwise_kernel.py b/tests/test_pairwise_kernel.py new file mode 100644 index 000000000..fbb408d89 --- /dev/null +++ b/tests/test_pairwise_kernel.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +import unittest + +import numpy as np +import numpy.testing as npt +import torch +from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad +from aepsych.kernels.pairwisekernel import PairwiseKernel +from gpytorch.kernels import RBFKernel + + +class PairwiseKernelTest(unittest.TestCase): + """ + Basic tests that PairwiseKernel is working + """ + + def setUp(self): + self.latent_kernel = RBFKernel() + self.kernel = PairwiseKernel(self.latent_kernel) + + def test_kernelgrad_pairwise(self): + kernel = PairwiseKernel(RBFKernelPartialObsGrad(), is_partial_obs=True) + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + + x1 = torch.cat((x1, torch.zeros(2, 1)), dim=1) + x2 = torch.cat((x2, torch.zeros(2, 1)), dim=1) + + deriv_idx_1 = x1[..., -1][:, None] + deriv_idx_2 = x2[..., -1][:, None] + + a = torch.cat((x1[..., :2], deriv_idx_1), dim=1) + b = torch.cat((x1[..., 2:-1], deriv_idx_1), dim=1) + c = torch.cat((x2[..., :2], deriv_idx_2), dim=1) + d = torch.cat((x2[..., 2:-1], deriv_idx_2), dim=1) + + c12 = kernel.forward(x1, x2).evaluate().detach().numpy() + pwc = ( + ( + kernel.latent_kernel.forward(a, c) + - kernel.latent_kernel.forward(a, d) + - kernel.latent_kernel.forward(b, c) + + kernel.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c12, pwc, atol=1e-6) + + def test_dim_check(self): + """ + Test that we get expected errors. + """ + x1 = torch.zeros(torch.Size([3])) + x2 = torch.zeros(torch.Size([3])) + x3 = torch.zeros(torch.Size([2])) + x4 = torch.zeros(torch.Size([4])) + + self.assertRaises(AssertionError, self.kernel.forward, x1=x1, x2=x2) + + self.assertRaises(AssertionError, self.kernel.forward, x1=x3, x2=x4) + + def test_covar(self): + """ + Test that we get expected covariances + """ + np.random.seed(1) + torch.manual_seed(1) + + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + a = x1[..., :2] + b = x1[..., 2:] + c = x2[..., :2] + d = x2[..., 2:] + c12 = self.kernel.forward(x1, x2).evaluate().detach().numpy() + pwc = ( + ( + self.latent_kernel.forward(a, c) + - self.latent_kernel.forward(a, d) + - self.latent_kernel.forward(b, c) + + self.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c12, pwc, atol=1e-6) + + shape = np.array(c12.shape) + npt.assert_equal(shape, np.array([2, 2])) + + x3 = torch.rand(torch.Size([3, 4])) + x4 = torch.rand(torch.Size([6, 4])) + a = x3[..., :2] + b = x3[..., 2:] + c = x4[..., :2] + d = x4[..., 2:] + c34 = self.kernel.forward(x3, x4).evaluate().detach().numpy() + pwc = ( + ( + self.latent_kernel.forward(a, c) + - self.latent_kernel.forward(a, d) + - self.latent_kernel.forward(b, c) + + self.latent_kernel.forward(b, d) + ) + .detach() + .numpy() + ) + npt.assert_allclose(c34, pwc, atol=1e-6) + + shape = np.array(c34.shape) + npt.assert_equal(shape, np.array([3, 6])) + + def test_latent_diag(self): + """ + g(a, a) = 0 for all a, so K((a, a), (a, a)) = 0 + """ + + np.random.seed(1) + torch.manual_seed(1) + a = torch.rand(torch.Size([2, 2])) + + # should get 0 variance on pairs (a,a) + diag = torch.cat((a, a), dim=1) + diagv = self.kernel.forward(diag, diag).evaluate().detach().numpy() + npt.assert_allclose(diagv, 0.0) + + def test_diag(self): + """ + make sure the diagonal is the right shape + """ + np.random.seed(1) + torch.manual_seed(1) + + x1 = torch.rand(torch.Size([2, 2, 4])) + x2 = torch.rand(torch.Size([2, 2, 4])) + + diag = self.kernel(x1, x2, diag=True) + shape = np.array(diag.shape) + npt.assert_equal(shape, np.array([2, 2])) + + x1 = torch.rand(torch.Size([2, 4])) + x2 = torch.rand(torch.Size([2, 4])) + + diag = self.kernel(x1, x2, diag=True) + shape = np.array(diag.shape) + npt.assert_equal(shape, np.array([2])) + + +if __name__ == "__main__": + unittest.main()