Skip to content

Commit 355b6eb

Browse files
olegshaldybinOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 844284634
1 parent d53e964 commit 355b6eb

File tree

6 files changed

+58
-48
lines changed

6 files changed

+58
-48
lines changed

export/orbax/export/data_processors/tf_data_processor_test.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def test_prepare_succeeds(self):
9090
)
9191
self.assertEqual(
9292
processor.output_signature,
93-
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f64),
93+
obm.ShloTensorSpec(
94+
shape=(None, 3), dtype=obm.ShloDType.f64, name='output_0'
95+
),
9496
)
9597

9698
def test_prepare_polymorphic_function_with_default_input_signature(self):
@@ -127,7 +129,9 @@ def preprocessor_callable(x, y):
127129
)
128130
self.assertEqual(
129131
processor.output_signature,
130-
obm.ShloTensorSpec(shape=(None, 4), dtype=obm.ShloDType.f32),
132+
obm.ShloTensorSpec(
133+
shape=(None, 4), dtype=obm.ShloDType.f32, name='output_0'
134+
),
131135
)
132136

133137
def test_suppress_x64_output(self):
@@ -146,7 +150,9 @@ def test_suppress_x64_output(self):
146150
processor.prepare(input_signature, suppress_x64_output=True)
147151
self.assertEqual(
148152
processor.output_signature,
149-
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f32),
153+
obm.ShloTensorSpec(
154+
shape=(None, 3), dtype=obm.ShloDType.f32, name='output_0'
155+
),
150156
)
151157

152158
def test_convert_to_bfloat16(self):
@@ -167,7 +173,9 @@ def func(x):
167173
)
168174
self.assertEqual(
169175
processor.output_signature,
170-
obm.ShloTensorSpec(shape=(2, 3), dtype=obm.ShloDType.bf16),
176+
obm.ShloTensorSpec(
177+
shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0'
178+
),
171179
)
172180
self.assertLen(processor.concrete_function.variables, 1)
173181
self.assertEqual(

export/orbax/export/oex_orchestration_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def tf_fn(a):
2727
return a
2828

2929

30+
def tf_t(shape, name=None, dtype=tf.float32):
31+
return tf.TensorSpec(shape=shape, dtype=dtype, name=name)
32+
33+
3034
class TestProcessor(data_processor_base.DataProcessor):
3135

3236
def prepare(self, input_signature):

export/orbax/export/testdata/expected_mnist_oex_orchestration_pipelines.textproto

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ name_to_pipeline {
6666
}
6767
}
6868
}
69+
outputs {
70+
named_tensor_types {
71+
name: "__OUTPUT_NAME__"
72+
tensor_type {
73+
shape {
74+
shape_with_known_rank {
75+
dimension_sizes {
76+
}
77+
}
78+
}
79+
dtype: f32
80+
}
81+
}
82+
}
6983
}
7084
model_functions {
7185
model_function_name: "__MODEL_FUNCTION_NAME__"

model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def tf_concrete_function_name_to_obm_function(
8888
'Both `fn` and `output_signature` are provided. Please provide only '
8989
'one of them.'
9090
)
91-
input_signature = utils.get_input_signature(fn)
92-
output_signature = utils.get_output_signature(fn)
91+
input_signature = fn.structured_input_signature
92+
output_signature = get_output_signature(fn)
9393

9494
input_names, _, _ = _flat_input_signature(fn)
9595
output_names = _output_names(fn)
@@ -258,7 +258,7 @@ def _flat_input_signature(
258258
fn: tf.types.experimental.ConcreteFunction,
259259
) -> SignatureFlat:
260260
"""Returns the flattened input signature of the given function."""
261-
leaves, tree_def = jax_tree_util.tree_flatten(utils.get_input_signature(fn))
261+
leaves, tree_def = jax_tree_util.tree_flatten(fn.structured_input_signature)
262262
# The argument names in SavedModel's SignatureDef may not match the names in
263263
# the input signature due to internal name mangling, hence we're looking
264264
# it up in the FunctionDef.
@@ -304,14 +304,35 @@ def _output_names(
304304
) -> Sequence[str]:
305305
"""Returns the flattened output signature of the given function."""
306306
leaves_with_path = jax_tree_util.tree_leaves_with_path(
307-
utils.get_output_signature(fn)
307+
fn.structured_outputs
308308
)
309309
if not leaves_with_path:
310310
return []
311311
paths, _ = zip(*leaves_with_path)
312312
return [_output_name(path) for path in paths]
313313

314314

315+
def get_output_signature(
316+
fn: tf.types.experimental.ConcreteFunction,
317+
) -> utils.TfSignature:
318+
"""Returns the output signature of the TF function.
319+
320+
Tensor names in the output signature match the output names of the TF function
321+
in the TF SavedModel.
322+
323+
Args:
324+
fn: A concrete TF function.
325+
"""
326+
output_names_iter = iter(list(_output_names(fn)))
327+
328+
return jax_tree_util.tree_map(
329+
lambda t: tf.TensorSpec(
330+
shape=t.shape, dtype=t.dtype, name=next(output_names_iter)
331+
),
332+
fn.structured_outputs,
333+
)
334+
335+
315336
def to_keyword_only_fn(
316337
f: tf.types.experimental.ConcreteFunction,
317338
) -> tf.types.experimental.ConcreteFunction:

model/orbax/experimental/model/tf2obm/tf_concrete_functions_to_obm_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from orbax.experimental.model import core as obm
2424
from orbax.experimental.model.tf2obm import tf_concrete_function_handle_pb2
2525
from orbax.experimental.model.tf2obm import tf_concrete_functions_to_obm as tf_obm
26-
from orbax.experimental.model.tf2obm import utils
2726
import tensorflow as tf
2827

2928
from tensorflow.python.util.protobuf import compare
@@ -62,7 +61,7 @@ def f():
6261
lambda spec: tf.zeros(shape=spec.shape, dtype=spec.dtype), tree
6362
)
6463

65-
return utils.get_output_signature(f.get_concrete_function())
64+
return tf_obm.get_output_signature(f.get_concrete_function())
6665

6766

6867
_INPUT_SIGNATURES = (
@@ -225,12 +224,12 @@ def is_spec_equiv(a, b):
225224
return True
226225

227226
self.assertTreeEquiv(
228-
tf_obm.utils.get_input_signature(new_cf),
227+
new_cf.structured_input_signature,
229228
new_input_sig,
230229
is_spec_equiv,
231230
)
232231
self.assertTreeEquiv(
233-
tf_obm.utils.get_output_signature(new_cf),
232+
tf_obm.get_output_signature(new_cf),
234233
new_output_sig,
235234
is_spec_equiv,
236235
)

model/orbax/experimental/model/tf2obm/utils.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -77,39 +77,3 @@ def tf_signature_to_obm_spec(tree: TfSignature) -> obm.Tree[obm.ShloTensorSpec]:
7777
f'Failed to convert TF signature {tree} of type {type(tree)} to OBM.'
7878
) from err
7979

80-
81-
def get_input_signature(
82-
concrete_function: tf.types.experimental.ConcreteFunction,
83-
) -> TfSignature:
84-
return concrete_function.structured_input_signature
85-
86-
87-
def get_output_signature(
88-
concrete_function: tf.types.experimental.ConcreteFunction,
89-
) -> TfSignature:
90-
"""Gets the output signature from a concrete function.
91-
92-
Args:
93-
concrete_function: The concrete function to get the output signature from.
94-
95-
Returns:
96-
The output signature as a PyTree of `tf.TensorSpec`s.
97-
98-
Raises:
99-
ValueError: If the structured_outputs cannot be converted to
100-
`tf.TensorSpec`.
101-
"""
102-
try:
103-
# The structured_outputs are `SymbolicTensor`s with "name" that we don't
104-
# need. To make a unified path to obm.ShloTensorSpec, we convert them to
105-
# `TensorSpec`s (without name) first.
106-
output_signature = jax_tree_util.tree_map(
107-
lambda x: tf.TensorSpec(shape=x.shape, dtype=x.dtype),
108-
concrete_function.structured_outputs,
109-
)
110-
except Exception as err:
111-
raise ValueError(
112-
'Failed to convert TF structured_outputs'
113-
f' {concrete_function.structured_outputs} to tf.TensorSpec.'
114-
) from err
115-
return output_signature

0 commit comments

Comments
 (0)