diff --git a/export/orbax/export/obm_configs.py b/export/orbax/export/obm_configs.py index 67c6a6a43..9fae75057 100644 --- a/export/orbax/export/obm_configs.py +++ b/export/orbax/export/obm_configs.py @@ -69,7 +69,6 @@ class MixedPriorityBatchingPolicy(enum.Enum): PRIORITY_MERGE = "priority_merge" - # LINT.ThenChange(//depot//orbax/export/obm_export.py) @@ -279,3 +278,15 @@ def __post_init__(self): self.low_priority_batch_options.max_enqueued_batches, is_low_priority_batch_options=True, ) + + +@dataclasses.dataclass(kw_only=True) +class ObmExportOptions: + """Options for Orbax Model Export. + + Attributes: + batch_options: The batch options for the model. + converter_options: The converter options for the model. + """ + + batch_options: BatchOptions | None = None diff --git a/export/orbax/export/serving_config.py b/export/orbax/export/serving_config.py index 65a0a56d5..7ccee7fb5 100644 --- a/export/orbax/export/serving_config.py +++ b/export/orbax/export/serving_config.py @@ -17,9 +17,12 @@ from collections.abc import Callable, Mapping, Sequence import dataclasses from typing import Any, Optional, Text, Union +import warnings from absl import logging import jax import jaxtyping +from orbax.export import constants +from orbax.export import obm_configs from orbax.export.data_processors import data_processor_base import tensorflow as tf @@ -63,6 +66,8 @@ class ServingConfig: # exactly one method which will be used. method_key: Optional[str] = None # Options passed to the Orbax Model export. + obm_export_options: obm_configs.ObmExportOptions | None = None + # DEPRECATED: use `obm_export_options` instead. obm_kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) # When set to true, it allows a portion of the preprocessor's outputs to be @@ -103,6 +108,19 @@ class ServingConfig: preprocess_output_passthrough_enabled: bool = False def __post_init__(self): + if self.obm_kwargs: + if self.obm_export_options is not None: + raise ValueError( + 'Both `obm_kwargs` and `obm_export_options` are set. Please only' + ' use `obm_export_options`.' + ) + warnings.warn( + '`obm_kwargs` is deprecated, use `obm_export_options` instead.', + DeprecationWarning, + ) + self.obm_export_options = obm_configs.ObmExportOptions( + batch_options=self.obm_kwargs.get(constants.BATCH_OPTIONS), + ) if not self.signature_key: raise ValueError('`signature_key` must be set.') if self.tf_preprocessor and self.preprocessors: diff --git a/export/orbax/export/serving_config_test.py b/export/orbax/export/serving_config_test.py index fbb445cb0..acdf36586 100644 --- a/export/orbax/export/serving_config_test.py +++ b/export/orbax/export/serving_config_test.py @@ -14,6 +14,7 @@ import jax.numpy as jnp import numpy as np +from orbax.export import obm_configs from orbax.export import serving_config import tensorflow as tf @@ -23,6 +24,29 @@ class ServingConfigTest(tf.test.TestCase): + def test_obm_kwargs_deprecation(self): + batch_opts = obm_configs.BatchOptions( + batch_component=obm_configs.BatchComponent.NO_BATCHING + ) + with self.assertWarnsRegex( + DeprecationWarning, '`obm_kwargs` is deprecated' + ): + sc = ServingConfig( + signature_key='f', obm_kwargs={'batch_options': batch_opts} + ) + self.assertIsNotNone(sc.obm_export_options) + self.assertEqual(sc.obm_export_options.batch_options, batch_opts) # pytype: disable=attribute-error + + def test_obm_kwargs_and_obm_export_options_set_raise_error(self): + with self.assertRaisesRegex( + ValueError, 'Both `obm_kwargs` and `obm_export_options` are set' + ): + ServingConfig( + signature_key='f', + obm_kwargs={'batch_options': 1}, + obm_export_options=obm_configs.ObmExportOptions(), + ) + def test_bind_tf(self): sc = ServingConfig( signature_key='f',