From f01032a71acbb3c65fd8bc234d2ad8ea4a252790 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 10 Nov 2025 12:05:49 -0800 Subject: [PATCH 01/24] Support jax2tf in JaxLayer for tf backend --- keras/src/layers/layer.py | 5 +- keras/src/utils/jax_layer.py | 240 +++++++++++++++++++++++++----- keras/src/utils/jax_layer_test.py | 34 +++-- 3 files changed, 234 insertions(+), 45 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 9e6c928e3ee4..504627f6b524 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1145,7 +1145,10 @@ def compute_output_spec(self, *args, **kwargs): call_spec=call_spec, class_name=self.__class__.__name__, ) - output_shape = self.compute_output_shape(**shapes_dict) + try: + output_shape = self.compute_output_shape(**shapes_dict) + except NotImplementedError as e: + return super().compute_output_spec(*args, **kwargs) if ( isinstance(output_shape, list) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index a02af992778f..b117667ff81a 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -1,8 +1,13 @@ import inspect +import collections +import functools +import itertools import numpy as np +import string from keras.src import backend +from keras.src import random from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.variables import is_float_dtype @@ -11,7 +16,18 @@ from keras.src.saving import serialization_lib from keras.src.utils import jax_utils from keras.src.utils import tracking +from keras.src import ops from keras.src.utils.module_utils import jax +from keras.src.utils.module_utils import tensorflow as tf + + +def standardize_pytree_collections(pytree): + if isinstance(pytree, collections.abc.Mapping): + return {k: standardize_pytree_collections(v) for k, v in pytree.items()} + elif isinstance(pytree, collections.abc.Sequence): + return [standardize_pytree_collections(v) for v in pytree] + else: + return pytree @keras_export("keras.layers.JaxLayer") @@ -196,6 +212,9 @@ def my_haiku_module_fn(inputs, training): init_fn: the function to call to initialize the model. See description above for the list of arguments it takes and the outputs it returns. If `None`, then `params` and/or `state` must be provided. + compute_output_shape_fn: Function that takes the input shape + (a tuple or nested structure of tuples) and returns the output + shape (a tuple or nested structure of tuples). params: A `PyTree` containing all the model trainable parameters. This allows passing trained parameters or controlling the initialization. If both `params` and `state` are `None`, `init_fn` is called at @@ -214,15 +233,16 @@ def __init__( self, call_fn, init_fn=None, + compute_output_shape_fn=None, params=None, state=None, seed=None, **kwargs, ): - if backend.backend() != "jax": + if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( - "JaxLayer is only supported with the JAX backend. Current " - f"backend: {backend.backend()}" + "JaxLayer is only supported with the JAX or Tensorflow backend" + f". Current backend: {backend.backend()}" ) if init_fn is None and params is None and state is None: @@ -233,7 +253,10 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.seed_generator = backend.random.SeedGenerator(seed) + self.compute_output_shape_fn = compute_output_shape_fn + if seed is None: + seed = random.seed_generator.make_default_seed() + self.jax_rng = jax.random.PRNGKey(seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: @@ -252,6 +275,10 @@ def __init__( init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"} ) + # Attributes for jax2tf functions + self.jax2tf_training_false_fn = None + self.jax2tf_training_true_fn = None + def _validate_signature(self, fn, fn_name, allowed, required): fn_parameters = inspect.signature(fn).parameters for parameter_name in required: @@ -272,6 +299,79 @@ def _validate_signature(self, fn, fn_name, allowed, required): return parameter_names + def _get_jax2tf_input_shape(self, input_shape): + """Convert input shape in a format suitable for `jax2tf`. + + `jax2tf` expects a letter for each unknown dimension, which allows + correlated dimensions. Since correlated dimensions are not supported by + Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We + however use 'batch' for dimension 0 if not defined to correlate the + batch size across inputs. + + Example (spaces added for readability): + ``` + input_shape: (None , 4 , None, None, 5 ) + result: "(batch, 4 , a , b , 5 )" + ``` + + Args: + input_shape: a single shape or a structure of shapes for the inputs. + Returns: + the shape or shapes structure in the `jax2tf` format as strings. + """ + dim_names = itertools.chain( + string.ascii_lowercase, # a, b, ... z + itertools.starmap( # aa, ab, ... az, ba, bb, ... zz + lambda a, b: a + b, + itertools.product(string.ascii_lowercase, repeat=2), + ), + ) + + def get_single_jax2tf_shape(shape): + jax2tf_shape = [] + + for index, dim in enumerate(shape): + if dim is not None: + jax2tf_shape.append(str(dim)) + elif index == 0: + jax2tf_shape.append("batch") + else: + jax2tf_shape.append(next(dim_names)) + + return "(" + ", ".join(jax2tf_shape) + ")" + + res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape) + return res + + def _jax2tf_convert(self, fn, polymorphic_shapes): + from jax.experimental import jax2tf + + converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes) + # Autograph won't work with the output of jax2tf. + converted_fn = tf.autograph.experimental.do_not_convert(converted_fn) + return converted_fn + + def _partial_with_positional(self, fn, index, value): + """Return a new partial with one positional argument set to a value. + + This is needed because `jax2tf` only supports positional arguments and + `functools.partial` only supports setting positional arguments starting + from the left. Our use case is the `training` argument which is + typically the righmost argument. + + Args: + fn: the function to wrap. + index: the index of the positional argument to set to `value`. + value: the value for the positional argument at `index`. + """ + + @functools.wraps(fn) + def wrapper(*args): + args = args[0:index] + (value,) + args[index:] + return fn(*args) + + return wrapper + @tracking.no_automatic_dependency_tracking def _create_variables(self, values, trainable): """Create a structure of variables from a structure of JAX arrays. @@ -296,14 +396,18 @@ def _create_variables(self, values, trainable): def create_variable(value): if backend.is_tensor(value) or isinstance( - value, (np.ndarray, np.generic) + value, (np.ndarray, np.generic, jax.Array) ): dtype = value.dtype if is_float_dtype(dtype): dtype = None # Use the layer dtype policy return self.add_weight( value.shape, - initializer=value, + initializer=( + backend.convert_to_tensor(value) + if value is not None + else None + ), dtype=dtype, trainable=trainable, ) @@ -328,8 +432,15 @@ def create_variable(value): else: self.state = variables - flat_variables, _ = jax.tree_util.tree_flatten(variables) - return flat_variables + if backend.backend() == "jax": + flat_variables, _ = jax.tree_util.tree_flatten(variables) + return flat_variables + elif backend.backend() == "tensorflow": + return variables + + def _split_jax_rng(self): + self.jax_rng, subkey = jax.random.split(self.jax_rng) + return subkey def _get_init_rng(self): """ @@ -343,7 +454,7 @@ def _get_init_rng(self): a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as the `rng` argument of `init_fn`. """ - return self.seed_generator.next() + return self._split_jax_rng() def _get_call_rng(self, training): """ @@ -359,23 +470,23 @@ def _get_call_rng(self, training): the `rng` argument of `call_fn`. """ if training: - return self.seed_generator.next() + return self._split_jax_rng() else: return None - def build(self, input_shape): - if self.params is not None or self.state is not None: - return - - if jax_utils.is_in_jax_tracing_scope(): + def _initialize_weights(self, input_shape): + if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): # This exception is not actually shown, it is caught and a detailed # warning about calling 'build' is printed. - raise ValueError("'JaxLayer' cannot be built in tracing scope") + raise ValueError( + "'JaxLayer' cannot be built in tracing scope" + "or inside tf function" + ) # Initialize `params` and `state` if needed by calling `init_fn`. def create_input(shape): shape = [d if d is not None else 1 for d in shape] - return jax.numpy.ones(shape) + return ops.ones(shape) init_inputs = tree.map_shape_structure(create_input, input_shape) init_args = [] @@ -398,6 +509,45 @@ def create_input(shape): ) self.tracked_state = self._create_variables(init_state, trainable=False) + def build(self, input_shape): + if self.params is None and self.state is None: + self._initialize_weights(input_shape) + + if backend.backend() == "tensorflow": + polymorphic_shapes = [] + for argument in self.call_fn_arguments: + if argument == "inputs": + polymorphic_shapes.append( + self._get_jax2tf_input_shape(input_shape) + ) + elif argument != "training": + # params, state, rng + polymorphic_shapes.append("...") + + if "training" in self.call_fn_arguments: + training_argument_index = self.call_fn_arguments.index( + "training" + ) + self.jax2tf_training_false_fn = self._jax2tf_convert( + self._partial_with_positional( + self.call_fn, training_argument_index, False + ), + polymorphic_shapes, + ) + self.jax2tf_training_true_fn = self._jax2tf_convert( + self._partial_with_positional( + self.call_fn, training_argument_index, True + ), + polymorphic_shapes, + ) + else: + self.jax2tf_training_false_fn = self._jax2tf_convert( + self.call_fn, + polymorphic_shapes, + ) + self.jax2tf_training_true_fn = None + super().build(input_shape) + def call(self, inputs, training=False): def unwrap_variable(variable): return None if variable is None else variable.value @@ -417,7 +567,8 @@ def unwrap_variable(variable): elif argument_name == "inputs": call_args.append(inputs) elif argument_name == "training": - call_args.append(training) + if backend.backend() == "jax": + call_args.append(training) def assign_state_to_variable(value, variable): # This exists only to make debugging this error case easier. @@ -429,14 +580,39 @@ def assign_state_to_variable(value, variable): ) variable.assign(value) - if self.has_state: - predictions, new_state = self.call_fn(*call_args) - jax.tree_util.tree_map( - assign_state_to_variable, new_state, self.state - ) - return predictions + def call_with_fn(fn): + if self.has_state: + predictions, new_state = fn(*call_args) + if backend.backend() == "jax": + jax.tree_util.tree_map( + assign_state_to_variable, new_state, self.state + ) + elif backend.backend() == "tensorflow": + jax.tree_util.tree_map( + assign_state_to_variable, + standardize_pytree_collections(new_state), + standardize_pytree_collections(self.state), + ) + return predictions + else: + return fn(*call_args) + + if backend.backend() == "jax": + return call_with_fn(self.call_fn) + elif backend.backend() == "tensorflow": + if self.jax2tf_training_true_fn is None: + return call_with_fn(self.jax2tf_training_false_fn) + else: + if training: + return call_with_fn(self.jax2tf_training_true_fn) + else: + return call_with_fn(self.jax2tf_training_false_fn) + + def compute_output_shape(self, input_shape): + if self.compute_output_shape_fn: + return self.compute_output_shape_fn(input_shape) else: - return self.call_fn(*call_args) + return super().compute_output_shape(input_shape) def get_config(self): config = { @@ -549,6 +725,7 @@ def my_flax_module_wrapper(module, inputs, training): def __init__( self, module, + compute_output_shape_fn=None, method=None, variables=None, **kwargs, @@ -556,12 +733,6 @@ def __init__( # Late import to only require Flax when this is used. from flax.core import scope as flax_scope - if backend.backend() != "jax": - raise ValueError( - "FlaxLayer is only supported with the JAX backend. Current " - f"backend: {backend.backend()}" - ) - self.module = module self.method = method @@ -618,6 +789,7 @@ def init_without_training(rng, inputs): super().__init__( call_fn=call_fn, init_fn=init_fn, + compute_output_shape_fn=compute_output_shape_fn, params=params, state=state, **kwargs, @@ -650,13 +822,13 @@ def _variables_to_params_and_state(self, variables): def _get_init_rng(self): return { - "params": self.seed_generator.next(), - "dropout": self.seed_generator.next(), + "params": self._split_jax_rng(), + "dropout": self._split_jax_rng(), } def _get_call_rng(self, training): if training: - return {"dropout": self.seed_generator.next()} + return {"dropout": self._split_jax_rng()} else: return {} diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 009ecd402e5f..52fee536c659 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -19,6 +19,8 @@ from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer +from keras.src import ops +from keras.src import random try: import flax @@ -69,6 +71,11 @@ def jax_stateful_apply(params, state, inputs, training): return outputs, state +@object_registration.register_keras_serializable() +def stateless_compute_output_shape(input_shape): + return (input_shape[0], num_classes) + + if flax is not None: @object_registration.register_keras_serializable() @@ -179,8 +186,8 @@ def from_config(cls, config): @pytest.mark.skipif( - backend.backend() != "jax", - reason="JaxLayer and FlaxLayer are only supported with JAX backend", + backend.backend() not in ["jax", "tensorflow"], + reason="JaxLayer and FlaxLayer are only supported with JAX and TF backend", ) class TestJaxLayer(testing.TestCase): def _test_layer( @@ -194,16 +201,18 @@ def _test_layer( non_trainable_params, ): # Fake MNIST data - x_train = np.random.uniform(size=(320, 28, 28, 1)) - y_train = np.eye(num_classes, dtype="int32")[ - (np.random.uniform(size=(320,)) * num_classes).astype("int32") - ] - x_test = np.random.uniform(size=(32, 28, 28, 1)) + x_train = random.uniform(shape=(320, 28, 28, 1)) + y_train_indices = ops.cast( + ops.random.uniform(shape=(320,), minval=0, maxval=num_classes), + dtype="int32", + ) + y_train = ops.one_hot(y_train_indices, num_classes, dtype="int32") + x_test = random.uniform(shape=(32, 28, 28, 1)) def _count_params(weights): count = 0 for weight in weights: - count = count + np.prod(weight.shape) + count = count + ops.prod(ops.shape(weight)) return count def verify_weights_and_params(layer): @@ -257,7 +266,7 @@ def verify_weights_and_params(layer): for before, after in zip(ntw1_before_fit, ntw1_after_fit): self.assertNotAllClose(before, after) - expected_ouput_shape = (x_test.shape[0], num_classes) + expected_ouput_shape = (ops.shape(x_test)[0], num_classes) output1 = model1(x_test) self.assertEqual(output1.shape, expected_ouput_shape) predict1 = model1.predict(x_test, steps=1) @@ -478,7 +487,12 @@ def create_wrapper(**kwargs): flax_model = flax_model_class() if flax_model_method: kwargs["method"] = getattr(flax_model, flax_model_method) - return FlaxLayer(flax_model_class(), **kwargs) + if backend.backend() == "jax": + return FlaxLayer(flax_model_class(), **kwargs) + elif backend.backend() == "tensorflow": + return FlaxLayer( + flax_model, stateless_compute_output_shape, **kwargs + ) self._test_layer( flax_model_class.__name__, From b57d7d9694e05d2736f828d44c741dd655724866 Mon Sep 17 00:00:00 2001 From: Wenyi Guo <41378453+wenyi-guo@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:10:55 -0800 Subject: [PATCH 02/24] Update keras/src/utils/jax_layer.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/utils/jax_layer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index b117667ff81a..723829ba632e 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -22,6 +22,8 @@ def standardize_pytree_collections(pytree): + if isinstance(pytree, (str, bytes)): + return pytree if isinstance(pytree, collections.abc.Mapping): return {k: standardize_pytree_collections(v) for k, v in pytree.items()} elif isinstance(pytree, collections.abc.Sequence): From a5d98c49d7eb0895a36c7f52ae7056158089872e Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Mon, 10 Nov 2025 12:41:44 -0800 Subject: [PATCH 03/24] format --- keras/src/layers/layer.py | 2 +- keras/src/utils/jax_layer.py | 8 ++++---- keras/src/utils/jax_layer_test.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 504627f6b524..1df416509952 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1147,7 +1147,7 @@ def compute_output_spec(self, *args, **kwargs): ) try: output_shape = self.compute_output_shape(**shapes_dict) - except NotImplementedError as e: + except NotImplementedError: return super().compute_output_spec(*args, **kwargs) if ( diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 723829ba632e..721c5c0fbfcd 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -1,12 +1,13 @@ -import inspect - import collections import functools +import inspect import itertools -import numpy as np import string +import numpy as np + from keras.src import backend +from keras.src import ops from keras.src import random from keras.src import tree from keras.src.api_export import keras_export @@ -16,7 +17,6 @@ from keras.src.saving import serialization_lib from keras.src.utils import jax_utils from keras.src.utils import tracking -from keras.src import ops from keras.src.utils.module_utils import jax from keras.src.utils.module_utils import tensorflow as tf diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 52fee536c659..1fd1eb3282cd 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -11,6 +11,8 @@ from keras.src import layers from keras.src import metrics from keras.src import models +from keras.src import ops +from keras.src import random from keras.src import saving from keras.src import testing from keras.src import tree @@ -19,8 +21,6 @@ from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer -from keras.src import ops -from keras.src import random try: import flax From 88f595db35fa779b0a60a8896c81a3109650bd1a Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 14:05:41 -0800 Subject: [PATCH 04/24] address comments --- keras/src/export/saved_model.py | 6 ---- keras/src/layers/layer.py | 5 +-- keras/src/utils/jax_layer.py | 60 +++++++++---------------------- keras/src/utils/jax_layer_test.py | 15 ++------ 4 files changed, 20 insertions(+), 66 deletions(-) diff --git a/keras/src/export/saved_model.py b/keras/src/export/saved_model.py index d5009a7ec4a6..95ec9afa4223 100644 --- a/keras/src/export/saved_model.py +++ b/keras/src/export/saved_model.py @@ -540,12 +540,6 @@ def write_out(self, filepath, options=None, verbose=True): ) def _convert_to_tf_variable(self, backend_variable): - if not isinstance(backend_variable, backend.Variable): - raise TypeError( - "`backend_variable` must be a `backend.Variable`. " - f"Recevied: backend_variable={backend_variable} of type " - f"({type(backend_variable)})" - ) return tf.Variable( backend_variable.value, dtype=backend_variable.dtype, diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 1df416509952..9e6c928e3ee4 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1145,10 +1145,7 @@ def compute_output_spec(self, *args, **kwargs): call_spec=call_spec, class_name=self.__class__.__name__, ) - try: - output_shape = self.compute_output_shape(**shapes_dict) - except NotImplementedError: - return super().compute_output_spec(*args, **kwargs) + output_shape = self.compute_output_shape(**shapes_dict) if ( isinstance(output_shape, list) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 721c5c0fbfcd..462d094ab29b 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -19,7 +19,7 @@ from keras.src.utils import tracking from keras.src.utils.module_utils import jax from keras.src.utils.module_utils import tensorflow as tf - +from keras.src.backend import jax as jax_backend def standardize_pytree_collections(pytree): if isinstance(pytree, (str, bytes)): @@ -214,9 +214,6 @@ def my_haiku_module_fn(inputs, training): init_fn: the function to call to initialize the model. See description above for the list of arguments it takes and the outputs it returns. If `None`, then `params` and/or `state` must be provided. - compute_output_shape_fn: Function that takes the input shape - (a tuple or nested structure of tuples) and returns the output - shape (a tuple or nested structure of tuples). params: A `PyTree` containing all the model trainable parameters. This allows passing trained parameters or controlling the initialization. If both `params` and `state` are `None`, `init_fn` is called at @@ -235,7 +232,6 @@ def __init__( self, call_fn, init_fn=None, - compute_output_shape_fn=None, params=None, state=None, seed=None, @@ -243,7 +239,7 @@ def __init__( ): if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( - "JaxLayer is only supported with the JAX or Tensorflow backend" + f"{self.__class__.__name__} is only supported with the JAX or Tensorflow backend" f". Current backend: {backend.backend()}" ) @@ -255,10 +251,7 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.compute_output_shape_fn = compute_output_shape_fn - if seed is None: - seed = random.seed_generator.make_default_seed() - self.jax_rng = jax.random.PRNGKey(seed) + self.seed_generator = backend.random.SeedGenerator(seed, backend=jax_backend) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: @@ -405,11 +398,7 @@ def create_variable(value): dtype = None # Use the layer dtype policy return self.add_weight( value.shape, - initializer=( - backend.convert_to_tensor(value) - if value is not None - else None - ), + initializer=backend.convert_to_tensor(value), dtype=dtype, trainable=trainable, ) @@ -434,15 +423,8 @@ def create_variable(value): else: self.state = variables - if backend.backend() == "jax": - flat_variables, _ = jax.tree_util.tree_flatten(variables) - return flat_variables - elif backend.backend() == "tensorflow": - return variables - - def _split_jax_rng(self): - self.jax_rng, subkey = jax.random.split(self.jax_rng) - return subkey + flat_variables, _ = jax.tree_util.tree_flatten(variables) + return flat_variables def _get_init_rng(self): """ @@ -456,7 +438,7 @@ def _get_init_rng(self): a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as the `rng` argument of `init_fn`. """ - return self._split_jax_rng() + return self.seed_generator.next() def _get_call_rng(self, training): """ @@ -472,7 +454,7 @@ def _get_call_rng(self, training): the `rng` argument of `call_fn`. """ if training: - return self._split_jax_rng() + return self.seed_generator.next() else: return None @@ -488,7 +470,7 @@ def _initialize_weights(self, input_shape): # Initialize `params` and `state` if needed by calling `init_fn`. def create_input(shape): shape = [d if d is not None else 1 for d in shape] - return ops.ones(shape) + return jax.numpy.ones(shape) init_inputs = tree.map_shape_structure(create_input, input_shape) init_args = [] @@ -602,19 +584,11 @@ def call_with_fn(fn): if backend.backend() == "jax": return call_with_fn(self.call_fn) elif backend.backend() == "tensorflow": - if self.jax2tf_training_true_fn is None: - return call_with_fn(self.jax2tf_training_false_fn) + if training and self.jax2tf_training_true_fn is not None: + return call_with_fn(self.jax2tf_training_true_fn) else: - if training: - return call_with_fn(self.jax2tf_training_true_fn) - else: - return call_with_fn(self.jax2tf_training_false_fn) - - def compute_output_shape(self, input_shape): - if self.compute_output_shape_fn: - return self.compute_output_shape_fn(input_shape) - else: - return super().compute_output_shape(input_shape) + return call_with_fn(self.jax2tf_training_false_fn) + def get_config(self): config = { @@ -727,7 +701,6 @@ def my_flax_module_wrapper(module, inputs, training): def __init__( self, module, - compute_output_shape_fn=None, method=None, variables=None, **kwargs, @@ -791,7 +764,6 @@ def init_without_training(rng, inputs): super().__init__( call_fn=call_fn, init_fn=init_fn, - compute_output_shape_fn=compute_output_shape_fn, params=params, state=state, **kwargs, @@ -824,13 +796,13 @@ def _variables_to_params_and_state(self, variables): def _get_init_rng(self): return { - "params": self._split_jax_rng(), - "dropout": self._split_jax_rng(), + "params": self.seed_generator.next(), + "dropout": self.seed_generator.next(), } def _get_call_rng(self, training): if training: - return {"dropout": self._split_jax_rng()} + return {"dropout": self.seed_generator.next()} else: return {} diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 1fd1eb3282cd..8813dc336696 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -6,6 +6,7 @@ import pytest import tensorflow as tf from absl.testing import parameterized +import math from keras.src import backend from keras.src import layers @@ -71,11 +72,6 @@ def jax_stateful_apply(params, state, inputs, training): return outputs, state -@object_registration.register_keras_serializable() -def stateless_compute_output_shape(input_shape): - return (input_shape[0], num_classes) - - if flax is not None: @object_registration.register_keras_serializable() @@ -212,7 +208,7 @@ def _test_layer( def _count_params(weights): count = 0 for weight in weights: - count = count + ops.prod(ops.shape(weight)) + count = count + math.prod(ops.shape(weight)) return count def verify_weights_and_params(layer): @@ -487,12 +483,7 @@ def create_wrapper(**kwargs): flax_model = flax_model_class() if flax_model_method: kwargs["method"] = getattr(flax_model, flax_model_method) - if backend.backend() == "jax": - return FlaxLayer(flax_model_class(), **kwargs) - elif backend.backend() == "tensorflow": - return FlaxLayer( - flax_model, stateless_compute_output_shape, **kwargs - ) + return FlaxLayer(flax_model_class(), **kwargs) self._test_layer( flax_model_class.__name__, From b2fd9ba22e6bf7cc7a4c812b98b398faeca053e2 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 14:35:34 -0800 Subject: [PATCH 05/24] lint --- keras/src/utils/jax_layer.py | 12 ++++++------ keras/src/utils/jax_layer_test.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 462d094ab29b..601690960ffd 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -7,10 +7,9 @@ import numpy as np from keras.src import backend -from keras.src import ops -from keras.src import random from keras.src import tree from keras.src.api_export import keras_export +from keras.src.backend import jax as jax_backend from keras.src.backend.common.variables import is_float_dtype from keras.src.backend.common.variables import standardize_dtype from keras.src.layers.layer import Layer @@ -19,7 +18,7 @@ from keras.src.utils import tracking from keras.src.utils.module_utils import jax from keras.src.utils.module_utils import tensorflow as tf -from keras.src.backend import jax as jax_backend + def standardize_pytree_collections(pytree): if isinstance(pytree, (str, bytes)): @@ -239,8 +238,8 @@ def __init__( ): if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( - f"{self.__class__.__name__} is only supported with the JAX or Tensorflow backend" - f". Current backend: {backend.backend()}" + f"{self.__class__.__name__} is only supported with the JAX or" + f" Tensorflow backend. Current backend: {backend.backend()}" ) if init_fn is None and params is None and state is None: @@ -251,7 +250,8 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.seed_generator = backend.random.SeedGenerator(seed, backend=jax_backend) + self.seed_generator = backend.random.SeedGenerator( + seed, backend=jax_backend) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 8813dc336696..902ec19c2e35 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -1,3 +1,4 @@ +import math import os import jax @@ -6,7 +7,6 @@ import pytest import tensorflow as tf from absl.testing import parameterized -import math from keras.src import backend from keras.src import layers From e18f9136a840607f6e30eb561c687ba3b2bc3748 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 14:53:49 -0800 Subject: [PATCH 06/24] local import --- keras/src/utils/jax_layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 601690960ffd..37a0613aac05 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -9,7 +9,6 @@ from keras.src import backend from keras.src import tree from keras.src.api_export import keras_export -from keras.src.backend import jax as jax_backend from keras.src.backend.common.variables import is_float_dtype from keras.src.backend.common.variables import standardize_dtype from keras.src.layers.layer import Layer @@ -236,6 +235,8 @@ def __init__( seed=None, **kwargs, ): + from keras.src.backend import jax as jax_backend + if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( f"{self.__class__.__name__} is only supported with the JAX or" From 62897ed57ae9cac5ecdcfe9cdd6ac9104525b206 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 15:13:05 -0800 Subject: [PATCH 07/24] format --- keras/src/utils/jax_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 37a0613aac05..955d203c94db 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -252,7 +252,8 @@ def __init__( self.call_fn = call_fn self.init_fn = init_fn self.seed_generator = backend.random.SeedGenerator( - seed, backend=jax_backend) + seed, backend=jax_backend + ) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: @@ -589,7 +590,6 @@ def call_with_fn(fn): return call_with_fn(self.jax2tf_training_true_fn) else: return call_with_fn(self.jax2tf_training_false_fn) - def get_config(self): config = { From 0e140b465ff02dce390b6d32b1edd27563e06783 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 17:27:32 -0800 Subject: [PATCH 08/24] Use backend respective seed generator, but passing dtype uint32 as that's the dtype for jax key. --- keras/src/random/seed_generator.py | 4 +++- keras/src/utils/jax_layer.py | 21 ++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py index dd2adbc13bbe..2161de249919 100644 --- a/keras/src/random/seed_generator.py +++ b/keras/src/random/seed_generator.py @@ -63,6 +63,7 @@ def __init__(self, seed=None, name=None, **kwargs): self.name = name custom_backend = kwargs.pop("backend", None) + dtype = kwargs.pop("dtype", None) if kwargs: raise ValueError(f"Unrecognized keyword arguments: {kwargs}") if custom_backend is not None: @@ -84,10 +85,11 @@ def seed_initializer(*args, **kwargs): return self.backend.convert_to_tensor([seed, 0], dtype=dtype) with self.backend.name_scope(self.name, caller=self): + dtype = dtype if dtype else self.backend.random_seed_dtype() self.state = self.backend.Variable( seed_initializer, shape=(2,), - dtype=self.backend.random_seed_dtype(), + dtype=dtype, trainable=False, aggregation="none", name="seed_generator_state", diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 955d203c94db..e64be70d35d7 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -251,8 +251,12 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn + if backend.backend() == "jax": + dtype = jax.numpy.uint32 + elif backend.backend() == "tensorflow": + dtype = tf.uint32 self.seed_generator = backend.random.SeedGenerator( - seed, backend=jax_backend + seed=seed, dtype=dtype ) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) @@ -439,8 +443,15 @@ def _get_init_rng(self): Returns: a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as the `rng` argument of `init_fn`. + """ - return self.seed_generator.next() + + from keras.src.backend.jax.core import convert_to_tensor + + if backend.backend() == 'jax': + return self.seed_generator.next() + if backend.backend() == 'tensorflow': + return convert_to_tensor(self.seed_generator.next()) def _get_call_rng(self, training): """ @@ -797,13 +808,13 @@ def _variables_to_params_and_state(self, variables): def _get_init_rng(self): return { - "params": self.seed_generator.next(), - "dropout": self.seed_generator.next(), + "params": super()._get_init_rng(), + "dropout": super()._get_init_rng(), } def _get_call_rng(self, training): if training: - return {"dropout": self.seed_generator.next()} + return {"dropout": super()._get_call_rng(training)} else: return {} From 4efbddfd52ad07bc9fea172c04f91b24e76d156d Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 17:31:16 -0800 Subject: [PATCH 09/24] format --- keras/src/utils/jax_layer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index e64be70d35d7..731fffac78e0 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -235,7 +235,6 @@ def __init__( seed=None, **kwargs, ): - from keras.src.backend import jax as jax_backend if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( @@ -443,14 +442,14 @@ def _get_init_rng(self): Returns: a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as the `rng` argument of `init_fn`. - + """ from keras.src.backend.jax.core import convert_to_tensor - - if backend.backend() == 'jax': + + if backend.backend() == "jax": return self.seed_generator.next() - if backend.backend() == 'tensorflow': + if backend.backend() == "tensorflow": return convert_to_tensor(self.seed_generator.next()) def _get_call_rng(self, training): From 0012271d31b45237d17228acf59b7c5eaf7bc288 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 21:31:03 -0800 Subject: [PATCH 10/24] resolve comments --- keras/src/export/saved_model.py | 6 + keras/src/utils/jax_layer.py | 59 ++--- log.log | 403 ++++++++++++++++++++++++++++++++ 3 files changed, 441 insertions(+), 27 deletions(-) create mode 100644 log.log diff --git a/keras/src/export/saved_model.py b/keras/src/export/saved_model.py index 95ec9afa4223..d5009a7ec4a6 100644 --- a/keras/src/export/saved_model.py +++ b/keras/src/export/saved_model.py @@ -540,6 +540,12 @@ def write_out(self, filepath, options=None, verbose=True): ) def _convert_to_tf_variable(self, backend_variable): + if not isinstance(backend_variable, backend.Variable): + raise TypeError( + "`backend_variable` must be a `backend.Variable`. " + f"Recevied: backend_variable={backend_variable} of type " + f"({type(backend_variable)})" + ) return tf.Variable( backend_variable.value, dtype=backend_variable.dtype, diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 731fffac78e0..4a0bf0061c13 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -30,6 +30,16 @@ def standardize_pytree_collections(pytree): return pytree +if backend.backend() == "tensorflow": + tf_no_automatic_dependency_tracking = ( + tf.__internal__.tracking.no_automatic_dependency_tracking + ) +else: + + def tf_no_automatic_dependency_tracking(fn): + return fn + + @keras_export("keras.layers.JaxLayer") class JaxLayer(Layer): """Keras Layer that wraps a JAX model. @@ -235,7 +245,6 @@ def __init__( seed=None, **kwargs, ): - if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( f"{self.__class__.__name__} is only supported with the JAX or" @@ -373,6 +382,7 @@ def wrapper(*args): return wrapper @tracking.no_automatic_dependency_tracking + @tf_no_automatic_dependency_tracking def _create_variables(self, values, trainable): """Create a structure of variables from a structure of JAX arrays. @@ -433,24 +443,20 @@ def create_variable(value): def _get_init_rng(self): """ - Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`. + Returns a key in form of the backend array of size 2 dtype uint32 + to pass to `init_fn`. - By default, this returns a single `PRNGKey` retrieved by calling + By default, this returns a Jax or TF array of size 2 by calling `self.seed_generator.next()`. Override this to return a different structure. Returns: - a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as - the `rng` argument of `init_fn`. - + a key as an Jax or TF array of size 2 dtype uint32 will be passed + as the `rng` argument of `init_fn`. """ - - from keras.src.backend.jax.core import convert_to_tensor - - if backend.backend() == "jax": - return self.seed_generator.next() - if backend.backend() == "tensorflow": - return convert_to_tensor(self.seed_generator.next()) + next = self.seed_generator.next() + print("!!type", type(next), next) + return next def _get_call_rng(self, training): """ @@ -471,6 +477,8 @@ def _get_call_rng(self, training): return None def _initialize_weights(self, input_shape): + from keras.src.backend.jax.core import convert_to_tensor + if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): # This exception is not actually shown, it is caught and a detailed # warning about calling 'build' is printed. @@ -488,7 +496,11 @@ def create_input(shape): init_args = [] for argument_name in self.init_fn_arguments: if argument_name == "rng": - init_args.append(self._get_init_rng()) + init_args.append( + jax.tree_util.tree_map( + convert_to_tensor, self._get_init_rng() + ) + ) elif argument_name == "inputs": init_args.append(init_inputs) elif argument_name == "training": @@ -579,16 +591,9 @@ def assign_state_to_variable(value, variable): def call_with_fn(fn): if self.has_state: predictions, new_state = fn(*call_args) - if backend.backend() == "jax": - jax.tree_util.tree_map( - assign_state_to_variable, new_state, self.state - ) - elif backend.backend() == "tensorflow": - jax.tree_util.tree_map( - assign_state_to_variable, - standardize_pytree_collections(new_state), - standardize_pytree_collections(self.state), - ) + jax.tree_util.tree_map( + assign_state_to_variable, new_state, self.state + ) return predictions else: return fn(*call_args) @@ -807,13 +812,13 @@ def _variables_to_params_and_state(self, variables): def _get_init_rng(self): return { - "params": super()._get_init_rng(), - "dropout": super()._get_init_rng(), + "params": self.seed_generator.next(), + "dropout": self.seed_generator.next(), } def _get_call_rng(self, training): if training: - return {"dropout": super()._get_call_rng(training)} + return {"dropout": self.seed_generator.next()} else: return {} diff --git a/log.log b/log.log new file mode 100644 index 000000000000..34a21dddd0c8 --- /dev/null +++ b/log.log @@ -0,0 +1,403 @@ +============================= test session starts ============================== +platform darwin -- Python 3.12.10, pytest-8.4.2, pluggy-1.6.0 -- /Users/wenyiguo/keras/venv/bin/python3.12 +cachedir: .pytest_cache +rootdir: /Users/wenyiguo/keras +configfile: pyproject.toml +plugins: cov-7.0.0 +collecting ... collected 28 items + +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method Model: "FlaxTrainingIndependentModel1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 5s 655ms/step - categorical_accuracy: 0.0625 - loss: 2.2743  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - categorical_accuracy: 0.0810 - loss: 2.3819  10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - categorical_accuracy: 0.1063 - loss: 2.3672 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step +Model: "FlaxTrainingIndependentModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step +Model: "FlaxTrainingIndependentModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 1,944,680 (7.42 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + Optimizer params: 1,296,454 (4.95 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmprgwrpg8r/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) +Captures: + 13552996560: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552996752: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552994064: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552996944: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552992912: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552993488: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552997904: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552997328: TensorSpec(shape=(), dtype=tf.resource, name=None) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_rng_state_no_method Model: "FlaxBatchNormModel1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ flax_layer (FlaxLayer) │ (None, 10) │ 354,794 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 354,794 (1.35 MB) + Trainable params: 354,258 (1.35 MB) + Non-trainable params: 536 (2.09 KB) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 10s 1s/step - categorical_accuracy: 0.0312 - loss: 2.8184  4/10 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - categorical_accuracy: 0.0482 - loss: 2.7029  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - categorical_accuracy: 0.0649 - loss: 2.6597 10/10 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - categorical_accuracy: 0.0701 - loss: 2.6486 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - categorical_accuracy: 0.0844 - loss: 2.6382 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step +Model: "FlaxBatchNormModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 354,794 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 354,794 (1.35 MB) + Trainable params: 354,258 (1.35 MB) + Non-trainable params: 536 (2.09 KB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step +Model: "FlaxBatchNormModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 354,794 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 1,063,312 (4.06 MB) + Trainable params: 354,258 (1.35 MB) + Non-trainable params: 536 (2.09 KB) + Optimizer params: 708,518 (2.70 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmp5f2wr2e9/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) +Captures: + 15041532880: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041530384: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041533648: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041533072: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041530000: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041529232: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041539792: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041541328: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041534416: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041530192: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15041537680: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552998864: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552998672: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553004048: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552998288: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552999440: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553007888: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553004432: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553003280: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552998480: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13552995408: TensorSpec(shape=(), dtype=tf.resource, name=None) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_rng_unbound_method Model: "FlaxDropoutModel1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 5s 666ms/step - categorical_accuracy: 0.0000e+00 - loss: 2.3471  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - categorical_accuracy: 0.0779 - loss: 2.3606  10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - categorical_accuracy: 0.1187 - loss: 2.3308 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step +Model: "FlaxDropoutModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step +Model: "FlaxDropoutModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 1,944,680 (7.42 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + Optimizer params: 1,296,454 (4.95 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpa8wp4lmo/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) +Captures: + 13553005968: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553000400: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553000592: TensorSpec(shape=(), dtype=tf.resource, name=None) + 13553002512: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15446802512: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15446807312: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15446809232: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15446806736: TensorSpec(shape=(), dtype=tf.resource, name=None) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy Model: "FlaxDropoutModel1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 6s 733ms/step - categorical_accuracy: 0.0625 - loss: 2.3310  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - categorical_accuracy: 0.0840 - loss: 2.3561  10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - categorical_accuracy: 0.0750 - loss: 2.3423 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step +Model: "FlaxDropoutModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 648,226 (2.47 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 45ms/step +Model: "FlaxDropoutModel2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 1,944,680 (7.42 MB) + Trainable params: 648,226 (2.47 MB) + Non-trainable params: 0 (0.00 B) + Optimizer params: 1,296,454 (4.95 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpumrh4fc7/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_3') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float16, name=None) +Captures: + 15316476432: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316482960: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316476816: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316477584: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316482768: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316471632: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316473936: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15316472784: TensorSpec(shape=(), dtype=tf.resource, name=None) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_jax_layer_training_independent !!type tf.Tensor([704741653 0], shape=(2,), dtype=uint32) +Model: "jax_stateless_apply1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ jax_layer (JaxLayer) │ (None, 10) │ 266,610 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 266,610 (1.02 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 0 (0.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 3s 404ms/step - categorical_accuracy: 0.0938 - loss: 2.3464 10/10 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - categorical_accuracy: 0.1000 - loss: 2.6237 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step +Model: "jax_stateless_apply2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,610 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 266,610 (1.02 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 0 (0.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step +!!type tf.Tensor([766770715 0], shape=(2,), dtype=uint32) +Model: "jax_stateless_apply2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,610 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 799,832 (3.05 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 0 (0.00 B) + Optimizer params: 533,222 (2.03 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpq_facvuv/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) +Captures: + 15448744848: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15448737360: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15448741008: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15448747728: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15448748112: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15448739664: TensorSpec(shape=(), dtype=tf.resource, name=None) +!!type tf.Tensor([852611657 0], shape=(2,), dtype=uint32) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_jax_layer_training_state !!type tf.Tensor([248521834 0], shape=(2,), dtype=uint32) +Model: "jax_stateful_apply1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ jax_layer (JaxLayer) │ (None, 10) │ 266,611 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 266,611 (1.02 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 1 (4.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 5s 622ms/step - categorical_accuracy: 0.1250 - loss: 2.3601 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - categorical_accuracy: 0.0875 - loss: 2.6616 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 32ms/step +Model: "jax_stateful_apply2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 266,611 (1.02 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 1 (4.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step +!!type tf.Tensor([990546260 0], shape=(2,), dtype=uint32) +Model: "jax_stateful_apply2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 799,833 (3.05 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 1 (4.00 B) + Optimizer params: 533,222 (2.03 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpmmigd4b8/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) +Captures: + 15562756688: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15562756880: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15562763216: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15562766672: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15562759376: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15562763600: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15562757456: TensorSpec(shape=(), dtype=tf.resource, name=None) +!!type tf.Tensor([971074536 0], shape=(2,), dtype=uint32) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_jax_layer_training_state_dtype_policy !!type tf.Tensor([539657327 0], shape=(2,), dtype=uint32) +Model: "jax_stateful_apply1" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ +├─────────────────────────────────┼────────────────────────┼───────────────┤ +│ jax_layer (JaxLayer) │ (None, 10) │ 266,611 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 266,611 (1.02 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 1 (4.00 B) +  1/10 ━━━━━━━━━━━━━━━━━━━━ 3s 443ms/step - categorical_accuracy: 0.0938 - loss: 2.3004 10/10 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - categorical_accuracy: 0.0906 - loss: 2.6952 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step +Model: "jax_stateful_apply2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 266,611 (1.02 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 1 (4.00 B) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step +!!type tf.Tensor([99170990 0], shape=(2,), dtype=uint32) +Model: "jax_stateful_apply2" +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ +┃ Layer (type) ┃ Output Shape ┃ Param # ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ +│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ +└─────────────────────────────────┴────────────────────────┴───────────────┘ + Total params: 799,833 (3.05 MB) + Trainable params: 266,610 (1.02 MB) + Non-trainable params: 1 (4.00 B) + Optimizer params: 533,222 (2.03 MB) + 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step +Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmp2vkivmt3/jax_layer_export'. The following endpoints are available: + +* Endpoint 'serve' + args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_3') +Output Type: + TensorSpec(shape=(None, 10), dtype=tf.float16, name=None) +Captures: + 15668397456: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15668397264: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15668397648: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15668402832: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15668400528: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15668400144: TensorSpec(shape=(), dtype=tf.resource, name=None) + 15668403408: TensorSpec(shape=(), dtype=tf.resource, name=None) +!!type tf.Tensor([506803281 0], shape=(2,), dtype=uint32) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_rng_seeding !!type tf.Tensor([906691060 0], shape=(2,), dtype=uint32) +!!type tf.Tensor([906691060 0], shape=(2,), dtype=uint32) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_mapping_instead_of_sequence PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_missing_dict_key PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_missing_variable_in_list PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_no_initial_state PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_sequence_instead_of_mapping PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_sequence_instead_of_variable PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_different_argument_order !!type tf.Tensor([548977049 0], shape=(2,), dtype=uint32) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_flax_state_no_params PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_minimal_arguments PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_missing_inputs_in_call_fn PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_missing_inputs_in_init_fn PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_no_init_fn_and_no_params PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_polymorphic_shape_more_than_26_dimension_names PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_state_jax_registered_node_class PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_state_non_tensor_leaves PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_state_none_leaves PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_structures_as_inputs_and_outputs PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_training_in_call_fn_but_not_init_fn !!type tf.Tensor([957413343 0], shape=(2,), dtype=uint32) +PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_unsupported_argument_in_call_fn PASSED +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_unsupported_argument_in_init_fn PASSED + +============================= 28 passed in 15.55s ============================== From 0b468af44310150f0b0a1d571b243b77bc126576 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 11 Nov 2025 21:31:45 -0800 Subject: [PATCH 11/24] update docstring --- keras/src/utils/jax_layer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 4a0bf0061c13..aa0db30acb12 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -460,16 +460,17 @@ def _get_init_rng(self): def _get_call_rng(self, training): """ - Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`. + Returns a key in form of the backend array of size 2 dtype uint32 + to pass to `call_fn`. - By default, this returns a single `PRNGKey` retrieved by calling + By default, this returns a Jax or TF array of size 2 by calling `self.seed_generator.next()` when `training` is `True`, and `None` when `training` is `False`. Override this to return a different structure or to pass RNGs in inference mode too. Returns: - a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as - the `rng` argument of `call_fn`. + a key as an Jax or TF array of size 2 dtype uint32 will be passed + as the `rng` argument of `call_fn`. """ if training: return self.seed_generator.next() From ac47583d1b5454b7ee84423fcb242d2f286ab85f Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 09:47:26 -0800 Subject: [PATCH 12/24] resolve comments --- keras/src/utils/jax_layer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index aa0db30acb12..47595cc80ee4 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -19,17 +19,6 @@ from keras.src.utils.module_utils import tensorflow as tf -def standardize_pytree_collections(pytree): - if isinstance(pytree, (str, bytes)): - return pytree - if isinstance(pytree, collections.abc.Mapping): - return {k: standardize_pytree_collections(v) for k, v in pytree.items()} - elif isinstance(pytree, collections.abc.Sequence): - return [standardize_pytree_collections(v) for v in pytree] - else: - return pytree - - if backend.backend() == "tensorflow": tf_no_automatic_dependency_tracking = ( tf.__internal__.tracking.no_automatic_dependency_tracking @@ -455,7 +444,6 @@ def _get_init_rng(self): as the `rng` argument of `init_fn`. """ next = self.seed_generator.next() - print("!!type", type(next), next) return next def _get_call_rng(self, training): From 4e7aa1a34a7fc8484f147b0685325f5e9483d120 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 09:48:47 -0800 Subject: [PATCH 13/24] delete log --- log.log | 403 -------------------------------------------------------- 1 file changed, 403 deletions(-) delete mode 100644 log.log diff --git a/log.log b/log.log deleted file mode 100644 index 34a21dddd0c8..000000000000 --- a/log.log +++ /dev/null @@ -1,403 +0,0 @@ -============================= test session starts ============================== -platform darwin -- Python 3.12.10, pytest-8.4.2, pluggy-1.6.0 -- /Users/wenyiguo/keras/venv/bin/python3.12 -cachedir: .pytest_cache -rootdir: /Users/wenyiguo/keras -configfile: pyproject.toml -plugins: cov-7.0.0 -collecting ... collected 28 items - -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method Model: "FlaxTrainingIndependentModel1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 648,226 (2.47 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 5s 655ms/step - categorical_accuracy: 0.0625 - loss: 2.2743  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - categorical_accuracy: 0.0810 - loss: 2.3819  10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - categorical_accuracy: 0.1063 - loss: 2.3672 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step -Model: "FlaxTrainingIndependentModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 648,226 (2.47 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step -Model: "FlaxTrainingIndependentModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 1,944,680 (7.42 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) - Optimizer params: 1,296,454 (4.95 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 42ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmprgwrpg8r/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) -Captures: - 13552996560: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552996752: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552994064: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552996944: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552992912: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552993488: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552997904: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552997328: TensorSpec(shape=(), dtype=tf.resource, name=None) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_rng_state_no_method Model: "FlaxBatchNormModel1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ flax_layer (FlaxLayer) │ (None, 10) │ 354,794 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 354,794 (1.35 MB) - Trainable params: 354,258 (1.35 MB) - Non-trainable params: 536 (2.09 KB) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 10s 1s/step - categorical_accuracy: 0.0312 - loss: 2.8184  4/10 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - categorical_accuracy: 0.0482 - loss: 2.7029  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - categorical_accuracy: 0.0649 - loss: 2.6597 10/10 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - categorical_accuracy: 0.0701 - loss: 2.6486 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - categorical_accuracy: 0.0844 - loss: 2.6382 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step -Model: "FlaxBatchNormModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 354,794 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 354,794 (1.35 MB) - Trainable params: 354,258 (1.35 MB) - Non-trainable params: 536 (2.09 KB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step -Model: "FlaxBatchNormModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 354,794 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 1,063,312 (4.06 MB) - Trainable params: 354,258 (1.35 MB) - Non-trainable params: 536 (2.09 KB) - Optimizer params: 708,518 (2.70 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmp5f2wr2e9/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) -Captures: - 15041532880: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041530384: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041533648: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041533072: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041530000: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041529232: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041539792: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041541328: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041534416: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041530192: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15041537680: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552998864: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552998672: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553004048: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552998288: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552999440: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553007888: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553004432: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553003280: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552998480: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13552995408: TensorSpec(shape=(), dtype=tf.resource, name=None) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_rng_unbound_method Model: "FlaxDropoutModel1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 648,226 (2.47 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 5s 666ms/step - categorical_accuracy: 0.0000e+00 - loss: 2.3471  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - categorical_accuracy: 0.0779 - loss: 2.3606  10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - categorical_accuracy: 0.1187 - loss: 2.3308 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step -Model: "FlaxDropoutModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 648,226 (2.47 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step -Model: "FlaxDropoutModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 1,944,680 (7.42 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) - Optimizer params: 1,296,454 (4.95 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpa8wp4lmo/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) -Captures: - 13553005968: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553000400: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553000592: TensorSpec(shape=(), dtype=tf.resource, name=None) - 13553002512: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15446802512: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15446807312: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15446809232: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15446806736: TensorSpec(shape=(), dtype=tf.resource, name=None) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy Model: "FlaxDropoutModel1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ flax_layer (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 648,226 (2.47 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 6s 733ms/step - categorical_accuracy: 0.0625 - loss: 2.3310  7/10 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - categorical_accuracy: 0.0840 - loss: 2.3561  10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - categorical_accuracy: 0.0750 - loss: 2.3423 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step -Model: "FlaxDropoutModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 648,226 (2.47 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 45ms/step -Model: "FlaxDropoutModel2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ flax_layer_1 (FlaxLayer) │ (None, 10) │ 648,226 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 1,944,680 (7.42 MB) - Trainable params: 648,226 (2.47 MB) - Non-trainable params: 0 (0.00 B) - Optimizer params: 1,296,454 (4.95 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpumrh4fc7/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_3') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float16, name=None) -Captures: - 15316476432: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316482960: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316476816: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316477584: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316482768: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316471632: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316473936: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15316472784: TensorSpec(shape=(), dtype=tf.resource, name=None) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_jax_layer_training_independent !!type tf.Tensor([704741653 0], shape=(2,), dtype=uint32) -Model: "jax_stateless_apply1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ jax_layer (JaxLayer) │ (None, 10) │ 266,610 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 266,610 (1.02 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 0 (0.00 B) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 3s 404ms/step - categorical_accuracy: 0.0938 - loss: 2.3464 10/10 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - categorical_accuracy: 0.1000 - loss: 2.6237 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step -Model: "jax_stateless_apply2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,610 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 266,610 (1.02 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 0 (0.00 B) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step -!!type tf.Tensor([766770715 0], shape=(2,), dtype=uint32) -Model: "jax_stateless_apply2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,610 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 799,832 (3.05 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 0 (0.00 B) - Optimizer params: 533,222 (2.03 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpq_facvuv/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) -Captures: - 15448744848: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15448737360: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15448741008: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15448747728: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15448748112: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15448739664: TensorSpec(shape=(), dtype=tf.resource, name=None) -!!type tf.Tensor([852611657 0], shape=(2,), dtype=uint32) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_jax_layer_training_state !!type tf.Tensor([248521834 0], shape=(2,), dtype=uint32) -Model: "jax_stateful_apply1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ jax_layer (JaxLayer) │ (None, 10) │ 266,611 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 266,611 (1.02 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 1 (4.00 B) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 5s 622ms/step - categorical_accuracy: 0.1250 - loss: 2.3601 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - categorical_accuracy: 0.0875 - loss: 2.6616 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 28ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 32ms/step -Model: "jax_stateful_apply2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 266,611 (1.02 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 1 (4.00 B) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step -!!type tf.Tensor([990546260 0], shape=(2,), dtype=uint32) -Model: "jax_stateful_apply2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 799,833 (3.05 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 1 (4.00 B) - Optimizer params: 533,222 (2.03 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmpmmigd4b8/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_2') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float32, name=None) -Captures: - 15562756688: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15562756880: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15562763216: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15562766672: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15562759376: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15562763600: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15562757456: TensorSpec(shape=(), dtype=tf.resource, name=None) -!!type tf.Tensor([971074536 0], shape=(2,), dtype=uint32) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_jax_layer_training_state_dtype_policy !!type tf.Tensor([539657327 0], shape=(2,), dtype=uint32) -Model: "jax_stateful_apply1" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ input_layer (InputLayer) │ (None, 28, 28, 1) │ 0 │ -├─────────────────────────────────┼────────────────────────┼───────────────┤ -│ jax_layer (JaxLayer) │ (None, 10) │ 266,611 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 266,611 (1.02 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 1 (4.00 B) -  1/10 ━━━━━━━━━━━━━━━━━━━━ 3s 443ms/step - categorical_accuracy: 0.0938 - loss: 2.3004 10/10 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - categorical_accuracy: 0.0906 - loss: 2.6952 - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step -Model: "jax_stateful_apply2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 266,611 (1.02 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 1 (4.00 B) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step -!!type tf.Tensor([99170990 0], shape=(2,), dtype=uint32) -Model: "jax_stateful_apply2" -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ -┃ Layer (type) ┃ Output Shape ┃ Param # ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ -│ jax_layer_1 (JaxLayer) │ (None, 10) │ 266,611 │ -└─────────────────────────────────┴────────────────────────┴───────────────┘ - Total params: 799,833 (3.05 MB) - Trainable params: 266,610 (1.02 MB) - Non-trainable params: 1 (4.00 B) - Optimizer params: 533,222 (2.03 MB) - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step -Saved artifact at '/var/folders/jk/kp6ss1yd1mx710fqvdvcrzfr00r8_1/T/tmp2vkivmt3/jax_layer_export'. The following endpoints are available: - -* Endpoint 'serve' - args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='keras_tensor_3') -Output Type: - TensorSpec(shape=(None, 10), dtype=tf.float16, name=None) -Captures: - 15668397456: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15668397264: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15668397648: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15668402832: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15668400528: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15668400144: TensorSpec(shape=(), dtype=tf.resource, name=None) - 15668403408: TensorSpec(shape=(), dtype=tf.resource, name=None) -!!type tf.Tensor([506803281 0], shape=(2,), dtype=uint32) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_rng_seeding !!type tf.Tensor([906691060 0], shape=(2,), dtype=uint32) -!!type tf.Tensor([906691060 0], shape=(2,), dtype=uint32) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_mapping_instead_of_sequence PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_missing_dict_key PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_missing_variable_in_list PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_no_initial_state PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_sequence_instead_of_mapping PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_state_mismatch_during_update_sequence_instead_of_variable PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_different_argument_order !!type tf.Tensor([548977049 0], shape=(2,), dtype=uint32) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_flax_state_no_params PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_minimal_arguments PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_missing_inputs_in_call_fn PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_missing_inputs_in_init_fn PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_no_init_fn_and_no_params PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_polymorphic_shape_more_than_26_dimension_names PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_state_jax_registered_node_class PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_state_non_tensor_leaves PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_state_none_leaves PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_structures_as_inputs_and_outputs PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_training_in_call_fn_but_not_init_fn !!type tf.Tensor([957413343 0], shape=(2,), dtype=uint32) -PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_unsupported_argument_in_call_fn PASSED -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_with_unsupported_argument_in_init_fn PASSED - -============================= 28 passed in 15.55s ============================== From ac85e30b34b126c21348714de9837c333d40ef39 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 11:45:20 -0800 Subject: [PATCH 14/24] use jax random dtype --- keras/src/utils/jax_layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 47595cc80ee4..57a0662595d4 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -1,4 +1,3 @@ -import collections import functools import inspect import itertools @@ -11,6 +10,9 @@ from keras.src.api_export import keras_export from keras.src.backend.common.variables import is_float_dtype from keras.src.backend.common.variables import standardize_dtype +from keras.src.backend.jax.core import ( + random_seed_dtype as jax_random_seed_dtype, +) from keras.src.layers.layer import Layer from keras.src.saving import serialization_lib from keras.src.utils import jax_utils @@ -18,7 +20,6 @@ from keras.src.utils.module_utils import jax from keras.src.utils.module_utils import tensorflow as tf - if backend.backend() == "tensorflow": tf_no_automatic_dependency_tracking = ( tf.__internal__.tracking.no_automatic_dependency_tracking @@ -252,6 +253,7 @@ def __init__( dtype = jax.numpy.uint32 elif backend.backend() == "tensorflow": dtype = tf.uint32 + dtype = jax_random_seed_dtype() self.seed_generator = backend.random.SeedGenerator( seed=seed, dtype=dtype ) From a1429b56ed4c48f2d9924e191577de9b10bb3abc Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 13:01:09 -0800 Subject: [PATCH 15/24] remove lines --- keras/src/utils/jax_layer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 57a0662595d4..802dd9739c8f 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -249,10 +249,6 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - if backend.backend() == "jax": - dtype = jax.numpy.uint32 - elif backend.backend() == "tensorflow": - dtype = tf.uint32 dtype = jax_random_seed_dtype() self.seed_generator = backend.random.SeedGenerator( seed=seed, dtype=dtype From c9360e72c53a81e3ea7bc79e59d29d1a749afb2b Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 13:03:42 -0800 Subject: [PATCH 16/24] local import --- keras/src/utils/jax_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 802dd9739c8f..1c1a69822df9 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -10,9 +10,6 @@ from keras.src.api_export import keras_export from keras.src.backend.common.variables import is_float_dtype from keras.src.backend.common.variables import standardize_dtype -from keras.src.backend.jax.core import ( - random_seed_dtype as jax_random_seed_dtype, -) from keras.src.layers.layer import Layer from keras.src.saving import serialization_lib from keras.src.utils import jax_utils @@ -235,6 +232,10 @@ def __init__( seed=None, **kwargs, ): + from keras.src.backend.jax.core import ( + random_seed_dtype as jax_random_seed_dtype, + ) + if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( f"{self.__class__.__name__} is only supported with the JAX or" From 29751dda6c0766dd612d0d2bb8d375d9a57737db Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 13:09:38 -0800 Subject: [PATCH 17/24] format --- keras/src/utils/jax_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 1c1a69822df9..8d35aeb8a839 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -235,7 +235,7 @@ def __init__( from keras.src.backend.jax.core import ( random_seed_dtype as jax_random_seed_dtype, ) - + if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( f"{self.__class__.__name__} is only supported with the JAX or" From 5238252c3aab347eb98c4e5d78bad45c86815227 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 17:46:39 -0800 Subject: [PATCH 18/24] Change seed_gen to backend respective dtype and convert later for gpu test --- keras/src/utils/jax_layer.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 8d35aeb8a839..6ebd96c1d02d 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -27,6 +27,16 @@ def tf_no_automatic_dependency_tracking(fn): return fn +def _convert_to_jax_key(tensor): + if backend.backend() == "tensorflow": + if tensor.dtype == tf.int64: + key_uint32 = tf.bitcast(tensor, tf.uint32)[0] + if tf.is_symbolic_tensor(key_uint32): + return key_uint32 + else: + return jax.numpy.array(key_uint32) + + @keras_export("keras.layers.JaxLayer") class JaxLayer(Layer): """Keras Layer that wraps a JAX model. @@ -250,9 +260,8 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - dtype = jax_random_seed_dtype() self.seed_generator = backend.random.SeedGenerator( - seed=seed, dtype=dtype + seed=seed ) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) @@ -442,8 +451,7 @@ def _get_init_rng(self): a key as an Jax or TF array of size 2 dtype uint32 will be passed as the `rng` argument of `init_fn`. """ - next = self.seed_generator.next() - return next + return self.seed_generator.next() def _get_call_rng(self, training): """ @@ -486,7 +494,7 @@ def create_input(shape): if argument_name == "rng": init_args.append( jax.tree_util.tree_map( - convert_to_tensor, self._get_init_rng() + _convert_to_jax_key, self._get_init_rng() ) ) elif argument_name == "inputs": @@ -559,7 +567,11 @@ def unwrap_variable(variable): jax.tree_util.tree_map(unwrap_variable, self.state) ) elif argument_name == "rng": - call_args.append(self._get_call_rng(training)) + call_args.append( + jax.tree_util.tree_map( + _convert_to_jax_key, self._get_call_rng(training) + ) + ) elif argument_name == "inputs": call_args.append(inputs) elif argument_name == "training": From a7373a8f33d1fd2784423b30cb8eba2640f3f714 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 17:47:09 -0800 Subject: [PATCH 19/24] format --- keras/src/utils/jax_layer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 6ebd96c1d02d..0901cc2a7e3c 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -242,9 +242,6 @@ def __init__( seed=None, **kwargs, ): - from keras.src.backend.jax.core import ( - random_seed_dtype as jax_random_seed_dtype, - ) if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( @@ -260,9 +257,7 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.seed_generator = backend.random.SeedGenerator( - seed=seed - ) + self.seed_generator = backend.random.SeedGenerator(seed=seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: @@ -473,7 +468,6 @@ def _get_call_rng(self, training): return None def _initialize_weights(self, input_shape): - from keras.src.backend.jax.core import convert_to_tensor if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): # This exception is not actually shown, it is caught and a detailed From 1801c2e3eecc0abeff8fe32e218e1bed25cf66ac Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 17:50:04 -0800 Subject: [PATCH 20/24] format again --- keras/src/utils/jax_layer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 0901cc2a7e3c..53f0f20e7266 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -242,7 +242,6 @@ def __init__( seed=None, **kwargs, ): - if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( f"{self.__class__.__name__} is only supported with the JAX or" @@ -468,7 +467,6 @@ def _get_call_rng(self, training): return None def _initialize_weights(self, input_shape): - if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): # This exception is not actually shown, it is caught and a detailed # warning about calling 'build' is printed. From 87daedf930aea0fe6ca54bec288353ff5ccbfb9c Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 18:33:05 -0800 Subject: [PATCH 21/24] rever seed_generator --- keras/src/random/seed_generator.py | 4 +- keras/src/utils/jax_layer.py | 2 +- log.log | 162 +++++++++++++++++++++++++++++ 3 files changed, 164 insertions(+), 4 deletions(-) create mode 100644 log.log diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py index 2161de249919..dd2adbc13bbe 100644 --- a/keras/src/random/seed_generator.py +++ b/keras/src/random/seed_generator.py @@ -63,7 +63,6 @@ def __init__(self, seed=None, name=None, **kwargs): self.name = name custom_backend = kwargs.pop("backend", None) - dtype = kwargs.pop("dtype", None) if kwargs: raise ValueError(f"Unrecognized keyword arguments: {kwargs}") if custom_backend is not None: @@ -85,11 +84,10 @@ def seed_initializer(*args, **kwargs): return self.backend.convert_to_tensor([seed, 0], dtype=dtype) with self.backend.name_scope(self.name, caller=self): - dtype = dtype if dtype else self.backend.random_seed_dtype() self.state = self.backend.Variable( seed_initializer, shape=(2,), - dtype=dtype, + dtype=self.backend.random_seed_dtype(), trainable=False, aggregation="none", name="seed_generator_state", diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 53f0f20e7266..2fb488c6bd71 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -256,7 +256,7 @@ def __init__( super().__init__(**kwargs) self.call_fn = call_fn self.init_fn = init_fn - self.seed_generator = backend.random.SeedGenerator(seed=seed) + self.seed_generator = backend.random.SeedGenerator(seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) if self.params is not None or self.state is not None: diff --git a/log.log b/log.log new file mode 100644 index 000000000000..df06bfe8567e --- /dev/null +++ b/log.log @@ -0,0 +1,162 @@ +============================= test session starts ============================== +platform darwin -- Python 3.12.10, pytest-8.4.2, pluggy-1.6.0 -- /Users/wenyiguo/keras/venv/bin/python3.12 +cachedir: .pytest_cache +rootdir: /Users/wenyiguo/keras +configfile: pyproject.toml +plugins: cov-7.0.0 +collecting ... collected 1 item + +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method FAILED + +=================================== FAILURES =================================== +________ TestJaxLayer.test_flax_layer_training_independent_bound_method ________ + +self = +flax_model_class = +flax_model_method = 'forward', init_kwargs = {}, trainable_weights = 8 +trainable_params = 648226, non_trainable_weights = 0, non_trainable_params = 0 + + @parameterized.named_parameters( + { + "testcase_name": "training_independent_bound_method", + "flax_model_class": "FlaxTrainingIndependentModel", + "flax_model_method": "forward", + "init_kwargs": {}, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_unbound_method", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_state_no_method", + "flax_model_class": "FlaxBatchNormModel", + "flax_model_method": None, + "init_kwargs": {}, + "trainable_weights": 13, + "trainable_params": 354258, + "non_trainable_weights": 8, + "non_trainable_params": 536, + }, + { + "testcase_name": "training_rng_unbound_method_dtype_policy", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + ) + @pytest.mark.skipif(flax is None, reason="Flax library is not available.") + def test_flax_layer( + self, + flax_model_class, + flax_model_method, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + flax_model_class = FLAX_OBJECTS.get(flax_model_class) + if "method" in init_kwargs: + init_kwargs["method"] = FLAX_OBJECTS.get(init_kwargs["method"]) + + def create_wrapper(**kwargs): + params = kwargs.pop("params") if "params" in kwargs else None + state = kwargs.pop("state") if "state" in kwargs else None + if params and state: + variables = {**params, **state} + elif params: + variables = params + elif state: + variables = state + else: + variables = None + kwargs["variables"] = variables + flax_model = flax_model_class() + if flax_model_method: + kwargs["method"] = getattr(flax_model, flax_model_method) + return FlaxLayer(flax_model_class(), **kwargs) + +> self._test_layer( + flax_model_class.__name__, + create_wrapper, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ) + +keras/src/utils/jax_layer_test.py:488: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +keras/src/utils/jax_layer_test.py:231: in _test_layer + outputs1 = layer1(inputs1) + ^^^^^^^^^^^^^^^ +keras/src/utils/traceback_utils.py:113: in error_handler + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ +keras/src/layers/layer.py:866: in __call__ + self._maybe_build(call_spec) +keras/src/layers/layer.py:1477: in _maybe_build + self.build(**shapes_dict) +keras/src/layers/layer.py:231: in build_wrapper + original_build_method(*args, **kwargs) +keras/src/utils/jax_layer.py:510: in build + self._initialize_weights(input_shape) +keras/src/utils/jax_layer.py:497: in _initialize_weights + init_result = self.init_fn(*init_args) + ^^^^^^^^^^^^^^^^^^^^^^^^ +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +rng = {'dropout': None, 'params': None} +inputs = Array([[[[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]... [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]]]], dtype=float32) + + def init_without_training(rng, inputs): + return self._variables_to_params_and_state( +> self.module.init( + rng, + inputs, + method=self.method, + ) + ) +E ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. +E -------------------- +E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. + +keras/src/utils/jax_layer.py:755: ValueError +=========================== short test summary info ============================ +FAILED keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method - ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. +-------------------- +For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. +============================== 1 failed in 1.72s =============================== From 233f8a5e81584e3dd609f4daa0bc132c57dc50e3 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Wed, 12 Nov 2025 18:35:57 -0800 Subject: [PATCH 22/24] fix jax backend bug --- keras/src/utils/jax_layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 2fb488c6bd71..5b451fc63076 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -35,6 +35,7 @@ def _convert_to_jax_key(tensor): return key_uint32 else: return jax.numpy.array(key_uint32) + return tensor @keras_export("keras.layers.JaxLayer") From 9df6e56c9ba1edc2f167daf4cff23ae7818682e2 Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Thu, 13 Nov 2025 13:56:06 -0800 Subject: [PATCH 23/24] address comment --- keras/src/utils/jax_layer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 5b451fc63076..7b912d59289e 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -29,12 +29,7 @@ def tf_no_automatic_dependency_tracking(fn): def _convert_to_jax_key(tensor): if backend.backend() == "tensorflow": - if tensor.dtype == tf.int64: - key_uint32 = tf.bitcast(tensor, tf.uint32)[0] - if tf.is_symbolic_tensor(key_uint32): - return key_uint32 - else: - return jax.numpy.array(key_uint32) + return tf.bitcast(tensor, tf.uint32)[0] return tensor @@ -487,7 +482,8 @@ def create_input(shape): if argument_name == "rng": init_args.append( jax.tree_util.tree_map( - _convert_to_jax_key, self._get_init_rng() + lambda x: jax.numpy.array(_convert_to_jax_key(x)), + self._get_init_rng(), ) ) elif argument_name == "inputs": From 6210659c6e04f5a1788b022209588416575e598e Mon Sep 17 00:00:00 2001 From: wenyi-guo Date: Tue, 18 Nov 2025 11:15:12 -0800 Subject: [PATCH 24/24] skip gpu test --- keras/src/utils/jax_layer_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 902ec19c2e35..778674d7b937 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -185,6 +185,7 @@ def from_config(cls, config): backend.backend() not in ["jax", "tensorflow"], reason="JaxLayer and FlaxLayer are only supported with JAX and TF backend", ) +@pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="GPU test failure") class TestJaxLayer(testing.TestCase): def _test_layer( self,