Skip to content

Commit 5ac3418

Browse files
hejiang0116Orbax Authors
authored andcommitted
Refactor ServingConfig to use ObmExportOptions dataclass.
PiperOrigin-RevId: 841912045
1 parent d53e964 commit 5ac3418

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

export/orbax/export/obm_configs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class MixedPriorityBatchingPolicy(enum.Enum):
6969
PRIORITY_MERGE = "priority_merge"
7070

7171

72-
7372
# LINT.ThenChange(//depot//orbax/export/obm_export.py)
7473

7574

@@ -279,3 +278,15 @@ def __post_init__(self):
279278
self.low_priority_batch_options.max_enqueued_batches,
280279
is_low_priority_batch_options=True,
281280
)
281+
282+
283+
@dataclasses.dataclass(kw_only=True)
284+
class ObmExportOptions:
285+
"""Options for Orbax Model Export.
286+
287+
Attributes:
288+
batch_options: The batch options for the model.
289+
converter_options: The converter options for the model.
290+
"""
291+
292+
batch_options: BatchOptions | None = None

export/orbax/export/serving_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from collections.abc import Callable, Mapping, Sequence
1818
import dataclasses
1919
from typing import Any, Optional, Text, Union
20+
import warnings
2021
from absl import logging
2122
import jax
2223
import jaxtyping
24+
from orbax.export import constants
25+
from orbax.export import obm_configs
2326
from orbax.export.data_processors import data_processor_base
2427
import tensorflow as tf
2528

@@ -63,6 +66,8 @@ class ServingConfig:
6366
# exactly one method which will be used.
6467
method_key: Optional[str] = None
6568
# Options passed to the Orbax Model export.
69+
obm_export_options: obm_configs.ObmExportOptions | None = None
70+
# DEPRECATED: use `obm_export_options` instead.
6671
obm_kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
6772

6873
# When set to true, it allows a portion of the preprocessor's outputs to be
@@ -103,6 +108,19 @@ class ServingConfig:
103108
preprocess_output_passthrough_enabled: bool = False
104109

105110
def __post_init__(self):
111+
if self.obm_kwargs:
112+
if self.obm_export_options is not None:
113+
raise ValueError(
114+
'Both `obm_kwargs` and `obm_export_options` are set. Please only'
115+
' use `obm_export_options`.'
116+
)
117+
warnings.warn(
118+
'`obm_kwargs` is deprecated, use `obm_export_options` instead.',
119+
DeprecationWarning,
120+
)
121+
self.obm_export_options = obm_configs.ObmExportOptions(
122+
batch_options=self.obm_kwargs.get(constants.BATCH_OPTIONS),
123+
)
106124
if not self.signature_key:
107125
raise ValueError('`signature_key` must be set.')
108126
if self.tf_preprocessor and self.preprocessors:

export/orbax/export/serving_config_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import jax.numpy as jnp
1616
import numpy as np
17+
from orbax.export import obm_configs
1718
from orbax.export import serving_config
1819
import tensorflow as tf
1920

@@ -23,6 +24,29 @@
2324

2425
class ServingConfigTest(tf.test.TestCase):
2526

27+
def test_obm_kwargs_deprecation(self):
28+
batch_opts = obm_configs.BatchOptions(
29+
batch_component=obm_configs.BatchComponent.NO_BATCHING
30+
)
31+
with self.assertWarnsRegex(
32+
DeprecationWarning, '`obm_kwargs` is deprecated'
33+
):
34+
sc = ServingConfig(
35+
signature_key='f', obm_kwargs={'batch_options': batch_opts}
36+
)
37+
self.assertIsNotNone(sc.obm_export_options)
38+
self.assertEqual(sc.obm_export_options.batch_options, batch_opts) # pytype: disable=attribute-error
39+
40+
def test_obm_kwargs_and_obm_export_options_set_raise_error(self):
41+
with self.assertRaisesRegex(
42+
ValueError, 'Both `obm_kwargs` and `obm_export_options` are set'
43+
):
44+
ServingConfig(
45+
signature_key='f',
46+
obm_kwargs={'batch_options': 1},
47+
obm_export_options=obm_configs.ObmExportOptions(),
48+
)
49+
2650
def test_bind_tf(self):
2751
sc = ServingConfig(
2852
signature_key='f',

0 commit comments

Comments
 (0)