Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 187 additions & 36 deletions keras/src/utils/jax_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import functools
import inspect
import itertools
import string

import numpy as np

Expand All @@ -12,6 +15,22 @@
from keras.src.utils import jax_utils
from keras.src.utils import tracking
from keras.src.utils.module_utils import jax
from keras.src.utils.module_utils import tensorflow as tf

if backend.backend() == "tensorflow":
tf_no_automatic_dependency_tracking = (
tf.__internal__.tracking.no_automatic_dependency_tracking
)
else:

def tf_no_automatic_dependency_tracking(fn):
return fn


def _convert_to_jax_key(tensor):
if backend.backend() == "tensorflow":
return tf.bitcast(tensor, tf.uint32)[0]
return tensor


@keras_export("keras.layers.JaxLayer")
Expand Down Expand Up @@ -219,10 +238,10 @@ def __init__(
seed=None,
**kwargs,
):
if backend.backend() != "jax":
if backend.backend() not in ["jax", "tensorflow"]:
raise ValueError(
"JaxLayer is only supported with the JAX backend. Current "
f"backend: {backend.backend()}"
f"{self.__class__.__name__} is only supported with the JAX or"
f" Tensorflow backend. Current backend: {backend.backend()}"
)

if init_fn is None and params is None and state is None:
Expand Down Expand Up @@ -252,6 +271,10 @@ def __init__(
init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
)

# Attributes for jax2tf functions
self.jax2tf_training_false_fn = None
self.jax2tf_training_true_fn = None

def _validate_signature(self, fn, fn_name, allowed, required):
fn_parameters = inspect.signature(fn).parameters
for parameter_name in required:
Expand All @@ -272,7 +295,81 @@ def _validate_signature(self, fn, fn_name, allowed, required):

return parameter_names

def _get_jax2tf_input_shape(self, input_shape):
"""Convert input shape in a format suitable for `jax2tf`.

`jax2tf` expects a letter for each unknown dimension, which allows
correlated dimensions. Since correlated dimensions are not supported by
Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
however use 'batch' for dimension 0 if not defined to correlate the
batch size across inputs.

Example (spaces added for readability):
```
input_shape: (None , 4 , None, None, 5 )
result: "(batch, 4 , a , b , 5 )"
```

Args:
input_shape: a single shape or a structure of shapes for the inputs.
Returns:
the shape or shapes structure in the `jax2tf` format as strings.
"""
dim_names = itertools.chain(
string.ascii_lowercase, # a, b, ... z
itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
lambda a, b: a + b,
itertools.product(string.ascii_lowercase, repeat=2),
),
)

def get_single_jax2tf_shape(shape):
jax2tf_shape = []

for index, dim in enumerate(shape):
if dim is not None:
jax2tf_shape.append(str(dim))
elif index == 0:
jax2tf_shape.append("batch")
else:
jax2tf_shape.append(next(dim_names))

return "(" + ", ".join(jax2tf_shape) + ")"

res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
return res

def _jax2tf_convert(self, fn, polymorphic_shapes):
from jax.experimental import jax2tf

converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
# Autograph won't work with the output of jax2tf.
converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
return converted_fn

def _partial_with_positional(self, fn, index, value):
"""Return a new partial with one positional argument set to a value.

This is needed because `jax2tf` only supports positional arguments and
`functools.partial` only supports setting positional arguments starting
from the left. Our use case is the `training` argument which is
typically the righmost argument.

Args:
fn: the function to wrap.
index: the index of the positional argument to set to `value`.
value: the value for the positional argument at `index`.
"""

@functools.wraps(fn)
def wrapper(*args):
args = args[0:index] + (value,) + args[index:]
return fn(*args)

return wrapper

@tracking.no_automatic_dependency_tracking
@tf_no_automatic_dependency_tracking
def _create_variables(self, values, trainable):
"""Create a structure of variables from a structure of JAX arrays.

Expand All @@ -296,14 +393,14 @@ def _create_variables(self, values, trainable):

def create_variable(value):
if backend.is_tensor(value) or isinstance(
value, (np.ndarray, np.generic)
value, (np.ndarray, np.generic, jax.Array)
):
dtype = value.dtype
if is_float_dtype(dtype):
dtype = None # Use the layer dtype policy
return self.add_weight(
value.shape,
initializer=value,
initializer=backend.convert_to_tensor(value),
dtype=dtype,
trainable=trainable,
)
Expand Down Expand Up @@ -333,44 +430,46 @@ def create_variable(value):

def _get_init_rng(self):
"""
Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.
Returns a key in form of the backend array of size 2 dtype uint32
to pass to `init_fn`.

By default, this returns a single `PRNGKey` retrieved by calling
By default, this returns a Jax or TF array of size 2 by calling
`self.seed_generator.next()`. Override this to return a different
structure.

Returns:
a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
the `rng` argument of `init_fn`.
a key as an Jax or TF array of size 2 dtype uint32 will be passed
as the `rng` argument of `init_fn`.
"""
return self.seed_generator.next()

def _get_call_rng(self, training):
"""
Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.
Returns a key in form of the backend array of size 2 dtype uint32
to pass to `call_fn`.

By default, this returns a single `PRNGKey` retrieved by calling
By default, this returns a Jax or TF array of size 2 by calling
`self.seed_generator.next()` when `training` is `True`, and `None` when
`training` is `False`. Override this to return a different structure or
to pass RNGs in inference mode too.

Returns:
a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
the `rng` argument of `call_fn`.
a key as an Jax or TF array of size 2 dtype uint32 will be passed
as the `rng` argument of `call_fn`.
"""
if training:
return self.seed_generator.next()
else:
return None

def build(self, input_shape):
if self.params is not None or self.state is not None:
return

if jax_utils.is_in_jax_tracing_scope():
def _initialize_weights(self, input_shape):
if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
# This exception is not actually shown, it is caught and a detailed
# warning about calling 'build' is printed.
raise ValueError("'JaxLayer' cannot be built in tracing scope")
raise ValueError(
"'JaxLayer' cannot be built in tracing scope"
"or inside tf function"
)

# Initialize `params` and `state` if needed by calling `init_fn`.
def create_input(shape):
Expand All @@ -381,7 +480,12 @@ def create_input(shape):
init_args = []
for argument_name in self.init_fn_arguments:
if argument_name == "rng":
init_args.append(self._get_init_rng())
init_args.append(
jax.tree_util.tree_map(
lambda x: jax.numpy.array(_convert_to_jax_key(x)),
self._get_init_rng(),
)
)
elif argument_name == "inputs":
init_args.append(init_inputs)
elif argument_name == "training":
Expand All @@ -398,6 +502,45 @@ def create_input(shape):
)
self.tracked_state = self._create_variables(init_state, trainable=False)

def build(self, input_shape):
if self.params is None and self.state is None:
self._initialize_weights(input_shape)

if backend.backend() == "tensorflow":
polymorphic_shapes = []
for argument in self.call_fn_arguments:
if argument == "inputs":
polymorphic_shapes.append(
self._get_jax2tf_input_shape(input_shape)
)
elif argument != "training":
# params, state, rng
polymorphic_shapes.append("...")

if "training" in self.call_fn_arguments:
training_argument_index = self.call_fn_arguments.index(
"training"
)
self.jax2tf_training_false_fn = self._jax2tf_convert(
self._partial_with_positional(
self.call_fn, training_argument_index, False
),
polymorphic_shapes,
)
self.jax2tf_training_true_fn = self._jax2tf_convert(
self._partial_with_positional(
self.call_fn, training_argument_index, True
),
polymorphic_shapes,
)
else:
self.jax2tf_training_false_fn = self._jax2tf_convert(
self.call_fn,
polymorphic_shapes,
)
self.jax2tf_training_true_fn = None
super().build(input_shape)

def call(self, inputs, training=False):
def unwrap_variable(variable):
return None if variable is None else variable.value
Expand All @@ -413,11 +556,16 @@ def unwrap_variable(variable):
jax.tree_util.tree_map(unwrap_variable, self.state)
)
elif argument_name == "rng":
call_args.append(self._get_call_rng(training))
call_args.append(
jax.tree_util.tree_map(
_convert_to_jax_key, self._get_call_rng(training)
)
)
elif argument_name == "inputs":
call_args.append(inputs)
elif argument_name == "training":
call_args.append(training)
if backend.backend() == "jax":
call_args.append(training)

def assign_state_to_variable(value, variable):
# This exists only to make debugging this error case easier.
Expand All @@ -429,14 +577,23 @@ def assign_state_to_variable(value, variable):
)
variable.assign(value)

if self.has_state:
predictions, new_state = self.call_fn(*call_args)
jax.tree_util.tree_map(
assign_state_to_variable, new_state, self.state
)
return predictions
else:
return self.call_fn(*call_args)
def call_with_fn(fn):
if self.has_state:
predictions, new_state = fn(*call_args)
jax.tree_util.tree_map(
assign_state_to_variable, new_state, self.state
)
return predictions
else:
return fn(*call_args)

if backend.backend() == "jax":
return call_with_fn(self.call_fn)
elif backend.backend() == "tensorflow":
if training and self.jax2tf_training_true_fn is not None:
return call_with_fn(self.jax2tf_training_true_fn)
else:
return call_with_fn(self.jax2tf_training_false_fn)

def get_config(self):
config = {
Expand Down Expand Up @@ -556,12 +713,6 @@ def __init__(
# Late import to only require Flax when this is used.
from flax.core import scope as flax_scope

if backend.backend() != "jax":
raise ValueError(
"FlaxLayer is only supported with the JAX backend. Current "
f"backend: {backend.backend()}"
)

self.module = module
self.method = method

Expand Down
Loading