@@ -2115,6 +2115,7 @@ def expi_jvp(primals, tangents):
21152115 (x_dot ,) = tangents
21162116 return expi (x ), jnp .exp (x ) / x * x_dot
21172117
2118+
21182119@custom_derivatives .custom_jvp
21192120@jit
21202121def sici (x : ArrayLike ) -> tuple [Array , Array ]:
@@ -2150,7 +2151,35 @@ def sici(x: ArrayLike) -> tuple[Array, Array]:
21502151 raise ValueError (
21512152 f"Argument `x` to sici must be real-valued. Got dtype { x .dtype } ."
21522153 )
2154+
2155+ si_series , ci_series = _sici_series (x )
2156+ si_asymp , ci_asymp = _sici_asympt (x )
2157+ si_approx , ci_approx = _sici_approx (x )
2158+
2159+ cond1 = x <= 4
2160+ cond2 = (x > 4 ) & (x <= 1e9 )
2161+
2162+ si = jnp .select ([cond1 , cond2 ], [si_series , si_asymp ], si_approx )
2163+ ci = jnp .select ([cond1 , cond2 ], [ci_series , ci_asymp ], ci_approx )
2164+
2165+ return si , ci
2166+
2167+ def _sici_approx (x : ArrayLike ):
2168+ si = (np .pi / 2 ) - jnp .cos (x ) / x
2169+ ci = jnp .sin (x ) / x
2170+
2171+ si = jnp .where (x == 0 , 0.0 , si )
2172+ si = jnp .where (isposinf (x ), np .pi / 2 , si )
2173+ si = jnp .where (isneginf (x ), - np .pi / 2 , si )
2174+
2175+ ci = jnp .where (x == 0 , - np .inf , ci )
2176+ ci = jnp .where (isposinf (x ), 0.0 , ci )
2177+ ci = jnp .where (isneginf (x ), np .nan , ci )
21532178
2179+ return si , ci
2180+
2181+
2182+ def _sici_series (x : ArrayLike ):
21542183 def si_series (x ):
21552184 # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
21562185 SN = np .array ([- 8.39167827910303881427E-11 ,
@@ -2201,6 +2230,115 @@ def ci_series(x):
22012230
22022231 return si , ci
22032232
2233+
2234+ def _sici_asympt (x : ArrayLike ):
2235+ s = jnp .sin (x )
2236+ c = jnp .cos (x )
2237+ z = 1.0 / (x * x )
2238+
2239+ # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
2240+ FN4 = jnp .array ([
2241+ 4.23612862892216586994E0 ,
2242+ 5.45937717161812843388E0 ,
2243+ 1.62083287701538329132E0 ,
2244+ 1.67006611831323023771E-1 ,
2245+ 6.81020132472518137426E-3 ,
2246+ 1.08936580650328664411E-4 ,
2247+ 5.48900223421373614008E-7 ,
2248+ ], dtype = x .dtype )
2249+ FD4 = jnp .array ([
2250+ 1 ,
2251+ 8.16496634205391016773E0 ,
2252+ 7.30828822505564552187E0 ,
2253+ 1.86792257950184183883E0 ,
2254+ 1.78792052963149907262E-1 ,
2255+ 7.01710668322789753610E-3 ,
2256+ 1.10034357153915731354E-4 ,
2257+ 5.48900252756255700982E-7 ,
2258+ ], dtype = x .dtype )
2259+ GN4 = jnp .array ([
2260+ 8.71001698973114191777E-2 ,
2261+ 6.11379109952219284151E-1 ,
2262+ 3.97180296392337498885E-1 ,
2263+ 7.48527737628469092119E-2 ,
2264+ 5.38868681462177273157E-3 ,
2265+ 1.61999794598934024525E-4 ,
2266+ 1.97963874140963632189E-6 ,
2267+ 7.82579040744090311069E-9 ,
2268+ ], dtype = x .dtype )
2269+ GD4 = jnp .array ([
2270+ 1 ,
2271+ 1.64402202413355338886E0 ,
2272+ 6.66296701268987968381E-1 ,
2273+ 9.88771761277688796203E-2 ,
2274+ 6.22396345441768420760E-3 ,
2275+ 1.73221081474177119497E-4 ,
2276+ 2.02659182086343991969E-6 ,
2277+ 7.82579218933534490868E-9 ,
2278+ ], dtype = x .dtype )
2279+
2280+ FN8 = jnp .array ([
2281+ 4.55880873470465315206E-1 ,
2282+ 7.13715274100146711374E-1 ,
2283+ 1.60300158222319456320E-1 ,
2284+ 1.16064229408124407915E-2 ,
2285+ 3.49556442447859055605E-4 ,
2286+ 4.86215430826454749482E-6 ,
2287+ 3.20092790091004902806E-8 ,
2288+ 9.41779576128512936592E-11 ,
2289+ 9.70507110881952024631E-14 ,
2290+ ], dtype = x .dtype )
2291+ FD8 = jnp .array ([
2292+ 1.0 ,
2293+ 9.17463611873684053703E-1 ,
2294+ 1.78685545332074536321E-1 ,
2295+ 1.22253594771971293032E-2 ,
2296+ 3.58696481881851580297E-4 ,
2297+ 4.92435064317881464393E-6 ,
2298+ 3.21956939101046018377E-8 ,
2299+ 9.43720590350276732376E-11 ,
2300+ 9.70507110881952025725E-14 ,
2301+ ], dtype = x .dtype )
2302+ GN8 = jnp .array ([
2303+ 6.97359953443276214934E-1 ,
2304+ 3.30410979305632063225E-1 ,
2305+ 3.84878767649974295920E-2 ,
2306+ 1.71718239052347903558E-3 ,
2307+ 3.48941165502279436777E-5 ,
2308+ 3.47131167084116673800E-7 ,
2309+ 1.70404452782044526189E-9 ,
2310+ 3.85945925430276600453E-12 ,
2311+ 3.14040098946363334640E-15 ,
2312+ ], dtype = x .dtype )
2313+ GD8 = jnp .array ([
2314+ 1.0 ,
2315+ 1.68548898811011640017E0 ,
2316+ 4.87852258695304967486E-1 ,
2317+ 4.67913194259625806320E-2 ,
2318+ 1.90284426674399523638E-3 ,
2319+ 3.68475504442561108162E-5 ,
2320+ 3.57043223443740838771E-7 ,
2321+ 1.72693748966316146736E-9 ,
2322+ 3.87830166023954706752E-12 ,
2323+ 3.14040098946363335242E-15 ,
2324+ ], dtype = x .dtype )
2325+
2326+ f4 = jnp .polyval (FN4 , z ) / (x * jnp .polyval (FD4 , z ))
2327+ g4 = z * jnp .polyval (GN4 , z ) / jnp .polyval (GD4 , z )
2328+
2329+ f8 = jnp .polyval (FN8 , z ) / (x * jnp .polyval (FD8 , z ))
2330+ g8 = z * jnp .polyval (GN8 , z ) / jnp .polyval (GD8 , z )
2331+
2332+ mask = x < 8.0
2333+ f = jnp .where (mask , f4 , f8 )
2334+ g = jnp .where (mask , g4 , g8 )
2335+
2336+ si = (np .pi / 2 ) - f * c - g * s
2337+ si = jnp .sign (x ) * si
2338+ ci = f * s - g * c
2339+
2340+ return si , ci
2341+
22042342@sici .defjvp
22052343@jit
22062344def sici_jvp (primals , tangents ):
0 commit comments