Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions export/orbax/export/data_processors/tf_data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def test_prepare_succeeds(self):
)
self.assertEqual(
processor.output_signature,
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f64),
obm.ShloTensorSpec(
shape=(None, 3), dtype=obm.ShloDType.f64, name='output_0'
),
)

def test_prepare_polymorphic_function_with_default_input_signature(self):
Expand Down Expand Up @@ -127,7 +129,9 @@ def preprocessor_callable(x, y):
)
self.assertEqual(
processor.output_signature,
obm.ShloTensorSpec(shape=(None, 4), dtype=obm.ShloDType.f32),
obm.ShloTensorSpec(
shape=(None, 4), dtype=obm.ShloDType.f32, name='output_0'
),
)

def test_suppress_x64_output(self):
Expand All @@ -146,7 +150,9 @@ def test_suppress_x64_output(self):
processor.prepare(input_signature, suppress_x64_output=True)
self.assertEqual(
processor.output_signature,
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f32),
obm.ShloTensorSpec(
shape=(None, 3), dtype=obm.ShloDType.f32, name='output_0'
),
)

def test_convert_to_bfloat16(self):
Expand All @@ -167,7 +173,9 @@ def func(x):
)
self.assertEqual(
processor.output_signature,
obm.ShloTensorSpec(shape=(2, 3), dtype=obm.ShloDType.bf16),
obm.ShloTensorSpec(
shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0'
),
)
self.assertLen(processor.concrete_function.variables, 1)
self.assertEqual(
Expand Down
4 changes: 4 additions & 0 deletions export/orbax/export/oex_orchestration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def tf_fn(a):
return a


def tf_t(shape, name=None, dtype=tf.float32):
return tf.TensorSpec(shape=shape, dtype=dtype, name=name)


class TestProcessor(data_processor_base.DataProcessor):

def prepare(self, input_signature):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@ name_to_pipeline {
}
}
}
outputs {
named_tensor_types {
name: "__OUTPUT_NAME__"
tensor_type {
shape {
shape_with_known_rank {
dimension_sizes {
}
}
}
dtype: f32
}
}
}
}
model_functions {
model_function_name: "__MODEL_FUNCTION_NAME__"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def tf_concrete_function_name_to_obm_function(
'Both `fn` and `output_signature` are provided. Please provide only '
'one of them.'
)
input_signature = utils.get_input_signature(fn)
output_signature = utils.get_output_signature(fn)
input_signature = fn.structured_input_signature
output_signature = get_output_signature(fn)

input_names, _, _ = _flat_input_signature(fn)
output_names = _output_names(fn)
Expand Down Expand Up @@ -258,7 +258,7 @@ def _flat_input_signature(
fn: tf.types.experimental.ConcreteFunction,
) -> SignatureFlat:
"""Returns the flattened input signature of the given function."""
leaves, tree_def = jax_tree_util.tree_flatten(utils.get_input_signature(fn))
leaves, tree_def = jax_tree_util.tree_flatten(fn.structured_input_signature)
# The argument names in SavedModel's SignatureDef may not match the names in
# the input signature due to internal name mangling, hence we're looking
# it up in the FunctionDef.
Expand Down Expand Up @@ -304,14 +304,35 @@ def _output_names(
) -> Sequence[str]:
"""Returns the flattened output signature of the given function."""
leaves_with_path = jax_tree_util.tree_leaves_with_path(
utils.get_output_signature(fn)
fn.structured_outputs
)
if not leaves_with_path:
return []
paths, _ = zip(*leaves_with_path)
return [_output_name(path) for path in paths]


def get_output_signature(
fn: tf.types.experimental.ConcreteFunction,
) -> utils.TfSignature:
"""Returns the output signature of the TF function.

Tensor names in the output signature match the output names of the TF function
in the TF SavedModel.

Args:
fn: A concrete TF function.
"""
output_names_iter = iter(list(_output_names(fn)))

return jax_tree_util.tree_map(
lambda t: tf.TensorSpec(
shape=t.shape, dtype=t.dtype, name=next(output_names_iter)
),
fn.structured_outputs,
)


def to_keyword_only_fn(
f: tf.types.experimental.ConcreteFunction,
) -> tf.types.experimental.ConcreteFunction:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from orbax.experimental.model import core as obm
from orbax.experimental.model.tf2obm import tf_concrete_function_handle_pb2
from orbax.experimental.model.tf2obm import tf_concrete_functions_to_obm as tf_obm
from orbax.experimental.model.tf2obm import utils
import tensorflow as tf

from tensorflow.python.util.protobuf import compare
Expand Down Expand Up @@ -62,7 +61,7 @@ def f():
lambda spec: tf.zeros(shape=spec.shape, dtype=spec.dtype), tree
)

return utils.get_output_signature(f.get_concrete_function())
return tf_obm.get_output_signature(f.get_concrete_function())


_INPUT_SIGNATURES = (
Expand Down Expand Up @@ -225,12 +224,12 @@ def is_spec_equiv(a, b):
return True

self.assertTreeEquiv(
tf_obm.utils.get_input_signature(new_cf),
new_cf.structured_input_signature,
new_input_sig,
is_spec_equiv,
)
self.assertTreeEquiv(
tf_obm.utils.get_output_signature(new_cf),
tf_obm.get_output_signature(new_cf),
new_output_sig,
is_spec_equiv,
)
Expand Down
36 changes: 0 additions & 36 deletions model/orbax/experimental/model/tf2obm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,3 @@ def tf_signature_to_obm_spec(tree: TfSignature) -> obm.Tree[obm.ShloTensorSpec]:
f'Failed to convert TF signature {tree} of type {type(tree)} to OBM.'
) from err


def get_input_signature(
concrete_function: tf.types.experimental.ConcreteFunction,
) -> TfSignature:
return concrete_function.structured_input_signature


def get_output_signature(
concrete_function: tf.types.experimental.ConcreteFunction,
) -> TfSignature:
"""Gets the output signature from a concrete function.

Args:
concrete_function: The concrete function to get the output signature from.

Returns:
The output signature as a PyTree of `tf.TensorSpec`s.

Raises:
ValueError: If the structured_outputs cannot be converted to
`tf.TensorSpec`.
"""
try:
# The structured_outputs are `SymbolicTensor`s with "name" that we don't
# need. To make a unified path to obm.ShloTensorSpec, we convert them to
# `TensorSpec`s (without name) first.
output_signature = jax_tree_util.tree_map(
lambda x: tf.TensorSpec(shape=x.shape, dtype=x.dtype),
concrete_function.structured_outputs,
)
except Exception as err:
raise ValueError(
'Failed to convert TF structured_outputs'
f' {concrete_function.structured_outputs} to tf.TensorSpec.'
) from err
return output_signature
Loading