-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Aot #33225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Aot #33225
Conversation
Summary of ChangesHello @danielsuo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a foundational Ahead-of-Time (AOT) compilation API into JAX, enabling the definition and efficient caching of computational components. The primary goal is to optimize performance by reducing redundant compilations and facilitating the reuse of pre-compiled graph segments. This is achieved through new modules for component definition and sophisticated caching utilities, alongside targeted modifications to JAX's core tracing and compilation infrastructure to seamlessly integrate AOT capabilities. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new experimental Ahead-of-Time (AOT) compilation API, centered around the @aot.component decorator. The changes are substantial, adding new files for the AOT logic and utilities, and modifying existing files to support caching and testing.
While this is a promising feature, the implementation appears to be in an early stage and has several critical issues that need to be addressed. For example, there's an incorrect equality implementation in ComponentKey, a module-level cache in pjit.py that will cause memory leaks and is not thread-safe, and a bug in traceback_util.py where a casted variable is not returned. Additionally, the code contains numerous TODO comments and debug logging statements that should be cleaned up. The hardcoding of platforms=["cpu"] also severely limits the current implementation.
jax/_src/pjit.py
Outdated
| return _infer_params_internal(fun, ji, args, kwargs) | ||
|
|
||
|
|
||
| cache = dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A module-level dictionary cache is used for caching in _infer_params_internal. This is problematic for several reasons:
- Memory Leak: The cache has no eviction policy and will grow indefinitely, leading to a memory leak.
- Thread Safety: A global mutable dictionary is not thread-safe. In a multi-threaded environment, this can lead to race conditions.
A proper caching mechanism should be used, such as functools.lru_cache or a custom cache implementation that handles thread safety and eviction.
| # TODO(dsuo): What are ordered effects vs effects? | ||
| ordered_effects=traced.jaxpr.effects, | ||
| # TODO(dsuo): Figure out why ctx.platforms=None. | ||
| platforms=["cpu"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The platforms argument is hardcoded to ["cpu"], which limits this AOT functionality to only the CPU backend. The associated TODO indicates this is a known issue. This should be resolved to correctly use the platform context from the lowering context (ctx.platforms) to support other backends like GPU and TPU.
jax/_src/traceback_util.py
Outdated
| casted = cast(C, reraise_with_filtered_traceback) | ||
|
|
||
| if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: | ||
| logging.info("casted reraise id %s", | ||
| id(casted)) | ||
|
|
||
| return reraise_with_filtered_traceback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The casted variable is created but the original reraise_with_filtered_traceback is returned. This is likely a bug. The surrounding logging statements also appear to be for debugging and can be removed. The code can be simplified to return the casted value directly.
| casted = cast(C, reraise_with_filtered_traceback) | |
| if hasattr(fun, '__name__') and 'wrapper' in fun.__name__: | |
| logging.info("casted reraise id %s", | |
| id(casted)) | |
| return reraise_with_filtered_traceback | |
| return cast(C, reraise_with_filtered_traceback) |
|
|
||
| @api_boundary | ||
| def cache_miss(*args, **kwargs): | ||
| logging.info('cpp_pjit fun: %s %s', id(fun), fun.__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
No description provided.