-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Sici Asymptotic Expansion #33178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Sici Asymptotic Expansion #33178
Conversation
Summary of ChangesHello @michaeleliot, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refines the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a great improvement to the sici function, bringing it closer to the Cephes implementation by adding an asymptotic expansion and an approximation for large values. The code is well-structured with helper functions for different value ranges.
Regarding your concern about inefficiency from computing all three branches (_sici_series, _sici_asympt, _sici_approx) and then selecting the result: this is an idiomatic and often necessary pattern in JAX. To make a function JIT-compatible, we generally need to avoid data-dependent control flow at the Python level. jnp.select (or jnp.where) is the standard way to express this kind of branching for array inputs. While it may seem inefficient, XLA is often able to optimize these patterns, and the performance impact might be less than you expect, especially on parallel hardware. The alternative using lax.cond would require vmap and be more complex to implement for array inputs.
I've added a few suggestions to clean up the special value handling and improve readability, along with a minor naming suggestion for the new test.
| si = jnp.where(x == 0, 0.0, si) | ||
| si = jnp.where(isposinf(x), np.pi / 2, si) | ||
| si = jnp.where(isneginf(x), -np.pi / 2, si) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this function is only called for x > 1e9, the checks for x == 0 and isneginf(x) are redundant and can be removed for clarity and to avoid unnecessary computation.
si = jnp.where(isposinf(x), np.pi / 2, si)
# The x == 0 and isneginf(x) cases are unreachable because this
# function is only used for x > 1e9.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this is currently only called this way, my reasoning was that the function might be called at some point with other values. Happy to change this if this isn't preferred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the bot here.
| ci = jnp.where(x == 0, -np.inf, ci) | ||
| ci = jnp.where(isposinf(x), 0.0, ci) | ||
| ci = jnp.where(isneginf(x), np.nan, ci) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix: get polyval ordering correct fix: correct tensor shape issue fix: correct approx to handle infinity and -infinity better fix: correct small nit with function names add tests fix: test name
jakevdp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks! A few comments below; in particular I think we should expand the test case.
|
|
||
| return si, ci | ||
|
|
||
| def _sici_approx(x: ArrayLike): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a comment here, something like # sici approximation valid for x > 1E9
| si = jnp.where(x == 0, 0.0, si) | ||
| si = jnp.where(isposinf(x), np.pi / 2, si) | ||
| si = jnp.where(isneginf(x), -np.pi / 2, si) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the bot here.
| return si, ci | ||
|
|
||
|
|
||
| def _sici_series(x: ArrayLike): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment about range of validity
| return si, ci | ||
|
|
||
|
|
||
| def _sici_asympt(x: ArrayLike): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment about range of validity
| z = 1.0 / (x * x) | ||
|
|
||
| # Values come from Cephes Implementation used by Scipy https://github.com/jeremybarnes/cephes/blob/60f27df395b8322c2da22c83751a2366b82d50d1/misc/sici.c | ||
| FN4 = jnp.array([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use np.array rather than jnp.array for lists of python floats; jnp.array(list) can be very slow.
| self.assertAllClose(si_jax, expected_si, atol=1e-6, rtol=1e-6) | ||
| self.assertAllClose(ci_jax, expected_ci, atol=1e-6, rtol=1e-6) | ||
|
|
||
| def testSiciValueRanges(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than specific values, I think it would be better to test sici with randomly generated arrays generated from jtu.rand_default(scale=10) and jtu.rand_default(scale=1E9) You could parametrize the test over the scale value (use jtu.sample_product)
Brings the Sici function closer to the Cephes implementation by adding asymptotic expansion for values > 4, and doing fixed approximation for values > 1e9.
Couple Thoughts
Related #32052
Closes #33081