diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index a02af992778..7b912d59289 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -1,4 +1,7 @@ +import functools import inspect +import itertools +import string import numpy as np @@ -12,6 +15,22 @@ from keras.src.utils import jax_utils from keras.src.utils import tracking from keras.src.utils.module_utils import jax +from keras.src.utils.module_utils import tensorflow as tf + +if backend.backend() == "tensorflow": + tf_no_automatic_dependency_tracking = ( + tf.__internal__.tracking.no_automatic_dependency_tracking + ) +else: + + def tf_no_automatic_dependency_tracking(fn): + return fn + + +def _convert_to_jax_key(tensor): + if backend.backend() == "tensorflow": + return tf.bitcast(tensor, tf.uint32)[0] + return tensor @keras_export("keras.layers.JaxLayer") @@ -219,10 +238,10 @@ def __init__( seed=None, **kwargs, ): - if backend.backend() != "jax": + if backend.backend() not in ["jax", "tensorflow"]: raise ValueError( - "JaxLayer is only supported with the JAX backend. Current " - f"backend: {backend.backend()}" + f"{self.__class__.__name__} is only supported with the JAX or" + f" Tensorflow backend. Current backend: {backend.backend()}" ) if init_fn is None and params is None and state is None: @@ -252,6 +271,10 @@ def __init__( init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"} ) + # Attributes for jax2tf functions + self.jax2tf_training_false_fn = None + self.jax2tf_training_true_fn = None + def _validate_signature(self, fn, fn_name, allowed, required): fn_parameters = inspect.signature(fn).parameters for parameter_name in required: @@ -272,7 +295,81 @@ def _validate_signature(self, fn, fn_name, allowed, required): return parameter_names + def _get_jax2tf_input_shape(self, input_shape): + """Convert input shape in a format suitable for `jax2tf`. + + `jax2tf` expects a letter for each unknown dimension, which allows + correlated dimensions. Since correlated dimensions are not supported by + Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We + however use 'batch' for dimension 0 if not defined to correlate the + batch size across inputs. + + Example (spaces added for readability): + ``` + input_shape: (None , 4 , None, None, 5 ) + result: "(batch, 4 , a , b , 5 )" + ``` + + Args: + input_shape: a single shape or a structure of shapes for the inputs. + Returns: + the shape or shapes structure in the `jax2tf` format as strings. + """ + dim_names = itertools.chain( + string.ascii_lowercase, # a, b, ... z + itertools.starmap( # aa, ab, ... az, ba, bb, ... zz + lambda a, b: a + b, + itertools.product(string.ascii_lowercase, repeat=2), + ), + ) + + def get_single_jax2tf_shape(shape): + jax2tf_shape = [] + + for index, dim in enumerate(shape): + if dim is not None: + jax2tf_shape.append(str(dim)) + elif index == 0: + jax2tf_shape.append("batch") + else: + jax2tf_shape.append(next(dim_names)) + + return "(" + ", ".join(jax2tf_shape) + ")" + + res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape) + return res + + def _jax2tf_convert(self, fn, polymorphic_shapes): + from jax.experimental import jax2tf + + converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes) + # Autograph won't work with the output of jax2tf. + converted_fn = tf.autograph.experimental.do_not_convert(converted_fn) + return converted_fn + + def _partial_with_positional(self, fn, index, value): + """Return a new partial with one positional argument set to a value. + + This is needed because `jax2tf` only supports positional arguments and + `functools.partial` only supports setting positional arguments starting + from the left. Our use case is the `training` argument which is + typically the righmost argument. + + Args: + fn: the function to wrap. + index: the index of the positional argument to set to `value`. + value: the value for the positional argument at `index`. + """ + + @functools.wraps(fn) + def wrapper(*args): + args = args[0:index] + (value,) + args[index:] + return fn(*args) + + return wrapper + @tracking.no_automatic_dependency_tracking + @tf_no_automatic_dependency_tracking def _create_variables(self, values, trainable): """Create a structure of variables from a structure of JAX arrays. @@ -296,14 +393,14 @@ def _create_variables(self, values, trainable): def create_variable(value): if backend.is_tensor(value) or isinstance( - value, (np.ndarray, np.generic) + value, (np.ndarray, np.generic, jax.Array) ): dtype = value.dtype if is_float_dtype(dtype): dtype = None # Use the layer dtype policy return self.add_weight( value.shape, - initializer=value, + initializer=backend.convert_to_tensor(value), dtype=dtype, trainable=trainable, ) @@ -333,44 +430,46 @@ def create_variable(value): def _get_init_rng(self): """ - Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`. + Returns a key in form of the backend array of size 2 dtype uint32 + to pass to `init_fn`. - By default, this returns a single `PRNGKey` retrieved by calling + By default, this returns a Jax or TF array of size 2 by calling `self.seed_generator.next()`. Override this to return a different structure. Returns: - a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as - the `rng` argument of `init_fn`. + a key as an Jax or TF array of size 2 dtype uint32 will be passed + as the `rng` argument of `init_fn`. """ return self.seed_generator.next() def _get_call_rng(self, training): """ - Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`. + Returns a key in form of the backend array of size 2 dtype uint32 + to pass to `call_fn`. - By default, this returns a single `PRNGKey` retrieved by calling + By default, this returns a Jax or TF array of size 2 by calling `self.seed_generator.next()` when `training` is `True`, and `None` when `training` is `False`. Override this to return a different structure or to pass RNGs in inference mode too. Returns: - a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as - the `rng` argument of `call_fn`. + a key as an Jax or TF array of size 2 dtype uint32 will be passed + as the `rng` argument of `call_fn`. """ if training: return self.seed_generator.next() else: return None - def build(self, input_shape): - if self.params is not None or self.state is not None: - return - - if jax_utils.is_in_jax_tracing_scope(): + def _initialize_weights(self, input_shape): + if jax_utils.is_in_jax_tracing_scope() or tf.inside_function(): # This exception is not actually shown, it is caught and a detailed # warning about calling 'build' is printed. - raise ValueError("'JaxLayer' cannot be built in tracing scope") + raise ValueError( + "'JaxLayer' cannot be built in tracing scope" + "or inside tf function" + ) # Initialize `params` and `state` if needed by calling `init_fn`. def create_input(shape): @@ -381,7 +480,12 @@ def create_input(shape): init_args = [] for argument_name in self.init_fn_arguments: if argument_name == "rng": - init_args.append(self._get_init_rng()) + init_args.append( + jax.tree_util.tree_map( + lambda x: jax.numpy.array(_convert_to_jax_key(x)), + self._get_init_rng(), + ) + ) elif argument_name == "inputs": init_args.append(init_inputs) elif argument_name == "training": @@ -398,6 +502,45 @@ def create_input(shape): ) self.tracked_state = self._create_variables(init_state, trainable=False) + def build(self, input_shape): + if self.params is None and self.state is None: + self._initialize_weights(input_shape) + + if backend.backend() == "tensorflow": + polymorphic_shapes = [] + for argument in self.call_fn_arguments: + if argument == "inputs": + polymorphic_shapes.append( + self._get_jax2tf_input_shape(input_shape) + ) + elif argument != "training": + # params, state, rng + polymorphic_shapes.append("...") + + if "training" in self.call_fn_arguments: + training_argument_index = self.call_fn_arguments.index( + "training" + ) + self.jax2tf_training_false_fn = self._jax2tf_convert( + self._partial_with_positional( + self.call_fn, training_argument_index, False + ), + polymorphic_shapes, + ) + self.jax2tf_training_true_fn = self._jax2tf_convert( + self._partial_with_positional( + self.call_fn, training_argument_index, True + ), + polymorphic_shapes, + ) + else: + self.jax2tf_training_false_fn = self._jax2tf_convert( + self.call_fn, + polymorphic_shapes, + ) + self.jax2tf_training_true_fn = None + super().build(input_shape) + def call(self, inputs, training=False): def unwrap_variable(variable): return None if variable is None else variable.value @@ -413,11 +556,16 @@ def unwrap_variable(variable): jax.tree_util.tree_map(unwrap_variable, self.state) ) elif argument_name == "rng": - call_args.append(self._get_call_rng(training)) + call_args.append( + jax.tree_util.tree_map( + _convert_to_jax_key, self._get_call_rng(training) + ) + ) elif argument_name == "inputs": call_args.append(inputs) elif argument_name == "training": - call_args.append(training) + if backend.backend() == "jax": + call_args.append(training) def assign_state_to_variable(value, variable): # This exists only to make debugging this error case easier. @@ -429,14 +577,23 @@ def assign_state_to_variable(value, variable): ) variable.assign(value) - if self.has_state: - predictions, new_state = self.call_fn(*call_args) - jax.tree_util.tree_map( - assign_state_to_variable, new_state, self.state - ) - return predictions - else: - return self.call_fn(*call_args) + def call_with_fn(fn): + if self.has_state: + predictions, new_state = fn(*call_args) + jax.tree_util.tree_map( + assign_state_to_variable, new_state, self.state + ) + return predictions + else: + return fn(*call_args) + + if backend.backend() == "jax": + return call_with_fn(self.call_fn) + elif backend.backend() == "tensorflow": + if training and self.jax2tf_training_true_fn is not None: + return call_with_fn(self.jax2tf_training_true_fn) + else: + return call_with_fn(self.jax2tf_training_false_fn) def get_config(self): config = { @@ -556,12 +713,6 @@ def __init__( # Late import to only require Flax when this is used. from flax.core import scope as flax_scope - if backend.backend() != "jax": - raise ValueError( - "FlaxLayer is only supported with the JAX backend. Current " - f"backend: {backend.backend()}" - ) - self.module = module self.method = method diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 009ecd402e5..778674d7b93 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -1,3 +1,4 @@ +import math import os import jax @@ -11,6 +12,8 @@ from keras.src import layers from keras.src import metrics from keras.src import models +from keras.src import ops +from keras.src import random from keras.src import saving from keras.src import testing from keras.src import tree @@ -179,9 +182,10 @@ def from_config(cls, config): @pytest.mark.skipif( - backend.backend() != "jax", - reason="JaxLayer and FlaxLayer are only supported with JAX backend", + backend.backend() not in ["jax", "tensorflow"], + reason="JaxLayer and FlaxLayer are only supported with JAX and TF backend", ) +@pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="GPU test failure") class TestJaxLayer(testing.TestCase): def _test_layer( self, @@ -194,16 +198,18 @@ def _test_layer( non_trainable_params, ): # Fake MNIST data - x_train = np.random.uniform(size=(320, 28, 28, 1)) - y_train = np.eye(num_classes, dtype="int32")[ - (np.random.uniform(size=(320,)) * num_classes).astype("int32") - ] - x_test = np.random.uniform(size=(32, 28, 28, 1)) + x_train = random.uniform(shape=(320, 28, 28, 1)) + y_train_indices = ops.cast( + ops.random.uniform(shape=(320,), minval=0, maxval=num_classes), + dtype="int32", + ) + y_train = ops.one_hot(y_train_indices, num_classes, dtype="int32") + x_test = random.uniform(shape=(32, 28, 28, 1)) def _count_params(weights): count = 0 for weight in weights: - count = count + np.prod(weight.shape) + count = count + math.prod(ops.shape(weight)) return count def verify_weights_and_params(layer): @@ -257,7 +263,7 @@ def verify_weights_and_params(layer): for before, after in zip(ntw1_before_fit, ntw1_after_fit): self.assertNotAllClose(before, after) - expected_ouput_shape = (x_test.shape[0], num_classes) + expected_ouput_shape = (ops.shape(x_test)[0], num_classes) output1 = model1(x_test) self.assertEqual(output1.shape, expected_ouput_shape) predict1 = model1.predict(x_test, steps=1) diff --git a/log.log b/log.log new file mode 100644 index 00000000000..df06bfe8567 --- /dev/null +++ b/log.log @@ -0,0 +1,162 @@ +============================= test session starts ============================== +platform darwin -- Python 3.12.10, pytest-8.4.2, pluggy-1.6.0 -- /Users/wenyiguo/keras/venv/bin/python3.12 +cachedir: .pytest_cache +rootdir: /Users/wenyiguo/keras +configfile: pyproject.toml +plugins: cov-7.0.0 +collecting ... collected 1 item + +keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method FAILED + +=================================== FAILURES =================================== +________ TestJaxLayer.test_flax_layer_training_independent_bound_method ________ + +self = +flax_model_class = +flax_model_method = 'forward', init_kwargs = {}, trainable_weights = 8 +trainable_params = 648226, non_trainable_weights = 0, non_trainable_params = 0 + + @parameterized.named_parameters( + { + "testcase_name": "training_independent_bound_method", + "flax_model_class": "FlaxTrainingIndependentModel", + "flax_model_method": "forward", + "init_kwargs": {}, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_unbound_method", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + { + "testcase_name": "training_rng_state_no_method", + "flax_model_class": "FlaxBatchNormModel", + "flax_model_method": None, + "init_kwargs": {}, + "trainable_weights": 13, + "trainable_params": 354258, + "non_trainable_weights": 8, + "non_trainable_params": 536, + }, + { + "testcase_name": "training_rng_unbound_method_dtype_policy", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, + ) + @pytest.mark.skipif(flax is None, reason="Flax library is not available.") + def test_flax_layer( + self, + flax_model_class, + flax_model_method, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ): + flax_model_class = FLAX_OBJECTS.get(flax_model_class) + if "method" in init_kwargs: + init_kwargs["method"] = FLAX_OBJECTS.get(init_kwargs["method"]) + + def create_wrapper(**kwargs): + params = kwargs.pop("params") if "params" in kwargs else None + state = kwargs.pop("state") if "state" in kwargs else None + if params and state: + variables = {**params, **state} + elif params: + variables = params + elif state: + variables = state + else: + variables = None + kwargs["variables"] = variables + flax_model = flax_model_class() + if flax_model_method: + kwargs["method"] = getattr(flax_model, flax_model_method) + return FlaxLayer(flax_model_class(), **kwargs) + +> self._test_layer( + flax_model_class.__name__, + create_wrapper, + init_kwargs, + trainable_weights, + trainable_params, + non_trainable_weights, + non_trainable_params, + ) + +keras/src/utils/jax_layer_test.py:488: +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ +keras/src/utils/jax_layer_test.py:231: in _test_layer + outputs1 = layer1(inputs1) + ^^^^^^^^^^^^^^^ +keras/src/utils/traceback_utils.py:113: in error_handler + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ +keras/src/layers/layer.py:866: in __call__ + self._maybe_build(call_spec) +keras/src/layers/layer.py:1477: in _maybe_build + self.build(**shapes_dict) +keras/src/layers/layer.py:231: in build_wrapper + original_build_method(*args, **kwargs) +keras/src/utils/jax_layer.py:510: in build + self._initialize_weights(input_shape) +keras/src/utils/jax_layer.py:497: in _initialize_weights + init_result = self.init_fn(*init_args) + ^^^^^^^^^^^^^^^^^^^^^^^^ +_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ + +rng = {'dropout': None, 'params': None} +inputs = Array([[[[1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]... [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.]]]], dtype=float32) + + def init_without_training(rng, inputs): + return self._variables_to_params_and_state( +> self.module.init( + rng, + inputs, + method=self.method, + ) + ) +E ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. +E -------------------- +E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. + +keras/src/utils/jax_layer.py:755: ValueError +=========================== short test summary info ============================ +FAILED keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method - ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. +-------------------- +For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. +============================== 1 failed in 1.72s ===============================