Skip to content

Commit f078420

Browse files
committed
first pass at separate functions
fix: get polyval ordering correct fix: correct tensor shape issue fix: correct approx to handle infinity and -infinity better fix: correct small nit with function names add tests fix: test name
1 parent 445d75c commit f078420

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

jax/_src/scipy/special.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,6 +2115,7 @@ def expi_jvp(primals, tangents):
21152115
(x_dot,) = tangents
21162116
return expi(x), jnp.exp(x) / x * x_dot
21172117

2118+
21182119
@custom_derivatives.custom_jvp
21192120
@jit
21202121
def sici(x: ArrayLike) -> tuple[Array, Array]:
@@ -2150,7 +2151,35 @@ def sici(x: ArrayLike) -> tuple[Array, Array]:
21502151
raise ValueError(
21512152
f"Argument `x` to sici must be real-valued. Got dtype {x.dtype}."
21522153
)
2154+
2155+
si_series, ci_series = _sici_series(x)
2156+
si_asymp, ci_asymp = _sici_asympt(x)
2157+
si_approx, ci_approx = _sici_approx(x)
2158+
2159+
cond1 = x <= 4
2160+
cond2 = (x > 4) & (x <= 1e9)
2161+
2162+
si = jnp.select([cond1, cond2], [si_series, si_asymp], si_approx)
2163+
ci = jnp.select([cond1, cond2], [ci_series, ci_asymp], ci_approx)
2164+
2165+
return si, ci
2166+
2167+
def _sici_approx(x: ArrayLike):
2168+
si = (np.pi / 2) - jnp.cos(x) / x
2169+
ci = jnp.sin(x) / x
2170+
2171+
si = jnp.where(x == 0, 0.0, si)
2172+
si = jnp.where(isposinf(x), np.pi / 2, si)
2173+
si = jnp.where(isneginf(x), -np.pi / 2, si)
2174+
2175+
ci = jnp.where(x == 0, -np.inf, ci)
2176+
ci = jnp.where(isposinf(x), 0.0, ci)
2177+
ci = jnp.where(isneginf(x), np.nan, ci)
21532178

2179+
return si, ci
2180+
2181+
2182+
def _sici_series(x: ArrayLike):
21542183
def si_series(x):
21552184
# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
21562185
SN = np.array([-8.39167827910303881427E-11,
@@ -2201,6 +2230,115 @@ def ci_series(x):
22012230

22022231
return si, ci
22032232

2233+
2234+
def _sici_asympt(x: ArrayLike):
2235+
s = jnp.sin(x)
2236+
c = jnp.cos(x)
2237+
z = 1.0 / (x * x)
2238+
2239+
# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
2240+
FN4 = jnp.array([
2241+
4.23612862892216586994E0,
2242+
5.45937717161812843388E0,
2243+
1.62083287701538329132E0,
2244+
1.67006611831323023771E-1,
2245+
6.81020132472518137426E-3,
2246+
1.08936580650328664411E-4,
2247+
5.48900223421373614008E-7,
2248+
], dtype=x.dtype)
2249+
FD4 = jnp.array([
2250+
1,
2251+
8.16496634205391016773E0,
2252+
7.30828822505564552187E0,
2253+
1.86792257950184183883E0,
2254+
1.78792052963149907262E-1,
2255+
7.01710668322789753610E-3,
2256+
1.10034357153915731354E-4,
2257+
5.48900252756255700982E-7,
2258+
], dtype=x.dtype)
2259+
GN4 = jnp.array([
2260+
8.71001698973114191777E-2,
2261+
6.11379109952219284151E-1,
2262+
3.97180296392337498885E-1,
2263+
7.48527737628469092119E-2,
2264+
5.38868681462177273157E-3,
2265+
1.61999794598934024525E-4,
2266+
1.97963874140963632189E-6,
2267+
7.82579040744090311069E-9,
2268+
], dtype=x.dtype)
2269+
GD4 = jnp.array([
2270+
1,
2271+
1.64402202413355338886E0,
2272+
6.66296701268987968381E-1,
2273+
9.88771761277688796203E-2,
2274+
6.22396345441768420760E-3,
2275+
1.73221081474177119497E-4,
2276+
2.02659182086343991969E-6,
2277+
7.82579218933534490868E-9,
2278+
], dtype=x.dtype)
2279+
2280+
FN8 = jnp.array([
2281+
4.55880873470465315206E-1,
2282+
7.13715274100146711374E-1,
2283+
1.60300158222319456320E-1,
2284+
1.16064229408124407915E-2,
2285+
3.49556442447859055605E-4,
2286+
4.86215430826454749482E-6,
2287+
3.20092790091004902806E-8,
2288+
9.41779576128512936592E-11,
2289+
9.70507110881952024631E-14,
2290+
], dtype=x.dtype)
2291+
FD8 = jnp.array([
2292+
1.0,
2293+
9.17463611873684053703E-1,
2294+
1.78685545332074536321E-1,
2295+
1.22253594771971293032E-2,
2296+
3.58696481881851580297E-4,
2297+
4.92435064317881464393E-6,
2298+
3.21956939101046018377E-8,
2299+
9.43720590350276732376E-11,
2300+
9.70507110881952025725E-14,
2301+
], dtype=x.dtype)
2302+
GN8 = jnp.array([
2303+
6.97359953443276214934E-1,
2304+
3.30410979305632063225E-1,
2305+
3.84878767649974295920E-2,
2306+
1.71718239052347903558E-3,
2307+
3.48941165502279436777E-5,
2308+
3.47131167084116673800E-7,
2309+
1.70404452782044526189E-9,
2310+
3.85945925430276600453E-12,
2311+
3.14040098946363334640E-15,
2312+
], dtype=x.dtype)
2313+
GD8 = jnp.array([
2314+
1.0,
2315+
1.68548898811011640017E0,
2316+
4.87852258695304967486E-1,
2317+
4.67913194259625806320E-2,
2318+
1.90284426674399523638E-3,
2319+
3.68475504442561108162E-5,
2320+
3.57043223443740838771E-7,
2321+
1.72693748966316146736E-9,
2322+
3.87830166023954706752E-12,
2323+
3.14040098946363335242E-15,
2324+
], dtype=x.dtype)
2325+
2326+
f4 = jnp.polyval(FN4, z) / (x * jnp.polyval(FD4, z))
2327+
g4 = z * jnp.polyval(GN4, z) / jnp.polyval(GD4, z)
2328+
2329+
f8 = jnp.polyval(FN8, z) / (x * jnp.polyval(FD8, z))
2330+
g8 = z * jnp.polyval(GN8, z) / jnp.polyval(GD8, z)
2331+
2332+
mask = x < 8.0
2333+
f = jnp.where(mask, f4, f8)
2334+
g = jnp.where(mask, g4, g8)
2335+
2336+
si = (np.pi / 2) - f * c - g * s
2337+
si = jnp.sign(x) * si
2338+
ci = f * s - g * c
2339+
2340+
return si, ci
2341+
22042342
@sici.defjvp
22052343
@jit
22062344
def sici_jvp(primals, tangents):

tests/lax_scipy_special_functions_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,16 @@ def testSiciEdgeCases(self):
400400
self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6)
401401
self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6)
402402

403+
def testSiciValueRanges(self):
404+
dtype = jnp.zeros(0).dtype
405+
x_samples = np.array([2, 6, 10, 1e15], dtype=dtype)
406+
scipy_op = lambda x: osp_special.sici(x)
407+
lax_op = lambda x: lsp_special.sici(x)
408+
si_scipy, ci_scipy = scipy_op(x_samples)
409+
si_jax, ci_jax = lax_op(x_samples)
410+
self.assertAllClose(si_jax, si_scipy, atol=1e-6, rtol=1e-6)
411+
self.assertAllClose(ci_jax, ci_scipy, atol=1e-6, rtol=1e-6)
412+
403413
def testSiciRaiseOnComplexInput(self):
404414
samples = jnp.arange(5, dtype=complex)
405415
with self.assertRaisesRegex(ValueError, "Argument `x` to sici must be real-valued."):

0 commit comments

Comments
 (0)