From f07842059633fd6812a83909268b05aaf77c5de0 Mon Sep 17 00:00:00 2001 From: Michael Eliot Date: Wed, 5 Nov 2025 22:30:15 -0500 Subject: [PATCH] first pass at separate functions fix: get polyval ordering correct fix: correct tensor shape issue fix: correct approx to handle infinity and -infinity better fix: correct small nit with function names add tests fix: test name --- jax/_src/scipy/special.py | 138 ++++++++++++++++++++++ tests/lax_scipy_special_functions_test.py | 10 ++ 2 files changed, 148 insertions(+) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 099d2a52cca0..1d295a112480 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -2115,6 +2115,7 @@ def expi_jvp(primals, tangents): (x_dot,) = tangents return expi(x), jnp.exp(x) / x * x_dot + @custom_derivatives.custom_jvp @jit def sici(x: ArrayLike) -> tuple[Array, Array]: @@ -2150,7 +2151,35 @@ def sici(x: ArrayLike) -> tuple[Array, Array]: raise ValueError( f"Argument `x` to sici must be real-valued. Got dtype {x.dtype}." ) + + si_series, ci_series = _sici_series(x) + si_asymp, ci_asymp = _sici_asympt(x) + si_approx, ci_approx = _sici_approx(x) + + cond1 = x <= 4 + cond2 = (x > 4) & (x <= 1e9) + + si = jnp.select([cond1, cond2], [si_series, si_asymp], si_approx) + ci = jnp.select([cond1, cond2], [ci_series, ci_asymp], ci_approx) + + return si, ci + +def _sici_approx(x: ArrayLike): + si = (np.pi / 2) - jnp.cos(x) / x + ci = jnp.sin(x) / x + + si = jnp.where(x == 0, 0.0, si) + si = jnp.where(isposinf(x), np.pi / 2, si) + si = jnp.where(isneginf(x), -np.pi / 2, si) + + ci = jnp.where(x == 0, -np.inf, ci) + ci = jnp.where(isposinf(x), 0.0, ci) + ci = jnp.where(isneginf(x), np.nan, ci) + return si, ci + + +def _sici_series(x: ArrayLike): def si_series(x): # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c SN = np.array([-8.39167827910303881427E-11, @@ -2201,6 +2230,115 @@ def ci_series(x): return si, ci + +def _sici_asympt(x: ArrayLike): + s = jnp.sin(x) + c = jnp.cos(x) + z = 1.0 / (x * x) + + # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c + FN4 = jnp.array([ + 4.23612862892216586994E0, + 5.45937717161812843388E0, + 1.62083287701538329132E0, + 1.67006611831323023771E-1, + 6.81020132472518137426E-3, + 1.08936580650328664411E-4, + 5.48900223421373614008E-7, + ], dtype=x.dtype) + FD4 = jnp.array([ + 1, + 8.16496634205391016773E0, + 7.30828822505564552187E0, + 1.86792257950184183883E0, + 1.78792052963149907262E-1, + 7.01710668322789753610E-3, + 1.10034357153915731354E-4, + 5.48900252756255700982E-7, + ], dtype=x.dtype) + GN4 = jnp.array([ + 8.71001698973114191777E-2, + 6.11379109952219284151E-1, + 3.97180296392337498885E-1, + 7.48527737628469092119E-2, + 5.38868681462177273157E-3, + 1.61999794598934024525E-4, + 1.97963874140963632189E-6, + 7.82579040744090311069E-9, + ], dtype=x.dtype) + GD4 = jnp.array([ + 1, + 1.64402202413355338886E0, + 6.66296701268987968381E-1, + 9.88771761277688796203E-2, + 6.22396345441768420760E-3, + 1.73221081474177119497E-4, + 2.02659182086343991969E-6, + 7.82579218933534490868E-9, + ], dtype=x.dtype) + + FN8 = jnp.array([ + 4.55880873470465315206E-1, + 7.13715274100146711374E-1, + 1.60300158222319456320E-1, + 1.16064229408124407915E-2, + 3.49556442447859055605E-4, + 4.86215430826454749482E-6, + 3.20092790091004902806E-8, + 9.41779576128512936592E-11, + 9.70507110881952024631E-14, + ], dtype=x.dtype) + FD8 = jnp.array([ + 1.0, + 9.17463611873684053703E-1, + 1.78685545332074536321E-1, + 1.22253594771971293032E-2, + 3.58696481881851580297E-4, + 4.92435064317881464393E-6, + 3.21956939101046018377E-8, + 9.43720590350276732376E-11, + 9.70507110881952025725E-14, + ], dtype=x.dtype) + GN8 = jnp.array([ + 6.97359953443276214934E-1, + 3.30410979305632063225E-1, + 3.84878767649974295920E-2, + 1.71718239052347903558E-3, + 3.48941165502279436777E-5, + 3.47131167084116673800E-7, + 1.70404452782044526189E-9, + 3.85945925430276600453E-12, + 3.14040098946363334640E-15, + ], dtype=x.dtype) + GD8 = jnp.array([ + 1.0, + 1.68548898811011640017E0, + 4.87852258695304967486E-1, + 4.67913194259625806320E-2, + 1.90284426674399523638E-3, + 3.68475504442561108162E-5, + 3.57043223443740838771E-7, + 1.72693748966316146736E-9, + 3.87830166023954706752E-12, + 3.14040098946363335242E-15, + ], dtype=x.dtype) + + f4 = jnp.polyval(FN4, z) / (x * jnp.polyval(FD4, z)) + g4 = z * jnp.polyval(GN4, z) / jnp.polyval(GD4, z) + + f8 = jnp.polyval(FN8, z) / (x * jnp.polyval(FD8, z)) + g8 = z * jnp.polyval(GN8, z) / jnp.polyval(GD8, z) + + mask = x < 8.0 + f = jnp.where(mask, f4, f8) + g = jnp.where(mask, g4, g8) + + si = (np.pi / 2) - f * c - g * s + si = jnp.sign(x) * si + ci = f * s - g * c + + return si, ci + @sici.defjvp @jit def sici_jvp(primals, tangents): diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index c9ccaf4ca9ee..4f2d9c6f4a10 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -400,6 +400,16 @@ def testSiciEdgeCases(self): self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6) self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6) + def testSiciValueRanges(self): + dtype = jnp.zeros(0).dtype + x_samples = np.array([2, 6, 10, 1e15], dtype=dtype) + scipy_op = lambda x: osp_special.sici(x) + lax_op = lambda x: lsp_special.sici(x) + si_scipy, ci_scipy = scipy_op(x_samples) + si_jax, ci_jax = lax_op(x_samples) + self.assertAllClose(si_jax, si_scipy, atol=1e-6, rtol=1e-6) + self.assertAllClose(ci_jax, ci_scipy, atol=1e-6, rtol=1e-6) + def testSiciRaiseOnComplexInput(self): samples = jnp.arange(5, dtype=complex) with self.assertRaisesRegex(ValueError, "Argument `x` to sici must be real-valued."):