diff --git a/export/orbax/export/data_processors/tf_data_processor_test.py b/export/orbax/export/data_processors/tf_data_processor_test.py index dd35d9b55..ae3d62ff9 100644 --- a/export/orbax/export/data_processors/tf_data_processor_test.py +++ b/export/orbax/export/data_processors/tf_data_processor_test.py @@ -90,7 +90,9 @@ def test_prepare_succeeds(self): ) self.assertEqual( processor.output_signature, - obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f64), + obm.ShloTensorSpec( + shape=(None, 3), dtype=obm.ShloDType.f64, name='output_0' + ), ) def test_prepare_polymorphic_function_with_default_input_signature(self): @@ -127,7 +129,9 @@ def preprocessor_callable(x, y): ) self.assertEqual( processor.output_signature, - obm.ShloTensorSpec(shape=(None, 4), dtype=obm.ShloDType.f32), + obm.ShloTensorSpec( + shape=(None, 4), dtype=obm.ShloDType.f32, name='output_0' + ), ) def test_suppress_x64_output(self): @@ -146,7 +150,9 @@ def test_suppress_x64_output(self): processor.prepare(input_signature, suppress_x64_output=True) self.assertEqual( processor.output_signature, - obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f32), + obm.ShloTensorSpec( + shape=(None, 3), dtype=obm.ShloDType.f32, name='output_0' + ), ) def test_convert_to_bfloat16(self): @@ -167,7 +173,9 @@ def func(x): ) self.assertEqual( processor.output_signature, - obm.ShloTensorSpec(shape=(2, 3), dtype=obm.ShloDType.bf16), + obm.ShloTensorSpec( + shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0' + ), ) self.assertLen(processor.concrete_function.variables, 1) self.assertEqual( diff --git a/export/orbax/export/oex_orchestration_test.py b/export/orbax/export/oex_orchestration_test.py index 8a6adf4c5..0901f49a9 100644 --- a/export/orbax/export/oex_orchestration_test.py +++ b/export/orbax/export/oex_orchestration_test.py @@ -27,6 +27,10 @@ def tf_fn(a): return a +def tf_t(shape, name=None, dtype=tf.float32): + return tf.TensorSpec(shape=shape, dtype=dtype, name=name) + + class TestProcessor(data_processor_base.DataProcessor): def prepare(self, input_signature): diff --git a/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto b/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto index bfeceaf9b..84ea11588 100644 --- a/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto +++ b/export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto @@ -66,6 +66,20 @@ name_to_pipeline { } } } + outputs { + named_tensor_types { + name: "__OUTPUT_NAME__" + tensor_type { + shape { + shape_with_known_rank { + dimension_sizes { + } + } + } + dtype: f32 + } + } + } } model_functions { model_function_name: "__MODEL_FUNCTION_NAME__" diff --git a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py index 58c9c2113..7dd9ed5bf 100644 --- a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py +++ b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py @@ -88,8 +88,8 @@ def tf_concrete_function_name_to_obm_function( 'Both `fn` and `output_signature` are provided. Please provide only ' 'one of them.' ) - input_signature = utils.get_input_signature(fn) - output_signature = utils.get_output_signature(fn) + input_signature = fn.structured_input_signature + output_signature = get_output_signature(fn) input_names, _, _ = _flat_input_signature(fn) output_names = _output_names(fn) @@ -258,7 +258,7 @@ def _flat_input_signature( fn: tf.types.experimental.ConcreteFunction, ) -> SignatureFlat: """Returns the flattened input signature of the given function.""" - leaves, tree_def = jax_tree_util.tree_flatten(utils.get_input_signature(fn)) + leaves, tree_def = jax_tree_util.tree_flatten(fn.structured_input_signature) # The argument names in SavedModel's SignatureDef may not match the names in # the input signature due to internal name mangling, hence we're looking # it up in the FunctionDef. @@ -304,7 +304,7 @@ def _output_names( ) -> Sequence[str]: """Returns the flattened output signature of the given function.""" leaves_with_path = jax_tree_util.tree_leaves_with_path( - utils.get_output_signature(fn) + fn.structured_outputs ) if not leaves_with_path: return [] @@ -312,6 +312,27 @@ def _output_names( return [_output_name(path) for path in paths] +def get_output_signature( + fn: tf.types.experimental.ConcreteFunction, +) -> utils.TfSignature: + """Returns the output signature of the TF function. + + Tensor names in the output signature match the output names of the TF function + in the TF SavedModel. + + Args: + fn: A concrete TF function. + """ + output_names_iter = iter(list(_output_names(fn))) + + return jax_tree_util.tree_map( + lambda t: tf.TensorSpec( + shape=t.shape, dtype=t.dtype, name=next(output_names_iter) + ), + fn.structured_outputs, + ) + + def to_keyword_only_fn( f: tf.types.experimental.ConcreteFunction, ) -> tf.types.experimental.ConcreteFunction: diff --git a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py index c32e6bcea..1bd17b568 100644 --- a/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py +++ b/model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py @@ -23,7 +23,6 @@ from orbax.experimental.model import core as obm from orbax.experimental.model.tf2obm import tf_concrete_function_handle_pb2 from orbax.experimental.model.tf2obm import tf_concrete_functions_to_obm as tf_obm -from orbax.experimental.model.tf2obm import utils import tensorflow as tf from tensorflow.python.util.protobuf import compare @@ -62,7 +61,7 @@ def f(): lambda spec: tf.zeros(shape=spec.shape, dtype=spec.dtype), tree ) - return utils.get_output_signature(f.get_concrete_function()) + return tf_obm.get_output_signature(f.get_concrete_function()) _INPUT_SIGNATURES = ( @@ -225,12 +224,12 @@ def is_spec_equiv(a, b): return True self.assertTreeEquiv( - tf_obm.utils.get_input_signature(new_cf), + new_cf.structured_input_signature, new_input_sig, is_spec_equiv, ) self.assertTreeEquiv( - tf_obm.utils.get_output_signature(new_cf), + tf_obm.get_output_signature(new_cf), new_output_sig, is_spec_equiv, ) diff --git a/model/orbax/experimental/model/tf2obm/utils.py b/model/orbax/experimental/model/tf2obm/utils.py index 4db7552e5..6c835327d 100644 --- a/model/orbax/experimental/model/tf2obm/utils.py +++ b/model/orbax/experimental/model/tf2obm/utils.py @@ -77,39 +77,3 @@ def tf_signature_to_obm_spec(tree: TfSignature) -> obm.Tree[obm.ShloTensorSpec]: f'Failed to convert TF signature {tree} of type {type(tree)} to OBM.' ) from err - -def get_input_signature( - concrete_function: tf.types.experimental.ConcreteFunction, -) -> TfSignature: - return concrete_function.structured_input_signature - - -def get_output_signature( - concrete_function: tf.types.experimental.ConcreteFunction, -) -> TfSignature: - """Gets the output signature from a concrete function. - - Args: - concrete_function: The concrete function to get the output signature from. - - Returns: - The output signature as a PyTree of `tf.TensorSpec`s. - - Raises: - ValueError: If the structured_outputs cannot be converted to - `tf.TensorSpec`. - """ - try: - # The structured_outputs are `SymbolicTensor`s with "name" that we don't - # need. To make a unified path to obm.ShloTensorSpec, we convert them to - # `TensorSpec`s (without name) first. - output_signature = jax_tree_util.tree_map( - lambda x: tf.TensorSpec(shape=x.shape, dtype=x.dtype), - concrete_function.structured_outputs, - ) - except Exception as err: - raise ValueError( - 'Failed to convert TF structured_outputs' - f' {concrete_function.structured_outputs} to tf.TensorSpec.' - ) from err - return output_signature