Skip to content

Conversation

@michaeleliot
Copy link
Contributor

Brings the Sici function closer to the Cephes implementation by adding asymptotic expansion for values > 4, and doing fixed approximation for values > 1e9.

Couple Thoughts

  • The Cephes code splits first on 4 & 1e9, and then splits on 8. I keep that form, but we could easily split into more functions and only have one branch.
  • I couldn't think of a way to do the branch selection depending on value, so I compute it using all 3 functions and then just pick the ones I need. This feels inefficient.
  • Not 100% sure how to test this. I selected values from each branch, but I don't know how to test for the time issue as I can't reproduce it on large values. Maybe also a test explicitly relating the call with jax.numpy.linalg.solve?

Related #32052
Closes #33081

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @michaeleliot, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refines the sici (Sine and Cosine Integral) function within JAX's SciPy special functions. By integrating advanced numerical methods, specifically asymptotic expansions and a fixed approximation for extremely large inputs, the function's accuracy and computational efficiency are enhanced across a wider domain. This brings the JAX implementation closer to established numerical libraries like Cephes, ensuring more robust behavior for diverse use cases.

Highlights

  • Asymptotic Expansion for Sici: Introduced an asymptotic expansion method for the sici function, specifically for input values greater than 4, to improve accuracy and performance for larger numbers.
  • Fixed Approximation for Large Values: Implemented a fixed approximation for sici for extremely large input values (greater than 1e9), further enhancing numerical stability and efficiency at the upper end of the domain.
  • Dynamic Method Selection: The sici function now dynamically selects between series expansion (for small values), asymptotic expansion (for intermediate values), and fixed approximation (for very large values) based on the input magnitude, mirroring the robust Cephes implementation.
  • Expanded Test Coverage: Added new tests to explicitly verify the sici function's behavior and accuracy across these different value ranges, ensuring consistency with SciPy's implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great improvement to the sici function, bringing it closer to the Cephes implementation by adding an asymptotic expansion and an approximation for large values. The code is well-structured with helper functions for different value ranges.

Regarding your concern about inefficiency from computing all three branches (_sici_series, _sici_asympt, _sici_approx) and then selecting the result: this is an idiomatic and often necessary pattern in JAX. To make a function JIT-compatible, we generally need to avoid data-dependent control flow at the Python level. jnp.select (or jnp.where) is the standard way to express this kind of branching for array inputs. While it may seem inefficient, XLA is often able to optimize these patterns, and the performance impact might be less than you expect, especially on parallel hardware. The alternative using lax.cond would require vmap and be more complex to implement for array inputs.

I've added a few suggestions to clean up the special value handling and improve readability, along with a minor naming suggestion for the new test.

Comment on lines +2171 to +2173
si = jnp.where(x == 0, 0.0, si)
si = jnp.where(isposinf(x), np.pi / 2, si)
si = jnp.where(isneginf(x), -np.pi / 2, si)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since this function is only called for x > 1e9, the checks for x == 0 and isneginf(x) are redundant and can be removed for clarity and to avoid unnecessary computation.

      si = jnp.where(isposinf(x), np.pi / 2, si)
      # The x == 0 and isneginf(x) cases are unreachable because this
      # function is only used for x > 1e9.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is currently only called this way, my reasoning was that the function might be called at some point with other values. Happy to change this if this isn't preferred.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the bot here.

Comment on lines +2175 to +2177
ci = jnp.where(x == 0, -np.inf, ci)
ci = jnp.where(isposinf(x), 0.0, ci)
ci = jnp.where(isneginf(x), np.nan, ci)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, since this function is only called for x > 1e9, the checks for x == 0 and isneginf(x) are redundant and can be removed.

      ci = jnp.where(isposinf(x), 0.0, ci)
      # The x == 0 and isneginf(x) cases are unreachable because this
      # function is only used for x > 1e9.

fix: get polyval ordering correct

fix: correct tensor shape issue

fix: correct approx to handle infinity and -infinity better

fix: correct small nit with function names

add tests

fix: test name
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! A few comments below; in particular I think we should expand the test case.


return si, ci

def _sici_approx(x: ArrayLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment here, something like # sici approximation valid for x > 1E9

Comment on lines +2171 to +2173
si = jnp.where(x == 0, 0.0, si)
si = jnp.where(isposinf(x), np.pi / 2, si)
si = jnp.where(isneginf(x), -np.pi / 2, si)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the bot here.

return si, ci


def _sici_series(x: ArrayLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about range of validity

return si, ci


def _sici_asympt(x: ArrayLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about range of validity

z = 1.0 / (x * x)

# Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c
FN4 = jnp.array([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use np.array rather than jnp.array for lists of python floats; jnp.array(list) can be very slow.

self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6)
self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6)

def testSiciValueRanges(self):
Copy link
Collaborator

@jakevdp jakevdp Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than specific values, I think it would be better to test sici with randomly generated arrays generated from jtu.rand_default(scale=10) and jtu.rand_default(scale=1E9) You could parametrize the test over the scale value (use jtu.sample_product)

@jakevdp jakevdp self-assigned this Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Code does not terminate when using jax.scipy.special.sici together with jax.numpy.linalg.solve

2 participants