Skip to content

Conversation

@wenyi-guo
Copy link
Collaborator

Support both tf and jax backend for JaxLayer

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wenyi-guo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the interoperability of Keras with JAX and Flax by enabling JaxLayer and FlaxLayer to operate seamlessly within a TensorFlow backend environment. It achieves this by integrating jax2tf to convert JAX functions into TensorFlow-compatible graphs, allowing JAX/Flax models to be used as Keras layers even when Keras is configured to use TensorFlow. This change broadens the utility of these layers, offering greater flexibility for developers working across different deep learning frameworks.

Highlights

  • TensorFlow Backend Support for JaxLayer: The JaxLayer and FlaxLayer classes now officially support the TensorFlow backend, allowing JAX and Flax models to be integrated into Keras models running on TensorFlow.
  • jax2tf Integration: New internal utilities have been introduced to facilitate the conversion of JAX functions to TensorFlow graphs using jax2tf, including handling polymorphic shapes and the training argument, ensuring seamless execution.
  • Unified Random Number Generation: The random number generation within JaxLayer and FlaxLayer has been standardized to consistently use JAX's PRNG keys and Keras's random module, replacing the previous SeedGenerator.
  • Flexible Output Shape Computation: A new compute_output_shape_fn argument has been added to JaxLayer and FlaxLayer, providing a mechanism for users to explicitly define how the output shape is computed, with a robust fallback in the base Layer class.
  • Expanded Test Coverage: The test suite for JaxLayer and FlaxLayer has been updated to include scenarios for the TensorFlow backend and now utilizes Keras's ops and random modules for consistent data generation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for jax2tf in JaxLayer, enabling its use with both JAX and TensorFlow backends. This is a significant and well-implemented enhancement. The changes are primarily in keras/src/utils/jax_layer.py and its corresponding test file, with a minor supporting change in keras/src/layers/layer.py. The implementation correctly handles jax2tf conversion, state management, and RNG for both backends. I've identified a critical bug in a new helper function that could lead to infinite recursion and a minor issue in the test setup. With these fixes, this will be a solid contribution.

wenyi-guo and others added 3 commits November 10, 2025 15:30
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@codecov-commenter
Copy link

codecov-commenter commented Nov 10, 2025

Codecov Report

❌ Patch coverage is 17.10526% with 63 lines in your changes missing coverage. Please review.
✅ Project coverage is 61.56%. Comparing base (4d30a7f) to head (6210659).
⚠️ Report is 5 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/utils/jax_layer.py 17.10% 62 Missing and 1 partial ⚠️

❗ There is a different number of reports uploaded between BASE (4d30a7f) and HEAD (6210659). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (4d30a7f) HEAD (6210659)
keras 5 2
keras-tensorflow 1 0
keras-torch 1 0
keras-jax 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #21842       +/-   ##
===========================================
- Coverage   82.66%   61.56%   -21.10%     
===========================================
  Files         577      577               
  Lines       59477    59568       +91     
  Branches     9329     9345       +16     
===========================================
- Hits        49167    36674    -12493     
- Misses       7907    20569    +12662     
+ Partials     2403     2325       -78     
Flag Coverage Δ
keras 61.55% <17.10%> (-20.94%) ⬇️
keras-jax ?
keras-numpy 57.51% <17.10%> (-0.04%) ⬇️
keras-openvino 34.33% <17.10%> (-0.01%) ⬇️
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting close!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 12, 2025
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants