Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,7 @@ def expi_jvp(primals, tangents):
(x_dot,) = tangents
return expi(x), jnp.exp(x) / x * x_dot


@custom_derivatives.custom_jvp
@jit
def sici(x: ArrayLike) -> tuple[Array, Array]:
Expand Down Expand Up @@ -2150,7 +2151,35 @@ def sici(x: ArrayLike) -> tuple[Array, Array]:
raise ValueError(
f"Argument `x` to sici must be real-valued. Got dtype {x.dtype}."
)

si_series, ci_series = _sici_series(x)
si_asymp, ci_asymp = _sici_asympt(x)
si_approx, ci_approx = _sici_approx(x)

cond1 = x <= 4
cond2 = (x > 4) & (x <= 1e9)

si = jnp.select([cond1, cond2], [si_series, si_asymp], si_approx)
ci = jnp.select([cond1, cond2], [ci_series, ci_asymp], ci_approx)

return si, ci

def _sici_approx(x: ArrayLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment here, something like # sici approximation valid for x > 1E9

si = (np.pi / 2) - jnp.cos(x) / x
ci = jnp.sin(x) / x

si = jnp.where(x == 0, 0.0, si)
si = jnp.where(isposinf(x), np.pi / 2, si)
si = jnp.where(isneginf(x), -np.pi / 2, si)
Comment on lines +2171 to +2173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since this function is only called for x > 1e9, the checks for x == 0 and isneginf(x) are redundant and can be removed for clarity and to avoid unnecessary computation.

      si = jnp.where(isposinf(x), np.pi / 2, si)
      # The x == 0 and isneginf(x) cases are unreachable because this
      # function is only used for x > 1e9.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is currently only called this way, my reasoning was that the function might be called at some point with other values. Happy to change this if this isn't preferred.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the bot here.


ci = jnp.where(x == 0, -np.inf, ci)
ci = jnp.where(isposinf(x), 0.0, ci)
ci = jnp.where(isneginf(x), np.nan, ci)
Comment on lines +2175 to +2177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, since this function is only called for x > 1e9, the checks for x == 0 and isneginf(x) are redundant and can be removed.

      ci = jnp.where(isposinf(x), 0.0, ci)
      # The x == 0 and isneginf(x) cases are unreachable because this
      # function is only used for x > 1e9.


return si, ci


def _sici_series(x: ArrayLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about range of validity

def si_series(x):
# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
SN = np.array([-8.39167827910303881427E-11,
Expand Down Expand Up @@ -2201,6 +2230,115 @@ def ci_series(x):

return si, ci


def _sici_asympt(x: ArrayLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about range of validity

s = jnp.sin(x)
c = jnp.cos(x)
z = 1.0 / (x * x)

# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
FN4 = jnp.array([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use np.array rather than jnp.array for lists of python floats; jnp.array(list) can be very slow.

4.23612862892216586994E0,
5.45937717161812843388E0,
1.62083287701538329132E0,
1.67006611831323023771E-1,
6.81020132472518137426E-3,
1.08936580650328664411E-4,
5.48900223421373614008E-7,
], dtype=x.dtype)
FD4 = jnp.array([
1,
8.16496634205391016773E0,
7.30828822505564552187E0,
1.86792257950184183883E0,
1.78792052963149907262E-1,
7.01710668322789753610E-3,
1.10034357153915731354E-4,
5.48900252756255700982E-7,
], dtype=x.dtype)
GN4 = jnp.array([
8.71001698973114191777E-2,
6.11379109952219284151E-1,
3.97180296392337498885E-1,
7.48527737628469092119E-2,
5.38868681462177273157E-3,
1.61999794598934024525E-4,
1.97963874140963632189E-6,
7.82579040744090311069E-9,
], dtype=x.dtype)
GD4 = jnp.array([
1,
1.64402202413355338886E0,
6.66296701268987968381E-1,
9.88771761277688796203E-2,
6.22396345441768420760E-3,
1.73221081474177119497E-4,
2.02659182086343991969E-6,
7.82579218933534490868E-9,
], dtype=x.dtype)

FN8 = jnp.array([
4.55880873470465315206E-1,
7.13715274100146711374E-1,
1.60300158222319456320E-1,
1.16064229408124407915E-2,
3.49556442447859055605E-4,
4.86215430826454749482E-6,
3.20092790091004902806E-8,
9.41779576128512936592E-11,
9.70507110881952024631E-14,
], dtype=x.dtype)
FD8 = jnp.array([
1.0,
9.17463611873684053703E-1,
1.78685545332074536321E-1,
1.22253594771971293032E-2,
3.58696481881851580297E-4,
4.92435064317881464393E-6,
3.21956939101046018377E-8,
9.43720590350276732376E-11,
9.70507110881952025725E-14,
], dtype=x.dtype)
GN8 = jnp.array([
6.97359953443276214934E-1,
3.30410979305632063225E-1,
3.84878767649974295920E-2,
1.71718239052347903558E-3,
3.48941165502279436777E-5,
3.47131167084116673800E-7,
1.70404452782044526189E-9,
3.85945925430276600453E-12,
3.14040098946363334640E-15,
], dtype=x.dtype)
GD8 = jnp.array([
1.0,
1.68548898811011640017E0,
4.87852258695304967486E-1,
4.67913194259625806320E-2,
1.90284426674399523638E-3,
3.68475504442561108162E-5,
3.57043223443740838771E-7,
1.72693748966316146736E-9,
3.87830166023954706752E-12,
3.14040098946363335242E-15,
], dtype=x.dtype)

f4 = jnp.polyval(FN4, z) / (x * jnp.polyval(FD4, z))
g4 = z * jnp.polyval(GN4, z) / jnp.polyval(GD4, z)

f8 = jnp.polyval(FN8, z) / (x * jnp.polyval(FD8, z))
g8 = z * jnp.polyval(GN8, z) / jnp.polyval(GD8, z)

mask = x < 8.0
f = jnp.where(mask, f4, f8)
g = jnp.where(mask, g4, g8)

si = (np.pi / 2) - f * c - g * s
si = jnp.sign(x) * si
ci = f * s - g * c

return si, ci

@sici.defjvp
@jit
def sici_jvp(primals, tangents):
Expand Down
10 changes: 10 additions & 0 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,16 @@ def testSiciEdgeCases(self):
self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6)
self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6)

def testSiciValueRanges(self):
Copy link
Collaborator

@jakevdp jakevdp Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than specific values, I think it would be better to test sici with randomly generated arrays generated from jtu.rand_default(scale=10) and jtu.rand_default(scale=1E9) You could parametrize the test over the scale value (use jtu.sample_product)

dtype = jnp.zeros(0).dtype
x_samples = np.array([2, 6, 10, 1e15], dtype=dtype)
scipy_op = lambda x: osp_special.sici(x)
lax_op = lambda x: lsp_special.sici(x)
si_scipy, ci_scipy = scipy_op(x_samples)
si_jax, ci_jax = lax_op(x_samples)
self.assertAllClose(si_jax, si_scipy, atol=1e-6, rtol=1e-6)
self.assertAllClose(ci_jax, ci_scipy, atol=1e-6, rtol=1e-6)

def testSiciRaiseOnComplexInput(self):
samples = jnp.arange(5, dtype=complex)
with self.assertRaisesRegex(ValueError, "Argument `x` to sici must be real-valued."):
Expand Down