Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
473cda4
Implemented XP and XP_utils
Jul 14, 2025
b2f093a
Implemented IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA
Jul 14, 2025
03269bf
Started developing XGetAndSetPrecessionVariables
Jul 28, 2025
de4781e
Implementing IMRPhenomX_Initialize_MSA_System
Jul 28, 2025
b2a3f41
Implemented IMRPhenomX_Initialize_MSA_System
Aug 12, 2025
a1a7528
Implemented IMRPhenomX_Get_PN_beta, sigma, tau, IMRPhenomX_psiofv
Aug 12, 2025
9d390fc
Merged XGetAndSetPrecessionVariables development
Aug 18, 2025
b779b70
Implement XGetAndSetPrecessionVariables and XPCheckMaxOpeningAngle fu…
Sep 29, 2025
054357f
Implement X_SetPrecessingRemnantParams
Oct 6, 2025
5222eb9
Remove functions in XP_utils that were copied from Pv2_utils
Oct 7, 2025
1034c13
Remove unnecessary notes from comments
Oct 9, 2025
6297db7
Fix bugs in IMRPhenomXP_utils while testing IMRPhenomXGetAndSetPreces…
Nov 4, 2025
f827d72
Implement XFinalSpin2017, XSTotR, evaluate_QNMfit_fring22, evaluate_Q…
Nov 4, 2025
b12e55d
Include spherical harmonics implementation from https://github.com/na…
Nov 19, 2025
aa7efb1
Implement X_Return_SNorm_MSA, Get_alphaepsilon_atfref, XLALSimIMRPhen…
Nov 19, 2025
0d854c3
Fix bugs in IMRPhenomXP_utils while testing IMRPhenomXGetAndSetPreces…
Nov 19, 2025
1ebca9f
Fix IMRPhenomX_Return_Roots_MSA to correctly handle vector inputs and…
Nov 27, 2025
d654bab
Implement WignerdCoefficients and XLALSimIMRPhenomXL2PNNS
Nov 27, 2025
abe827a
Fix Return_phi_zeta_costhetaL_MSA, Return_MSA_Corrections_MSA, Return…
Nov 28, 2025
95b228f
Define top level functions to generate IMRPhenomXP waveform
Dec 9, 2025
3d514d2
Add elliptic integrals
Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions src/ripplegw/gsl_ellint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
Elliptic integral utilities implemented with JAX.
Some functions are based on https://github.com/tagordon/ellip/tree/main
"""

from __future__ import annotations

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


# relative error will be "less in magnitude than r"
R = 1.0e-15

# GSL constants
GSL_DBL_EPSILON = 2.2204460492503131e-16


@jax.jit
@jnp.vectorize
def rf(x, y, z):
r"""JAX implementation of Carlson's :math:`R_\mathrm{F}`

Computed using the algorithm in Carlson, 1994: https://arxiv.org/pdf/math/9409227.pdf
Code taken from https://github.com/tagordon/ellip/tree/main

Args:
x: arraylike, real valued.
y: arraylike, real valued.
z: arraylike, real valued.

Returns:
The value of the integral :math:`R_\mathrm{F}`

Notes:
``rf`` does not support complex-valued inputs.
``rf`` requires `jax.config.update("jax_enable_x64", True)`
"""

xyz = jnp.array([x, y, z])
a0 = jnp.sum(xyz) / 3.0
v = jnp.max(jnp.abs(a0 - xyz))
q = (3 * R) ** (-1 / 6) * v

# cond = lambda s: s["f"] * q > jnp.abs(s["An"])
def cond(s):
return s["f"] * q > jnp.abs(s["An"])

def body(s):

xyz = s["xyz"]
lam = jnp.sqrt(xyz[0] * xyz[1]) + jnp.sqrt(xyz[0] * xyz[2]) + jnp.sqrt(xyz[1] * xyz[2])

s["An"] = 0.25 * (s["An"] + lam)
s["xyz"] = 0.25 * (s["xyz"] + lam)
s["f"] = s["f"] * 0.25

return s

s = {"f": 1, "An": a0, "xyz": xyz}
s = jax.lax.while_loop(cond, body, s)

x = (a0 - x) / s["An"] * s["f"]
y = (a0 - y) / s["An"] * s["f"]
z = -(x + y)
e2 = x * y - z * z
e3 = x * y * z

return (1 - 0.1 * e2 + e3 / 14 + e2 * e2 / 24 - 3 * e2 * e3 / 44) / jnp.sqrt(s["An"])


@jax.jit
@jnp.vectorize
def ellipfinc(phi, k):
r"""JAX implementation of the incomplete elliptic integral of the first kind
Code taken from https://github.com/tagordon/ellip/tree/main

.. math::

\[F\left(\phi,k\right)=\int_{0}^{\phi}\frac{\,\mathrm{d}\theta}{\sqrt{1-k^{2}{%
\sin}^{2}\theta}}]

Args:
phi: arraylike, real valued.
k: arraylike, real valued.

Returns:
The value of the complete elliptic integral of the first kind, :math:`F(\phi, k)`

Notes:
``ellipfinc`` does not support complex-valued inputs.
``ellipfinc`` requires `jax.config.update("jax_enable_x64", True)`
"""

c = 1.0 / jnp.sin(phi) ** 2
return rf(c - 1, c - k**2, c)


@jax.jit
@jnp.vectorize
def gsl_sf_elljac_e(u, m): # double * sn, double * cn, double * dn
"""
JAX implementation of the Jacobi elliptic functions sn(u|m), cn(u|m), dn(u|m)
Based on https://github.com/ampl/gsl/blob/master/specfunc/elljac.c
"""

def little_m_branch(u, m):
_ = m
sn = jnp.sin(u)
cn = jnp.cos(u)
dn = 1.0
return sn, cn, dn

def little_m_min_one_branch(u, m):
_ = m
sn = jnp.tanh(u)
cn = 1.0 / jnp.cosh(u)
dn = cn
return sn, cn, dn

def main_branch(u, m):
_n = 16
mu = jnp.zeros(_n, dtype=jnp.float64).at[0].set(1.0)
nu = jnp.zeros(_n, dtype=jnp.float64).at[0].set(jnp.sqrt(jnp.clip(1.0 - m, 0.0, jnp.inf)))

def cond(state):
n, mu, nu = state
diff = jnp.abs(mu[n] - nu[n])
tol = 4.0 * GSL_DBL_EPSILON * jnp.abs(mu[n] + nu[n])
return (diff > tol) & (n < _n - 1)

def body(state):
n, mu, nu = state
mu = mu.at[n + 1].set(0.5 * (mu[n] + nu[n]))
nu = nu.at[n + 1].set(jnp.sqrt(mu[n] * nu[n]))
return n + 1, mu, nu

n, mu, nu = jax.lax.while_loop(cond, body, (0, mu, nu))

sin_umu = jnp.sin(u * mu[n])
cos_umu = jnp.cos(u * mu[n])

def cos_branch(args):
sin_umu, cos_umu, n, mu, nu = args
c = jnp.zeros_like(mu).at[n].set(mu[n] * (sin_umu / cos_umu))
d = jnp.zeros_like(mu).at[n].set(1.0)

def cond(kcd):
k, _, _ = kcd
return k > 0

def body(kcd):
k, c, d = kcd
k = k - 1
r = (c[k + 1] * c[k + 1]) / mu[k + 1]
c = c.at[k].set(d[k + 1] * c[k + 1])
d = d.at[k].set((r + nu[k]) / (r + mu[k]))
return k, c, d

_, c, d = jax.lax.while_loop(cond, body, (n, c, d))
dn = jnp.sqrt(jnp.clip(1.0 - m, 0.0, jnp.inf)) / d[0]
cn = dn * jnp.sign(cos_umu) / jnp.hypot(1.0, c[0])
sn = cn * c[0] / jnp.sqrt(jnp.clip(1.0 - m, 0.0, jnp.inf))
return sn, cn, dn

def sin_branch(args):
sin_umu, cos_umu, n, mu, nu = args
c = jnp.zeros_like(mu).at[n].set(mu[n] * (cos_umu / sin_umu))
d = jnp.zeros_like(mu).at[n].set(1.0)

def cond(kcd):
k, _, _ = kcd
return k > 0

def body(kcd):
k, c, d = kcd
k = k - 1
r = (c[k + 1] * c[k + 1]) / mu[k + 1]
c = c.at[k].set(d[k + 1] * c[k + 1])
d = d.at[k].set((r + nu[k]) / (r + mu[k]))
return k, c, d

_, c, d = jax.lax.while_loop(cond, body, (n, c, d))
dn = d[0]
sn = jnp.sign(sin_umu) / jnp.hypot(1.0, c[0])
cn = c[0] * sn
return sn, cn, dn

return jax.lax.cond(
jnp.abs(sin_umu) < jnp.abs(cos_umu), cos_branch, sin_branch, operand=(sin_umu, cos_umu, n, mu, nu)
)

sn, cn, dn = jax.lax.cond(
jnp.abs(m) < 2.0 * GSL_DBL_EPSILON,
lambda args: little_m_branch(*args),
lambda args: jax.lax.cond(
jnp.abs(args[1] - 1.0) < 2.0 * GSL_DBL_EPSILON,
lambda inner_args: little_m_min_one_branch(*inner_args),
lambda inner_args: main_branch(*inner_args),
args,
),
(u, m),
)

return sn, cn, dn
159 changes: 159 additions & 0 deletions src/ripplegw/waveforms/IMRPhenomXP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import jax
import jax.numpy as jnp
from ripple import Mc_eta_to_ms

from ..constants import gt, MSUN
import numpy as np
from .IMRPhenomXAS import Phase as PhDPhase
from .IMRPhenomXAS import Amp as PhDAmp
from .IMRPhenomXAS import gen_IMRPhenomXAS
from .IMRPhenomX_utils import PhenomX_amp_coeff_table, PhenomX_phase_coeff_table

from ..typing import Array
from .IMRPhenomXP_utils import *
from .IMRPhenomX_utils import *


def PhenomXPCoreTwistUp22(
Mf, ## Frequency in geometric units (on LAL says Hz?)
hAS, ## Underlying aligned-spin IMRPhenomXAS strain
pWF, ## IMRPhenomX Waveform Struct (TODO)
pPrec ## IMRPhenomXP Precession Struct (TODO)
):

omega = jnp.pi * Mf
logomega = jnp.log(omega)
omega_cbrt = (omega) ** (1 / 3)
omega_cbrt2 = omega_cbrt * omega_cbrt

v = omega_cbrt

vangles = jnp.array([0,0,0])

## Euler Angles from Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967
vangles = IMRPhenomX_Return_phi_zeta_costhetaL_MSA(v,pWF,pPrec)
alpha = vangles[0] - pPrec["alpha_offset"]
epsilon = vangles[1] - pPrec["epsilon_offset"]
cos_beta = vangles[2]


# print("alpha, epsilon: ", alpha, epsilon)
cBetah, sBetah = WignerdCoefficients_cosbeta(cos_beta)

cBetah2 = cBetah * cBetah
cBetah3 = cBetah2 * cBetah
cBetah4 = cBetah3 * cBetah
sBetah2 = sBetah * sBetah
sBetah3 = sBetah2 * sBetah
sBetah4 = sBetah3 * sBetah

# d2 = jnp.array(
# [
# sBetah4,
# 2 * cBetah * sBetah3,
# jnp.sqrt(6) * sBetah2 * cBetah2,
# 2 * cBetah3 * sBetah,
# cBetah4,
# ]
# ) ## LR in PhenomP.py we don't compute this, but in X.c and P.c yes
## same for dm2

# Y2m are the spherical harmonics with s=-2, l=2, m=-2,-1,0,1,2
Y2mA = jnp.array([pPrec['Y2m2'],pPrec['Y2m1'],pPrec['Y20'],pPrec['Y21'],pPrec['Y22']]) # need to pass Y2m in a 5-component list
hp_sum = 0
hc_sum = 0

cexp_i_alpha = jnp.exp(1j * alpha)
cexp_2i_alpha = cexp_i_alpha * cexp_i_alpha
cexp_mi_alpha = 1.0 / cexp_i_alpha
cexp_m2i_alpha = cexp_mi_alpha * cexp_mi_alpha
A2m2emm = (
cexp_2i_alpha * cBetah4 * Y2mA[0]
- cexp_i_alpha * 2 * cBetah3 * sBetah * Y2mA[1]
+ 1 * jnp.sqrt(6) * sBetah2 * cBetah2 * Y2mA[2]
- cexp_mi_alpha * 2 * cBetah * sBetah3 * Y2mA[3]
+ cexp_m2i_alpha * sBetah4 * Y2mA[4]
)
A22emmstar = (
cexp_m2i_alpha * sBetah4 * jnp.conjugate(Y2mA[0])
+ cexp_mi_alpha * 2 * cBetah * sBetah3 * jnp.conjugate(Y2mA[1])
+ 1 * jnp.sqrt(6) * sBetah2 * cBetah2 * jnp.conjugate(Y2mA[2])
+ cexp_i_alpha * 2 * cBetah3 * sBetah * jnp.conjugate(Y2mA[3])
+ cexp_2i_alpha * cBetah4 * jnp.conjugate(Y2mA[4])
)
hp_sum = A2m2emm + A22emmstar * pPrec['PolarizationSymmetry']
hc_sum = 1j * (A2m2emm - A22emmstar * pPrec['PolarizationSymmetry'])
eps_phase_hP = jnp.exp(-2j * epsilon) * hAS / 2.0

hp = eps_phase_hP * hp_sum
hc = eps_phase_hP * hc_sum

return hp, hc


def _gen_IMRPhenomXP_hphc(f: Array,
params: Array,
prec_params,
f_ref: float):
"""
The following function generates an IMRPhenomXP frequency-domain waveform.
It is the translation of IMRPhenomXPGenerateFD from LAL.
Calls gen_IMRPhenomXAS to generate the coprecessing waveform (In LAL the coprecessing waveform is generated within IMRPhenomXPGenerateFD)
Returns:
--------
hp (array): Strain of the plus polarization
hc (array): Strain of the cross polarization
"""

Mc, _ = ms_to_Mc_eta([params['m1'], params['m2']])
iota = params['inclination']
params_list = [Mc, params['eta'], params['chi1_norm'], params['chi2_norm'], params['D'], params['tc'], params['phi0']]
## line 2372 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html
hcoprec = gen_IMRPhenomXAS(f, params_list, f_ref)

## line 2403 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html
hp, hc = PhenomXPCoreTwistUp22(f, hcoprec, params, prec_params)

## rotate waveform by 2 zeta.
# line 2469 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html
zeta = prec_params['zeta_polarization']
cond = jnp.abs(zeta) > 0.0

def no_rotation(args):
hp, hc, _ = args
return hp, hc

def do_rotation(args):
hp, hc, z = args
angle = 2.0 * z

cosPol = jnp.cos(angle)
sinPol = jnp.sin(angle)

new_hp = cosPol * hp + sinPol * hc
new_hc = cosPol * hc - sinPol * hp

return new_hp, new_hc

hp, hc = jax.lax.cond(
cond,
do_rotation,
no_rotation,
operand=(hp, hc, zeta)
)

return hp, hc

def gen_IMRPhenomXP_hphc(f: Array,
params: Array,
f_ref: float):

params_aux = {}

## line 1192 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html
prec_params = IMRPhenomXGetAndSetPrecessionVariables(params, params_aux)

## line 1213
hp, hc = _gen_IMRPhenomXP_hphc(f, params, prec_params, f_ref)

return hp, hc
Loading