From 473cda4dcbc9486c6453ce2af035cf023f6785ec Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Mon, 14 Jul 2025 12:18:33 +0200 Subject: [PATCH 01/20] Implemented XP and XP_utils --- src/ripplegw/waveforms/IMRPhenomXP.py | 126 +++ src/ripplegw/waveforms/IMRPhenomXP_utils.py | 800 ++++++++++++++++++++ 2 files changed, 926 insertions(+) create mode 100644 src/ripplegw/waveforms/IMRPhenomXP.py create mode 100644 src/ripplegw/waveforms/IMRPhenomXP_utils.py diff --git a/src/ripplegw/waveforms/IMRPhenomXP.py b/src/ripplegw/waveforms/IMRPhenomXP.py new file mode 100644 index 00000000..2d4abdf5 --- /dev/null +++ b/src/ripplegw/waveforms/IMRPhenomXP.py @@ -0,0 +1,126 @@ +import jax +import jax.numpy as jnp +from ripple import Mc_eta_to_ms + +from ..constants import gt, MSUN +import numpy as np +from .IMRPhenomXAS import Phase as PhDPhase +from .IMRPhenomXAS import Amp as PhDAmp +from .IMRPhenomXAS import gen_IMRPhenomXAS +from .IMRPhenomX_utils import PhenomX_amp_coeff_table, PhenomX_phase_coeff_table + +from ..typing import Array +from .IMRPhenomXP_utils import * +from .IMRPhenomX_utils import * + + +def PhenomXPCoreTwistUp22( + Mf, ## Frequency in geometric units (on LAL says Hz?) + hAS, ## Underlying aligned-spin IMRPhenomXAS strain + pWF, ## IMRPhenomX Waveform Struct (TODO) + pPrec ## IMRPhenomXP Precession Struct (TODO) ## in py is a dict +): + +# check if hp,hc not None not present in PhenomP + + # here it is used to be LAL_MTSUN_SI + # f = fHz * gt * M # Frequency in geometric units + # q = (1.0 + jnp.sqrt(1.0 - 4.0 * eta) - 2.0 * eta) / (2.0 * eta) ## not needed + # m1 = 1.0 / (1.0 + q) # Mass of the smaller BH for unit total mass M=1. + # m2 = q / (1.0 + q) # Mass of the larger BH for unit total mass M=1. + #Sperp = chip * ( + # m2 * m2 + #) # Dimensionfull spin component in the orbital plane. S_perp = S_2_perp ## already in pPrec + # chi_eff = m1 * chi1_l + m2 * chi2_l # effective spin for M=1 + + # SL = chi1_l * m1 * m1 + chi2_l * m2 * m2 # Dimensionfull aligned spin. ## already in pPrec + + omega = jnp.pi * Mf + logomega = jnp.log(omega) + omega_cbrt = (omega) ** (1 / 3) + omega_cbrt2 = omega_cbrt * omega_cbrt + + v = omega_cbrt + + vangles = jnp.array([0,0,0]) ## is it okay to use jnp array?? (instead of a struct) it is also defined inside the function... + + ## Euler Angles from Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + vangles = IMRPhenomX_Return_phi_zeta_costhetaL_MSA(v,pWF,pPrec) ## DONE ## has to output a jnp array + alpha = vangles[0] - pPrec["alpha_offset"] + epsilon = vangles[1] - pPrec["epsilon_offset"] + cos_beta = vangles[2] + + + # print("alpha, epsilon: ", alpha, epsilon) + cBetah, sBetah = WignerdCoefficients_cosbeta(cos_beta) ## DONE + + cBetah2 = cBetah * cBetah + cBetah3 = cBetah2 * cBetah + cBetah4 = cBetah3 * cBetah + sBetah2 = sBetah * sBetah + sBetah3 = sBetah2 * sBetah + sBetah4 = sBetah3 * sBetah + + # d2 = jnp.array( + # [ + # sBetah4, + # 2 * cBetah * sBetah3, + # jnp.sqrt(6) * sBetah2 * cBetah2, + # 2 * cBetah3 * sBetah, + # cBetah4, + # ] + # ) ## LR in PhenomP.py we don't compute this, but in X.c and P.c yes + ## same for dm2 + + # Y2m are the spherical harmonics with s=-2, l=2, m=-2,-1,0,1,2 + Y2mA = jnp.array([pPrec['Y2m2'],pPrec['Y2m1'],pPrec['Y20'],pPrec['Y21'],pPrec['Y22']]) # need to pass Y2m in a 5-component list + hp_sum = 0 + hc_sum = 0 + + cexp_i_alpha = jnp.exp(1j * alpha) + cexp_2i_alpha = cexp_i_alpha * cexp_i_alpha + cexp_mi_alpha = 1.0 / cexp_i_alpha + cexp_m2i_alpha = cexp_mi_alpha * cexp_mi_alpha + A2m2emm = ( + cexp_2i_alpha * cBetah4 * Y2mA[0] + - cexp_i_alpha * 2 * cBetah3 * sBetah * Y2mA[1] + + 1 * jnp.sqrt(6) * sBetah2 * cBetah2 * Y2mA[2] + - cexp_mi_alpha * 2 * cBetah * sBetah3 * Y2mA[3] + + cexp_m2i_alpha * sBetah4 * Y2mA[4] + ) + A22emmstar = ( + cexp_m2i_alpha * sBetah4 * jnp.conjugate(Y2mA[0]) + + cexp_mi_alpha * 2 * cBetah * sBetah3 * jnp.conjugate(Y2mA[1]) + + 1 * jnp.sqrt(6) * sBetah2 * cBetah2 * jnp.conjugate(Y2mA[2]) + + cexp_i_alpha * 2 * cBetah3 * sBetah * jnp.conjugate(Y2mA[3]) + + cexp_2i_alpha * cBetah4 * jnp.conjugate(Y2mA[4]) + ) + hp_sum = A2m2emm + A22emmstar * pPrec['PolarizationSymmetry'] + hc_sum = 1j * (A2m2emm - A22emmstar * pPrec['PolarizationSymmetry']) + eps_phase_hP = jnp.exp(-2j * epsilon) * hAS / 2.0 + + hp = eps_phase_hP * hp_sum + hc = eps_phase_hP * hc_sum + + return hp, hc + + +def gen_IMRPhenomXP_hphc(f: Array, + params: Array, ## equivalent to pWF waveform struct ?? + prec_params, ## equivalent to pPrec precession struct + f_ref: float): ## why needs to be input separetely ?? + """ + Returns: + -------- + hp (array): Strain of the plus polarization + hc (array): Strain of the cross polarization + """ + iota = params[7] + h0 = gen_IMRPhenomXAS(f, params, f_ref) + + hp, hc = PhenomXPCoreTwistUp22(f, h0, params, prec_params) + + hp = h0 * (1 / 2 * (1 + jnp.cos(iota) ** 2)) ## to keep or not to keep ?? + hc = -1j * h0 * jnp.cos(iota) + + return hp, hc \ No newline at end of file diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py new file mode 100644 index 00000000..bf3ef409 --- /dev/null +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -0,0 +1,800 @@ +import jax +import jax.numpy as jnp +from ripple import Mc_eta_to_ms + +from typing import Tuple +from ..constants import gt, MSUN +import numpy as np +from .IMRPhenomD import Phase as PhDPhase +from .IMRPhenomD import Amp as PhDAmp +from .IMRPhenomD_utils import ( + get_coeffs, + get_transition_frequencies, + EradRational0815, + FinalSpin0815_s, +) +from ..typing import Array +from .IMRPhenomD_QNMdata import QNMData_a, QNMData_fRD, QNMData_fdamp + + +# helper functions for LALtoPhenomP: +def ROTATEZ(angle, x, y, z): + tmp_x = x * jnp.cos(angle) - y * jnp.sin(angle) + tmp_y = x * jnp.sin(angle) + y * jnp.cos(angle) + return tmp_x, tmp_y, z + + +def ROTATEY(angle, x, y, z): + tmp_x = x * jnp.cos(angle) + z * jnp.sin(angle) + tmp_z = -x * jnp.sin(angle) + z * jnp.cos(angle) + return tmp_x, y, tmp_z + + +def FinalSpin0815(eta, chi1, chi2): + Seta = jnp.sqrt(1.0 - 4.0 * eta) + m1 = 0.5 * (1.0 + Seta) + m2 = 0.5 * (1.0 - Seta) + m1s = m1 * m1 + m2s = m2 * m2 + s = m1s * chi1 + m2s * chi2 + return FinalSpin0815_s(eta, s) + + +def convert_spins( + m1: float, + m2: float, + f_ref: float, + phiRef: float, + incl: float, + s1x: float, + s1y: float, + s1z: float, + s2x: float, + s2y: float, + s2z: float, +) -> Tuple[float, float, float, float, float, float, float]: + # m1 = m1_SI / MSUN # Masses in solar masses + # m2 = m2_SI / MSUN + M = m1 + m2 + m1_2 = m1 * m1 + m2_2 = m2 * m2 + eta = m1 * m2 / (M * M) # Symmetric mass-ratio + + # From the components in the source frame, we can easily determine + # chi1_l, chi2_l, chip and phi_aligned, which we need to return. + # We also compute the spherical angles of J, + # which we need to transform to the J frame + + # Aligned spins + chi1_l = s1z # Dimensionless aligned spin on BH 1 + chi2_l = s2z # Dimensionless aligned spin on BH 2 + + # Magnitude of the spin projections in the orbital plane + S1_perp = m1_2 * jnp.sqrt(s1x**2 + s1y**2) + S2_perp = m2_2 * jnp.sqrt(s2x**2 + s2y**2) + + A1 = 2 + (3 * m2) / (2 * m1) + A2 = 2 + (3 * m1) / (2 * m2) + ASp1 = A1 * S1_perp + ASp2 = A2 * S2_perp + num = jnp.maximum(ASp1, ASp2) + den = A2 * m2_2 # warning: this assumes m2 > m1 + chip = num / den + + m_sec = M * gt + piM = jnp.pi * m_sec + v_ref = (piM * f_ref) ** (1 / 3) + L0 = M * M * L2PNR(v_ref, eta) + J0x_sf = m1_2 * s1x + m2_2 * s2x + J0y_sf = m1_2 * s1y + m2_2 * s2y + J0z_sf = L0 + m1_2 * s1z + m2_2 * s2z + J0 = jnp.sqrt(J0x_sf * J0x_sf + J0y_sf * J0y_sf + J0z_sf * J0z_sf) + + thetaJ_sf = jnp.arccos(J0z_sf / J0) + + phiJ_sf = jnp.arctan2(J0y_sf, J0x_sf) + + phi_aligned = -phiJ_sf + + # First we determine kappa + # in the source frame, the components of N are given in Eq (35c) of T1500606-v6 + Nx_sf = jnp.sin(incl) * jnp.cos(jnp.pi / 2.0 - phiRef) + Ny_sf = jnp.sin(incl) * jnp.sin(jnp.pi / 2.0 - phiRef) + Nz_sf = jnp.cos(incl) + + tmp_x = Nx_sf + tmp_y = Ny_sf + tmp_z = Nz_sf + + tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + + kappa = -jnp.arctan2(tmp_y, tmp_x) + + # Then we determine alpha0, by rotating LN + tmp_x, tmp_y, tmp_z = 0, 0, 1 + tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) + + alpha0 = jnp.arctan2(tmp_y, tmp_x) + + # Finally we determine thetaJ, by rotating N + tmp_x, tmp_y, tmp_z = Nx_sf, Ny_sf, Nz_sf + tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) + Nx_Jf, Nz_Jf = tmp_x, tmp_z + thetaJN = jnp.arccos(Nz_Jf) + + # Finally, we need to redefine the polarizations: + # PhenomP's polarizations are defined following Arun et al (arXiv:0810.5336) + # i.e. projecting the metric onto the P,Q,N triad defined with P=NxJ/|NxJ| (see (2.6) in there). + # By contrast, the triad X,Y,N used in LAL + # ("waveframe" in the nomenclature of T1500606-v6) + # is defined in e.g. eq (35) of this document + # (via its components in the source frame; note we use the defautl Omega=Pi/2). + # Both triads differ from each other by a rotation around N by an angle \zeta + # and we need to rotate the polarizations accordingly by 2\zeta + + Xx_sf = -jnp.cos(incl) * jnp.sin(phiRef) + Xy_sf = -jnp.cos(incl) * jnp.cos(phiRef) + Xz_sf = jnp.sin(incl) + tmp_x, tmp_y, tmp_z = Xx_sf, Xy_sf, Xz_sf + tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) + + # Now the tmp_a are the components of X in the J frame + # We need the polar angle of that vector in the P,Q basis of Arun et al + # P = NxJ/|NxJ| and since we put N in the (pos x)z half plane of the J frame + PArunx_Jf = 0.0 + PAruny_Jf = -1.0 + PArunz_Jf = 0.0 + + # Q = NxP + QArunx_Jf = Nz_Jf + QAruny_Jf = 0.0 + QArunz_Jf = -Nx_Jf + + # Calculate the dot products XdotPArun and XdotQArun + XdotPArun = tmp_x * PArunx_Jf + tmp_y * PAruny_Jf + tmp_z * PArunz_Jf + XdotQArun = tmp_x * QArunx_Jf + tmp_y * QAruny_Jf + tmp_z * QArunz_Jf + + zeta_polariz = jnp.arctan2(XdotQArun, XdotPArun) + return chi1_l, chi2_l, chip, thetaJN, alpha0, phi_aligned, zeta_polariz + + +def SpinWeightedY(theta, phi, s, l, m): + "copied from SphericalHarmonics.c in LAL" + if s == -2: + if l == 2: + if m == -2: + fac = ( + jnp.sqrt(5.0 / (64.0 * jnp.pi)) + * (1.0 - jnp.cos(theta)) + * (1.0 - jnp.cos(theta)) + ) + elif m == -1: + fac = ( + jnp.sqrt(5.0 / (16.0 * jnp.pi)) + * jnp.sin(theta) + * (1.0 - jnp.cos(theta)) + ) + elif m == 0: + fac = jnp.sqrt(15.0 / (32.0 * jnp.pi)) * jnp.sin(theta) * jnp.sin(theta) + elif m == 1: + fac = ( + jnp.sqrt(5.0 / (16.0 * jnp.pi)) + * jnp.sin(theta) + * (1.0 + jnp.cos(theta)) + ) + elif m == 2: + fac = ( + jnp.sqrt(5.0 / (64.0 * jnp.pi)) + * (1.0 + jnp.cos(theta)) + * (1.0 + jnp.cos(theta)) + ) + else: + raise ValueError(f"Invalid mode s={s}, l={l}, m={m} - require |m| <= l") + return fac * np.exp(1j * m * phi) + + +def L2PNR(v: float, eta: float) -> float: + eta2 = eta**2 + x = v**2 + x2 = x**2 + return ( + eta + * ( + 1.0 + + (1.5 + eta / 6.0) * x + + (3.375 - (19.0 * eta) / 8.0 - eta2 / 24.0) * x2 + ) + ) / x**0.5 + + +def WignerdCoefficients(v: float, SL: float, eta: float, Sp: float): + # We define the shorthand s := Sp / (L + SL) + #Sp: perpendicular, SL: parallel + L = L2PNR(v, eta) #### what is this function? + s = Sp / (L + SL) + s2 = s**2 + cos_beta = 1.0 / (1.0 + s2) ** 0.5 + cos_beta_half = ((1.0 + cos_beta) / 2.0) ** 0.5 # cos(beta/2) + sin_beta_half = ((1.0 - cos_beta) / 2.0) ** 0.5 # sin(beta/2) + + return cos_beta_half, sin_beta_half + + +def ComputeNNLOanglecoeffs(q, chil, chip): + ##### -> IMRPhenomXP.pdf for coefficients, same as IMRPhenomPV + m2 = q / (1.0 + q) + m1 = 1.0 / (1.0 + q) + dm = m1 - m2 + mtot = 1.0 + eta = m1 * m2 # mtot = 1 + eta2 = eta * eta + eta3 = eta2 * eta + eta4 = eta3 * eta + mtot2 = mtot * mtot + mtot4 = mtot2 * mtot2 + mtot6 = mtot4 * mtot2 + mtot8 = mtot6 * mtot2 + chil2 = chil * chil + chip2 = chip * chip + chip4 = chip2 * chip2 + dm2 = dm * dm + dm3 = dm2 * dm + m2_2 = m2 * m2 + m2_3 = m2_2 * m2 + m2_4 = m2_3 * m2 + m2_5 = m2_4 * m2 + m2_6 = m2_5 * m2 + m2_7 = m2_6 * m2 + m2_8 = m2_7 * m2 + + angcoeffs = {} + angcoeffs["alphacoeff1"] = -0.18229166666666666 - (5 * dm) / (64.0 * m2) + + angcoeffs["alphacoeff2"] = (-15 * dm * m2 * chil) / (128.0 * mtot2 * eta) - ( + 35 * m2_2 * chil + ) / (128.0 * mtot2 * eta) + + angcoeffs["alphacoeff3"] = ( + -1.7952473958333333 + - (4555 * dm) / (7168.0 * m2) + - (15 * chip2 * dm * m2_3) / (128.0 * mtot4 * eta2) + - (35 * chip2 * m2_4) / (128.0 * mtot4 * eta2) + - (515 * eta) / 384.0 + - (15 * dm2 * eta) / (256.0 * m2_2) + - (175 * dm * eta) / (256.0 * m2) + ) + + angcoeffs["alphacoeff4"] = ( + -(35 * jnp.pi) / 48.0 + - (5 * dm * jnp.pi) / (16.0 * m2) + + (5 * dm2 * chil) / (16.0 * mtot2) + + (5 * dm * m2 * chil) / (3.0 * mtot2) + + (2545 * m2_2 * chil) / (1152.0 * mtot2) + - (5 * chip2 * dm * m2_5 * chil) / (128.0 * mtot6 * eta3) + - (35 * chip2 * m2_6 * chil) / (384.0 * mtot6 * eta3) + + (2035 * dm * m2 * chil) / (21504.0 * mtot2 * eta) + + (2995 * m2_2 * chil) / (9216.0 * mtot2 * eta) + ) + + angcoeffs["alphacoeff5"] = ( + 4.318908476114694 + + (27895885 * dm) / (2.1676032e7 * m2) + - (15 * chip4 * dm * m2_7) / (512.0 * mtot8 * eta4) + - (35 * chip4 * m2_8) / (512.0 * mtot8 * eta4) + - (485 * chip2 * dm * m2_3) / (14336.0 * mtot4 * eta2) + + (475 * chip2 * m2_4) / (6144.0 * mtot4 * eta2) + + (15 * chip2 * dm2 * m2_2) / (256.0 * mtot4 * eta) + + (145 * chip2 * dm * m2_3) / (512.0 * mtot4 * eta) + + (575 * chip2 * m2_4) / (1536.0 * mtot4 * eta) + + (39695 * eta) / 86016.0 + + (1615 * dm2 * eta) / (28672.0 * m2_2) + - (265 * dm * eta) / (14336.0 * m2) + + (955 * eta2) / 576.0 + + (15 * dm3 * eta2) / (1024.0 * m2_3) + + (35 * dm2 * eta2) / (256.0 * m2_2) + + (2725 * dm * eta2) / (3072.0 * m2) + - (15 * dm * m2 * jnp.pi * chil) / (16.0 * mtot2 * eta) + - (35 * m2_2 * jnp.pi * chil) / (16.0 * mtot2 * eta) + + (15 * chip2 * dm * m2_7 * chil2) / (128.0 * mtot8 * eta4) + + (35 * chip2 * m2_8 * chil2) / (128.0 * mtot8 * eta4) + + (375 * dm2 * m2_2 * chil2) / (256.0 * mtot4 * eta) + + (1815 * dm * m2_3 * chil2) / (256.0 * mtot4 * eta) + + (1645 * m2_4 * chil2) / (192.0 * mtot4 * eta) + ) + + angcoeffs["epsiloncoeff1"] = -0.18229166666666666 - (5 * dm) / (64.0 * m2) + angcoeffs["epsiloncoeff2"] = (-15 * dm * m2 * chil) / (128.0 * mtot2 * eta) - ( + 35 * m2_2 * chil + ) / (128.0 * mtot2 * eta) + angcoeffs["epsiloncoeff3"] = ( + -1.7952473958333333 + - (4555 * dm) / (7168.0 * m2) + - (515 * eta) / 384.0 + - (15 * dm2 * eta) / (256.0 * m2_2) + - (175 * dm * eta) / (256.0 * m2) + ) + angcoeffs["epsiloncoeff4"] = ( + -(35 * jnp.pi) / 48.0 + - (5 * dm * jnp.pi) / (16.0 * m2) + + (5 * dm2 * chil) / (16.0 * mtot2) + + (5 * dm * m2 * chil) / (3.0 * mtot2) + + (2545 * m2_2 * chil) / (1152.0 * mtot2) + + (2035 * dm * m2 * chil) / (21504.0 * mtot2 * eta) + + (2995 * m2_2 * chil) / (9216.0 * mtot2 * eta) + ) + angcoeffs["epsiloncoeff5"] = ( + 4.318908476114694 + + (27895885 * dm) / (2.1676032e7 * m2) + + (39695 * eta) / 86016.0 + + (1615 * dm2 * eta) / (28672.0 * m2_2) + - (265 * dm * eta) / (14336.0 * m2) + + (955 * eta2) / 576.0 + + (15 * dm3 * eta2) / (1024.0 * m2_3) + + (35 * dm2 * eta2) / (256.0 * m2_2) + + (2725 * dm * eta2) / (3072.0 * m2) + - (15 * dm * m2 * jnp.pi * chil) / (16.0 * mtot2 * eta) + - (35 * m2_2 * jnp.pi * chil) / (16.0 * mtot2 * eta) + + (375 * dm2 * m2_2 * chil2) / (256.0 * mtot4 * eta) + + (1815 * dm * m2_3 * chil2) / (256.0 * mtot4 * eta) + + (1645 * m2_4 * chil2) / (192.0 * mtot4 * eta) + ) + return angcoeffs + + +def FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip): + M = m1 + m2 + eta = m1 * m2 / (M * M) + # Here I assume m1 > m2, the convention used in phenomD + # (not the convention of internal phenomP) + q_factor = m1 / M + af_parallel = FinalSpin0815(eta, chi1_l, chi2_l) + Sperp = chip * q_factor * q_factor + af = jnp.copysign(1.0, af_parallel) * jnp.sqrt( + Sperp * Sperp + af_parallel * af_parallel + ) + return af + + +def phP_get_fRD_fdamp(m1, m2, chi1_l, chi2_l, chip): + # m1 > m2 should hold here + finspin = FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip) + m1_s = m1 * gt + m2_s = m2 * gt + M_s = m1_s + m2_s + eta_s = m1_s * m2_s / (M_s**2.0) + Erad = EradRational0815(eta_s, chi1_l, chi2_l) + fRD = jnp.interp(finspin, QNMData_a, QNMData_fRD) / (1.0 - Erad) + fdamp = jnp.interp(finspin, QNMData_a, QNMData_fdamp) / (1.0 - Erad) + + return fRD / M_s, fdamp / M_s + + +def phP_get_transition_frequencies( + theta: Array, + gamma2: float, + gamma3: float, + chip: float, +) -> Tuple[float, float, float, float, float, float]: + # m1 > m2 should hold here + + m1, m2, chi1, chi2 = theta + M = m1 + m2 + f_RD, f_damp = phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip) + + # Phase transition frequencies + f1 = 0.018 / (M * gt) + f2 = 0.5 * f_RD + + # Amplitude transition frequencies + f3 = 0.014 / (M * gt) + f4_gammaneg_gtr_1 = lambda f_RD_, f_damp_, gamma3_, gamma2_: jnp.abs( + f_RD_ + (-f_damp_ * gamma3_) / gamma2_ + ) + f4_gammaneg_less_1 = lambda f_RD_, f_damp_, gamma3_, gamma2_: jnp.abs( + f_RD_ + (f_damp_ * (-1 + jnp.sqrt(1 - (gamma2_) ** 2.0)) * gamma3_) / gamma2_ + ) + f4 = jax.lax.cond( + gamma2 >= 1, + f4_gammaneg_gtr_1, + f4_gammaneg_less_1, + f_RD, + f_damp, + gamma3, + gamma2, + ) + return f1, f2, f3, f4, f_RD, f_damp + +def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( + v, ## velocity + pWF, ## IMRPhenomX waveform struct + pPrec ## IMRPhenomX precession struct ### LR: casting?? + ): ## has to output a jnp array + + vout = jnp.array([0,0,0]) + L_norm = pWF["eta"] / v + + J_norm = IMRPhenomX_JNorm_MSA(L_norm,pPrec) + + ## J_norm = jax.lax.cond(pPrec["useMSA"], ) + + L_norm3PN = IMRPhenomX_L_norm_3PN_of_v(v, v*v, L_norm, pPrec) ## for 223 + + J_norm3PN = IMRPhenomX_JNorm_MSA(L_norm3PN,pPrec) + + vRoots = IMRPhenomX_Return_Roots_MSA(L_norm,J_norm,pPrec) ## return jnp.array + + pPrec["S32"] = vRoots[0] + pPrec["Smi2"] = vRoots[1] + pPrec["Spl2"] = vRoots[2] + + pPrec["Spl2mSmi2"] = pPrec["Spl2"] - pPrec["Smi2"] + pPrec["Spl2pSmi2"] = pPrec["Spl2"] + pPrec["Smi2"] + pPrec["Spl"] = jnp.sqrt(pPrec["Spl2"]) + pPrec["Smi"] = jnp.sqrt(pPrec["Smi2"]) + + SNorm = IMRPhenomX_Return_SNorm_MSA(v,pPrec) + pPrec["S_norm"] = SNorm + pPrec["S_norm_2"] = SNorm * SNorm + + ''' Get phiz_0_MSA and zeta_0_MSA ''' + vMSA = jax.lax.cond((jnp.fabs(pPrec["Smi2"] - pPrec["Spl2"]) > 1.e-5), + IMRPhenomX_Return_MSA_Corrections_MSA, ## return 3D jnp.array + lambda v, L_norm, J_norm, pPrec: jnp.array([0,0,0]), ## ugly but okay?? + v, L_norm, J_norm, pPrec) + + phiz_MSA = vMSA[0] + zeta_MSA = vMSA[1] + + phiz = IMRPhenomX_Return_phiz_MSA(v,J_norm,pPrec) ## (DONE) + zeta = IMRPhenomX_Return_zeta_MSA(v,pPrec) ## (DONE) + cos_theta_L = IMRPhenomX_costhetaLJ(L_norm3PN,J_norm3PN,SNorm) ## (DONE) + + vout[0] = phiz + phiz_MSA + vout[1] = zeta + zeta_MSA + vout[2] = cos_theta_L + + return vout + +def WignerdCoefficients_cosbeta( + cos_beta ## cos(beta) + ): + ''' Note that the results here are indeed always non-negative ''' + cos_beta_half = + jnp.sqrt( jnp.fabs(1.0 + cos_beta) / 2.0 ) + sin_beta_half = + jnp.sqrt( jnp.fabs(1.0 - cos_beta) / 2.0 ) + + return cos_beta_half, sin_beta_half + +def IMRPhenomX_costhetaLJ( + L_norm: float, + J_norm: float, + S_norm: float + ) -> float: + costhetaLJ = 0.5 * (J_norm**2 + L_norm**2 - S_norm**2) / L_norm * J_norm + + # Clamp the value to the interval [-1.0, 1.0] + costhetaLJ = jnp.clip(costhetaLJ, -1.0, 1.0) + + return costhetaLJ + +def IMRPhenomX_Return_zeta_MSA( + v: float, + pPrec + ) -> float: + invv = 1.0 / v + invv2 = invv * invv + invv3 = invv * invv2 + v2 = v * v + logv = jnp.log(v) + + # Compute zeta using precession coefficients + zeta_out = pPrec["eta"] * ( + pPrec["Omegazeta0_coeff"] * invv3 + + pPrec["Omegazeta1_coeff"] * invv2 + + pPrec["Omegazeta2_coeff"] * invv + + pPrec["Omegazeta3_coeff"] * logv + + pPrec["Omegazeta4_coeff"] * v + + pPrec["Omegazeta5_coeff"] * v2 + ) + pPrec["zeta_0"] + + # Replace NaNs with 0 using jnp.nan_to_num + zeta_out = jnp.nan_to_num(zeta_out, nan=0.0) + + return zeta_out + + +def IMRPhenomX_Return_phiz_MSA( + v: float, + JNorm: float, + pPrec + ) -> float: + + invv = 1.0 / v + invv2 = invv * invv + LNewt = pPrec["eta"] / v + + c1 = pPrec["c1"] + c12 = c1 * c1 + + SAv2 = pPrec["SAv2"] + SAv = pPrec["SAv"] + invSAv = pPrec["invSAv"] + invSAv2 = pPrec["invSAv2"] + + # These are log functions defined in Eq. D27 and D28 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + log1 = jnp.log(jnp.abs(c1 + JNorm * pPrec["eta"] + pPrec["eta"] * LNewt)) + log2 = jnp.log(jnp.abs(c1 + JNorm * SAv * v + SAv2 * v)) + + # Eq. D22-D27 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + phiz_0_coeff = (JNorm * pPrec["inveta4"]) * ( + 0.5 * c12 - (c1 * pPrec["eta2"] * invv) / 6.0 - (SAv2 * pPrec["eta2"]) / 3.0 - (pPrec["eta4"] * invv2) / 3.0 + ) - (0.5 * c1 * pPrec["inveta"]) * ( + c12 * pPrec["inveta4"] - SAv2 * pPrec["inveta2"] + ) * log1 + + phiz_1_coeff = ( + -0.5 * JNorm * pPrec["inveta2"] * (c1 + pPrec["eta"] * LNewt) + + 0.5 * pPrec["inveta3"] * (c12 - pPrec["eta2"] * SAv2) * log1 + ) + + phiz_2_coeff = -JNorm + SAv * log2 - c1 * log1 * pPrec["inveta"] + + phiz_3_coeff = JNorm * v - pPrec["eta"] * log1 + c1 * log2 * invSAv + + phiz_4_coeff = ( + 0.5 * JNorm * invSAv2 * v * (c1 + v * SAv2) + - 0.5 * invSAv2 * invSAv * (c12 - pPrec["eta2"] * SAv2) * log2 + ) + + phiz_5_coeff = ( + -JNorm * v * ( + 0.5 * c12 * invSAv2 * invSAv2 + - c1 * v * invSAv2 / 6.0 + - v * v / 3.0 + - pPrec["eta2"] * invSAv2 / 3.0 + ) + + 0.5 * c1 * invSAv2 * invSAv2 * invSAv * (c12 - pPrec["eta2"] * SAv2) * log2 + ) + + # Eq. 66 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + + # \phi_{z,-1} = \sum^5_{n=0} <\Omega_z>^(n) \phi_z^(n) + \phi_{z,-1}^0 + + # Note that the <\Omega_z>^(n) are given by pPrec->Omegazn_coeff's as in Eqs. D15-D20 + phiz_out = ( + phiz_0_coeff * pPrec["Omegaz0_coeff"] + + phiz_1_coeff * pPrec["Omegaz1_coeff"] + + phiz_2_coeff * pPrec["Omegaz2_coeff"] + + phiz_3_coeff * pPrec["Omegaz3_coeff"] + + phiz_4_coeff * pPrec["Omegaz4_coeff"] + + phiz_5_coeff * pPrec["Omegaz5_coeff"] + + pPrec["phiz_0"] + ) + + # Ensure no NaN (replace with 0.0 if NaN) + phiz_out = jnp.nan_to_num(phiz_out, nan=0.0) + + return phiz_out + + +def IMRPhenomX_Return_MSA_Corrections_MSA( + v, + LNorm, + JNorm, + pPrec + ): + + v2 = v * v + + # Sets c0, c2 and c4 in pPrec as per Eq. B6-B8 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + c_vec = IMRPhenomX_Return_Constants_c_MSA(v, JNorm, pPrec) + # Sets d0, d2 and d4 in pPrec as per Eq. B9-B11 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + d_vec = IMRPhenomX_Return_Constants_d_MSA(LNorm, JNorm, pPrec) + + c0, c2, c4 = c_vec + d0, d2, d4 = d_vec + + two_d0 = 2.0 * d0 + + # Eq. B20 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + sd = jnp.sqrt(jnp.abs(d2 * d2 - 4.0 * d0 * d4)) + + # Eq. F20-21 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + A_theta_L = 0.5 * ((JNorm / LNorm) + (LNorm / JNorm) - (pPrec["Spl2"] / (JNorm * LNorm))) + B_theta_L = 0.5 * pPrec["Spl2mSmi2"] / (JNorm * LNorm) + + nc_num = 2.0 * (d0 + d2 + d4) + nc_denom = two_d0 + d2 + sd + + nc = nc_num / nc_denom + nd = nc_denom / two_d0 + + sqrt_nc = jnp.sqrt(jnp.abs(nc)) + sqrt_nd = jnp.sqrt(jnp.abs(nd)) + + psi = IMRPhenomX_Return_Psi_MSA(v, v2, pPrec) + pPrec["psi0"] ## (TODO) + psi_dot = IMRPhenomX_Return_Psi_dot_MSA(v, pPrec) ## (TODO) + + tan_psi = jnp.tan(psi) + atan_psi = jnp.arctan(tan_psi) + + C1 = -0.5 * (c0 / d0 - 2.0 * (c0 + c2 + c4) / nc_num) + C2num = (c0 * (-2.0 * d0 * d4 + d2 * d2 + d2 * d4) - + c2 * d0 * (d2 + 2.0 * d4) + + c4 * d0 * (two_d0 + d2)) + C2den = 2.0 * d0 * sd * (d0 + d2 + d4) + C2 = C2num / C2den + + Cphi = C1 + C2 + Dphi = C1 - C2 + + def compute_Cphi_term(): + + return jnp.abs(( + (c4 * d0 * ((2 * d0 + d2) + sd) - + c2 * d0 * ((d2 + 2.0 * d4) - sd) - + c0 * ((2 * d0 * d4) - (d2 + d4) * (d2 - + sd))) / C2den) * (sqrt_nc / (nc - 1.0)) * (atan_psi - jnp.arctan(sqrt_nc * tan_psi))) / psi_dot + + def compute_Dphi_term(): + return jnp.abs(( + (-c4 * d0 * ((2 * d0 + d2) - sd) + + c2 * d0 * ((d2 + 2.0 * d4) + sd) - + c0 * (-(2 * d0 * d4) + (d2 + d4) * (d2 + sd))) / C2den + ) * (sqrt_nd / (nd - 1.0)) * (atan_psi - jnp.arctan(sqrt_nd * tan_psi))) / psi_dot + + phiz_0_MSA_Cphi_term = jnp.where(nc == 1.0, 0.0, compute_Cphi_term()) + phiz_0_MSA_Dphi_term = jnp.where(nd == 1.0, 0.0, compute_Dphi_term()) + + vMSA_x = phiz_0_MSA_Cphi_term + phiz_0_MSA_Dphi_term + + ##### restart from here + vMSA_y = A_theta_L * vMSA_x + 2.0 * B_theta_L * d0 * ( + phiz_0_MSA_Cphi_term / (sd - d2) - phiz_0_MSA_Dphi_term / (sd + d2)) + + vMSA_x = jnp.where(jnp.isnan(vMSA_x), 0.0, vMSA_x) + vMSA_y = jnp.where(jnp.isnan(vMSA_y), 0.0, vMSA_y) + + return jnp.array([vMSA_x, vMSA_y, 0.0]) + +def IMRPhenomX_JNorm_MSA(LNorm, pPrec): + JNorm2 = (LNorm * LNorm + 2.0 * LNorm * pPrec['c1_over_eta'] + pPrec['SAv2']) + return jnp.sqrt(JNorm2) + +def IMRPhenomX_Return_SNorm_MSA(): + ## TODO: implement elliptic Jacobi functions + return + + +def IMRPhenomX_L_norm_3PN_of_v(v: float, v2: float, L_norm: float, pPrec) -> float: + cL = pPrec['constants_L'] # shorthand + return L_norm * 1.0 + v2 * ( + cL[0] + v * cL[1] + v2 * ( + cL[2] + v * cL[3] + v2 * cL[4] + ) + ) + +def IMRPhenomX_Return_Roots_MSA(LNorm, JNorm, pPrec): + vBCD = IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec) ## TODO + B, C, D = vBCD[0], vBCD[1], vBCD[2] + + B2 = B * B + B3 = B2 * B + BC = B * C + + p = C - B2 / 3.0 + qc = (2.0 / 27.0) * B3 - BC / 3.0 + D + + sqrtarg = jnp.sqrt(-p / 3.0) + acosarg = 1.5 * qc / (p * sqrtarg) + acosarg = jnp.clip(acosarg, -1.0, 1.0) + + theta = jnp.arccos(acosarg) / 3.0 + cos_theta = jnp.cos(theta) + + invalid_case = ( + jnp.isnan(theta) | + jnp.isnan(sqrtarg) | + (pPrec.dotS1Ln == 1.0) | + (pPrec.dotS2Ln == 1.0) | + (pPrec.dotS1Ln == -1.0) | + (pPrec.dotS2Ln == -1.0) | + (pPrec.S1_norm_2 == 0.0) | + (pPrec.S2_norm_2 == 0.0) + ) + + def roots_when_valid(): + tmp1 = 2.0 * sqrtarg * jnp.cos(theta - 4.0 * jnp.pi / 3.0) - B / 3.0 + tmp2 = 2.0 * sqrtarg * jnp.cos(theta - 2.0 * jnp.pi / 3.0) - B / 3.0 + tmp3 = 2.0 * sqrtarg * cos_theta - B / 3.0 + + tmp4 = jnp.maximum(jnp.maximum(tmp1, tmp2), tmp3) + tmp5 = jnp.minimum(jnp.minimum(tmp1, tmp2), tmp3) + + tmp6 = jnp.where( + (tmp4 - tmp3 > 0.0) & (tmp5 - tmp3 < 0.0), + tmp3, + jnp.where((tmp4 - tmp1 > 0.0) & (tmp5 - tmp1 < 0.0), tmp1, tmp2) + ) + + S32 = tmp5 + Smi2 = jnp.abs(tmp6) + Spl2 = jnp.abs(tmp4) + return S32, Smi2, Spl2 + + def roots_when_invalid(): + Smi2 = pPrec['S_0_norm_2'] + Spl2 = Smi2 + 1e-9 + S32 = 0.0 + return S32, Smi2, Spl2 + + S32, Smi2, Spl2 = jax.lax.cond( + invalid_case, + roots_when_invalid, + roots_when_valid + ) + + return jnp.array([S32, Smi2, Spl2]) + +def IMRPhenomX_Return_Constants_c_MSA(v, JNorm, pPrec): + v2 = v * v + v3 = v * v2 + v4 = v2 * v2 + v6 = v3 * v3 + JNorm2 = JNorm * JNorm + Seff = pPrec['Seff'] + + + x = JNorm * ( + 0.75 * (1.0 - Seff * v) * v2 * ( + pPrec['eta3'] + + 4.0 * pPrec['eta3'] * Seff * v + - 2.0 * pPrec['eta'] * ( + JNorm2 - pPrec['Spl2'] + 2.0 * (pPrec['S1_norm_2'] - pPrec['S2_norm_2']) * pPrec['delta_qq'] + ) * v2 + - 4.0 * pPrec['eta'] * Seff * (JNorm2 - pPrec['Spl2']) * v3 + + (JNorm2 - pPrec['Spl2']) ** 2 * v4 * pPrec['inveta'] + ) + ) + + y = JNorm * ( + -1.5 * pPrec['eta'] * (pPrec['Spl2'] - pPrec['Smi2']) + * (1.0 + 2.0 * Seff * v - (JNorm2 - pPrec['Spl2']) * v2 * pPrec['inveta2']) + * (1.0 - Seff * v) * v4 + ) + + z = JNorm * ( + 0.75 * pPrec['inveta'] * (pPrec['Spl2'] - pPrec['Smi2']) ** 2 + * (1.0 - Seff * v) * v6 + ) + + return jnp.array([x, y, z]) + +def IMRPhenomX_Return_Constants_d_MSA(LNorm, JNorm, pPrec): + LNorm2 = LNorm * LNorm + JNorm2 = JNorm * JNorm + + x = - (JNorm2 - (LNorm + pPrec['Spl'])) ** 2 * (JNorm2 - (LNorm - pPrec['Spl'])) ** 2 + + y = -2.0 * (pPrec['Spl2'] - pPrec['Smi2']) * (JNorm2 + LNorm2 - pPrec['Spl2']) + + z = -(pPrec['Spl2'] - pPrec['Smi2']) ** 2 + + return jnp.array([x, y, z]) + +def IMRPhenomX_Return_Psi_MSA(v, v2, pPrec): + return -0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + pPrec['psi1'] * v + pPrec['psi2'] * v2) / (v2 * v) + +def IMRPhenomX_Return_Psi_dot_MSA(v, pPrec): + v2 = v * v + + A_coeff = -1.5 * v2 * v2 * v2 * (1.0 - v * pPrec['Seff']) * pPrec['sqrt_inveta'] + psi_dot = 0.5 * A_coeff * jnp.sqrt(pPrec['Spl2'] - pPrec['S32']) + + return psi_dot From b2f093ad631b42614badecfea07b7d5cae7eadbe Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Mon, 14 Jul 2025 12:41:56 +0200 Subject: [PATCH 02/20] Implemented IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 46 +++++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index bf3ef409..5e1e0617 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -377,7 +377,7 @@ def phP_get_fRD_fdamp(m1, m2, chi1_l, chi2_l, chip): def phP_get_transition_frequencies( - theta: Array, + theta: jnp.array, gamma2: float, gamma3: float, chip: float, @@ -618,8 +618,8 @@ def IMRPhenomX_Return_MSA_Corrections_MSA( sqrt_nc = jnp.sqrt(jnp.abs(nc)) sqrt_nd = jnp.sqrt(jnp.abs(nd)) - psi = IMRPhenomX_Return_Psi_MSA(v, v2, pPrec) + pPrec["psi0"] ## (TODO) - psi_dot = IMRPhenomX_Return_Psi_dot_MSA(v, pPrec) ## (TODO) + psi = IMRPhenomX_Return_Psi_MSA(v, v2, pPrec) + pPrec["psi0"] + psi_dot = IMRPhenomX_Return_Psi_dot_MSA(v, pPrec) tan_psi = jnp.tan(psi) atan_psi = jnp.arctan(tan_psi) @@ -681,7 +681,7 @@ def IMRPhenomX_L_norm_3PN_of_v(v: float, v2: float, L_norm: float, pPrec) -> flo ) def IMRPhenomX_Return_Roots_MSA(LNorm, JNorm, pPrec): - vBCD = IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec) ## TODO + vBCD = IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec) B, C, D = vBCD[0], vBCD[1], vBCD[2] B2 = B * B @@ -798,3 +798,41 @@ def IMRPhenomX_Return_Psi_dot_MSA(v, pPrec): psi_dot = 0.5 * A_coeff * jnp.sqrt(pPrec['Spl2'] - pPrec['S32']) return psi_dot + +def IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec): + JNorm2 = JNorm * JNorm + LNorm2 = LNorm * LNorm + + S1Norm2 = pPrec['S1_norm_2'] + S2Norm2 = pPrec['S2_norm_2'] + q = pPrec['qq'] + eta = pPrec['eta'] + delta = pPrec['delta_qq'] + deltaSq = delta * delta + Seff = pPrec['Seff'] + + J2mL2 = JNorm2 - LNorm2 + J2mL2Sq = J2mL2 * J2mL2 + + # B coefficient (Eq. B2) + B_coeff = ((LNorm2 + S1Norm2) * q + + 2.0 * LNorm * Seff - + 2.0 * JNorm2 - + S1Norm2 - S2Norm2 + + (LNorm2 + S2Norm2) / q) + + # C coefficient (Eq. B3) + C_coeff = (J2mL2Sq - + 2.0 * LNorm * Seff * J2mL2 - + 2.0 * ((1.0 - q) / q) * LNorm2 * (S1Norm2 - q * S2Norm2) + + 4.0 * eta * LNorm2 * Seff * Seff - + 2.0 * delta * (S1Norm2 - S2Norm2) * Seff * LNorm + + 2.0 * ((1.0 - q) / q) * (q * S1Norm2 - S2Norm2) * JNorm2) + + # D coefficient (Eq. B4) + D_coeff = (((1.0 - q) / q) * (S2Norm2 - q * S1Norm2) * J2mL2Sq + + deltaSq * (S1Norm2 - S2Norm2)**2 * LNorm2 / eta + + 2.0 * delta * LNorm * Seff * (S1Norm2 - S2Norm2) * J2mL2) + + return jnp.array([B_coeff, C_coeff, D_coeff]) + From 03269bf07ef634c3fb97722141f3db8e0d75ab61 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Mon, 28 Jul 2025 11:35:02 +0200 Subject: [PATCH 03/20] Started developing XGetAndSetPrecessionVariables --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 52 +++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 5e1e0617..06d2b63b 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -836,3 +836,55 @@ def IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec): return jnp.array([B_coeff, C_coeff, D_coeff]) +def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, + chi1x, chi1y, chi1z, + chi2x, chi2y, chi2z, + lalParams): + + pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(lalParams) + + Mtot_SI = m1_SI + m2_SI + # Normalize masses + m1 = m1_SI / Mtot_SI + m2 = m2_SI / Mtot_SI + M = m1 + m2 + #pWF['M'] = m1 + m2 ### pWF needs to be a dict?? + + # Mass ratio and symmetric mass ratio + q = m1 / m2 + eta = pWF[1] + + ## TODO: compute delta? + ## TODO: compute chieff? + ## TODO: compute twopiGM, piGM? + + # Spin inputs + for i, (x, y, z) in enumerate([(chi1x, chi1y, chi1z), (chi2x, chi2y, chi2z)], start=1): + chi_norm = jnp.sqrt(x*x + y*y + z*z) + pPrec[f'chi{i}x'] = x + pPrec[f'chi{i}y'] = y + pPrec[f'chi{i}z'] = z + pPrec[f'chi{i}_norm'] = chi_norm + + # Dimensionful spins + pPrec['S1x'] = chi1x * m1_2 + pPrec['S1y'] = chi1y * m1_2 + pPrec['S1z'] = chi1z * m1_2 + pPrec['S1_norm'] = jnp.abs(pPrec['chi1_norm']) * m1_2 + pPrec['S2x'] = chi2x * m2_2 + pPrec['S2y'] = chi2y * m2_2 + pPrec['S2z'] = chi2z * m2_2 + pPrec['S2_norm'] = jnp.abs(pPrec['chi2_norm']) * m2_2 + + # In-plane magnitudes + pPrec['chi1_perp'] = jnp.sqrt(chi1x*chi1x + chi1y*chi1y) + pPrec['chi2_perp'] = jnp.sqrt(chi2x*chi2x + chi2y*chi2y) + pPrec['S1_perp'] = m1_2 * pPrec['chi1_perp'] + pPrec['S2_perp'] = m2_2 * pPrec['chi2_perp'] + STot_x = pPrec['S1x'] + pPrec['S2x'] + STot_y = pPrec['S1y'] + pPrec['S2y'] + pPrec['STot_perp'] = jnp.sqrt(STot_x**2 + STot_y**2) + pPrec['chiTot_perp'] = pPrec['STot_perp'] * (M**2) / m1_2 + # pWF['chiTot_perp'] = pPrec['chiTot_perp'] ### pWF needs to be a dict?? + + return 0 # Success From de4781eb5df48ce5e8a1d535331436a777316068 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Mon, 28 Jul 2025 13:27:10 +0200 Subject: [PATCH 04/20] Implementing IMRPhenomX_Initialize_MSA_System --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 322 ++++++++++++++++++++ 1 file changed, 322 insertions(+) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 5e1e0617..761d57db 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -836,3 +836,325 @@ def IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec): return jnp.array([B_coeff, C_coeff, D_coeff]) +def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): + eta = pPrec['eta'] + eta2 = pPrec['eta2'] + eta3 = pPrec['eta3'] + eta4 = pPrec['eta4'] + + m1 = pWF['m1'] + m2 = pWF['m2'] + + domegadt_constants_NS = jnp.array([ + 96. / 5., -1486. / 35., -264. / 5., 384. * jnp.pi / 5., 34103. / 945., + 13661. / 105., 944. / 15., jnp.pi * (-4159. / 35.), jnp.pi * (-2268. / 5.), + (16447322263. / 7276500. + jnp.pi**2 * 512. / 5. - jnp.log(2.) * 109568. / 175. - jnp.euler_gamma * 54784. / 175.), + (-56198689. / 11340. + jnp.pi**2 * 902. / 5.), + 1623. / 140., -1121. / 27., -54784. / 525., -jnp.pi * 883. / 42., + jnp.pi * 71735. / 63., jnp.pi * 73196. / 63. + ]) + + domegadt_constants_SO = jnp.array([ + -904. / 5., -120., -62638. / 105., 4636. / 5., -6472. / 35., 3372. / 5., + -jnp.pi * 720., -jnp.pi * 2416. / 5., -208520. / 63., 796069. / 105., + -100019. / 45., -1195759. / 945., 514046. / 105., -8709. / 5., + -jnp.pi * 307708. / 105., jnp.pi * 44011. / 7., -jnp.pi * 7992. / 7., + jnp.pi * 151449. / 35. + ]) + + domegadt_constants_SS = jnp.array([ + -494. / 5., -1442. / 5., -233. / 5., -719. / 5. + ]) + + L_csts_nonspin = jnp.array([ + 3. / 2., 1. / 6., 27. / 8., -19. / 8., 1. / 24., 135. / 16., + -6889. / 144. + 41. / 24. * jnp.pi**2, 31. / 24., 7. / 1296. + ]) + + L_csts_spinorbit = jnp.array([ + -14. / 6., -3. / 2., -11. / 2., 133. / 72., -33. / 8., 7. / 4. + ]) + + # Flip q convention: Chatziioannou uses q < 1 (m1 > m2), IMRPhenomX uses q > 1 + q = m2 / m1 # q < 1, m1 > m2 + invq = 1.0 / q + + pPrec['qq'] = q + pPrec['invqq'] = invq + + # Reduced mass + mu = (m1 * m2) / (m1 + m2) + + pPrec['delta_qq'] = (1.0 - pPrec['qq']) / (1.0 + pPrec['qq']) + pPrec['delta2_qq'] = pPrec['delta_qq'] ** 2 + pPrec['delta3_qq'] = pPrec['delta_qq'] * pPrec['delta2_qq'] + pPrec['delta4_qq'] = pPrec['delta_qq'] * pPrec['delta3_qq'] + + # Initialize vectors + S1v = jnp.array([0.0, 0.0, 0.0]) + S2v = jnp.array([0.0, 0.0, 0.0]) + Lhat = jnp.array([0.0, 0.0, 1.0]) + + # Set fixed Lhat variables + pPrec['Lhat_cos_theta'] = 1.0 + pPrec['Lhat_phi'] = 0.0 + pPrec['Lhat_theta'] = 0.0 + + # Dimensionful spin vectors (eta = m1 * m2, q = m2 / m1) + S1v[0] = pPrec['chi1x'] * eta / q + S1v[1] = pPrec['chi1y'] * eta / q + S1v[2] = pPrec['chi1z'] * eta / q + + S2v[0] = pPrec['chi2x'] * eta * q + S2v[1] = pPrec['chi2y'] * eta * q + S2v[2] = pPrec['chi2z'] * eta * q + + # Norms of spin vectors + S1_0_norm = jnp.linalg.norm(S1v) + S2_0_norm = jnp.linalg.norm(S2v) + + # Store initial spin vectors + pPrec['S1_0'] = S1v + pPrec['S2_0'] = S2v + + # Reference velocity v and v^2 + pPrec['v_0'] = jnp.cbrt(pPrec['piGM'] * pWF['fRef']) + pPrec['v_0_2'] = pPrec['v_0'] ** 2 + + # Reference orbital angular momentum + L_0 = Lhat * eta / pPrec['v_0'] + pPrec['L_0'] = L_0 + + dotS1L = jnp.dot(S1v, Lhat) + dotS2L = jnp.dot(S2v, Lhat) + dotS1S2 = jnp.dot(S1v, S2v) + dotS1Ln = dotS1L / S1_0_norm + dotS2Ln = dotS2L / S2_0_norm + + # Store results in the precession structure + pPrec['dotS1L'] = dotS1L + pPrec['dotS2L'] = dotS2L + pPrec['dotS1S2'] = dotS1S2 + pPrec['dotS1Ln'] = dotS1Ln + pPrec['dotS2Ln'] = dotS2Ln + + # --- PN Coefficients for orbital angular momentum --- + pPrec['constants_L'] = jnp.zeros(5) + pPrec['constants_L'] = pPrec['constants_L'].at[0].set( + L_csts_nonspin[0] + eta * L_csts_nonspin[1] + ) + + pPrec['constants_L'] = pPrec['constants_L'].at[1].set( + IMRPhenomX_Get_PN_beta(L_csts_spinorbit[0], L_csts_spinorbit[1], pPrec) ## (TODO) + ) + + pPrec['constants_L'] = pPrec['constants_L'].at[2].set( + L_csts_nonspin[2] + + eta * L_csts_nonspin[3] + + eta ** 2 * L_csts_nonspin[4] + ) + + pPrec['constants_L'] = pPrec['constants_L'].at[3].set( + IMRPhenomX_Get_PN_beta( + L_csts_spinorbit[2] + L_csts_spinorbit[3] * eta, + L_csts_spinorbit[4] + L_csts_spinorbit[5] * eta, + pPrec + ) + ) + + pPrec['constants_L'] = pPrec['constants_L'].at[4].set( + L_csts_nonspin[5] + + L_csts_nonspin[6] * eta + + L_csts_nonspin[7] * eta ** 2 + + L_csts_nonspin[8] * eta ** 3 + ) + + # Effective total spin + Seff = (1.0 + q) * pPrec['dotS1L'] + (1.0 + 1.0 / q) * pPrec['dotS2L'] + pPrec['Seff'] = Seff + + # Initial total spin, S0 = S1 + S2 + S_0 = S1v + S2v + pPrec['S_0'] = S0 + + # Initial total angular momentum J0 = L + S1 + S2 + pPrec['J_0'] = L_0 + S_0 + + pPrec['S_0_norm'] = jnp.linalg.norm(S0) + + pPrec['L_0_norm'] = jnp.linalg.norm(pPrec['L_0']) + pPrec['J_0_norm'] = jnp.linalg.norm(pPrec['J_0']) + + vBCD = IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(pPrec['L_0_norm'], pPrec['J_0_norm'], pPrec) + + vRoots = IMRPhenomX_Return_Roots_MSA(pPrec['L_0_norm'], pPrec['J_0_norm'], pPrec) + + pPrec['Spl2'] = vRoots[2] + pPrec['Smi2'] = vRoots[1] + pPrec['S32'] = vRoots[0] + + pPrec['Spl2pSmi2'] = pPrec['Spl2'] + pPrec['Smi2'] + pPrec['Spl2mSmi2'] = pPrec['Spl2'] - pPrec['Smi2'] + + pPrec['Spl'] = jnp.sqrt(pPrec['Spl2']) + pPrec['Smi'] = jnp.sqrt(pPrec['Smi2']) + + # Eq. 45 of PRD, 95, 104004, (2017), arXiv:1703.03967 + pPrec['SAv2'] = 0.5 * pPrec['Spl2pSmi2'] + pPrec['SAv'] = jnp.sqrt(pPrec['SAv2']) + pPrec['invSAv2'] = 1.0 / pPrec['SAv2'] + pPrec['invSAv'] = 1.0 / pPrec['SAv'] + + # Eq. 41 of PRD, 95, 104004, (2017), arXiv:1703.03967 + c_1 = 0.5 * (pPrec['J_0_norm']**2 - pPrec['L_0_norm']**2 - pPrec['SAv2']) / (pPrec['L_0_norm'] * eta) + pPrec['c1'] = c_1 + pPrec['c12'] = c_1 ** 2 + pPrec['c1_over_eta'] = c_1 / eta + + # Average spin couplings over one precession cycle: A9 - A14 of arXiv:1703.03967 + omqsq = (1.0 - q)**2 + 1e-16 + omq2 = (1.0 - q**2) + 1e-16 + + pPrec['S1L_pav'] = (c_1 * (1.0 + q) - q * eta * Seff) / (eta * omq2) + pPrec['S2L_pav'] = -q * (c_1 * (1.0 + q) - eta * Seff) / (eta * omq2) + pPrec['S1S2_pav'] = 0.5 * pPrec['SAv2'] - 0.5 * (pPrec['S1_norm_2'] + pPrec['S2_norm_2']) + pPrec['S1Lsq_pav'] = pPrec['S1L_pav'] ** 2 + (pPrec['Spl2mSmi2'] ** 2 * pPrec['v_0_2']) / (32.0 * eta2 * omqsq) + pPrec['S2Lsq_pav'] = pPrec['S2L_pav'] ** 2 + (q**2 * spl2m_smi2_sq * v0_2) / (32.0 * eta2 * omqsq) + pPrec['S1LS2L_pav'] = pPrec['S1L_pav'] * pPrec['S2L_pav'] - q * pPrec['Spl2mSmi2'] * pPrec['v_0_2'] / (32.0 * eta2 * omqsq) + + # beta coefficients + pPrec['beta3'] = ((113.0/12.0) + (25.0/4.0)*(m2/m1)) * pPrec['S1L_pav'] + + ((113.0/12.0) + (25.0/4.0)*(m1/m2)) * pPrec['S2L_pav'] + + pPrec['beta5'] = (((31319.0/1008.0) - (1159.0/24.0)*eta) + (m2/m1)*((809.0/84.0) - (281.0/8.0)*eta)) * pPrec['S1L_pav'] + + (((31319.0/1008.0) - (1159.0/24.0)*eta) + (m1/m2)*((809.0/84.0) - (281.0/8.0)*eta)) * pPrec['S2L_pav'] + + pPrec['beta6'] = jnp.pi * ( + ((75.0/2.0) + (151.0/6.0)*(m2/m1)) * pPrec['S1L_pav'] + + ((75.0/2.0) + (151.0/6.0)*(m1/m2)) * pPrec['S2L_pav'] + ) + + beta7_common = (130325.0/756.0) - (796069.0/2016.0)*eta + (100019.0/864.0)*eta2 + beta7_S = (1195759.0/18144.0 - 257023.0/1008.0 * eta + 2903.0/32.0 * eta2) + + pPrec['beta7'] = beta7_common + (m2/m1) * beta7_S * pPrec['S1L_pav'] + beta7_common + (m1/m2) * beta7_S * pPrec['S2L_pav'] + + pPrec['sigma4'] = (1.0 / mu) * ((247.0/48.0) * pPrec['S1S2_pav'] - (721.0/48.0) * pPrec['S1L_pav'] * pPrec['S2L_pav']) + + (1.0 / (m1**2)) * ((233.0/96.0) * pPrec['S1_norm_2'] - (719.0/96.0) * pPrec['S1Lsq_pav']) + + (1.0 / (m2**2)) * ((233.0/96.0) * pPrec['S2_norm_2'] - (719.0/96.0) * pPrec['S2Lsq_pav']) + + # PN coefficients + pPrec['a0'] = 96.0 * eta / 5.0 + ''' + pPrec['a2'] = (-(743.0/336.0) - (11.0/4.0)*eta) * pPrec['a0'] + pPrec['a3'] = (4.0 * jnp.pi - pPrec['beta3']) * pPrec['a0'] + pPrec['a4'] = ((34103.0/18144.0) + (13661.0/2016.0)*eta + (59.0/18.0)*eta2 - pPrec['sigma4']) * pPrec['a0'] + pPrec['a5'] = (-(4159.0/672.0)*jnp.pi - (189.0/8.0)*jnp.pi*eta - pPrec['beta5']) * pPrec['a0'] + ''' + pPrec['a6'] = ((16447322263.0/139708800.0) + (16.0/3.0)*jnp.pi**2 - (856.0/105.0)*jnp.log(16.0) - (1712.0/105.0)*jnp.euler_gamma + - pPrec['beta6'] + eta*(451.0/48.0)*jnp.pi**2 - (56198689.0/217728.0) + eta2*(541.0/896.0) - eta3*(5605.0/2592.0)) * pPrec['a0'] + pPrec['a7'] = (-(4415.0/4032.0)*jnp.pi + (358675.0/6048.0)*jnp.pi*eta + (91495.0/1512.0)*jnp.pi*eta2 - pPrec['beta7']) * pPrec['a0'] + + # Default behaviour of IMRPhenomXP (223: MSA with fallback to NNLO) + pPrec['a0'] = eta * domegadt_constants_NS[0] + pPrec['a2'] = eta * (domegadt_constants_NS[1] + eta * domegadt_constants_NS[2]) + pPrec['a3'] = eta * (domegadt_constants_NS[3] + + IMRPhenomX_Get_PN_beta(domegadt_constants_SO[0],domegadt_constants_SO[1],pPrec)) + pPrec['a4'] = eta * (domegadt_constants_NS[4] +eta * (domegadt_constants_NS[5] + eta * domegadt_constants_NS[6]) + + IMRPhenomX_Get_PN_sigma(domegadt_constants_SS[0],domegadt_constants_SS[1],pPrec) + + IMRPhenomX_Get_PN_tau(domegadt_constants_SS[2],domegadt_constants_SS[3],pPrec)) + pPrec['a5'] = eta * (domegadt_constants_NS[7] +eta * domegadt_constants_NS[8] + + IMRPhenomX_Get_PN_beta( + domegadt_constants_SO[2] + eta * domegadt_constants_SO[3], + domegadt_constants_SO[4] + eta * domegadt_constants_SO[5], + pPrec + ) + ) + + pPrec['a0_2'] = pPrec['a0'] * pPrec['a0'] + pPrec['a0_3'] = pPrec['a0_2'] * pPrec['a0'] + pPrec['a2_2'] = pPrec['a2'] * pPrec['a2'] + + # g-coefficients from Appendix A of Chatziioannou et al, PRD, 95, 104004, (2017), arXiv:1703.03967. + pPrec['g0'] = 1.0 / pPrec['a0'] + pPrec['g2'] = -pPrec['a2'] / pPrec['a0_2'] + pPrec['g3'] = -pPrec['a3'] / pPrec['a0_2'] + pPrec['g4'] = -(pPrec['a4'] * pPrec['a0'] - pPrec['a2_2']) / pPrec['a0_3'] + pPrec['g5'] = -(pPrec['a5'] * pPrec['a0'] - 2.0 * pPrec['a3'] * pPrec['a2']) / pPrec['a0_3'] + + delta = pPrec['delta_qq'] + delta2 = delta * delta + delta3 = delta * delta2 + delta4 = delta * delta3 + + # Phase coefficients: Eq. 51 and Appendix C of arXiv:1703.03967 + pPrec['psi0'] = 0.0 + pPrec['psi1'] = 3.0 * (2.0 * eta2 * Seff - c_1) / (eta * delta2) + pPrec['psi2'] = 0.0 + + # Precompute useful quantities + c_1_over_nu = pPrec['c1_over_eta'] + c_1_over_nu_2 = c_1_over_eta * c_1_over_eta + one_p_q_sq = (1.0 + q)**2 + Seff_2 = Seff * Seff + q_2 = q * q + one_m_q_sq = (1.0 - q)**2 + one_m_q2_2 = (1.0 - q_2)**2 + one_m_q_4 = one_m_q_sq * one_m_q_sq + + Del1 = 4.0 * c_1_over_nu_2 * one_p_q_sq + Del2 = 8.0 * c_1_over_nu * q * (1.0 + q) * Seff + Del3 = 4.0 * (one_m_q2_2 * pPrec['S1_norm_2'] - q_2 * Seff_2) + Del4 = 4.0 * c_1_over_nu_2 * q_2 * one_p_q_sq + Del5 = 8.0 * c_1_over_nu * q_2 * (1.0 + q) * Seff + Del6 = 4.0 * (one_m_q2_2 * pPrec['S2_norm_2'] - q_2 * Seff_2) + + pPrec['Delta'] = jnp.sqrt(jnp.abs((Del1 - Del2 - Del3) * (Del4 - Del5 - Del6))) + + u1 = 3.0 * pPrec['g2'] / pPrec['g0'] + u2 = 0.75 * one_p_q_sq / one_m_q_4 + u3 = -20.0 * c_1_over_nu_2 * q_2 * one_p_q_sq + u4 = 2.0 * one_m_q2_2 * (q * (2.0 + q) * pPrec['S1_norm_2'] + + (1.0 + 2.0 * q) * pPrec['S2_norm_2'] - 2.0 * q * pPrec['SAv2']) + u5 = 4.0 * q_2 * (7.0 + 6.0 * q + 7.0 * q_2) * c_1_over_nu * Seff + u6 = 2.0 * q_2 * (3.0 + 4.0 * q + 3.0 * q_2) * Seff_2 + u7 = q * pPrec['Delta'] + + # (Eq. C2 of 1703.03967) + pPrec['psi2'] = u1 + u2 * (u3 + u4 + u5 - u6 + u7) + + # Eq. D1 - D5 of 1703.03967 + Rm = pPrec['Spl2'] - pPrec['Smi2'] + Rm_2 = Rm * Rm + cp = pPrec['Spl2'] * eta2 - c1_2 + cm = pPrec['Smi2'] * eta2 - c1_2 + cpcm = jnp.abs(cp * cm) + sqrt_cpcm = jnp.sqrt(cpcm) + a1dD = 0.5 + 0.75 / eta + a2dD = -0.75 * Seff / eta + + # Eq. E3- E4 of 1703.03967 + D2RmSq = (cp - sqrt_cpcm) / eta2 + D4RmSq = -0.5 * Rm * sqrt_cpcm / eta2 - cp / eta4 * (sqrt_cpcm - cp) + + S0m = pPrec['S1_norm_2'] - pPrec['S2_norm_2'] + + aw = -3.0 * (1. + q) / q * (2. * (1. + q) * eta2 * Seff * c_1 - (1. + q) * c1_2 + (1. - q) * eta2 * S0m) + cw = 3.0 / 32.0 / eta * Rm_2 + dw = 4.0 * cp - 4.0 * D2RmSq * eta2 + hw = -2.0 * (2.0 * D2RmSq - Rm) * c_1 + fw = Rm * D2RmSq - D4RmSq - 0.25 * Rm_2 + + adD = aw / dw + hdD = hw / dw + cdD = cw / dw + fdD = fw / dw + + gw = 3. / 16. / eta2 / eta * Rm_2 * (c_1 - eta2 * Seff) + gdD = gw / dw + + # Powers of coefficients + hdD_2 = hdD * hdD + adDfdD = adD * fdD + adDfdDhdD = adDfdD * hdD + adDhdD_2 = adD * hdD_2 From b2a3f418382526428534756563f69fc2b1d7e655 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Tue, 12 Aug 2025 11:55:42 +0200 Subject: [PATCH 05/20] Implemented IMRPhenomX_Initialize_MSA_System --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 135 ++++++++++++++++++-- 1 file changed, 121 insertions(+), 14 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 761d57db..49d4c326 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -975,12 +975,12 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): # Initial total spin, S0 = S1 + S2 S_0 = S1v + S2v - pPrec['S_0'] = S0 + pPrec['S_0'] = S_0 # Initial total angular momentum J0 = L + S1 + S2 pPrec['J_0'] = L_0 + S_0 - pPrec['S_0_norm'] = jnp.linalg.norm(S0) + pPrec['S_0_norm'] = jnp.linalg.norm(S_0) pPrec['L_0_norm'] = jnp.linalg.norm(pPrec['L_0']) pPrec['J_0_norm'] = jnp.linalg.norm(pPrec['J_0']) @@ -1019,15 +1019,15 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): pPrec['S2L_pav'] = -q * (c_1 * (1.0 + q) - eta * Seff) / (eta * omq2) pPrec['S1S2_pav'] = 0.5 * pPrec['SAv2'] - 0.5 * (pPrec['S1_norm_2'] + pPrec['S2_norm_2']) pPrec['S1Lsq_pav'] = pPrec['S1L_pav'] ** 2 + (pPrec['Spl2mSmi2'] ** 2 * pPrec['v_0_2']) / (32.0 * eta2 * omqsq) - pPrec['S2Lsq_pav'] = pPrec['S2L_pav'] ** 2 + (q**2 * spl2m_smi2_sq * v0_2) / (32.0 * eta2 * omqsq) + pPrec['S2Lsq_pav'] = pPrec['S2L_pav'] ** 2 + (q**2 * (pPrec['Spl2mSmi2'] ** 2) * pPrec['v_0_2']) / (32.0 * eta2 * omqsq) pPrec['S1LS2L_pav'] = pPrec['S1L_pav'] * pPrec['S2L_pav'] - q * pPrec['Spl2mSmi2'] * pPrec['v_0_2'] / (32.0 * eta2 * omqsq) # beta coefficients - pPrec['beta3'] = ((113.0/12.0) + (25.0/4.0)*(m2/m1)) * pPrec['S1L_pav'] - + ((113.0/12.0) + (25.0/4.0)*(m1/m2)) * pPrec['S2L_pav'] + pPrec['beta3'] = ((113.0/12.0) + (25.0/4.0)*(m2/m1)) * pPrec['S1L_pav'] + ((113.0/12.0) + + (25.0/4.0)*(m1/m2)) * pPrec['S2L_pav'] - pPrec['beta5'] = (((31319.0/1008.0) - (1159.0/24.0)*eta) + (m2/m1)*((809.0/84.0) - (281.0/8.0)*eta)) * pPrec['S1L_pav'] - + (((31319.0/1008.0) - (1159.0/24.0)*eta) + (m1/m2)*((809.0/84.0) - (281.0/8.0)*eta)) * pPrec['S2L_pav'] + pPrec['beta5'] = (((31319.0/1008.0) - (1159.0/24.0)*eta) + (m2/m1)*((809.0/84.0) + - (281.0/8.0)*eta)) * pPrec['S1L_pav'] + (((31319.0/1008.0) - (1159.0/24.0)*eta) + (m1/m2)*((809.0/84.0) - (281.0/8.0)*eta)) * pPrec['S2L_pav'] pPrec['beta6'] = jnp.pi * ( ((75.0/2.0) + (151.0/6.0)*(m2/m1)) * pPrec['S1L_pav'] + @@ -1039,9 +1039,8 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): pPrec['beta7'] = beta7_common + (m2/m1) * beta7_S * pPrec['S1L_pav'] + beta7_common + (m1/m2) * beta7_S * pPrec['S2L_pav'] - pPrec['sigma4'] = (1.0 / mu) * ((247.0/48.0) * pPrec['S1S2_pav'] - (721.0/48.0) * pPrec['S1L_pav'] * pPrec['S2L_pav']) - + (1.0 / (m1**2)) * ((233.0/96.0) * pPrec['S1_norm_2'] - (719.0/96.0) * pPrec['S1Lsq_pav']) - + (1.0 / (m2**2)) * ((233.0/96.0) * pPrec['S2_norm_2'] - (719.0/96.0) * pPrec['S2Lsq_pav']) + pPrec['sigma4'] = (1.0 / mu) * ((247.0/48.0) * pPrec['S1S2_pav'] - (721.0/48.0) * pPrec['S1L_pav'] * pPrec['S2L_pav']) + (1.0 / (m1**2)) * ((233.0/96.0) * pPrec['S1_norm_2'] + - (719.0/96.0) * pPrec['S1Lsq_pav']) + (1.0 / (m2**2)) * ((233.0/96.0) * pPrec['S2_norm_2'] - (719.0/96.0) * pPrec['S2Lsq_pav']) # PN coefficients pPrec['a0'] = 96.0 * eta / 5.0 @@ -1094,7 +1093,7 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): # Precompute useful quantities c_1_over_nu = pPrec['c1_over_eta'] - c_1_over_nu_2 = c_1_over_eta * c_1_over_eta + c_1_over_nu_2 = c_1_over_nu * c_1_over_nu one_p_q_sq = (1.0 + q)**2 Seff_2 = Seff * Seff q_2 = q * q @@ -1126,8 +1125,8 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): # Eq. D1 - D5 of 1703.03967 Rm = pPrec['Spl2'] - pPrec['Smi2'] Rm_2 = Rm * Rm - cp = pPrec['Spl2'] * eta2 - c1_2 - cm = pPrec['Smi2'] * eta2 - c1_2 + cp = pPrec['Spl2'] * eta2 - pPrec['c12'] + cm = pPrec['Smi2'] * eta2 - pPrec['c12'] cpcm = jnp.abs(cp * cm) sqrt_cpcm = jnp.sqrt(cpcm) a1dD = 0.5 + 0.75 / eta @@ -1139,7 +1138,7 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): S0m = pPrec['S1_norm_2'] - pPrec['S2_norm_2'] - aw = -3.0 * (1. + q) / q * (2. * (1. + q) * eta2 * Seff * c_1 - (1. + q) * c1_2 + (1. - q) * eta2 * S0m) + aw = -3.0 * (1. + q) / q * (2. * (1. + q) * eta2 * Seff * c_1 - (1. + q) * pPrec['c12'] + (1. - q) * eta2 * S0m) cw = 3.0 / 32.0 / eta * Rm_2 dw = 4.0 * cp - 4.0 * D2RmSq * eta2 hw = -2.0 * (2.0 * D2RmSq - Rm) * c_1 @@ -1158,3 +1157,111 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): adDfdD = adD * fdD adDfdDhdD = adDfdD * hdD adDhdD_2 = adD * hdD_2 + + # Eq. D10-D15 in PRD, 95, 104004, (2017), arXiv:1703.03967 + pPrec['Omegaz0'] = a1dD + adD + pPrec['Omegaz1'] = a2dD - adD*Seff - adD*hdD + pPrec['Omegaz2'] = adD*hdD*Seff + cdD - adD*fdD + adD*hdD_2 + pPrec['Omegaz3'] = (adDfdD - cdD - adDhdD_2)*(Seff + hdD) + adDfdDhdD + pPrec['Omegaz4'] = (cdD + adDhdD_2 - 2.0*adDfdD)*(hdD*Seff + hdD_2 - fdD) - adD*fdD*fdD + pPrec['Omegaz5'] = (cdD - adDfdD + adDhdD_2) * fdD * (Seff + 2.0*hdD) - (cdD + adDhdD_2 - 2.0*adDfdD) * hdD_2 * (Seff + hdD) - adDfdD*fdD*hdD + + # Condition for MSA fallback to NNLO + condition = jnp.abs(pPrec['Omegaz5']) > 1000.0 + pPrec['MSA_ERROR'] = jnp.where(condition, 1, 0) + + g0 = pPrec['g0'] + + # Eq. 65 coefficients (D16 - D21 of PRD, 95, 104004, (2017), arXiv:1703.03967) + pPrec['Omegaz0_coeff'] = 3.0 * g0 * pPrec['Omegaz0'] + pPrec['Omegaz1_coeff'] = 3.0 * g0 * pPrec['Omegaz1'] + pPrec['Omegaz2_coeff'] = 3.0 * (g0 * pPrec['Omegaz2'] + pPrec['g2'] * pPrec['Omegaz0']) + pPrec['Omegaz3_coeff'] = 3.0 * (g0 * pPrec['Omegaz3'] + pPrec['g2'] * pPrec['Omegaz1'] + pPrec['g3'] * pPrec['Omegaz0']) + pPrec['Omegaz4_coeff'] = 3.0 * (g0 * pPrec['Omegaz4'] + pPrec['g2'] * pPrec['Omegaz2'] + pPrec['g3'] * pPrec['Omegaz1'] + pPrec['g4'] * pPrec['Omegaz0']) + pPrec['Omegaz5_coeff'] = 3.0 * (g0 * pPrec['Omegaz5'] + pPrec['g2'] * pPrec['Omegaz3'] + pPrec['g3'] * pPrec['Omegaz2'] + pPrec['g4'] * pPrec['Omegaz1'] + pPrec['g5'] * pPrec['Omegaz0']) + + # zeta coefficients (Appendix E of PRD, 95, 104004, (2017), arXiv:1703.03967) + c1oveta2 = c_1 / eta2 + pPrec['Omegazeta0'] = pPrec['Omegaz0'] + pPrec['Omegazeta1'] = pPrec['Omegaz1'] + pPrec['Omegaz0'] * c1oveta2 + pPrec['Omegazeta2'] = pPrec['Omegaz2'] + pPrec['Omegaz1'] * c1oveta2 + pPrec['Omegazeta3'] = pPrec['Omegaz3'] + pPrec['Omegaz2'] * c1oveta2 + gdD + pPrec['Omegazeta4'] = pPrec['Omegaz4'] + pPrec['Omegaz3'] * c1oveta2 - gdD * Seff - gdD * hdD + pPrec['Omegazeta5'] = pPrec['Omegaz5'] + pPrec['Omegaz4'] * c1oveta2 + gdD * hdD * Seff + gdD * (hdD_2 - fdD) + + pPrec['Omegazeta0_coeff'] = -pPrec['g0'] * pPrec['Omegazeta0'] + pPrec['Omegazeta1_coeff'] = -1.5 * pPrec['g0'] * pPrec['Omegazeta1'] + pPrec['Omegazeta2_coeff'] = -3.0 * (pPrec['g0'] * pPrec['Omegazeta2'] + pPrec['g2'] * pPrec['Omegazeta0']) + pPrec['Omegazeta3_coeff'] = 3.0 * (pPrec['g0'] * pPrec['Omegazeta3'] + pPrec['g2'] * pPrec['Omegazeta1'] + pPrec['g3'] * pPrec['Omegazeta0']) + pPrec['Omegazeta4_coeff'] = 3.0 * (pPrec['g0'] * pPrec['Omegazeta4'] + pPrec['g2'] * pPrec['Omegazeta2'] + pPrec['g3'] * pPrec['Omegazeta1'] + pPrec['g4'] * pPrec['Omegazeta0']) + pPrec['Omegazeta5_coeff'] = 1.5 * (pPrec['g0'] * pPrec['Omegazeta5'] + pPrec['g2'] * pPrec['Omegazeta3'] + pPrec['g3'] * pPrec['Omegazeta2'] + pPrec['g4'] * pPrec['Omegazeta1'] + pPrec['g5'] * pPrec['Omegazeta0']) + + ## Here we're only considering the default setting, where the expansion order for MSA correction is 5 + ## switch to choose expansion order not yet implemented (TODO) + pPrec['Omegaz5_coeff'] = 0.0 + pPrec['Omegazeta5_coeff'] = 0.0 + + + condition_equal = jnp.abs(pPrec['Smi2'] - pPrec['Spl2']) < 1.0e-5 + + def branch_equal(pPrec): + return 0.0 + + def branch_not_equal(pPrec): + mm_val = jnp.sqrt((pPrec['Smi2'] - pPrec['Spl2']) + (pPrec['S32'] - pPrec['Spl2'])) + tmpB_val = ((pPrec['S_0_norm'] * pPrec['S_0_norm']) - pPrec['Spl2']) / (pPrec['Smi2'] - pPrec['Spl2']) + + vol_elem = jnp.dot(jnp.cross(L_0, S1v),S2v) + vol_sign_val = jnp.sign(vol_elem) + + psi_v0_val = IMRPhenomX_psiofv(pPrec['v_0'], pPrec['v_0_2'], 0.0, + pPrec['psi1'], pPrec['psi2'], pPrec) + + # Clamp tmpB in conditions + cond_case1 = jnp.logical_and(tmpB_val > 1.0, + (tmpB_val - 1.0) < 1.0e-5) + cond_case2 = jnp.logical_and(tmpB_val < 0.0, + tmpB_val > -1.0e-5) + + def case1(): + return gsl_sf_ellint_F( ##(TODO) + jnp.arcsin(vol_sign_val * jnp.sqrt(1.0)), + mm_val + ) - psi_v0_val + + def case2(): + return gsl_sf_ellint_F( + jnp.arcsin(vol_sign_val * jnp.sqrt(0.0)), + mm_val + ) - psi_v0_val + + def case3(): + return gsl_sf_ellint_F( + jnp.arcsin(vol_sign_val * jnp.sqrt(tmpB_val)), + mm_val + ) - psi_v0_val + + psi0_val = jnp.select( + [cond_case1, cond_case2, jnp.logical_not( + jnp.logical_or(tmpB_val > 1.0, tmpB_val < 0.0))], + [case1(), case2(), case3()] + ) + + return psi0_val + + pPrec['psi0'] = jnp.where(condition_equal, + branch_equal(pPrec), + branch_not_equal(pPrec)) + + vMSA = jnp.where(condition_equal, jnp.array([0.,0.,0.]), + IMRPhenomX_Return_MSA_Corrections_MSA(pPrec['v_0'],pPrec['L_0_norm'],pPrec['J_0_norm'],pPrec)) + + phiz_0 = IMRPhenomX_Return_phiz_MSA(pPrec['v_0'],pPrec['J_0_norm'],pPrec) + + zeta_0 = IMRPhenomX_Return_zeta_MSA(pPrec['v_0'],pPrec) + + pPrec['phiz_0'] = - phiz_0 - vMSA[0] + pPrec['zeta_0'] = - zeta_0 - vMSA[1] + + return pPrec \ No newline at end of file From a1a75286a8a5f30580d72b13f20d093de6ab0cab Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Tue, 12 Aug 2025 12:13:44 +0200 Subject: [PATCH 06/20] Implemented IMRPhenomX_Get_PN_beta, sigma, tau, IMRPhenomX_psiofv --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 49d4c326..2160c85d 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -1264,4 +1264,22 @@ def case3(): pPrec['phiz_0'] = - phiz_0 - vMSA[0] pPrec['zeta_0'] = - zeta_0 - vMSA[1] - return pPrec \ No newline at end of file + return pPrec + +def IMRPhenomX_Get_PN_beta(a, b, pPrec): + return (pPrec['dotS1L'] * (a + b * pPrec['qq']) + + pPrec['dotS2L'] * (a + b / pPrec['qq'])) + +def IMRPhenomX_Get_PN_sigma(a, b, pPrec): + return pPrec['inveta'] * (a * pPrec['dotS1S2'] - + b * pPrec['dotS1L'] * pPrec['dotS2L']) + +def IMRPhenomX_Get_PN_tau(a, b, pPrec): + return ( + pPrec['qq'] * (pPrec['S1_norm_2'] * a - b * pPrec['dotS1L'] * pPrec['dotS1L']) + + (a * pPrec['S2_norm_2'] - b * pPrec['dotS2L'] * pPrec['dotS2L']) / pPrec['qq'] + ) / pPrec['eta'] + +def IMRPhenomX_psiofv(v, v2, psi0, psi1, psi2, pPrec): + # Equation 51 in arXiv:1703.03967 + return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) \ No newline at end of file From b779b70f8347c5a531839a19e29d62b6d1e69d64 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Mon, 29 Sep 2025 15:12:55 +0200 Subject: [PATCH 07/20] Implement XGetAndSetPrecessionVariables and XPCheckMaxOpeningAngle functions --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 357 +++++++++++++++++++- 1 file changed, 353 insertions(+), 4 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index c6cb9c17..9a9711e3 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -841,7 +841,9 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, chi2x, chi2y, chi2z, lalParams): - pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(lalParams) + pPrec = {} + + pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(lalParams) ## (TODO) Mtot_SI = m1_SI + m2_SI # Normalize masses @@ -854,7 +856,11 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, q = m1 / m2 eta = pWF[1] - ## TODO: compute delta? + m1_2 = m1**2 + m2_2 = m2**2 + + ## TODO: check how is delta stored in pWF + delta = pWF['delta'] ## TODO: compute chieff? ## TODO: compute twopiGM, piGM? @@ -887,7 +893,286 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec['chiTot_perp'] = pPrec['STot_perp'] * (M**2) / m1_2 # pWF['chiTot_perp'] = pPrec['chiTot_perp'] ### pWF needs to be a dict?? - return 0 # Success + ## disable tuned PNR angles, tuned coprec and mode asymmetries in low in-plane spin limit (TODO) + + pPrec['A1'] = 2.0 + (3.0 * m2) / (2.0 * m1) + pPrec['A2'] = 2.0 + (3.0 * m1) / (2.0 * m2) + + pPrec['ASp1'] = pPrec['A1'] * pPrec['S1_perp'] + pPrec['ASp2'] = pPrec['A2'] * pPrec['S2_perp'] + + # S_p = max(A1 S1_perp, A2 S2_perp) + num = jnp.where(pPrec['ASp2'] > pPrec['ASp1'], pPrec['ASp2'], pPrec['ASp1']) + + den = jnp.where(m2 > m1, pPrec['A2'] * m2_2, pPrec['A1'] * m1_2) + + chip = num / den + chi1L = chi1z + chi2L = chi2z + + pPrec['chi_p'] = chip + pWF['chi_p'] = pPrec['chi_p'] # propagate to waveform struct ### pWF needs to be a dict?? + pPrec['phi0_aligned'] = pWF['phi0'] + + # Effective (dimensionful) aligned spin + pPrec['SL'] = chi1L * m1_2 + chi2L * m2_2 + + # Effective (dimensionful) in-plane spin + # (code assumes m1 > m2) + pPrec['Sperp'] = chip * m1_2 + + # initialize error flag + pPrec['MSA_ERROR'] = 0 + + pPrec['pWF22AS'] = None + + pPrec['precessing_tag'] = 2 + + # Initialize PNR variables + pPrec['chi_singleSpin'] = 0.0 + pPrec['costheta_singleSpin'] = 0.0 + pPrec['costheta_final_singleSpin'] = 0.0 + pPrec['chi_singleSpin_antisymmetric'] = 0.0 + pPrec['theta_antisymmetric'] = 0.0 + pPrec['PNR_HM_Mflow'] = 0.0 + pPrec['PNR_HM_Mfhigh'] = 0.0 + + pPrec['PNR_q_window_lower'] = 0.0 + pPrec['PNR_q_window_upper'] = 0.0 + pPrec['PNR_chi_window_lower'] = 0.0 + pPrec['PNR_chi_window_upper'] = 0.0 + # pPrec['PNRInspiralScaling'] = 0 # if needed + + # Call external routine + status = IMRPhenomX_PNR_GetAndSetPNRVariables(pWF, pPrec) ## (TODO) + + # XLAL_CHECK equivalent (trigger MSA_ERROR if failed) + pPrec['MSA_ERROR'] = status + + # Reset alpha, beta, gamma + pPrec['alphaPNR'] = 0.0 + pPrec['betaPNR'] = 0.0 + pPrec['gammaPNR'] = 0.0 + + # Get and/or store CoPrec params + status = IMRPhenomX_PNR_GetAndSetCoPrecParams(pWF, pPrec, lalParams) ## (TODO) + + pPrec['MSA_ERROR'] = status + + pPrec = IMRPhenomX_Initialize_MSA_System(pWF,pPrec,pPrec['ExpansionOrder']) ## is it okay to update + + IMRPhenomX_SetPrecessingRemnantParams(pWF,pPrec,lalParams) ## (TODO) + + chip2 = chip * chip + + chi1L2 = chi1L * chi1L + chi2L2 = chi2L * chi2L + + pPrec['L0'] = 1.0 + pPrec['L1'] = 0.0 + pPrec['L2'] = 1.5 + eta / 6.0 + pPrec['L3'] = (-7.0 * (chi1L + chi2L + chi1L * delta - chi2L * delta) + + 5.0 * (chi1L + chi2L) * eta) / 6.0 + pPrec['L4'] = (81.0 + (-57.0 + eta) * eta) / 24.0 + pPrec['L5'] = (-1650.0 * (chi1L + chi2L + chi1L * delta - chi2L * delta) + + 1336.0 * (chi1L + chi2L) * eta + + 511.0 * (chi1L - chi2L) * delta * eta + + 28.0 * (chi1L + chi2L) * eta**2) / 600.0 + pPrec['L6'] = (10935.0 + eta * (-62001.0 + 1674.0 * eta + 7.0 * eta**2 + + 2214.0 * jnp.pi**2)) / 1296.0 + pPrec['L7'] = 0.0 + pPrec['L8'] = 0.0 + + pPrec['L8L'] = 0.0 + + pPrec['LRef'] = ( + M * M * XLALSimIMRPhenomXLPNAnsatz( ## (TODO) + pWF['v_ref'], ## (TODO) check if variables in pWF are stored correctly + pWF['eta'] / pWF['v_ref'], + pPrec['L0'], pPrec['L1'], pPrec['L2'], pPrec['L3'], pPrec['L4'], + pPrec['L5'], pPrec['L6'], pPrec['L7'], pPrec['L8'], pPrec['L8L'] + ) + ) + + # Source frame J = L + S1 + S2 + pPrec['J0x_Sf'] = m1_2 * chi1x + m2_2 * chi2x + pPrec['J0y_Sf'] = m1_2 * chi1y + m2_2 * chi2y + pPrec['J0z_Sf'] = m1_2 * chi1z + m2_2 * chi2z + pPrec['LRef'] + + pPrec['J0'] = jnp.sqrt( + pPrec['J0x_Sf'] * pPrec['J0x_Sf'] + + pPrec['J0y_Sf'] * pPrec['J0y_Sf'] + + pPrec['J0z_Sf'] * pPrec['J0z_Sf'] + ) + + # Angle between J0 and LN (z-direction) + def small_J0_case(_): + return 0.0 + + def general_case(pPrec): + return jnp.arccos(pPrec['J0z_Sf'] / pPrec['J0']) + + pPrec['thetaJ_Sf'] = jnp.where(pPrec['J0'] < 1e-10, + small_J0_case(pPrec), + general_case(pPrec)) + + phiRef = pWF['phiRef_In'] ## (TODO) check if variables in pWF are stored correctly + + convention = 1 + + max_tol_condition = jnp.logical_and( + jnp.abs(pPrec['J0x_Sf']) < 1e-15, + jnp.abs(pPrec['J0y_Sf']) < 1e-15 + ) + + pPrec['phiJ_Sf'] = jnp.where( + max_tol_condition, + 0.0, + jnp.arctan2(pPrec['J0y_Sf'], pPrec['J0x_Sf']) + ) + + pPrec['phi0_aligned'] = -pPrec['phiJ_Sf'] + pWF['phi0'] = 0 ## (TODO) check if variables in pWF are stored correctly + + # Now rotate from SF to J frame to compute alpha0, the azimuthal angle of LN, as well as + # thetaJ, the angle between J and N. + + pPrec['Nx_Sf'] = jnp.sin(pWF['inclination']) * jnp.cos((jnp.pi / 2.0) - pWF['phiRef_In']) + pPrec['Ny_Sf'] = jnp.sin(pWF['inclination']) * jnp.sin((jnp.pi / 2.0) - pWF['phiRef_In']) + pPrec['Nz_Sf'] = jnp.cos(pWF['inclination']) + + # Temporary vector copy + tmp_v = jnp.array([pPrec['Nx_Sf'], pPrec['Ny_Sf'], pPrec['Nz_Sf']]) + + # Rotate around z, then y + tmp_v = IMRPhenomX_rotate_z(-pPrec['phiJ_Sf'], tmp_v) + tmp_v = IMRPhenomX_rotate_y(-pPrec['thetaJ_Sf'], tmp_v) + + ## Difference in overall - sign w.r.t PhenomPv2 code + pPrec['kappa'] = jnp.where((jnp.abs(tmp_v[0]) < 1e-15) & (jnp.abs(tmp_v[1]) < 1e-15), 0.0, jnp.arctan2(tmp_v[1], tmp_v[0])) + + ## Now determine alpha0 by rotating LN. In the source frame, LN = {0,0,1} + tmp_v = jnp.array([0,0,1]) + tmp_v = IMRPhenomX_rotate_z(-pPrec['phiJ_Sf'], tmp_v) + tmp_v = IMRPhenomX_rotate_y(-pPrec['thetaJ_Sf'], tmp_v) + tmp_v = IMRPhenomX_rotate_z(-pPrec['kappa'], tmp_v) + + pPrec['alpha0'] = jnp.pi - pPrec['kappa'] + + J0dotN = ( + pPrec['J0x_Sf'] * pPrec['Nx_Sf'] + + pPrec['J0y_Sf'] * pPrec['Ny_Sf'] + + pPrec['J0z_Sf'] * pPrec['Nz_Sf'] + ) + + pPrec['thetaJN'] = jnp.arccos(J0dotN / pPrec['J0']) + + pPrec['Nz_Jf'] = jnp.cos(pPrec['thetaJN']) + pPrec['Nx_Jf'] = jnp.sin(pPrec['thetaJN']) + + """ + Define the polarizations used. This follows the conventions adopted for IMRPhenomPv2. + + The IMRPhenomP polarizations are defined following the conventions in Arun et al (arXiv:0810.5336), + i.e. projecting the metric onto the P, Q, N triad defining where: P = (N x J) / |N x J|. + + However, the triad X,Y,N used in LAL (the "waveframe") follows the definition in the + NR Injection Infrastructure (Schmidt et al, arXiv:1703.01076). + + The triads differ from each other by a rotation around N by an angle \zeta. We therefore need to rotate + the polarizations by an angle 2 \zeta. + """ + + pPrec['Xx_Sf'] = -jnp.cos(pWF['inclination']) * jnp.sin(phiRef) + pPrec['Xy_Sf'] = -jnp.cos(pWF['inclination']) * jnp.cos(phiRef) + pPrec['Xz_Sf'] = jnp.sin(pWF['inclination']) + + tmp_v = jnp.array([pPrec['Xx_Sf'],pPrec['Xy_Sf'],pPrec['Xz_Sf']]) + tmp_v = IMRPhenomX_rotate_z(-pPrec['phiJ_Sf'], tmp_v) + tmp_v = IMRPhenomX_rotate_y(-pPrec['thetaJ_Sf'], tmp_v) + tmp_v = IMRPhenomX_rotate_z(-pPrec['kappa'], tmp_v) + + pPrec['PArun_Jf'] = jnp.array([pPrec['Nz_Jf'],0,-pPrec['Nx_Jf']]) + + # Q = (N x P) by construction + pPrec['QArun_Jf'] = jnp.array([0,1,0]) + + pPrec['XdotPArun'] = jnp.dot(tmp_v, pPrec['PArun_Jf']) + pPrec['XdotQArun'] = jnp.dot(tmp_v, pPrec['QArun_Jf']) + + pPrec['zeta_polarization'] = jnp.arctan2(pPrec['XdotQArun'], pPrec['XdotPArun']) + + pPrec['alpha1'] = 0.0 + pPrec['alpha2'] = 0.0 + pPrec['alpha3'] = 0.0 + pPrec['alpha4L'] = 0.0 + pPrec['alpha5'] = 0.0 + pPrec['epsilon1'] = 0.0 + pPrec['epsilon2'] = 0.0 + pPrec['epsilon3'] = 0.0 + pPrec['epsilon4L'] = 0.0 + pPrec['epsilon5'] = 0.0 + + pPrec['epsilon0'] = pPrec['phiJ_Sf'] - jnp.pi + + alpha_offset, epsilon_offset = Get_alphaepsilon_atfref(2, pPrec, pWF) ## (TODO) + pPrec['alpha_offset'] = alpha_offset + pPrec['epsilon_offset'] = epsilon_offset + pPrec['alpha_offset_1'] = alpha_offset + pPrec['epsilon_offset_1'] = epsilon_offset + pPrec['alpha_offset_3'] = alpha_offset + pPrec['epsilon_offset_3'] = epsilon_offset + pPrec['alpha_offset_4'] = alpha_offset + pPrec['epsilon_offset_4'] = epsilon_offset + + pPrec['cexp_i_alpha'] = 0.0 + pPrec['cexp_i_epsilon'] = 0.0 + pPrec['cexp_i_betah'] = 0.0 + + """ + Check whether maximum opening angle becomes larger than \pi/2 or \pi/4. + + If (L + S_L) < 0, then Wigner-d Coefficients will not track the angle between J and L, meaning + that the model may become pathological as one moves away from the aligned-spin limit. + + If this does not happen, then max_beta will be the actual maximum opening angle. + + This function uses a 2PN non-spinning approximation to the orbital angular momentum L, as + the roots can be analytically derived. + """ + + # lalParams key, is set to zero if max_beta > Pi/2, q>7 and conditionalPrecMBand = 1, + # thus disabling multibanding + PrecThresholdMband = IMRPhenomXPCheckMaxOpeningAngle(pWF, pPrec) + + ## TODO: conditions to switch on/ off multibanding + + ytheta = pPrec['thetaJN'] + yphi = 0.0 + pPrec['Y2m2'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, -2) ## (TODO) + pPrec['Y2m1'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, -1) + pPrec['Y20'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, 0) + pPrec['Y21'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, 1) + pPrec['Y22'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, 2) + pPrec['Y3m3'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, -3) + pPrec['Y3m2'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, -2) + pPrec['Y3m1'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, -1) + pPrec['Y30'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 0) + pPrec['Y31'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 1) + pPrec['Y32'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 2) + pPrec['Y33'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 3) + pPrec['Y4m4'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -4) + pPrec['Y4m3'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -3) + pPrec['Y4m2'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -2) + pPrec['Y4m1'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -1) + pPrec['Y40'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 0) + pPrec['Y41'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 1) + pPrec['Y42'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 2) + pPrec['Y43'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 3) + pPrec['Y44'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 4) + + return pPrec + def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): eta = pPrec['eta'] eta2 = pPrec['eta2'] @@ -1318,6 +1603,52 @@ def case3(): return pPrec +def IMRPhenomXPCheckMaxOpeningAngle( + pWF, + pPrec +): + eta = pWF['eta'] + + # For now, use the 2PN non-spinning maximum opening angle + inner_sqrt = 1539.0 - 1008.0 * eta + 19.0 * eta * eta + numerator = -9.0 - eta + jnp.sqrt(inner_sqrt) + denominator = 81.0 - 57.0 * eta + eta * eta + v_at_max_beta = jnp.sqrt(2.0 / 3.0) * jnp.sqrt(numerator / denominator) + + status, cBetah, sBetah = IMRPhenomXWignerdCoefficients(cBetah, sBetah, v_at_max_beta, pWF, pPrec) ## (TODO) + + jax.lax.cond( + status != 0, + lambda s: jax.debug.print("Call to IMRPhenomXWignerdCoefficients failed."), + lambda s: s, + status + ) + + L_min = XLALSimIMRPhenomXL2PNNS(v_at_max_beta, eta) ## (TODO) + max_beta = 2.0 * jnp.arccos(cBetah) + + def positive_case(vals): + q, conditionalPrecMBand = vals[0,1] + jax.debug.print('The maximum opening angle exceeds Pi/2. The model may be pathological in this regime.') + return jax.lax.cond((q > 7.0 and conditionalPrecMBand == 1), lambda x: x*0, lambda x: x, vals[3]) + + def negative_case(vals): + max_beta = vals[2] + PrecThresholdMband = vals[3] + return PrecThresholdMband + + L_plus_SL = L_min + pPrec['SL'] + + ### missing a warning message if negatice case and max_beta > PI/4 (TODO) + + return jax.lax.cond( + (L_plus_SL < 0.0) & (pPrec['chi_p'] > 0.0), + positive_case, + negative_case, + (pWF['q'], pPrec['conditionalPrecMBand'], max_beta, pPrec['PrecThresholdMband']) + ) + + def IMRPhenomX_Get_PN_beta(a, b, pPrec): return (pPrec['dotS1L'] * (a + b * pPrec['qq']) + pPrec['dotS2L'] * (a + b / pPrec['qq'])) @@ -1334,4 +1665,22 @@ def IMRPhenomX_Get_PN_tau(a, b, pPrec): def IMRPhenomX_psiofv(v, v2, psi0, psi1, psi2, pPrec): # Equation 51 in arXiv:1703.03967 - return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) \ No newline at end of file + return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) + +def IMRPhenomX_rotate_y(v, theta): + """Rotate vector(s) v about the y-axis by angle theta (radians).""" + R_y = jnp.array([ + [ jnp.cos(theta), 0.0, jnp.sin(theta)], + [ 0.0, 1.0, 0.0 ], + [-jnp.sin(theta), 0.0, jnp.cos(theta)] + ]) + return R_y @ v + +def IMRPhenomX_rotate_z(v, theta): + """Rotate vector(s) v about the z-axis by angle theta (radians).""" + R_z = jnp.array([ + [ jnp.cos(theta), -jnp.sin(theta), 0.0], + [ jnp.sin(theta), jnp.cos(theta), 0.0], + [ 0.0, 0.0, 1.0] + ]) + return R_z @ v \ No newline at end of file From 054357fa5621a62c2d8e0da6f9b2792d1f03cfef Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Mon, 6 Oct 2025 18:04:27 +0200 Subject: [PATCH 08/20] Implement X_SetPrecessingRemnantParams --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 53 ++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 9a9711e3..bb0c441a 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -1615,7 +1615,7 @@ def IMRPhenomXPCheckMaxOpeningAngle( denominator = 81.0 - 57.0 * eta + eta * eta v_at_max_beta = jnp.sqrt(2.0 / 3.0) * jnp.sqrt(numerator / denominator) - status, cBetah, sBetah = IMRPhenomXWignerdCoefficients(cBetah, sBetah, v_at_max_beta, pWF, pPrec) ## (TODO) + status, cBetah, sBetah = WignerdCoefficients(cBetah, sBetah, v_at_max_beta, pWF, pPrec) jax.lax.cond( status != 0, @@ -1648,6 +1648,57 @@ def negative_case(vals): (pWF['q'], pPrec['conditionalPrecMBand'], max_beta, pPrec['PrecThresholdMband']) ) +def IMRPhenomX_SetPrecessingRemnantParams(pWF, pPrec, params): + + # TODO: Not yet implemented toggles: + ## Toggle for PNR coprecessing tuning. PNRUseTunedCoprec == 0 for the moment + ## Toggle for enforced use of non-precessing spin as is required during tuning of PNR's coprecessing model + ## High-level toggle for whether to apply deviations. APPLY_PNR_DEVIATIONS == 0 for the moment + + af_parallel = XLALSimIMRPhenomXFinalSpin2017(pWF['eta'], pPrec['chi1z'], pPrec['chi2z']) ## (TODO) + M = pWF['M'] + Lfinal = M * M * af_parallel - pWF['m1_2'] * pPrec['chi1z'] - pWF['m2_2'] * pPrec['chi2z'] + + # Just implementing FinalSpinMod fsflag == 3, since it is default for MSA approximation + # see arXiv:2004.06503v2 + # Just implementing branch for MSA_ERROR == 0. No fallback to NNLO for the moment + + # what is this flag? + method_flag = XLALSimInspiralWaveformParamsLookupPhenomXPTransPrecessionMethod( ## (TODO) + params + ) + sign = jax.lax.cond( + method_flag == 1, + lambda _: jnp.copysign(1.0, af_parallel), + lambda _: 1.0, + operand=None, + ) + pWF['afinal_prec'] = sign * jnp.sqrt(pPrec['SAv2'] + Lfinal**2 + 2.0 * Lfinal * (pPrec['S1L_pav'] + pPrec['S2L_pav']) + ) / (M**2) + + pWF['afinal'] = pWF['afinal_prec'] + + def clip_afinal_prec(_): + afinal_prec = jnp.copysign(1.0, pWF['afinal_prec']) + afinal = jnp.copysign(1.0, pWF['afinal']) + return {**pWF, 'afinal_prec': afinal_prec, 'afinal': afinal} + + def no_clip(_): + return pWF + + pWF = jax.lax.cond( + jnp.abs(pWF['afinal_prec']) > 1.0, + clip_afinal_prec, + no_clip, + operand=None, + ) + + # Update ringdown and damping frequency: no precession; to be used for PNR tuned deviations + pWF['fRING'] = evaluate_QNMfit_fring22(pWF['afinal']) / pWF['Mfinal'] ## (TODO) + pWF['fDAMP'] = evaluate_QNMfit_fdamp22(pWF['afinal']) / pWF['Mfinal'] ## (TODO) + + return pWF + def IMRPhenomX_Get_PN_beta(a, b, pPrec): return (pPrec['dotS1L'] * (a + b * pPrec['qq']) + From 5222eb9a86aaa2b35a13e76a8fdbf018df9dee6f Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Tue, 7 Oct 2025 18:02:31 +0200 Subject: [PATCH 09/20] Remove functions in XP_utils that were copied from Pv2_utils --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 391 +------------------- 1 file changed, 5 insertions(+), 386 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index bb0c441a..3dbdeb33 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -29,387 +29,7 @@ def ROTATEY(angle, x, y, z): tmp_z = -x * jnp.sin(angle) + z * jnp.cos(angle) return tmp_x, y, tmp_z - -def FinalSpin0815(eta, chi1, chi2): - Seta = jnp.sqrt(1.0 - 4.0 * eta) - m1 = 0.5 * (1.0 + Seta) - m2 = 0.5 * (1.0 - Seta) - m1s = m1 * m1 - m2s = m2 * m2 - s = m1s * chi1 + m2s * chi2 - return FinalSpin0815_s(eta, s) - - -def convert_spins( - m1: float, - m2: float, - f_ref: float, - phiRef: float, - incl: float, - s1x: float, - s1y: float, - s1z: float, - s2x: float, - s2y: float, - s2z: float, -) -> Tuple[float, float, float, float, float, float, float]: - # m1 = m1_SI / MSUN # Masses in solar masses - # m2 = m2_SI / MSUN - M = m1 + m2 - m1_2 = m1 * m1 - m2_2 = m2 * m2 - eta = m1 * m2 / (M * M) # Symmetric mass-ratio - - # From the components in the source frame, we can easily determine - # chi1_l, chi2_l, chip and phi_aligned, which we need to return. - # We also compute the spherical angles of J, - # which we need to transform to the J frame - - # Aligned spins - chi1_l = s1z # Dimensionless aligned spin on BH 1 - chi2_l = s2z # Dimensionless aligned spin on BH 2 - - # Magnitude of the spin projections in the orbital plane - S1_perp = m1_2 * jnp.sqrt(s1x**2 + s1y**2) - S2_perp = m2_2 * jnp.sqrt(s2x**2 + s2y**2) - - A1 = 2 + (3 * m2) / (2 * m1) - A2 = 2 + (3 * m1) / (2 * m2) - ASp1 = A1 * S1_perp - ASp2 = A2 * S2_perp - num = jnp.maximum(ASp1, ASp2) - den = A2 * m2_2 # warning: this assumes m2 > m1 - chip = num / den - - m_sec = M * gt - piM = jnp.pi * m_sec - v_ref = (piM * f_ref) ** (1 / 3) - L0 = M * M * L2PNR(v_ref, eta) - J0x_sf = m1_2 * s1x + m2_2 * s2x - J0y_sf = m1_2 * s1y + m2_2 * s2y - J0z_sf = L0 + m1_2 * s1z + m2_2 * s2z - J0 = jnp.sqrt(J0x_sf * J0x_sf + J0y_sf * J0y_sf + J0z_sf * J0z_sf) - - thetaJ_sf = jnp.arccos(J0z_sf / J0) - - phiJ_sf = jnp.arctan2(J0y_sf, J0x_sf) - - phi_aligned = -phiJ_sf - - # First we determine kappa - # in the source frame, the components of N are given in Eq (35c) of T1500606-v6 - Nx_sf = jnp.sin(incl) * jnp.cos(jnp.pi / 2.0 - phiRef) - Ny_sf = jnp.sin(incl) * jnp.sin(jnp.pi / 2.0 - phiRef) - Nz_sf = jnp.cos(incl) - - tmp_x = Nx_sf - tmp_y = Ny_sf - tmp_z = Nz_sf - - tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - - kappa = -jnp.arctan2(tmp_y, tmp_x) - - # Then we determine alpha0, by rotating LN - tmp_x, tmp_y, tmp_z = 0, 0, 1 - tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) - - alpha0 = jnp.arctan2(tmp_y, tmp_x) - - # Finally we determine thetaJ, by rotating N - tmp_x, tmp_y, tmp_z = Nx_sf, Ny_sf, Nz_sf - tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) - Nx_Jf, Nz_Jf = tmp_x, tmp_z - thetaJN = jnp.arccos(Nz_Jf) - - # Finally, we need to redefine the polarizations: - # PhenomP's polarizations are defined following Arun et al (arXiv:0810.5336) - # i.e. projecting the metric onto the P,Q,N triad defined with P=NxJ/|NxJ| (see (2.6) in there). - # By contrast, the triad X,Y,N used in LAL - # ("waveframe" in the nomenclature of T1500606-v6) - # is defined in e.g. eq (35) of this document - # (via its components in the source frame; note we use the defautl Omega=Pi/2). - # Both triads differ from each other by a rotation around N by an angle \zeta - # and we need to rotate the polarizations accordingly by 2\zeta - - Xx_sf = -jnp.cos(incl) * jnp.sin(phiRef) - Xy_sf = -jnp.cos(incl) * jnp.cos(phiRef) - Xz_sf = jnp.sin(incl) - tmp_x, tmp_y, tmp_z = Xx_sf, Xy_sf, Xz_sf - tmp_x, tmp_y, tmp_z = ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) - - # Now the tmp_a are the components of X in the J frame - # We need the polar angle of that vector in the P,Q basis of Arun et al - # P = NxJ/|NxJ| and since we put N in the (pos x)z half plane of the J frame - PArunx_Jf = 0.0 - PAruny_Jf = -1.0 - PArunz_Jf = 0.0 - - # Q = NxP - QArunx_Jf = Nz_Jf - QAruny_Jf = 0.0 - QArunz_Jf = -Nx_Jf - - # Calculate the dot products XdotPArun and XdotQArun - XdotPArun = tmp_x * PArunx_Jf + tmp_y * PAruny_Jf + tmp_z * PArunz_Jf - XdotQArun = tmp_x * QArunx_Jf + tmp_y * QAruny_Jf + tmp_z * QArunz_Jf - - zeta_polariz = jnp.arctan2(XdotQArun, XdotPArun) - return chi1_l, chi2_l, chip, thetaJN, alpha0, phi_aligned, zeta_polariz - - -def SpinWeightedY(theta, phi, s, l, m): - "copied from SphericalHarmonics.c in LAL" - if s == -2: - if l == 2: - if m == -2: - fac = ( - jnp.sqrt(5.0 / (64.0 * jnp.pi)) - * (1.0 - jnp.cos(theta)) - * (1.0 - jnp.cos(theta)) - ) - elif m == -1: - fac = ( - jnp.sqrt(5.0 / (16.0 * jnp.pi)) - * jnp.sin(theta) - * (1.0 - jnp.cos(theta)) - ) - elif m == 0: - fac = jnp.sqrt(15.0 / (32.0 * jnp.pi)) * jnp.sin(theta) * jnp.sin(theta) - elif m == 1: - fac = ( - jnp.sqrt(5.0 / (16.0 * jnp.pi)) - * jnp.sin(theta) - * (1.0 + jnp.cos(theta)) - ) - elif m == 2: - fac = ( - jnp.sqrt(5.0 / (64.0 * jnp.pi)) - * (1.0 + jnp.cos(theta)) - * (1.0 + jnp.cos(theta)) - ) - else: - raise ValueError(f"Invalid mode s={s}, l={l}, m={m} - require |m| <= l") - return fac * np.exp(1j * m * phi) - - -def L2PNR(v: float, eta: float) -> float: - eta2 = eta**2 - x = v**2 - x2 = x**2 - return ( - eta - * ( - 1.0 - + (1.5 + eta / 6.0) * x - + (3.375 - (19.0 * eta) / 8.0 - eta2 / 24.0) * x2 - ) - ) / x**0.5 - - -def WignerdCoefficients(v: float, SL: float, eta: float, Sp: float): - # We define the shorthand s := Sp / (L + SL) - #Sp: perpendicular, SL: parallel - L = L2PNR(v, eta) #### what is this function? - s = Sp / (L + SL) - s2 = s**2 - cos_beta = 1.0 / (1.0 + s2) ** 0.5 - cos_beta_half = ((1.0 + cos_beta) / 2.0) ** 0.5 # cos(beta/2) - sin_beta_half = ((1.0 - cos_beta) / 2.0) ** 0.5 # sin(beta/2) - - return cos_beta_half, sin_beta_half - - -def ComputeNNLOanglecoeffs(q, chil, chip): - ##### -> IMRPhenomXP.pdf for coefficients, same as IMRPhenomPV - m2 = q / (1.0 + q) - m1 = 1.0 / (1.0 + q) - dm = m1 - m2 - mtot = 1.0 - eta = m1 * m2 # mtot = 1 - eta2 = eta * eta - eta3 = eta2 * eta - eta4 = eta3 * eta - mtot2 = mtot * mtot - mtot4 = mtot2 * mtot2 - mtot6 = mtot4 * mtot2 - mtot8 = mtot6 * mtot2 - chil2 = chil * chil - chip2 = chip * chip - chip4 = chip2 * chip2 - dm2 = dm * dm - dm3 = dm2 * dm - m2_2 = m2 * m2 - m2_3 = m2_2 * m2 - m2_4 = m2_3 * m2 - m2_5 = m2_4 * m2 - m2_6 = m2_5 * m2 - m2_7 = m2_6 * m2 - m2_8 = m2_7 * m2 - - angcoeffs = {} - angcoeffs["alphacoeff1"] = -0.18229166666666666 - (5 * dm) / (64.0 * m2) - - angcoeffs["alphacoeff2"] = (-15 * dm * m2 * chil) / (128.0 * mtot2 * eta) - ( - 35 * m2_2 * chil - ) / (128.0 * mtot2 * eta) - - angcoeffs["alphacoeff3"] = ( - -1.7952473958333333 - - (4555 * dm) / (7168.0 * m2) - - (15 * chip2 * dm * m2_3) / (128.0 * mtot4 * eta2) - - (35 * chip2 * m2_4) / (128.0 * mtot4 * eta2) - - (515 * eta) / 384.0 - - (15 * dm2 * eta) / (256.0 * m2_2) - - (175 * dm * eta) / (256.0 * m2) - ) - - angcoeffs["alphacoeff4"] = ( - -(35 * jnp.pi) / 48.0 - - (5 * dm * jnp.pi) / (16.0 * m2) - + (5 * dm2 * chil) / (16.0 * mtot2) - + (5 * dm * m2 * chil) / (3.0 * mtot2) - + (2545 * m2_2 * chil) / (1152.0 * mtot2) - - (5 * chip2 * dm * m2_5 * chil) / (128.0 * mtot6 * eta3) - - (35 * chip2 * m2_6 * chil) / (384.0 * mtot6 * eta3) - + (2035 * dm * m2 * chil) / (21504.0 * mtot2 * eta) - + (2995 * m2_2 * chil) / (9216.0 * mtot2 * eta) - ) - - angcoeffs["alphacoeff5"] = ( - 4.318908476114694 - + (27895885 * dm) / (2.1676032e7 * m2) - - (15 * chip4 * dm * m2_7) / (512.0 * mtot8 * eta4) - - (35 * chip4 * m2_8) / (512.0 * mtot8 * eta4) - - (485 * chip2 * dm * m2_3) / (14336.0 * mtot4 * eta2) - + (475 * chip2 * m2_4) / (6144.0 * mtot4 * eta2) - + (15 * chip2 * dm2 * m2_2) / (256.0 * mtot4 * eta) - + (145 * chip2 * dm * m2_3) / (512.0 * mtot4 * eta) - + (575 * chip2 * m2_4) / (1536.0 * mtot4 * eta) - + (39695 * eta) / 86016.0 - + (1615 * dm2 * eta) / (28672.0 * m2_2) - - (265 * dm * eta) / (14336.0 * m2) - + (955 * eta2) / 576.0 - + (15 * dm3 * eta2) / (1024.0 * m2_3) - + (35 * dm2 * eta2) / (256.0 * m2_2) - + (2725 * dm * eta2) / (3072.0 * m2) - - (15 * dm * m2 * jnp.pi * chil) / (16.0 * mtot2 * eta) - - (35 * m2_2 * jnp.pi * chil) / (16.0 * mtot2 * eta) - + (15 * chip2 * dm * m2_7 * chil2) / (128.0 * mtot8 * eta4) - + (35 * chip2 * m2_8 * chil2) / (128.0 * mtot8 * eta4) - + (375 * dm2 * m2_2 * chil2) / (256.0 * mtot4 * eta) - + (1815 * dm * m2_3 * chil2) / (256.0 * mtot4 * eta) - + (1645 * m2_4 * chil2) / (192.0 * mtot4 * eta) - ) - - angcoeffs["epsiloncoeff1"] = -0.18229166666666666 - (5 * dm) / (64.0 * m2) - angcoeffs["epsiloncoeff2"] = (-15 * dm * m2 * chil) / (128.0 * mtot2 * eta) - ( - 35 * m2_2 * chil - ) / (128.0 * mtot2 * eta) - angcoeffs["epsiloncoeff3"] = ( - -1.7952473958333333 - - (4555 * dm) / (7168.0 * m2) - - (515 * eta) / 384.0 - - (15 * dm2 * eta) / (256.0 * m2_2) - - (175 * dm * eta) / (256.0 * m2) - ) - angcoeffs["epsiloncoeff4"] = ( - -(35 * jnp.pi) / 48.0 - - (5 * dm * jnp.pi) / (16.0 * m2) - + (5 * dm2 * chil) / (16.0 * mtot2) - + (5 * dm * m2 * chil) / (3.0 * mtot2) - + (2545 * m2_2 * chil) / (1152.0 * mtot2) - + (2035 * dm * m2 * chil) / (21504.0 * mtot2 * eta) - + (2995 * m2_2 * chil) / (9216.0 * mtot2 * eta) - ) - angcoeffs["epsiloncoeff5"] = ( - 4.318908476114694 - + (27895885 * dm) / (2.1676032e7 * m2) - + (39695 * eta) / 86016.0 - + (1615 * dm2 * eta) / (28672.0 * m2_2) - - (265 * dm * eta) / (14336.0 * m2) - + (955 * eta2) / 576.0 - + (15 * dm3 * eta2) / (1024.0 * m2_3) - + (35 * dm2 * eta2) / (256.0 * m2_2) - + (2725 * dm * eta2) / (3072.0 * m2) - - (15 * dm * m2 * jnp.pi * chil) / (16.0 * mtot2 * eta) - - (35 * m2_2 * jnp.pi * chil) / (16.0 * mtot2 * eta) - + (375 * dm2 * m2_2 * chil2) / (256.0 * mtot4 * eta) - + (1815 * dm * m2_3 * chil2) / (256.0 * mtot4 * eta) - + (1645 * m2_4 * chil2) / (192.0 * mtot4 * eta) - ) - return angcoeffs - - -def FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip): - M = m1 + m2 - eta = m1 * m2 / (M * M) - # Here I assume m1 > m2, the convention used in phenomD - # (not the convention of internal phenomP) - q_factor = m1 / M - af_parallel = FinalSpin0815(eta, chi1_l, chi2_l) - Sperp = chip * q_factor * q_factor - af = jnp.copysign(1.0, af_parallel) * jnp.sqrt( - Sperp * Sperp + af_parallel * af_parallel - ) - return af - - -def phP_get_fRD_fdamp(m1, m2, chi1_l, chi2_l, chip): - # m1 > m2 should hold here - finspin = FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip) - m1_s = m1 * gt - m2_s = m2 * gt - M_s = m1_s + m2_s - eta_s = m1_s * m2_s / (M_s**2.0) - Erad = EradRational0815(eta_s, chi1_l, chi2_l) - fRD = jnp.interp(finspin, QNMData_a, QNMData_fRD) / (1.0 - Erad) - fdamp = jnp.interp(finspin, QNMData_a, QNMData_fdamp) / (1.0 - Erad) - - return fRD / M_s, fdamp / M_s - - -def phP_get_transition_frequencies( - theta: jnp.array, - gamma2: float, - gamma3: float, - chip: float, -) -> Tuple[float, float, float, float, float, float]: - # m1 > m2 should hold here - - m1, m2, chi1, chi2 = theta - M = m1 + m2 - f_RD, f_damp = phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip) - - # Phase transition frequencies - f1 = 0.018 / (M * gt) - f2 = 0.5 * f_RD - - # Amplitude transition frequencies - f3 = 0.014 / (M * gt) - f4_gammaneg_gtr_1 = lambda f_RD_, f_damp_, gamma3_, gamma2_: jnp.abs( - f_RD_ + (-f_damp_ * gamma3_) / gamma2_ - ) - f4_gammaneg_less_1 = lambda f_RD_, f_damp_, gamma3_, gamma2_: jnp.abs( - f_RD_ + (f_damp_ * (-1 + jnp.sqrt(1 - (gamma2_) ** 2.0)) * gamma3_) / gamma2_ - ) - f4 = jax.lax.cond( - gamma2 >= 1, - f4_gammaneg_gtr_1, - f4_gammaneg_less_1, - f_RD, - f_damp, - gamma3, - gamma2, - ) - return f1, f2, f3, f4, f_RD, f_damp +### def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( v, ## velocity @@ -839,11 +459,11 @@ def IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec): def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, chi1x, chi1y, chi1z, chi2x, chi2y, chi2z, - lalParams): + params): pPrec = {} - pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(lalParams) ## (TODO) + pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(params) ## (TODO) Mtot_SI = m1_SI + m2_SI # Normalize masses @@ -943,7 +563,6 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec['PNR_chi_window_upper'] = 0.0 # pPrec['PNRInspiralScaling'] = 0 # if needed - # Call external routine status = IMRPhenomX_PNR_GetAndSetPNRVariables(pWF, pPrec) ## (TODO) # XLAL_CHECK equivalent (trigger MSA_ERROR if failed) @@ -955,13 +574,13 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec['gammaPNR'] = 0.0 # Get and/or store CoPrec params - status = IMRPhenomX_PNR_GetAndSetCoPrecParams(pWF, pPrec, lalParams) ## (TODO) + status = IMRPhenomX_PNR_GetAndSetCoPrecParams(pWF, pPrec, params) ## (TODO) pPrec['MSA_ERROR'] = status pPrec = IMRPhenomX_Initialize_MSA_System(pWF,pPrec,pPrec['ExpansionOrder']) ## is it okay to update - IMRPhenomX_SetPrecessingRemnantParams(pWF,pPrec,lalParams) ## (TODO) + IMRPhenomX_SetPrecessingRemnantParams(pWF,pPrec, params) ## (TODO) chip2 = chip * chip From 1034c13aed16872c4de540ad53166c7d64547da8 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Thu, 9 Oct 2025 18:25:40 +0200 Subject: [PATCH 10/20] Remove unnecessary notes from comments --- src/ripplegw/waveforms/IMRPhenomXP.py | 30 +++------ src/ripplegw/waveforms/IMRPhenomXP_utils.py | 69 +++++++++------------ 2 files changed, 39 insertions(+), 60 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP.py b/src/ripplegw/waveforms/IMRPhenomXP.py index 2d4abdf5..7edaa34d 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP.py +++ b/src/ripplegw/waveforms/IMRPhenomXP.py @@ -18,23 +18,9 @@ def PhenomXPCoreTwistUp22( Mf, ## Frequency in geometric units (on LAL says Hz?) hAS, ## Underlying aligned-spin IMRPhenomXAS strain pWF, ## IMRPhenomX Waveform Struct (TODO) - pPrec ## IMRPhenomXP Precession Struct (TODO) ## in py is a dict + pPrec ## IMRPhenomXP Precession Struct (TODO) ): -# check if hp,hc not None not present in PhenomP - - # here it is used to be LAL_MTSUN_SI - # f = fHz * gt * M # Frequency in geometric units - # q = (1.0 + jnp.sqrt(1.0 - 4.0 * eta) - 2.0 * eta) / (2.0 * eta) ## not needed - # m1 = 1.0 / (1.0 + q) # Mass of the smaller BH for unit total mass M=1. - # m2 = q / (1.0 + q) # Mass of the larger BH for unit total mass M=1. - #Sperp = chip * ( - # m2 * m2 - #) # Dimensionfull spin component in the orbital plane. S_perp = S_2_perp ## already in pPrec - # chi_eff = m1 * chi1_l + m2 * chi2_l # effective spin for M=1 - - # SL = chi1_l * m1 * m1 + chi2_l * m2 * m2 # Dimensionfull aligned spin. ## already in pPrec - omega = jnp.pi * Mf logomega = jnp.log(omega) omega_cbrt = (omega) ** (1 / 3) @@ -42,17 +28,17 @@ def PhenomXPCoreTwistUp22( v = omega_cbrt - vangles = jnp.array([0,0,0]) ## is it okay to use jnp array?? (instead of a struct) it is also defined inside the function... + vangles = jnp.array([0,0,0]) ## Euler Angles from Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 - vangles = IMRPhenomX_Return_phi_zeta_costhetaL_MSA(v,pWF,pPrec) ## DONE ## has to output a jnp array + vangles = IMRPhenomX_Return_phi_zeta_costhetaL_MSA(v,pWF,pPrec) alpha = vangles[0] - pPrec["alpha_offset"] epsilon = vangles[1] - pPrec["epsilon_offset"] cos_beta = vangles[2] # print("alpha, epsilon: ", alpha, epsilon) - cBetah, sBetah = WignerdCoefficients_cosbeta(cos_beta) ## DONE + cBetah, sBetah = WignerdCoefficients_cosbeta(cos_beta) cBetah2 = cBetah * cBetah cBetah3 = cBetah2 * cBetah @@ -106,9 +92,9 @@ def PhenomXPCoreTwistUp22( def gen_IMRPhenomXP_hphc(f: Array, - params: Array, ## equivalent to pWF waveform struct ?? - prec_params, ## equivalent to pPrec precession struct - f_ref: float): ## why needs to be input separetely ?? + params: Array, + prec_params, + f_ref: float): """ Returns: -------- @@ -120,7 +106,7 @@ def gen_IMRPhenomXP_hphc(f: Array, hp, hc = PhenomXPCoreTwistUp22(f, h0, params, prec_params) - hp = h0 * (1 / 2 * (1 + jnp.cos(iota) ** 2)) ## to keep or not to keep ?? + hp = h0 * (1 / 2 * (1 + jnp.cos(iota) ** 2)) hc = -1j * h0 * jnp.cos(iota) return hp, hc \ No newline at end of file diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 3dbdeb33..230dcd61 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -18,23 +18,30 @@ # helper functions for LALtoPhenomP: -def ROTATEZ(angle, x, y, z): - tmp_x = x * jnp.cos(angle) - y * jnp.sin(angle) - tmp_y = x * jnp.sin(angle) + y * jnp.cos(angle) - return tmp_x, tmp_y, z - +def IMRPhenomX_rotate_y(v, theta): + """Rotate vector(s) v about the y-axis by angle theta (radians).""" + R_y = jnp.array([ + [ jnp.cos(theta), 0.0, jnp.sin(theta)], + [ 0.0, 1.0, 0.0 ], + [-jnp.sin(theta), 0.0, jnp.cos(theta)] + ]) + return R_y @ v -def ROTATEY(angle, x, y, z): - tmp_x = x * jnp.cos(angle) + z * jnp.sin(angle) - tmp_z = -x * jnp.sin(angle) + z * jnp.cos(angle) - return tmp_x, y, tmp_z +def IMRPhenomX_rotate_z(v, theta): + """Rotate vector(s) v about the z-axis by angle theta (radians).""" + R_z = jnp.array([ + [ jnp.cos(theta), -jnp.sin(theta), 0.0], + [ jnp.sin(theta), jnp.cos(theta), 0.0], + [ 0.0, 0.0, 1.0] + ]) + return R_z @ v ### def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( v, ## velocity pWF, ## IMRPhenomX waveform struct - pPrec ## IMRPhenomX precession struct ### LR: casting?? + pPrec ## IMRPhenomX precession struct ): ## has to output a jnp array vout = jnp.array([0,0,0]) @@ -66,15 +73,15 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( ''' Get phiz_0_MSA and zeta_0_MSA ''' vMSA = jax.lax.cond((jnp.fabs(pPrec["Smi2"] - pPrec["Spl2"]) > 1.e-5), IMRPhenomX_Return_MSA_Corrections_MSA, ## return 3D jnp.array - lambda v, L_norm, J_norm, pPrec: jnp.array([0,0,0]), ## ugly but okay?? + lambda v, L_norm, J_norm, pPrec: jnp.array([0,0,0]), v, L_norm, J_norm, pPrec) phiz_MSA = vMSA[0] zeta_MSA = vMSA[1] - phiz = IMRPhenomX_Return_phiz_MSA(v,J_norm,pPrec) ## (DONE) - zeta = IMRPhenomX_Return_zeta_MSA(v,pPrec) ## (DONE) - cos_theta_L = IMRPhenomX_costhetaLJ(L_norm3PN,J_norm3PN,SNorm) ## (DONE) + phiz = IMRPhenomX_Return_phiz_MSA(v,J_norm,pPrec) + zeta = IMRPhenomX_Return_zeta_MSA(v,pPrec) + cos_theta_L = IMRPhenomX_costhetaLJ(L_norm3PN,J_norm3PN,SNorm) vout[0] = phiz + phiz_MSA vout[1] = zeta + zeta_MSA @@ -470,7 +477,6 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, m1 = m1_SI / Mtot_SI m2 = m2_SI / Mtot_SI M = m1 + m2 - #pWF['M'] = m1 + m2 ### pWF needs to be a dict?? # Mass ratio and symmetric mass ratio q = m1 / m2 @@ -479,6 +485,7 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, m1_2 = m1**2 m2_2 = m2**2 + ### pWF needs to be a dict?? ## TODO: check how is delta stored in pWF delta = pWF['delta'] ## TODO: compute chieff? @@ -511,7 +518,7 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, STot_y = pPrec['S1y'] + pPrec['S2y'] pPrec['STot_perp'] = jnp.sqrt(STot_x**2 + STot_y**2) pPrec['chiTot_perp'] = pPrec['STot_perp'] * (M**2) / m1_2 - # pWF['chiTot_perp'] = pPrec['chiTot_perp'] ### pWF needs to be a dict?? + # pWF['chiTot_perp'] = pPrec['chiTot_perp'] ## disable tuned PNR angles, tuned coprec and mode asymmetries in low in-plane spin limit (TODO) @@ -531,7 +538,7 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, chi2L = chi2z pPrec['chi_p'] = chip - pWF['chi_p'] = pPrec['chi_p'] # propagate to waveform struct ### pWF needs to be a dict?? + pWF['chi_p'] = pPrec['chi_p'] # propagate to waveform struct pPrec['phi0_aligned'] = pWF['phi0'] # Effective (dimensionful) aligned spin @@ -548,6 +555,10 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec['precessing_tag'] = 2 + ''''' + The following code initializes variables needed for PNR. Since I am implementing only XP->223 for the moment + I assume this is not needed (to be checked further) + # Initialize PNR variables pPrec['chi_singleSpin'] = 0.0 pPrec['costheta_singleSpin'] = 0.0 @@ -577,8 +588,9 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, status = IMRPhenomX_PNR_GetAndSetCoPrecParams(pWF, pPrec, params) ## (TODO) pPrec['MSA_ERROR'] = status + ''''' - pPrec = IMRPhenomX_Initialize_MSA_System(pWF,pPrec,pPrec['ExpansionOrder']) ## is it okay to update + pPrec = IMRPhenomX_Initialize_MSA_System(pWF,pPrec,pPrec['ExpansionOrder']) IMRPhenomX_SetPrecessingRemnantParams(pWF,pPrec, params) ## (TODO) @@ -1282,7 +1294,6 @@ def IMRPhenomX_SetPrecessingRemnantParams(pWF, pPrec, params): # see arXiv:2004.06503v2 # Just implementing branch for MSA_ERROR == 0. No fallback to NNLO for the moment - # what is this flag? method_flag = XLALSimInspiralWaveformParamsLookupPhenomXPTransPrecessionMethod( ## (TODO) params ) @@ -1335,22 +1346,4 @@ def IMRPhenomX_Get_PN_tau(a, b, pPrec): def IMRPhenomX_psiofv(v, v2, psi0, psi1, psi2, pPrec): # Equation 51 in arXiv:1703.03967 - return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) - -def IMRPhenomX_rotate_y(v, theta): - """Rotate vector(s) v about the y-axis by angle theta (radians).""" - R_y = jnp.array([ - [ jnp.cos(theta), 0.0, jnp.sin(theta)], - [ 0.0, 1.0, 0.0 ], - [-jnp.sin(theta), 0.0, jnp.cos(theta)] - ]) - return R_y @ v - -def IMRPhenomX_rotate_z(v, theta): - """Rotate vector(s) v about the z-axis by angle theta (radians).""" - R_z = jnp.array([ - [ jnp.cos(theta), -jnp.sin(theta), 0.0], - [ jnp.sin(theta), jnp.cos(theta), 0.0], - [ 0.0, 0.0, 1.0] - ]) - return R_z @ v \ No newline at end of file + return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) \ No newline at end of file From 6297db722658e204372f183872f9be9f23fc54da Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Tue, 4 Nov 2025 16:29:18 +0100 Subject: [PATCH 11/20] Fix bugs in IMRPhenomXP_utils while testing IMRPhenomXGetAndSetPrecessionVariables --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 81 ++++++++++++--------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 230dcd61..ac52f818 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -1,9 +1,9 @@ import jax import jax.numpy as jnp -from ripple import Mc_eta_to_ms +from ripplegw import Mc_eta_to_ms from typing import Tuple -from ..constants import gt, MSUN +from ..constants import gt, MSUN, G, C import numpy as np from .IMRPhenomD import Phase as PhDPhase from .IMRPhenomD import Amp as PhDAmp @@ -45,7 +45,7 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( ): ## has to output a jnp array vout = jnp.array([0,0,0]) - L_norm = pWF["eta"] / v + L_norm = pWF['eta'] / v J_norm = IMRPhenomX_JNorm_MSA(L_norm,pPrec) @@ -159,15 +159,15 @@ def IMRPhenomX_Return_phiz_MSA( log2 = jnp.log(jnp.abs(c1 + JNorm * SAv * v + SAv2 * v)) # Eq. D22-D27 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 - phiz_0_coeff = (JNorm * pPrec["inveta4"]) * ( + phiz_0_coeff = (JNorm * pPrec["inveta"]**4) * ( 0.5 * c12 - (c1 * pPrec["eta2"] * invv) / 6.0 - (SAv2 * pPrec["eta2"]) / 3.0 - (pPrec["eta4"] * invv2) / 3.0 ) - (0.5 * c1 * pPrec["inveta"]) * ( - c12 * pPrec["inveta4"] - SAv2 * pPrec["inveta2"] + c12 * pPrec["inveta"]**4 - SAv2 * pPrec["inveta"]**2 ) * log1 phiz_1_coeff = ( - -0.5 * JNorm * pPrec["inveta2"] * (c1 + pPrec["eta"] * LNewt) - + 0.5 * pPrec["inveta3"] * (c12 - pPrec["eta2"] * SAv2) * log1 + -0.5 * JNorm * pPrec["inveta"]**2 * (c1 + pPrec["eta"] * LNewt) + + 0.5 * pPrec["inveta"]**3 * (c12 - pPrec["eta2"] * SAv2) * log1 ) phiz_2_coeff = -JNorm + SAv * log2 - c1 * log1 * pPrec["inveta"] @@ -328,12 +328,12 @@ def IMRPhenomX_Return_Roots_MSA(LNorm, JNorm, pPrec): invalid_case = ( jnp.isnan(theta) | jnp.isnan(sqrtarg) | - (pPrec.dotS1Ln == 1.0) | - (pPrec.dotS2Ln == 1.0) | - (pPrec.dotS1Ln == -1.0) | - (pPrec.dotS2Ln == -1.0) | - (pPrec.S1_norm_2 == 0.0) | - (pPrec.S2_norm_2 == 0.0) + (pPrec['dotS1Ln'] == 1.0) | + (pPrec['dotS2Ln'] == 1.0) | + (pPrec['dotS1Ln'] == -1.0) | + (pPrec['dotS2Ln'] == -1.0) | + (pPrec['S1_norm_2'] == 0.0) | + (pPrec['S2_norm_2'] == 0.0) ) def roots_when_valid(): @@ -356,7 +356,7 @@ def roots_when_valid(): return S32, Smi2, Spl2 def roots_when_invalid(): - Smi2 = pPrec['S_0_norm_2'] + Smi2 = pPrec['S_0_norm']**2 Spl2 = Smi2 + 1e-9 S32 = 0.0 return S32, Smi2, Spl2 @@ -392,7 +392,7 @@ def IMRPhenomX_Return_Constants_c_MSA(v, JNorm, pPrec): y = JNorm * ( -1.5 * pPrec['eta'] * (pPrec['Spl2'] - pPrec['Smi2']) - * (1.0 + 2.0 * Seff * v - (JNorm2 - pPrec['Spl2']) * v2 * pPrec['inveta2']) + * (1.0 + 2.0 * Seff * v - (JNorm2 - pPrec['Spl2']) * v2 * pPrec['inveta']**2) * (1.0 - Seff * v) * v4 ) @@ -421,7 +421,7 @@ def IMRPhenomX_Return_Psi_MSA(v, v2, pPrec): def IMRPhenomX_Return_Psi_dot_MSA(v, pPrec): v2 = v * v - A_coeff = -1.5 * v2 * v2 * v2 * (1.0 - v * pPrec['Seff']) * pPrec['sqrt_inveta'] + A_coeff = -1.5 * v2 * v2 * v2 * (1.0 - v * pPrec['Seff']) * jnp.sqrt(pPrec['inveta']) psi_dot = 0.5 * A_coeff * jnp.sqrt(pPrec['Spl2'] - pPrec['S32']) return psi_dot @@ -470,7 +470,9 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec = {} - pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(params) ## (TODO) + ## Here we're only considering the default setting, where the expansion order for MSA correction is 5 + #pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(params) ## (TODO) + pPrec['ExpansionOrder'] = 5 Mtot_SI = m1_SI + m2_SI # Normalize masses @@ -480,7 +482,12 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, # Mass ratio and symmetric mass ratio q = m1 / m2 - eta = pWF[1] + eta = pWF['eta'] + pPrec['eta'] = eta + pPrec['eta2'] = eta**2 + pPrec['eta3'] = eta**3 + pPrec['eta4'] = eta**4 + pPrec['inveta'] = 1 / eta m1_2 = m1**2 m2_2 = m2**2 @@ -489,7 +496,8 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, ## TODO: check how is delta stored in pWF delta = pWF['delta'] ## TODO: compute chieff? - ## TODO: compute twopiGM, piGM? + + pPrec['piGM'] = jnp.pi * (m1_SI+m2_SI) * G / C / C / C # Spin inputs for i, (x, y, z) in enumerate([(chi1x, chi1y, chi1z), (chi2x, chi2y, chi2z)], start=1): @@ -508,6 +516,9 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec['S2y'] = chi2y * m2_2 pPrec['S2z'] = chi2z * m2_2 pPrec['S2_norm'] = jnp.abs(pPrec['chi2_norm']) * m2_2 + + pPrec['S1_norm_2'] = pPrec['S1_norm']**2 + pPrec['S2_norm_2'] = pPrec['S2_norm']**2 # In-plane magnitudes pPrec['chi1_perp'] = jnp.sqrt(chi1x*chi1x + chi1y*chi1y) @@ -710,8 +721,8 @@ def general_case(pPrec): However, the triad X,Y,N used in LAL (the "waveframe") follows the definition in the NR Injection Infrastructure (Schmidt et al, arXiv:1703.01076). - The triads differ from each other by a rotation around N by an angle \zeta. We therefore need to rotate - the polarizations by an angle 2 \zeta. + The triads differ from each other by a rotation around N by an angle zeta. We therefore need to rotate + the polarizations by an angle 2 zeta. """ pPrec['Xx_Sf'] = -jnp.cos(pWF['inclination']) * jnp.sin(phiRef) @@ -761,7 +772,7 @@ def general_case(pPrec): pPrec['cexp_i_betah'] = 0.0 """ - Check whether maximum opening angle becomes larger than \pi/2 or \pi/4. + Check whether maximum opening angle becomes larger than pi/2 or pi/4. If (L + S_L) < 0, then Wigner-d Coefficients will not track the angle between J and L, meaning that the model may become pathological as one moves away from the aligned-spin limit. @@ -805,10 +816,10 @@ def general_case(pPrec): return pPrec def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): - eta = pPrec['eta'] - eta2 = pPrec['eta2'] - eta3 = pPrec['eta3'] - eta4 = pPrec['eta4'] + eta = pWF['eta'] + eta2 = eta*eta + eta3 = eta2*eta + eta4 = eta3*eta m1 = pWF['m1'] m2 = pWF['m2'] @@ -869,13 +880,13 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): pPrec['Lhat_theta'] = 0.0 # Dimensionful spin vectors (eta = m1 * m2, q = m2 / m1) - S1v[0] = pPrec['chi1x'] * eta / q - S1v[1] = pPrec['chi1y'] * eta / q - S1v[2] = pPrec['chi1z'] * eta / q + S1v.at[0].set(pPrec['chi1x'] * eta / q) + S1v.at[1].set(pPrec['chi1y'] * eta / q) + S1v.at[2].set(pPrec['chi1z'] * eta / q) - S2v[0] = pPrec['chi2x'] * eta * q - S2v[1] = pPrec['chi2y'] * eta * q - S2v[2] = pPrec['chi2z'] * eta * q + S2v.at[0].set(pPrec['chi2x'] * eta * q) + S2v.at[1].set(pPrec['chi2y'] * eta * q) + S2v.at[2].set(pPrec['chi2z'] * eta * q) # Norms of spin vectors S1_0_norm = jnp.linalg.norm(S1v) @@ -1079,6 +1090,7 @@ def IMRPhenomX_Initialize_MSA_System(pWF, pPrec, ExpansionOrder): pPrec['Delta'] = jnp.sqrt(jnp.abs((Del1 - Del2 - Del3) * (Del4 - Del5 - Del6))) u1 = 3.0 * pPrec['g2'] / pPrec['g0'] + ### TODO: when m1=m2, one_p_q_sq = 0, division by zero occurs. Check what to do about it u2 = 0.75 * one_p_q_sq / one_m_q_4 u3 = -20.0 * c_1_over_nu_2 * q_2 * one_p_q_sq u4 = 2.0 * one_m_q2_2 * (q * (2.0 + q) * pPrec['S1_norm_2'] @@ -1176,8 +1188,7 @@ def branch_equal(pPrec): return 0.0 def branch_not_equal(pPrec): - mm_val = jnp.sqrt((pPrec['Smi2'] - pPrec['Spl2']) - (pPrec['S32'] - pPrec['Spl2'])) + mm_val = jnp.sqrt((pPrec['Smi2'] - pPrec['Spl2']) / (pPrec['S32'] - pPrec['Spl2'])) tmpB_val = ((pPrec['S_0_norm'] * pPrec['S_0_norm']) - pPrec['Spl2']) / (pPrec['Smi2'] - pPrec['Spl2']) vol_elem = jnp.dot(jnp.cross(L_0, S1v),S2v) @@ -1225,8 +1236,10 @@ def case3(): vMSA = jnp.where(condition_equal, jnp.array([0.,0.,0.]), IMRPhenomX_Return_MSA_Corrections_MSA(pPrec['v_0'],pPrec['L_0_norm'],pPrec['J_0_norm'],pPrec)) + pPrec['phiz_0'] = 0. phiz_0 = IMRPhenomX_Return_phiz_MSA(pPrec['v_0'],pPrec['J_0_norm'],pPrec) + pPrec['zeta_0'] = 0. zeta_0 = IMRPhenomX_Return_zeta_MSA(pPrec['v_0'],pPrec) pPrec['phiz_0'] = - phiz_0 - vMSA[0] From f827d72e21f46c10a6e7cb75c78a5943b10015ec Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Tue, 4 Nov 2025 17:11:54 +0100 Subject: [PATCH 12/20] Implement XFinalSpin2017, XSTotR, evaluate_QNMfit_fring22, evaluate_QNMfit_fdamp22 functions --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 145 +++++++++++++++++++- 1 file changed, 142 insertions(+), 3 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index ac52f818..07f99887 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -1337,8 +1337,8 @@ def no_clip(_): ) # Update ringdown and damping frequency: no precession; to be used for PNR tuned deviations - pWF['fRING'] = evaluate_QNMfit_fring22(pWF['afinal']) / pWF['Mfinal'] ## (TODO) - pWF['fDAMP'] = evaluate_QNMfit_fdamp22(pWF['afinal']) / pWF['Mfinal'] ## (TODO) + pWF['fRING'] = evaluate_QNMfit_fring22(pWF['afinal']) / pWF['Mfinal'] + pWF['fDAMP'] = evaluate_QNMfit_fdamp22(pWF['afinal']) / pWF['Mfinal'] return pWF @@ -1359,4 +1359,143 @@ def IMRPhenomX_Get_PN_tau(a, b, pPrec): def IMRPhenomX_psiofv(v, v2, psi0, psi1, psi2, pPrec): # Equation 51 in arXiv:1703.03967 - return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) \ No newline at end of file + return psi0 - 0.75 * pPrec['g0'] * pPrec['delta_qq'] * (1.0 + psi1 * v + psi2 * v2) / (v2 * v) + + +def XLALSimIMRPhenomXFinalSpin2017(eta, chi1L, chi2L): + + delta = jnp.sqrt(1.0 - 4.0 * eta) + m1 = 0.5 * (1.0 + delta) + m2 = 0.5 * (1.0 - delta) + m1Sq = m1 * m1 + m2Sq = m2 * m2 + + eta2 = eta * eta + eta3 = eta2 * eta + + S = XLALSimIMRPhenomXSTotR(eta, chi1L, chi2L) + S2 = S * S + S3 = S2 * S + + dchi = chi1L - chi2L + dchi2 = dchi * dchi + + noSpin = (3.4641016151377544 * eta + + 20.0830030082033 * eta2 + - 12.333573402277912 * eta3) / (1.0 + 7.2388440419467335 * eta) + + eqSpin = ((m1Sq + m2Sq) * S + + ((-0.8561951310209386 * eta + - 0.09939065676370885 * eta2 + + 1.668810429851045 * eta3) * S + + (0.5881660363307388 * eta + - 2.149269067519131 * eta2 + + 3.4768263932898678 * eta3) * S2 + + (0.142443244743048 * eta + - 0.9598353840147513 * eta2 + + 1.9595643107593743 * eta3) * S3) + / (1.0 + (-0.9142232693081653 + + 2.3191363426522633 * eta + - 9.710576749140989 * eta3) * S)) + + uneqSpin = (0.3223660562764661 * dchi * delta * (1 + 9.332575956437443 * eta) * eta2 + - 0.059808322561702126 * dchi2 * eta3 + + 2.3170397514509933 * dchi * delta * (1 - 3.2624649875884852 * eta) * eta3 * S) + + return noSpin + eqSpin + uneqSpin + +def XLALSimIMRPhenomXSTotR(eta, chi1L, chi2L): + """ + Total spin normalized to [-1, 1] + """ + delta = jnp.sqrt(1.0 - 4.0 * eta) + m1 = 0.5 * (1.0 + delta) + m2 = 0.5 * (1.0 - delta) + + m1s = m1 * m1 + m2s = m2 * m2 + + # Spin combination + return (m1s * chi1L + m2s * chi2L) / (m1s + m2s) + +def evaluate_QNMfit_fring22(a): + """ + This implementation is the same as the in from IMRPhenomX_utils -> get_cutoff_fMs + Taken from https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_t_h_m__fits_8c_source.html + Evaluate the ringdown frequency + """ + a2 = a * a + a3 = a2 * a + a4 = a3 * a + a5 = a4 * a + a6 = a5 * a + a7 = a6 * a + + return ( + ( + 0.05947169566573468 + - 0.14989771215394762 * a + + 0.09535606290986028 * a2 + + 0.02260924869042963 * a3 + - 0.02501704155363241 * a4 + - 0.005852438240997211 * a5 + + 0.0027489038393367993 * a6 + + 0.0005821983163192694 * a7 + ) + / ( + 1 + - 2.8570126619966296 * a + + 2.373335413978394 * a2 + - 0.6036964688511505 * a4 + + 0.0873798215084077 * a6 + ) + ) + +def evaluate_QNMfit_fdamp22(a): + """ + This implementation is the same as the in from IMRPhenomX_utils -> get_cutoff_fMs + Taken from https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_t_h_m__fits_8c_source.html + Evaluate the ringdown frequency + """ + a2 = a * a + a3 = a2 * a + a4 = a3 * a + a5 = a4 * a + a6 = a5 * a + a7 = a6 * a + + return ( + ( + 0.014158792290965177 + - 0.036989395871554566 * a + + 0.026822526296575368 * a2 + + 0.0008490933750566702 * a3 + - 0.004843996907020524 * a4 + - 0.00014745235759327472 * a5 + + 0.0001504546201236794 * a6 + ) + / ( + 1 + - 2.5900842798681376 * a + + 1.8952576220623967 * a2 + - 0.31416610693042507 * a4 + + 0.009002719412204133 * a6 + ) + ) + + +########################## +## NOT YET IMPLEMENTED ### +########################## + +def gsl_sf_ellint_F(x, y): + """ + TODO: Not yet implemented + """ + return x + +def XLALSimInspiralWaveformParamsLookupPhenomXPTransPrecessionMethod(dict): + """ + TODO: Not yet implemented + """ + return 1 \ No newline at end of file From b12e55d3b09766f8325dbf07e16d81ff933e22b8 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Wed, 19 Nov 2025 11:22:49 +0100 Subject: [PATCH 13/20] Include spherical harmonics implementation from https://github.com/narolaharsh/ripple --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 45 +++++++------ src/ripplegw/waveforms/spherical_harmonics.py | 65 +++++++++++++++++++ 2 files changed, 89 insertions(+), 21 deletions(-) create mode 100644 src/ripplegw/waveforms/spherical_harmonics.py diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 07f99887..e7270806 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -15,6 +15,7 @@ ) from ..typing import Array from .IMRPhenomD_QNMdata import QNMData_a, QNMData_fRD, QNMData_fdamp +from .spherical_harmonics import * # helper functions for LALtoPhenomP: @@ -791,27 +792,29 @@ def general_case(pPrec): ytheta = pPrec['thetaJN'] yphi = 0.0 - pPrec['Y2m2'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, -2) ## (TODO) - pPrec['Y2m1'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, -1) - pPrec['Y20'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, 0) - pPrec['Y21'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, 1) - pPrec['Y22'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 2, 2) - pPrec['Y3m3'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, -3) - pPrec['Y3m2'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, -2) - pPrec['Y3m1'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, -1) - pPrec['Y30'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 0) - pPrec['Y31'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 1) - pPrec['Y32'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 2) - pPrec['Y33'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 3, 3) - pPrec['Y4m4'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -4) - pPrec['Y4m3'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -3) - pPrec['Y4m2'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -2) - pPrec['Y4m1'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, -1) - pPrec['Y40'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 0) - pPrec['Y41'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 1) - pPrec['Y42'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 2) - pPrec['Y43'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 3) - pPrec['Y44'] = XLALSpinWeightedSphericalHarmonic(ytheta, yphi, -2, 4, 4) + + pPrec['Y2m2'] = compute_sminus2_l2(ytheta, -2) + pPrec['Y2m1'] = compute_sminus2_l2(ytheta, -1) + pPrec['Y20'] = compute_sminus2_l2(ytheta, 0) + pPrec['Y21'] = compute_sminus2_l2(ytheta, 1) + pPrec['Y22'] = compute_sminus2_l2(ytheta, 2) + pPrec['Y3m3'] = compute_sminus2_l3(ytheta, -3) + pPrec['Y3m2'] = compute_sminus2_l3(ytheta, -2) + pPrec['Y3m1'] = compute_sminus2_l3(ytheta, -1) + pPrec['Y30'] = compute_sminus2_l3(ytheta, 0) + pPrec['Y31'] = compute_sminus2_l3(ytheta, 1) + pPrec['Y32'] = compute_sminus2_l3(ytheta, 2) + pPrec['Y33'] = compute_sminus2_l3(ytheta, 3) + pPrec['Y4m4'] = compute_sminus2_l4(ytheta, -4) + pPrec['Y4m3'] = compute_sminus2_l4(ytheta, -3) + pPrec['Y4m2'] = compute_sminus2_l4(ytheta, -2) + pPrec['Y4m1'] = compute_sminus2_l4(ytheta, -1) + pPrec['Y40'] = compute_sminus2_l4(ytheta, 0) + pPrec['Y41'] = compute_sminus2_l4(ytheta, 1) + pPrec['Y42'] = compute_sminus2_l4(ytheta, 2) + pPrec['Y43'] = compute_sminus2_l4(ytheta, 3) + pPrec['Y44'] = compute_sminus2_l4(ytheta, 4) + return pPrec diff --git a/src/ripplegw/waveforms/spherical_harmonics.py b/src/ripplegw/waveforms/spherical_harmonics.py new file mode 100644 index 00000000..60aefd70 --- /dev/null +++ b/src/ripplegw/waveforms/spherical_harmonics.py @@ -0,0 +1,65 @@ +import jax +import jax.numpy as jnp +import math +from ..typing import Array + + +def compute_sminus2_l2(theta, m): + """ + Spin -2 weighted spherical harmonic for l=2, phi=0. + theta: float or array + m: integer in [-2, -1, 0, 1, 2] + """ + + # Compute the fac factor based on m + harmonics = jnp.where(m == -2, jnp.sqrt(5.0 / (64.0 * jnp.pi)) * (1.0 - jnp.cos(theta))**2, + jnp.where(m == -1, jnp.sqrt(5.0 / (16.0 * jnp.pi)) * jnp.sin(theta) * (1.0 - jnp.cos(theta)), + jnp.where(m == 0, jnp.sqrt(15.0 / (32.0 * jnp.pi)) * jnp.sin(theta)**2, + jnp.where(m == 1, jnp.sqrt(5.0 / (16.0 * jnp.pi)) * jnp.sin(theta) * (1.0 + jnp.cos(theta)), + jnp.sqrt(5.0 / (64.0 * jnp.pi)) * (1.0 + jnp.cos(theta))**2 + )))) + + return harmonics + + +def compute_sminus2_l3(theta, m): + """ + Spin -2 weighted spherical harmonic for l=3, phi=0. + theta: scalar or array + m: integer in [-3, 3] + """ + + harmonics = jnp.where(m == -3, jnp.sqrt(21.0 / (2.0 * jnp.pi)) * jnp.cos(theta / 2.0) * (jnp.sin(theta / 2.0) ** 5), + jnp.where(m == -2, jnp.sqrt(7.0 / (4.0 * jnp.pi)) * (2.0 + 3.0 * jnp.cos(theta)) * (jnp.sin(theta / 2.0) ** 4), + jnp.where(m == -1, jnp.sqrt(35.0 / (2.0 * jnp.pi)) * (jnp.sin(theta) + 4.0 * jnp.sin(2.0*theta) - 3.0 * jnp.sin(3.0*theta)) / 32.0, + jnp.where(m == 0, (jnp.sqrt(105.0 / (2.0 * jnp.pi)) * jnp.cos(theta) * (jnp.sin(theta) ** 2)) / 4.0, + jnp.where(m == 1, -jnp.sqrt(35.0 / (2.0 * jnp.pi)) * (jnp.sin(theta) - 4.0 * jnp.sin(2.0*theta) - 3.0 * jnp.sin(3.0*theta)) / 32.0, + jnp.where(m == 2, jnp.sqrt(7.0 / jnp.pi) * (jnp.cos(theta/2.0)**4) * (-2.0 + 3.0 * jnp.cos(theta)) / 2.0, + -jnp.sqrt(21.0 / (2.0 * jnp.pi)) * (jnp.cos(theta/2.0)**5) * jnp.sin(theta/2.0) + )))))) + + return harmonics + + +def compute_sminus2_l4(theta, m): + """ + Spin -2 weighted spherical harmonic for l=4, phi=0. + theta: scalar or array + m: integer in [-4, 4] + """ + + harmonics = jnp.where(m == -4, 3.0 * jnp.sqrt(7.0 / jnp.pi) * (jnp.cos(theta/2.0)**2) * (jnp.sin(theta/2.0)**6), + jnp.where(m == -3, 3.0 * jnp.sqrt(7.0 / (2.0 * jnp.pi)) * jnp.cos(theta/2.0) * (1.0 + 2.0 * jnp.cos(theta)) * (jnp.sin(theta/2.0)**5), + jnp.where(m == -2, 3.0 * (9.0 + 14.0 * jnp.cos(theta) + 7.0 * jnp.cos(2.0*theta)) * (jnp.sin(theta/2.0)**4) / (4.0 * jnp.sqrt(jnp.pi)), + jnp.where(m == -1, 3.0 * (3.0 * jnp.sin(theta) + 2.0 * jnp.sin(2.0*theta) + 7.0 * jnp.sin(3.0*theta) - 7.0 * jnp.sin(4.0*theta)) / (32.0 * jnp.sqrt(2.0 * jnp.pi)), + jnp.where(m == 0, 3.0 * jnp.sqrt(5.0 / (2.0 * jnp.pi)) * (5.0 + 7.0 * jnp.cos(2.0*theta)) * (jnp.sin(theta)**2) / 16.0, + jnp.where(m == 1, 3.0 * (3.0 * jnp.sin(theta) - 2.0 * jnp.sin(2.0*theta) + 7.0 * jnp.sin(3.0*theta) + 7.0 * jnp.sin(4.0*theta)) / (32.0 * jnp.sqrt(2.0 * jnp.pi)), + jnp.where(m == 2, 3.0 * (jnp.cos(theta/2.0)**4) * (9.0 - 14.0 * jnp.cos(theta) + 7.0 * jnp.cos(2.0*theta)) / (4.0 * jnp.sqrt(jnp.pi)), + jnp.where(m == 3, -3.0 * jnp.sqrt(7.0 / (2.0 * jnp.pi)) * (jnp.cos(theta/2.0)**5) * (-1.0 + 2.0 * jnp.cos(theta)) * jnp.sin(theta/2.0), + 3.0 * jnp.sqrt(7.0 / jnp.pi) * (jnp.cos(theta/2.0)**6) * (jnp.sin(theta/2.0)**2) + )))))))) + + return harmonics + + + From aa7efb1e4ed03e0fd04c850842a023633686bfd8 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Wed, 19 Nov 2025 11:26:46 +0100 Subject: [PATCH 14/20] Implement X_Return_SNorm_MSA, Get_alphaepsilon_atfref, XLALSimIMRPhenomXLPNAnsatz functions --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 79 +++++++++++++++++++-- 1 file changed, 75 insertions(+), 4 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index e7270806..fa1075d7 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -295,10 +295,36 @@ def IMRPhenomX_JNorm_MSA(LNorm, pPrec): JNorm2 = (LNorm * LNorm + 2.0 * LNorm * pPrec['c1_over_eta'] + pPrec['SAv2']) return jnp.sqrt(JNorm2) -def IMRPhenomX_Return_SNorm_MSA(): - ## TODO: implement elliptic Jacobi functions - return - +def IMRPhenomX_Return_SNorm_MSA(v, pPrec): + + v2 = v * v + + cancel_condition = jnp.abs(pPrec['Smi2'] - pPrec['Spl2']) < 1e-5 + + def sn_zero(_): + sn = jnp.array(0.0) + return sn + + def sn_jacobi(_): + # Equation 25 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + m = (pPrec['Smi2'] - pPrec['Spl2']) / (pPrec['S32'] - pPrec['Spl2']) + + psi = IMRPhenomX_psiofv( + v, v2, + pPrec['psi0'], pPrec['psi1'], pPrec['psi2'], + pPrec + ) + + # Jacobi elliptic functions + sn, cn, dn = gsl_sf_elljac_e(psi, m) + return sn + + sn = jax.lax.cond(cancel_condition, sn_zero, sn_jacobi, operand=None) + + # Equation 23 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 + SNorm2 = pPrec['Spl2'] + (pPrec['Smi2'] - pPrec['Spl2']) * sn * sn + + return jnp.sqrt(SNorm2) def IMRPhenomX_L_norm_3PN_of_v(v: float, v2: float, L_norm: float, pPrec) -> float: cL = pPrec['constants_L'] # shorthand @@ -1485,6 +1511,51 @@ def evaluate_QNMfit_fdamp22(a): + 0.009002719412204133 * a6 ) ) + +def XLALSimIMRPhenomXLPNAnsatz(v, LNorm, L0, L1, L2, L3, L4, L5, L6, L7, L8, L8L): + """ + Computes the PN orbital angular momentum expansion. + v : Input velocity. + LNorm : Orbital angular momentum normalization (e.g. η / sqrt(x)). + L0–L8, L8L : PN coefficients. + + Returns + L : Post-Newtonian angular momentum. + """ + + x = v * v + x2 = x * x + x3 = x * x2 + x4 = x * x3 + sqx = jnp.sqrt(x) + + L = ( + L0 + + L1 * sqx + + L2 * x + + L3 * (x * sqx) + + L4 * x2 + + L5 * (x2 * sqx) + + L6 * x3 + + L7 * (x3 * sqx) + + L8 * x4 + + L8L * x4 * jnp.log(x) + ) + return LNorm * L + + +def Get_alphaepsilon_atfref(mprime, pPrec, pWF): + + omega_ref = pWF['piM'] * pWF['fRef'] * (2.0 / mprime) + + v = jnp.cbrt(omega_ref) + + vangles = IMRPhenomX_Return_phi_zeta_costhetaL_MSA(v, pWF, pPrec) + + alpha_offset = vangles[0] - pPrec['alpha0'] + epsilon_offset = vangles[1] - pPrec['epsilon0'] + + return alpha_offset, epsilon_offset ########################## From 0d854c31537b5325af8fc059158ac8aaa4342ca1 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Wed, 19 Nov 2025 11:27:35 +0100 Subject: [PATCH 15/20] Fix bugs in IMRPhenomXP_utils while testing IMRPhenomXGetAndSetPrecessionVariables --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 23 +++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index fa1075d7..741e3a25 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -714,17 +714,17 @@ def general_case(pPrec): tmp_v = jnp.array([pPrec['Nx_Sf'], pPrec['Ny_Sf'], pPrec['Nz_Sf']]) # Rotate around z, then y - tmp_v = IMRPhenomX_rotate_z(-pPrec['phiJ_Sf'], tmp_v) - tmp_v = IMRPhenomX_rotate_y(-pPrec['thetaJ_Sf'], tmp_v) + tmp_v = IMRPhenomX_rotate_z(tmp_v, -pPrec['phiJ_Sf']) + tmp_v = IMRPhenomX_rotate_y(tmp_v, -pPrec['thetaJ_Sf']) ## Difference in overall - sign w.r.t PhenomPv2 code pPrec['kappa'] = jnp.where((jnp.abs(tmp_v[0]) < 1e-15) & (jnp.abs(tmp_v[1]) < 1e-15), 0.0, jnp.arctan2(tmp_v[1], tmp_v[0])) ## Now determine alpha0 by rotating LN. In the source frame, LN = {0,0,1} tmp_v = jnp.array([0,0,1]) - tmp_v = IMRPhenomX_rotate_z(-pPrec['phiJ_Sf'], tmp_v) - tmp_v = IMRPhenomX_rotate_y(-pPrec['thetaJ_Sf'], tmp_v) - tmp_v = IMRPhenomX_rotate_z(-pPrec['kappa'], tmp_v) + tmp_v = IMRPhenomX_rotate_z(tmp_v, -pPrec['phiJ_Sf']) + tmp_v = IMRPhenomX_rotate_y(tmp_v, -pPrec['thetaJ_Sf']) + tmp_v = IMRPhenomX_rotate_z(tmp_v, -pPrec['kappa']) pPrec['alpha0'] = jnp.pi - pPrec['kappa'] @@ -757,9 +757,9 @@ def general_case(pPrec): pPrec['Xz_Sf'] = jnp.sin(pWF['inclination']) tmp_v = jnp.array([pPrec['Xx_Sf'],pPrec['Xy_Sf'],pPrec['Xz_Sf']]) - tmp_v = IMRPhenomX_rotate_z(-pPrec['phiJ_Sf'], tmp_v) - tmp_v = IMRPhenomX_rotate_y(-pPrec['thetaJ_Sf'], tmp_v) - tmp_v = IMRPhenomX_rotate_z(-pPrec['kappa'], tmp_v) + tmp_v = IMRPhenomX_rotate_z(tmp_v, -pPrec['phiJ_Sf']) + tmp_v = IMRPhenomX_rotate_y(tmp_v, -pPrec['thetaJ_Sf']) + tmp_v = IMRPhenomX_rotate_z(tmp_v, -pPrec['kappa']) pPrec['PArun_Jf'] = jnp.array([pPrec['Nz_Jf'],0,-pPrec['Nx_Jf']]) @@ -1491,7 +1491,6 @@ def evaluate_QNMfit_fdamp22(a): a4 = a3 * a a5 = a4 * a a6 = a5 * a - a7 = a6 * a return ( ( @@ -1568,6 +1567,12 @@ def gsl_sf_ellint_F(x, y): """ return x +def gsl_sf_elljac_e(x, y): + """ + TODO: Not yet implemented + """ + return jnp.array(1.), jnp.array(2.), jnp.array(3.) + def XLALSimInspiralWaveformParamsLookupPhenomXPTransPrecessionMethod(dict): """ TODO: Not yet implemented From 1ebca9f7843a8465ba0f9fccb8fbf0b8b903030e Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Thu, 27 Nov 2025 14:52:13 +0100 Subject: [PATCH 16/20] Fix IMRPhenomX_Return_Roots_MSA to correctly handle vector inputs and fix other minor bugs --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 91 +++++++++++---------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 741e3a25..7718e45d 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -73,9 +73,9 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( ''' Get phiz_0_MSA and zeta_0_MSA ''' vMSA = jax.lax.cond((jnp.fabs(pPrec["Smi2"] - pPrec["Spl2"]) > 1.e-5), - IMRPhenomX_Return_MSA_Corrections_MSA, ## return 3D jnp.array - lambda v, L_norm, J_norm, pPrec: jnp.array([0,0,0]), - v, L_norm, J_norm, pPrec) + lambda args: IMRPhenomX_Return_MSA_Corrections_MSA(*args), ## return 3D jnp.array + lambda args: jnp.array([0.,0.,0.]), + (v, L_norm, J_norm, pPrec)) phiz_MSA = vMSA[0] zeta_MSA = vMSA[1] @@ -84,9 +84,9 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( zeta = IMRPhenomX_Return_zeta_MSA(v,pPrec) cos_theta_L = IMRPhenomX_costhetaLJ(L_norm3PN,J_norm3PN,SNorm) - vout[0] = phiz + phiz_MSA - vout[1] = zeta + zeta_MSA - vout[2] = cos_theta_L + vout.at[0].set(phiz + phiz_MSA) + vout.at[1].set(zeta + zeta_MSA) + vout.at[2].set(cos_theta_L) return vout @@ -351,17 +351,18 @@ def IMRPhenomX_Return_Roots_MSA(LNorm, JNorm, pPrec): theta = jnp.arccos(acosarg) / 3.0 cos_theta = jnp.cos(theta) - - invalid_case = ( - jnp.isnan(theta) | - jnp.isnan(sqrtarg) | - (pPrec['dotS1Ln'] == 1.0) | - (pPrec['dotS2Ln'] == 1.0) | - (pPrec['dotS1Ln'] == -1.0) | - (pPrec['dotS2Ln'] == -1.0) | - (pPrec['S1_norm_2'] == 0.0) | - (pPrec['S2_norm_2'] == 0.0) - ) + + print(f'{p=}, {sqrtarg=}, {theta=}, {B=}, {B2=}, {C=}') + + vector_condition = jnp.logical_or(jnp.isnan(theta), + (jnp.isnan(sqrtarg))) + scalar_condition = jnp.logical_or.reduce(jnp.array([(pPrec['dotS1Ln'] == 1.0), + (pPrec['dotS2Ln'] == 1.0), + (pPrec['dotS1Ln'] == -1.0), + (pPrec['dotS2Ln'] == -1.0), + (pPrec['S1_norm_2'] == 0.0), + (pPrec['S2_norm_2'] == 0.0)])) + invalid_case = jnp.logical_or(vector_condition, scalar_condition) def roots_when_valid(): tmp1 = 2.0 * sqrtarg * jnp.cos(theta - 4.0 * jnp.pi / 3.0) - B / 3.0 @@ -380,21 +381,23 @@ def roots_when_valid(): S32 = tmp5 Smi2 = jnp.abs(tmp6) Spl2 = jnp.abs(tmp4) - return S32, Smi2, Spl2 + return jnp.array([S32, Smi2, Spl2]) def roots_when_invalid(): - Smi2 = pPrec['S_0_norm']**2 + Smi2 = pPrec['S_0_norm']**2 * jnp.ones_like(LNorm) Spl2 = Smi2 + 1e-9 - S32 = 0.0 - return S32, Smi2, Spl2 + S32 = jnp.zeros_like(LNorm) + return jnp.array([S32, Smi2, Spl2]) - S32, Smi2, Spl2 = jax.lax.cond( - invalid_case, - roots_when_invalid, - roots_when_valid + roots_array = jnp.where( + jnp.atleast_1d(invalid_case), + roots_when_invalid(), + roots_when_valid() ) + + print(f'{roots_array=}') - return jnp.array([S32, Smi2, Spl2]) + return roots_array def IMRPhenomX_Return_Constants_c_MSA(v, JNorm, pPrec): v2 = v * v @@ -490,17 +493,24 @@ def IMRPhenomX_Return_Spin_Evolution_Coefficients_MSA(LNorm, JNorm, pPrec): return jnp.array([B_coeff, C_coeff, D_coeff]) -def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, - chi1x, chi1y, chi1z, - chi2x, chi2y, chi2z, - params): +def IMRPhenomXGetAndSetPrecessionVariables(pWF, params): + + m1_SI = pWF['m1']*gt + m2_SI = pWF['m2']*gt + + chi1x, chi1y, chi1z, chi2x, chi2y, chi2z = pWF['chi1x'], pWF['chi1y'], pWF['chi1z'], pWF['chi2x'], pWF['chi2y'], pWF['chi2z'] pPrec = {} ## Here we're only considering the default setting, where the expansion order for MSA correction is 5 - #pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(params) ## (TODO) + ## pPrec['ExpansionOrder'] = XLALSimInspiralWaveformParamsLookupPhenomXPExpansionOrder(params) ## (TODO) pPrec['ExpansionOrder'] = 5 + ## allow for conditional disabling of precession multibanding given mass ratio and opening angle ## (TODO) -> line 137 of x_precession.C + pPrec['conditionalPrecMBand'] = 0 + ## default value for PrecThresholdMband according to arXiv:2004.06503v2, Table VI + pPrec['PrecThresholdMband'] = 1e-3 + Mtot_SI = m1_SI + m2_SI # Normalize masses m1 = m1_SI / Mtot_SI @@ -655,7 +665,7 @@ def IMRPhenomXGetAndSetPrecessionVariables(pWF, m1_SI, m2_SI, pPrec['L8L'] = 0.0 pPrec['LRef'] = ( - M * M * XLALSimIMRPhenomXLPNAnsatz( ## (TODO) + M * M * XLALSimIMRPhenomXLPNAnsatz( pWF['v_ref'], ## (TODO) check if variables in pWF are stored correctly pWF['eta'] / pWF['v_ref'], pPrec['L0'], pPrec['L1'], pPrec['L2'], pPrec['L3'], pPrec['L4'], @@ -784,7 +794,7 @@ def general_case(pPrec): pPrec['epsilon0'] = pPrec['phiJ_Sf'] - jnp.pi - alpha_offset, epsilon_offset = Get_alphaepsilon_atfref(2, pPrec, pWF) ## (TODO) + alpha_offset, epsilon_offset = Get_alphaepsilon_atfref(2, pPrec, pWF) pPrec['alpha_offset'] = alpha_offset pPrec['epsilon_offset'] = epsilon_offset pPrec['alpha_offset_1'] = alpha_offset @@ -1288,22 +1298,15 @@ def IMRPhenomXPCheckMaxOpeningAngle( denominator = 81.0 - 57.0 * eta + eta * eta v_at_max_beta = jnp.sqrt(2.0 / 3.0) * jnp.sqrt(numerator / denominator) - status, cBetah, sBetah = WignerdCoefficients(cBetah, sBetah, v_at_max_beta, pWF, pPrec) - - jax.lax.cond( - status != 0, - lambda s: jax.debug.print("Call to IMRPhenomXWignerdCoefficients failed."), - lambda s: s, - status - ) + cBetah, sBetah = WignerdCoefficients(v_at_max_beta, pWF, pPrec) - L_min = XLALSimIMRPhenomXL2PNNS(v_at_max_beta, eta) ## (TODO) + L_min = XLALSimIMRPhenomXL2PNNS(v_at_max_beta, eta) max_beta = 2.0 * jnp.arccos(cBetah) def positive_case(vals): - q, conditionalPrecMBand = vals[0,1] + q, conditionalPrecMBand = vals[0], vals[1] jax.debug.print('The maximum opening angle exceeds Pi/2. The model may be pathological in this regime.') - return jax.lax.cond((q > 7.0 and conditionalPrecMBand == 1), lambda x: x*0, lambda x: x, vals[3]) + return jax.lax.cond(jnp.logical_and(q > 7.0, conditionalPrecMBand == 1), lambda x: x*0, lambda x: x, vals[3]) def negative_case(vals): max_beta = vals[2] From d654bab97e8a200dbfac155d2cdf639ba3485e53 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Thu, 27 Nov 2025 14:54:22 +0100 Subject: [PATCH 17/20] Implement WignerdCoefficients and XLALSimIMRPhenomXL2PNNS --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 43 ++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 7718e45d..4e22fb86 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -54,7 +54,7 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( L_norm3PN = IMRPhenomX_L_norm_3PN_of_v(v, v*v, L_norm, pPrec) ## for 223 - J_norm3PN = IMRPhenomX_JNorm_MSA(L_norm3PN,pPrec) + J_norm3PN = IMRPhenomX_JNorm_MSA(L_norm3PN,pPrec) vRoots = IMRPhenomX_Return_Roots_MSA(L_norm,J_norm,pPrec) ## return jnp.array @@ -99,6 +99,35 @@ def WignerdCoefficients_cosbeta( return cos_beta_half, sin_beta_half +def WignerdCoefficients( + v, # Cubic root of (Pi * Frequency (geometric)) + pWF, # IMRPhenomX waveform struct + pPrec # IMRPhenomX precession struct +): + # This implementation is different respect to the one in Pv2_utils + + # Orbital angular momentum L + L = XLALSimIMRPhenomXLPNAnsatz( + v, + pWF['eta'] / v, + pPrec['L0'], pPrec['L1'], pPrec['L2'], pPrec['L3'], pPrec['L4'], + pPrec['L5'], pPrec['L6'], pPrec['L7'], pPrec['L8'], pPrec['L8L'] + ) + + # We ignore the sign of L + SL below: + # s = S_perp / (L + SL) + denom = L + pPrec['SL'] + s = pPrec['Sperp'] / denom + s2 = s * s + + cos_beta = jnp.sign(denom) / jnp.sqrt(1.0 + s2) + + # cos(beta/2) and sin(beta/2) + cos_beta_half = jnp.sqrt(jnp.abs(1.0 + cos_beta) / 2.0) + sin_beta_half = jnp.sqrt(jnp.abs(1.0 - cos_beta) / 2.0) + + return cos_beta_half, sin_beta_half + def IMRPhenomX_costhetaLJ( L_norm: float, J_norm: float, @@ -1545,6 +1574,18 @@ def XLALSimIMRPhenomXLPNAnsatz(v, LNorm, L0, L1, L2, L3, L4, L5, L6, L7, L8, L8L ) return LNorm * L +def XLALSimIMRPhenomXL2PNNS(v, eta): + + eta2 = eta * eta + x = v * v + x2 = x * x + sqx = v + + return (eta / sqx) * ( + 1.0 + + x * (1.5 + eta / 6.0) + + x2 * (27.0 / 8.0 - (19.0 * eta) / 8.0 + eta2 / 24.0) + ) def Get_alphaepsilon_atfref(mprime, pPrec, pWF): From abe827abbf23c817da594d7c62bb6554efcfda74 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Fri, 28 Nov 2025 17:26:47 +0100 Subject: [PATCH 18/20] Fix Return_phi_zeta_costhetaL_MSA, Return_MSA_Corrections_MSA, Return_Snorm_MSA to correctly hand vector inputs and fix other minor bugs --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 50 ++++++++++++--------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 4e22fb86..9de84740 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -45,7 +45,6 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( pPrec ## IMRPhenomX precession struct ): ## has to output a jnp array - vout = jnp.array([0,0,0]) L_norm = pWF['eta'] / v J_norm = IMRPhenomX_JNorm_MSA(L_norm,pPrec) @@ -72,22 +71,29 @@ def IMRPhenomX_Return_phi_zeta_costhetaL_MSA( pPrec["S_norm_2"] = SNorm * SNorm ''' Get phiz_0_MSA and zeta_0_MSA ''' - vMSA = jax.lax.cond((jnp.fabs(pPrec["Smi2"] - pPrec["Spl2"]) > 1.e-5), - lambda args: IMRPhenomX_Return_MSA_Corrections_MSA(*args), ## return 3D jnp.array - lambda args: jnp.array([0.,0.,0.]), - (v, L_norm, J_norm, pPrec)) - - phiz_MSA = vMSA[0] - zeta_MSA = vMSA[1] - - phiz = IMRPhenomX_Return_phiz_MSA(v,J_norm,pPrec) - zeta = IMRPhenomX_Return_zeta_MSA(v,pPrec) - cos_theta_L = IMRPhenomX_costhetaLJ(L_norm3PN,J_norm3PN,SNorm) - - vout.at[0].set(phiz + phiz_MSA) - vout.at[1].set(zeta + zeta_MSA) - vout.at[2].set(cos_theta_L) + condition = jnp.atleast_1d(jnp.fabs(pPrec["Smi2"] - pPrec["Spl2"]) > 1.e-5) + ## this check is needed because of broadcasting if condition is N-dimensional + # (N,) cannot broadcast to (N,3) + condition = condition[..., None] if condition.ndim == 1 else condition + + vMSA = jnp.where(condition, + IMRPhenomX_Return_MSA_Corrections_MSA(v, L_norm, J_norm, pPrec), ## return v.shape[0]x3 dimensional jnp.array + jnp.zeros((jnp.atleast_1d(v).shape[0], 3)), + ) + phiz_MSA = vMSA[:,0] + zeta_MSA = vMSA[:,1] + + phiz = jnp.atleast_1d(IMRPhenomX_Return_phiz_MSA(v,J_norm,pPrec)) + zeta = jnp.atleast_1d(IMRPhenomX_Return_zeta_MSA(v,pPrec)) + cos_theta_L = jnp.atleast_1d(IMRPhenomX_costhetaLJ(L_norm3PN,J_norm3PN,SNorm)) + + vout = jnp.stack([ + phiz + phiz_MSA, + zeta + zeta_MSA, + cos_theta_L + ], axis=-1) + return vout def WignerdCoefficients_cosbeta( @@ -318,7 +324,7 @@ def compute_Dphi_term(): vMSA_x = jnp.where(jnp.isnan(vMSA_x), 0.0, vMSA_x) vMSA_y = jnp.where(jnp.isnan(vMSA_y), 0.0, vMSA_y) - return jnp.array([vMSA_x, vMSA_y, 0.0]) + return jnp.stack([vMSA_x, vMSA_y, jnp.zeros_like(vMSA_x)], axis=-1) def IMRPhenomX_JNorm_MSA(LNorm, pPrec): JNorm2 = (LNorm * LNorm + 2.0 * LNorm * pPrec['c1_over_eta'] + pPrec['SAv2']) @@ -330,11 +336,11 @@ def IMRPhenomX_Return_SNorm_MSA(v, pPrec): cancel_condition = jnp.abs(pPrec['Smi2'] - pPrec['Spl2']) < 1e-5 - def sn_zero(_): + def sn_zero(): sn = jnp.array(0.0) return sn - def sn_jacobi(_): + def sn_jacobi(): # Equation 25 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 m = (pPrec['Smi2'] - pPrec['Spl2']) / (pPrec['S32'] - pPrec['Spl2']) @@ -348,7 +354,7 @@ def sn_jacobi(_): sn, cn, dn = gsl_sf_elljac_e(psi, m) return sn - sn = jax.lax.cond(cancel_condition, sn_zero, sn_jacobi, operand=None) + sn = jnp.where(cancel_condition, sn_zero(), sn_jacobi()) # Equation 23 of Chatziioannou et al, PRD 95, 104004, (2017), arXiv:1703.03967 SNorm2 = pPrec['Spl2'] + (pPrec['Smi2'] - pPrec['Spl2']) * sn * sn @@ -1595,8 +1601,8 @@ def Get_alphaepsilon_atfref(mprime, pPrec, pWF): vangles = IMRPhenomX_Return_phi_zeta_costhetaL_MSA(v, pWF, pPrec) - alpha_offset = vangles[0] - pPrec['alpha0'] - epsilon_offset = vangles[1] - pPrec['epsilon0'] + alpha_offset = vangles[:,0] - pPrec['alpha0'] + epsilon_offset = vangles[:,1] - pPrec['epsilon0'] return alpha_offset, epsilon_offset From 95b228f2493d6a8a611fafddb03f94ad9166656f Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Tue, 9 Dec 2025 16:47:34 +0100 Subject: [PATCH 19/20] Define top level functions to generate IMRPhenomXP waveform --- src/ripplegw/gsl_ellint.py | 207 ++++++++++++++++ src/ripplegw/waveforms/IMRPhenomXP.py | 59 ++++- test/benchmark_waveform.py | 327 +++++++++++++++----------- 3 files changed, 456 insertions(+), 137 deletions(-) create mode 100644 src/ripplegw/gsl_ellint.py diff --git a/src/ripplegw/gsl_ellint.py b/src/ripplegw/gsl_ellint.py new file mode 100644 index 00000000..dddc1747 --- /dev/null +++ b/src/ripplegw/gsl_ellint.py @@ -0,0 +1,207 @@ +""" +Elliptic integral utilities implemented with JAX. +Some functions are based on https://github.com/tagordon/ellip/tree/main +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp + +jax.config.update("jax_enable_x64", True) + + +# relative error will be "less in magnitude than r" +R = 1.0e-15 + +# GSL constants +GSL_DBL_EPSILON = 2.2204460492503131e-16 + + +@jax.jit +@jnp.vectorize +def rf(x, y, z): + r"""JAX implementation of Carlson's :math:`R_\mathrm{F}` + + Computed using the algorithm in Carlson, 1994: https://arxiv.org/pdf/math/9409227.pdf + Code taken from https://github.com/tagordon/ellip/tree/main + + Args: + x: arraylike, real valued. + y: arraylike, real valued. + z: arraylike, real valued. + + Returns: + The value of the integral :math:`R_\mathrm{F}` + + Notes: + ``rf`` does not support complex-valued inputs. + ``rf`` requires `jax.config.update("jax_enable_x64", True)` + """ + + xyz = jnp.array([x, y, z]) + a0 = jnp.sum(xyz) / 3.0 + v = jnp.max(jnp.abs(a0 - xyz)) + q = (3 * R) ** (-1 / 6) * v + + # cond = lambda s: s["f"] * q > jnp.abs(s["An"]) + def cond(s): + return s["f"] * q > jnp.abs(s["An"]) + + def body(s): + + xyz = s["xyz"] + lam = jnp.sqrt(xyz[0] * xyz[1]) + jnp.sqrt(xyz[0] * xyz[2]) + jnp.sqrt(xyz[1] * xyz[2]) + + s["An"] = 0.25 * (s["An"] + lam) + s["xyz"] = 0.25 * (s["xyz"] + lam) + s["f"] = s["f"] * 0.25 + + return s + + s = {"f": 1, "An": a0, "xyz": xyz} + s = jax.lax.while_loop(cond, body, s) + + x = (a0 - x) / s["An"] * s["f"] + y = (a0 - y) / s["An"] * s["f"] + z = -(x + y) + e2 = x * y - z * z + e3 = x * y * z + + return (1 - 0.1 * e2 + e3 / 14 + e2 * e2 / 24 - 3 * e2 * e3 / 44) / jnp.sqrt(s["An"]) + + +@jax.jit +@jnp.vectorize +def ellipfinc(phi, k): + r"""JAX implementation of the incomplete elliptic integral of the first kind + Code taken from https://github.com/tagordon/ellip/tree/main + + .. math:: + + \[F\left(\phi,k\right)=\int_{0}^{\phi}\frac{\,\mathrm{d}\theta}{\sqrt{1-k^{2}{% + \sin}^{2}\theta}}] + + Args: + phi: arraylike, real valued. + k: arraylike, real valued. + + Returns: + The value of the complete elliptic integral of the first kind, :math:`F(\phi, k)` + + Notes: + ``ellipfinc`` does not support complex-valued inputs. + ``ellipfinc`` requires `jax.config.update("jax_enable_x64", True)` + """ + + c = 1.0 / jnp.sin(phi) ** 2 + return rf(c - 1, c - k**2, c) + + +@jax.jit +@jnp.vectorize +def gsl_sf_elljac_e(u, m): # double * sn, double * cn, double * dn + """ + JAX implementation of the Jacobi elliptic functions sn(u|m), cn(u|m), dn(u|m) + Based on https://github.com/ampl/gsl/blob/master/specfunc/elljac.c + """ + + def little_m_branch(u, m): + _ = m + sn = jnp.sin(u) + cn = jnp.cos(u) + dn = 1.0 + return sn, cn, dn + + def little_m_min_one_branch(u, m): + _ = m + sn = jnp.tanh(u) + cn = 1.0 / jnp.cosh(u) + dn = cn + return sn, cn, dn + + def main_branch(u, m): + _n = 16 + mu = jnp.zeros(_n, dtype=jnp.float64).at[0].set(1.0) + nu = jnp.zeros(_n, dtype=jnp.float64).at[0].set(jnp.sqrt(jnp.clip(1.0 - m, 0.0, jnp.inf))) + + def cond(state): + n, mu, nu = state + diff = jnp.abs(mu[n] - nu[n]) + tol = 4.0 * GSL_DBL_EPSILON * jnp.abs(mu[n] + nu[n]) + return (diff > tol) & (n < _n - 1) + + def body(state): + n, mu, nu = state + mu = mu.at[n + 1].set(0.5 * (mu[n] + nu[n])) + nu = nu.at[n + 1].set(jnp.sqrt(mu[n] * nu[n])) + return n + 1, mu, nu + + n, mu, nu = jax.lax.while_loop(cond, body, (0, mu, nu)) + + sin_umu = jnp.sin(u * mu[n]) + cos_umu = jnp.cos(u * mu[n]) + + def cos_branch(args): + sin_umu, cos_umu, n, mu, nu = args + c = jnp.zeros_like(mu).at[n].set(mu[n] * (sin_umu / cos_umu)) + d = jnp.zeros_like(mu).at[n].set(1.0) + + def cond(kcd): + k, _, _ = kcd + return k > 0 + + def body(kcd): + k, c, d = kcd + k = k - 1 + r = (c[k + 1] * c[k + 1]) / mu[k + 1] + c = c.at[k].set(d[k + 1] * c[k + 1]) + d = d.at[k].set((r + nu[k]) / (r + mu[k])) + return k, c, d + + _, c, d = jax.lax.while_loop(cond, body, (n, c, d)) + dn = jnp.sqrt(jnp.clip(1.0 - m, 0.0, jnp.inf)) / d[0] + cn = dn * jnp.sign(cos_umu) / jnp.hypot(1.0, c[0]) + sn = cn * c[0] / jnp.sqrt(jnp.clip(1.0 - m, 0.0, jnp.inf)) + return sn, cn, dn + + def sin_branch(args): + sin_umu, cos_umu, n, mu, nu = args + c = jnp.zeros_like(mu).at[n].set(mu[n] * (cos_umu / sin_umu)) + d = jnp.zeros_like(mu).at[n].set(1.0) + + def cond(kcd): + k, _, _ = kcd + return k > 0 + + def body(kcd): + k, c, d = kcd + k = k - 1 + r = (c[k + 1] * c[k + 1]) / mu[k + 1] + c = c.at[k].set(d[k + 1] * c[k + 1]) + d = d.at[k].set((r + nu[k]) / (r + mu[k])) + return k, c, d + + _, c, d = jax.lax.while_loop(cond, body, (n, c, d)) + dn = d[0] + sn = jnp.sign(sin_umu) / jnp.hypot(1.0, c[0]) + cn = c[0] * sn + return sn, cn, dn + + return jax.lax.cond( + jnp.abs(sin_umu) < jnp.abs(cos_umu), cos_branch, sin_branch, operand=(sin_umu, cos_umu, n, mu, nu) + ) + + sn, cn, dn = jax.lax.cond( + jnp.abs(m) < 2.0 * GSL_DBL_EPSILON, + lambda args: little_m_branch(*args), + lambda args: jax.lax.cond( + jnp.abs(args[1] - 1.0) < 2.0 * GSL_DBL_EPSILON, + lambda inner_args: little_m_min_one_branch(*inner_args), + lambda inner_args: main_branch(*inner_args), + args, + ), + (u, m), + ) + + return sn, cn, dn diff --git a/src/ripplegw/waveforms/IMRPhenomXP.py b/src/ripplegw/waveforms/IMRPhenomXP.py index 7edaa34d..818ca66a 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP.py +++ b/src/ripplegw/waveforms/IMRPhenomXP.py @@ -91,22 +91,69 @@ def PhenomXPCoreTwistUp22( return hp, hc -def gen_IMRPhenomXP_hphc(f: Array, +def _gen_IMRPhenomXP_hphc(f: Array, params: Array, prec_params, f_ref: float): """ + The following function generates an IMRPhenomXP frequency-domain waveform. + It is the translation of IMRPhenomXPGenerateFD from LAL. + Calls gen_IMRPhenomXAS to generate the coprecessing waveform (In LAL the coprecessing waveform is generated within IMRPhenomXPGenerateFD) Returns: -------- hp (array): Strain of the plus polarization hc (array): Strain of the cross polarization """ - iota = params[7] - h0 = gen_IMRPhenomXAS(f, params, f_ref) + + Mc, _ = ms_to_Mc_eta([params['m1'], params['m2']]) + iota = params['inclination'] + params_list = [Mc, params['eta'], params['chi1_norm'], params['chi2_norm'], params['D'], params['tc'], params['phi0']] + ## line 2372 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html + hcoprec = gen_IMRPhenomXAS(f, params_list, f_ref) - hp, hc = PhenomXPCoreTwistUp22(f, h0, params, prec_params) + ## line 2403 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html + hp, hc = PhenomXPCoreTwistUp22(f, hcoprec, params, prec_params) - hp = h0 * (1 / 2 * (1 + jnp.cos(iota) ** 2)) - hc = -1j * h0 * jnp.cos(iota) + ## rotate waveform by 2 zeta. + # line 2469 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html + zeta = prec_params['zeta_polarization'] + cond = jnp.abs(zeta) > 0.0 + + def no_rotation(args): + hp, hc, _ = args + return hp, hc + + def do_rotation(args): + hp, hc, z = args + angle = 2.0 * z + + cosPol = jnp.cos(angle) + sinPol = jnp.sin(angle) + + new_hp = cosPol * hp + sinPol * hc + new_hc = cosPol * hc - sinPol * hp + + return new_hp, new_hc + + hp, hc = jax.lax.cond( + cond, + do_rotation, + no_rotation, + operand=(hp, hc, zeta) + ) + return hp, hc + +def gen_IMRPhenomXP_hphc(f: Array, + params: Array, + f_ref: float): + + params_aux = {} + + ## line 1192 of https://lscsoft.docs.ligo.org/lalsuite/lalsimulation/_l_a_l_sim_i_m_r_phenom_x_8c_source.html + prec_params = IMRPhenomXGetAndSetPrecessionVariables(params, params_aux) + + ## line 1213 + hp, hc = _gen_IMRPhenomXP_hphc(f, params, prec_params, f_ref) + return hp, hc \ No newline at end of file diff --git a/test/benchmark_waveform.py b/test/benchmark_waveform.py index ccfcb105..87cf0010 100644 --- a/test/benchmark_waveform.py +++ b/test/benchmark_waveform.py @@ -5,77 +5,89 @@ TODO: Implement precession here as well. """ + import time -import numpy as np + import jax import jax.numpy as jnp +import lal +import lalsimulation as lalsim +import numpy as np import pandas as pd from tqdm import tqdm -from ripple import get_eff_pads, get_match_arr, ms_to_Mc_eta, lambdas_to_lambda_tildes +from ripple import get_eff_pads, get_match_arr, lambdas_to_lambda_tildes, ms_to_Mc_eta from ripple.constants import PI -import lal -import lalsimulation as lalsim - jax.config.update("jax_enable_x64", True) -########################### +########################### ### Auxiliary functions ### ########################### + def check_is_tidal(waveform_name: str): # Check if the given waveform is supported: bns_waveforms = ["IMRPhenomD_NRTidalv2", "TaylorF2"] bbh_waveforms = ["IMRPhenomD"] - + all_waveforms = bns_waveforms + bbh_waveforms if waveform_name not in all_waveforms: - raise ValueError(f"Waveform approximant {waveform_name} not supported by ripple") - + raise ValueError( + f"Waveform approximant {waveform_name} not supported by ripple" + ) + if waveform_name in bns_waveforms: is_tidal = True else: is_tidal = False - + return is_tidal + def get_jitted_waveform(waveform_name: str, fs: np.array, f_ref: float): if waveform_name == "IMRPhenomD": # Import the waveform - from ripple.waveforms.IMRPhenomD import gen_IMRPhenomD_hphc as waveform_generator - + from ripple.waveforms.IMRPhenomD import ( + gen_IMRPhenomD_hphc as waveform_generator, + ) + # Get jitted version (note, use IMRPhenomD as underlying waveform model) @jax.jit def waveform(theta): hp, _ = waveform_generator(fs, theta, f_ref) return hp - + elif waveform_name == "IMRPhenomD_NRTidalv2": # Import the waveform - from ripple.waveforms.X_NRTidalv2 import gen_NRTidalv2_hphc as waveform_generator - + from ripple.waveforms.X_NRTidalv2 import ( + gen_NRTidalv2_hphc as waveform_generator, + ) + # Get jitted version (note, use IMRPhenomD as underlying waveform model) @jax.jit def waveform(theta): hp, _ = waveform_generator(fs, theta, f_ref, IMRphenom="IMRPhenomD") return hp - + elif waveform_name == "TaylorF2": # Import the waveform from ripple.waveforms.TaylorF2 import gen_TaylorF2_hphc as waveform_generator - + # Get jitted version @jax.jit def waveform(theta): hp, _ = waveform_generator(fs, theta, f_ref) return hp - + else: - raise ValueError(f"Waveform approximant {waveform_name} not supported by ripple") - + raise ValueError( + f"Waveform approximant {waveform_name} not supported by ripple" + ) + return waveform + def get_freqs(f_l, f_u, f_sampling, T): # Build the frequency grid delta_t = 1 / f_sampling @@ -89,10 +101,17 @@ def get_freqs(f_l, f_u, f_sampling, T): ### Match against LAL ### ######################### -def random_match(n: int, bounds: dict, IMRphenom: str = "IMRPhenomD_NRTidalv2", outdir: str = None, psd_file: str = "psds/psd.txt"): + +def random_match( + n: int, + bounds: dict, + IMRphenom: str = "IMRPhenomD_NRTidalv2", + outdir: str = None, + psd_file: str = "psds/psd.txt", +): """ Generates random waveform match scores between LAL and ripple. - + Note, currently only IMRPhenomD is supported. Args: n: int @@ -116,7 +135,7 @@ def random_match(n: int, bounds: dict, IMRphenom: str = "IMRPhenomD_NRTidalv2", f_ref = f_l fs = get_freqs(f_l, f_u, f_sampling, T) df = fs[1] - fs[0] - + waveform = get_jitted_waveform(IMRphenom, fs, f_ref) is_tidal = check_is_tidal(IMRphenom) @@ -129,7 +148,7 @@ def random_match(n: int, bounds: dict, IMRphenom: str = "IMRPhenomD_NRTidalv2", # Mismatches computations: for _ in tqdm(range(n)): non_precessing_matchmaking( - bounds, IMRphenom, f_l, f_u, df, fs, waveform, f_ASD, ASD, thetas, matches + bounds, IMRphenom, f_l, f_u, df, fs, waveform, f_ASD, ASD, thetas, matches ) # Save and report mismatches @@ -144,11 +163,21 @@ def random_match(n: int, bounds: dict, IMRphenom: str = "IMRPhenomD_NRTidalv2", def non_precessing_matchmaking( - bounds, IMRphenom, f_l, f_u, df, fs, waveform, f_ASD, ASD, thetas, matches, + bounds, + IMRphenom, + f_l, + f_u, + df, + fs, + waveform, + f_ASD, + ASD, + thetas, + matches, ): - + is_tidal = check_is_tidal(IMRphenom) - + m1 = np.random.uniform(bounds["m"][0], bounds["m"][1]) m2 = np.random.uniform(bounds["m"][0], bounds["m"][1]) s1 = np.random.uniform(bounds["chi"][0], bounds["chi"][1]) @@ -158,9 +187,9 @@ def non_precessing_matchmaking( dist_mpc = np.random.uniform(bounds["d_L"][0], bounds["d_L"][1]) tc = 0.0 - inclination = np.random.uniform(0, 2*PI) - phi_ref = np.random.uniform(0, 2*PI) - + inclination = np.random.uniform(0, 2 * PI) + phi_ref = np.random.uniform(0, 2 * PI) + # Ensure m1 > m2 if m1 < m2: theta = np.array([m2, m1, s2, s1, l2, l1, dist_mpc, tc, phi_ref, inclination]) @@ -168,21 +197,21 @@ def non_precessing_matchmaking( theta = np.array([m1, m2, s1, s2, l1, l2, dist_mpc, tc, phi_ref, inclination]) else: raise ValueError("Something went wrong with the parameters") - + # If not tidal, remove l1 and l2 from theta if not is_tidal: theta = np.delete(theta, [4, 5]) l1 = 0.0 l2 = 0.0 - + # Get approximant for lal approximant = lalsim.SimInspiralGetApproximantFromString(IMRphenom) - + f_ref = f_l m1_kg = theta[0] * lal.MSUN_SI m2_kg = theta[1] * lal.MSUN_SI distance = dist_mpc * 1e6 * lal.PC_SI - + if is_tidal: laldict = lal.CreateDict() lalsim.SimInspiralWaveformParamsInsertTidalLambda1(laldict, l1) @@ -200,10 +229,10 @@ def non_precessing_matchmaking( m2_kg, 0.0, 0.0, - theta[2], # spin m1 zero component + theta[2], # spin m1 zero component 0.0, 0.0, - theta[3], # spin m2 zero component + theta[3], # spin m2 zero component distance, inclination, phi_ref, @@ -222,28 +251,41 @@ def non_precessing_matchmaking( mask_lal = (freqs_lal > f_l) & (freqs_lal < f_u) freqs_lal = freqs_lal[mask_lal] hp_lalsuite = hp.data.data[mask_lal] - + # Get the ripple waveform Mc, eta = ms_to_Mc_eta(jnp.array([theta[0], theta[1]])) - lambda_tilde, delta_lambda_tilde = lambdas_to_lambda_tildes(jnp.array([l1, l2, m1, m2])) + lambda_tilde, delta_lambda_tilde = lambdas_to_lambda_tildes( + jnp.array([l1, l2, m1, m2]) + ) theta_ripple = jnp.array( - [Mc, eta, theta[2], theta[3], lambda_tilde, delta_lambda_tilde, dist_mpc, tc, phi_ref, inclination] + [ + Mc, + eta, + theta[2], + theta[3], + lambda_tilde, + delta_lambda_tilde, + dist_mpc, + tc, + phi_ref, + inclination, + ] ) - + # If not tidal, remove lambda parameters if not is_tidal: theta_ripple = jnp.delete(theta_ripple, jnp.array([4, 5])) - + hp_ripple = waveform(theta_ripple) - + # Check if strain has NaNs if jnp.isnan(hp_ripple).any(): print("NaNs in ripple strain") - + if jnp.isnan(hp_lalsuite).any(): print("NaNs in lalsuite strain") - + # Compute match PSD_vals = np.interp(fs, f_ASD, ASD) ** 2 pad_low, pad_high = get_eff_pads(fs) @@ -258,63 +300,68 @@ def non_precessing_matchmaking( ) thetas.append(theta) + def save_matches(filename, thetas, matches, verbose=True, is_tidal=False): # Get the parameters, which depends on whether or not tidal: if is_tidal: - m1 = thetas[:, 0] - m2 = thetas[:, 1] - chi1 = thetas[:, 2] - chi2 = thetas[:, 3] - lambda1 = thetas[:, 4] - lambda2 = thetas[:, 5] - dist_mpc = thetas[:, 6] - tc = thetas[:, 7] - phi_ref = thetas[:, 8] + m1 = thetas[:, 0] + m2 = thetas[:, 1] + chi1 = thetas[:, 2] + chi2 = thetas[:, 3] + lambda1 = thetas[:, 4] + lambda2 = thetas[:, 5] + dist_mpc = thetas[:, 6] + tc = thetas[:, 7] + phi_ref = thetas[:, 8] inclination = thetas[:, 9] - + mismatches = np.log10(1 - matches) - - my_dict = {'m1': m1, - 'm2': m2, - 'chi1': chi1, - 'chi2': chi2, - 'lambda1': lambda1, - 'lambda2': lambda2, - 'dist_mpc': dist_mpc, - 'tc': tc, - 'phi_ref': phi_ref, - 'inclination': inclination, - 'match': matches, - 'mismatch': mismatches} + + my_dict = { + "m1": m1, + "m2": m2, + "chi1": chi1, + "chi2": chi2, + "lambda1": lambda1, + "lambda2": lambda2, + "dist_mpc": dist_mpc, + "tc": tc, + "phi_ref": phi_ref, + "inclination": inclination, + "match": matches, + "mismatch": mismatches, + } else: - m1 = thetas[:, 0] - m2 = thetas[:, 1] - chi1 = thetas[:, 2] - chi2 = thetas[:, 3] - dist_mpc = thetas[:, 4] - tc = thetas[:, 5] - phi_ref = thetas[:, 6] + m1 = thetas[:, 0] + m2 = thetas[:, 1] + chi1 = thetas[:, 2] + chi2 = thetas[:, 3] + dist_mpc = thetas[:, 4] + tc = thetas[:, 5] + phi_ref = thetas[:, 6] inclination = thetas[:, 7] - + mismatches = np.log10(1 - matches) - my_dict = {'m1': m1, - 'm2': m2, - 'chi1': chi1, - 'chi2': chi2, - 'dist_mpc': dist_mpc, - 'tc': tc, - 'phi_ref': phi_ref, - 'inclination': inclination, - 'match': matches, - 'mismatch': mismatches} - + my_dict = { + "m1": m1, + "m2": m2, + "chi1": chi1, + "chi2": chi2, + "dist_mpc": dist_mpc, + "tc": tc, + "phi_ref": phi_ref, + "inclination": inclination, + "match": matches, + "mismatch": mismatches, + } + # Sort the dict and print if desired df = pd.DataFrame.from_dict(my_dict) df = df.sort_values(by="mismatch", ascending=False) df.to_csv(filename) - + if verbose: print("Mean mismatch:", np.mean(mismatches)) print("Median mismatch:", np.median(mismatches)) @@ -323,24 +370,26 @@ def save_matches(filename, thetas, matches, verbose=True, is_tidal=False): return df + ########################## ### Speed benchmarking ### ########################## + def benchmark_speed(IMRphenom: str, n: int = 10_000): - + # Specify frequency range f_l = 20 - f_sampling = 1 * 2048 # 2048 for IMRPhenomD benchmark + f_sampling = 1 * 2048 # 2048 for IMRPhenomD benchmark T = 16 f_u = f_sampling // 2 f_ref = f_l fs = get_freqs(f_l, f_u, f_sampling, T) df = fs[1] - fs[0] - + is_tidal = check_is_tidal(IMRphenom) - + m_l, m_u = 0.5, 3.0 chi_l, chi_u = -1, 1 lambda_l, lambda_u = 0, 5000 @@ -354,27 +403,40 @@ def benchmark_speed(IMRphenom: str, n: int = 10_000): dist_mpc = np.random.uniform(0, 1000, n) tc = np.zeros_like(dist_mpc) - inclination = np.random.uniform(0, 2*PI, n) - phi_ref = np.random.uniform(0, 2*PI, n) - + inclination = np.random.uniform(0, 2 * PI, n) + phi_ref = np.random.uniform(0, 2 * PI, n) + waveform = get_jitted_waveform(IMRphenom, fs, f_ref) - + Mc, eta = ms_to_Mc_eta(jnp.array([m1, m2])) - lambda_tilde, delta_lambda_tilde = lambdas_to_lambda_tildes(jnp.array([l1, l2, m1, m2])) + lambda_tilde, delta_lambda_tilde = lambdas_to_lambda_tildes( + jnp.array([l1, l2, m1, m2]) + ) theta_ripple = np.array( - [Mc, eta, s1, s2, lambda_tilde, delta_lambda_tilde, dist_mpc, tc, phi_ref, inclination] + [ + Mc, + eta, + s1, + s2, + lambda_tilde, + delta_lambda_tilde, + dist_mpc, + tc, + phi_ref, + inclination, + ] ).T - + # If not tidal, remove lambda parameters if not is_tidal: theta_ripple = np.delete(theta_ripple, [4, 5], axis=1) - + # Perform the compilation before we time print("JIT compiling") waveform(theta_ripple[0])[0].block_until_ready() print("Finished JIT compiling") - + # First, benchmark the jitted version print("Benchmarking . . .") start = time.time() @@ -386,28 +448,28 @@ def benchmark_speed(IMRphenom: str, n: int = 10_000): # Second, benchmark the vmapped version func = jax.vmap(waveform) func(theta_ripple)[0].block_until_ready() - + print("Benchmarking . . .") start = time.time() func(theta_ripple)[0].block_until_ready() end = time.time() print("Vmapped ripple waveform call takes: %.6f ms" % ((end - start) * 1000 / n)) - - + + def benchmark_speed_lal(IMRphenom, n: int = 10_000): - + # Specify frequency range f_l = 20 - f_sampling = 1 * 2048 # 2048 for IMRPhenomD benchmark + f_sampling = 1 * 2048 # 2048 for IMRPhenomD benchmark T = 16 f_u = f_sampling // 2 f_ref = f_l fs = get_freqs(f_l, f_u, f_sampling, T) df = fs[1] - fs[0] - + is_tidal = check_is_tidal(IMRphenom) - + m_l, m_u = 0.5, 3.0 chi_l, chi_u = -1, 1 lambda_l, lambda_u = 0, 5000 @@ -421,27 +483,27 @@ def benchmark_speed_lal(IMRphenom, n: int = 10_000): dist_mpc = np.random.uniform(0, 1000, n) tc = np.zeros_like(dist_mpc) - inclination = np.random.uniform(0, 2*PI, n) - phi_ref = np.random.uniform(0, 2*PI, n) - + inclination = np.random.uniform(0, 2 * PI, n) + phi_ref = np.random.uniform(0, 2 * PI, n) + theta = np.array([m1, m2, s1, s2, l1, l2, dist_mpc, tc, phi_ref, inclination]).T # Ensure m1 > m2, but now for all entries in theta booleans = theta[:, 0] < theta[:, 1] booleans = np.repeat(booleans[:, np.newaxis], 10, axis=1) theta = np.where(booleans, theta[:, [1, 0, 3, 2, 5, 4, 6, 7, 8, 9]], theta) - + # Get approximant for lal approximant = lalsim.SimInspiralGetApproximantFromString(IMRphenom) - + f_ref = f_l - + # Define the lal waveform generators def lal_waveform_tidal(theta): - + # Get the tidal parameters and create laldict l1 = theta[4] l2 = theta[5] - + laldict = lal.CreateDict() lalsim.SimInspiralWaveformParamsInsertTidalLambda1(laldict, l1) lalsim.SimInspiralWaveformParamsInsertTidalLambda2(laldict, l2) @@ -450,7 +512,7 @@ def lal_waveform_tidal(theta): # Note that these are dquadmon, not quadmon, hence have to subtract 1 since that is added again later lalsim.SimInspiralWaveformParamsInsertdQuadMon1(laldict, quad1 - 1) lalsim.SimInspiralWaveformParamsInsertdQuadMon2(laldict, quad2 - 1) - + m1_kg = theta[0] * lal.MSUN_SI m2_kg = theta[1] * lal.MSUN_SI # Get distance parameter @@ -464,10 +526,10 @@ def lal_waveform_tidal(theta): m2_kg, 0.0, 0.0, - theta[2], # spin m1 zero component + theta[2], # spin m1 zero component 0.0, 0.0, - theta[3], # spin m2 zero component + theta[3], # spin m2 zero component distance, inclination, phi_ref, @@ -483,9 +545,9 @@ def lal_waveform_tidal(theta): ) def lal_waveform_no_tidal(theta): - + # Note: theta still has lambda parameters, but we just select the ones we need here - + m1_kg = theta[0] * lal.MSUN_SI m2_kg = theta[1] * lal.MSUN_SI # Get distance parameter @@ -499,10 +561,10 @@ def lal_waveform_no_tidal(theta): m2_kg, 0.0, 0.0, - theta[2], # spin m1 zero component + theta[2], # spin m1 zero component 0.0, 0.0, - theta[3], # spin m2 zero component + theta[3], # spin m2 zero component distance, inclination, phi_ref, @@ -516,7 +578,7 @@ def lal_waveform_no_tidal(theta): None, approximant, ) - + # Start benchmarking if is_tidal: print("Benchmarking tidal waveform") @@ -535,17 +597,20 @@ def lal_waveform_no_tidal(theta): print("Done") + if __name__ == "__main__": - + # Showing an example of benchmarking: - bounds = {"m": [0.5, 3.0], - "chi": [-0.05, 0.05], - "lambda": [0.0, 5000.0], - "d_L": [30.0, 300.0]} - + bounds = { + "m": [0.5, 3.0], + "chi": [-0.05, 0.05], + "lambda": [0.0, 5000.0], + "d_L": [30.0, 300.0], + } + approximant = "TaylorF2" print(f"Checking approximant {approximant}") print("Checking mismatches wrt LAL") - df = random_match(1000, bounds, approximant, outdir = "./") + df = random_match(1000, bounds, approximant, outdir="./") print("Done. The dataframe is:") - print(df) \ No newline at end of file + print(df) From 3d514d2ca8293965670133184e534b212077cb13 Mon Sep 17 00:00:00 2001 From: "leonardo.ricca" Date: Thu, 11 Dec 2025 17:52:33 +0100 Subject: [PATCH 20/20] Add elliptic integrals --- src/ripplegw/waveforms/IMRPhenomXP_utils.py | 25 +++++---------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/ripplegw/waveforms/IMRPhenomXP_utils.py b/src/ripplegw/waveforms/IMRPhenomXP_utils.py index 9de84740..47842768 100644 --- a/src/ripplegw/waveforms/IMRPhenomXP_utils.py +++ b/src/ripplegw/waveforms/IMRPhenomXP_utils.py @@ -16,6 +16,7 @@ from ..typing import Array from .IMRPhenomD_QNMdata import QNMData_a, QNMData_fRD, QNMData_fdamp from .spherical_harmonics import * +from ..gsl_ellint import ellipfinc, gsl_sf_elljac_e # helper functions for LALtoPhenomP: @@ -387,8 +388,6 @@ def IMRPhenomX_Return_Roots_MSA(LNorm, JNorm, pPrec): theta = jnp.arccos(acosarg) / 3.0 cos_theta = jnp.cos(theta) - print(f'{p=}, {sqrtarg=}, {theta=}, {B=}, {B2=}, {C=}') - vector_condition = jnp.logical_or(jnp.isnan(theta), (jnp.isnan(sqrtarg))) scalar_condition = jnp.logical_or.reduce(jnp.array([(pPrec['dotS1Ln'] == 1.0), @@ -429,8 +428,6 @@ def roots_when_invalid(): roots_when_invalid(), roots_when_valid() ) - - print(f'{roots_array=}') return roots_array @@ -1278,19 +1275,19 @@ def branch_not_equal(pPrec): tmpB_val > -1.0e-5) def case1(): - return gsl_sf_ellint_F( ##(TODO) + return ellipfinc( jnp.arcsin(vol_sign_val * jnp.sqrt(1.0)), mm_val ) - psi_v0_val def case2(): - return gsl_sf_ellint_F( + return ellipfinc( jnp.arcsin(vol_sign_val * jnp.sqrt(0.0)), mm_val ) - psi_v0_val def case3(): - return gsl_sf_ellint_F( + return ellipfinc( jnp.arcsin(vol_sign_val * jnp.sqrt(tmpB_val)), mm_val ) - psi_v0_val @@ -1366,7 +1363,7 @@ def IMRPhenomX_SetPrecessingRemnantParams(pWF, pPrec, params): ## Toggle for enforced use of non-precessing spin as is required during tuning of PNR's coprecessing model ## High-level toggle for whether to apply deviations. APPLY_PNR_DEVIATIONS == 0 for the moment - af_parallel = XLALSimIMRPhenomXFinalSpin2017(pWF['eta'], pPrec['chi1z'], pPrec['chi2z']) ## (TODO) + af_parallel = XLALSimIMRPhenomXFinalSpin2017(pWF['eta'], pPrec['chi1z'], pPrec['chi2z']) M = pWF['M'] Lfinal = M * M * af_parallel - pWF['m1_2'] * pPrec['chi1z'] - pWF['m2_2'] * pPrec['chi2z'] @@ -1611,18 +1608,6 @@ def Get_alphaepsilon_atfref(mprime, pPrec, pWF): ## NOT YET IMPLEMENTED ### ########################## -def gsl_sf_ellint_F(x, y): - """ - TODO: Not yet implemented - """ - return x - -def gsl_sf_elljac_e(x, y): - """ - TODO: Not yet implemented - """ - return jnp.array(1.), jnp.array(2.), jnp.array(3.) - def XLALSimInspiralWaveformParamsLookupPhenomXPTransPrecessionMethod(dict): """ TODO: Not yet implemented