Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion export/orbax/export/obm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class MixedPriorityBatchingPolicy(enum.Enum):
PRIORITY_MERGE = "priority_merge"



# LINT.ThenChange(//depot//orbax/export/obm_export.py)


Expand Down Expand Up @@ -279,3 +278,15 @@ def __post_init__(self):
self.low_priority_batch_options.max_enqueued_batches,
is_low_priority_batch_options=True,
)


@dataclasses.dataclass(kw_only=True)
class ObmExportOptions:
"""Options for Orbax Model Export.

Attributes:
batch_options: The batch options for the model.
converter_options: The converter options for the model.
"""

batch_options: BatchOptions | None = None
18 changes: 18 additions & 0 deletions export/orbax/export/serving_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
from collections.abc import Callable, Mapping, Sequence
import dataclasses
from typing import Any, Optional, Text, Union
import warnings
from absl import logging
import jax
import jaxtyping
from orbax.export import constants
from orbax.export import obm_configs
from orbax.export.data_processors import data_processor_base
import tensorflow as tf

Expand Down Expand Up @@ -63,6 +66,8 @@ class ServingConfig:
# exactly one method which will be used.
method_key: Optional[str] = None
# Options passed to the Orbax Model export.
obm_export_options: obm_configs.ObmExportOptions | None = None
# DEPRECATED: use `obm_export_options` instead.
obm_kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)

# When set to true, it allows a portion of the preprocessor's outputs to be
Expand Down Expand Up @@ -103,6 +108,19 @@ class ServingConfig:
preprocess_output_passthrough_enabled: bool = False

def __post_init__(self):
if self.obm_kwargs:
if self.obm_export_options is not None:
raise ValueError(
'Both `obm_kwargs` and `obm_export_options` are set. Please only'
' use `obm_export_options`.'
)
warnings.warn(
'`obm_kwargs` is deprecated, use `obm_export_options` instead.',
DeprecationWarning,
)
self.obm_export_options = obm_configs.ObmExportOptions(
batch_options=self.obm_kwargs.get(constants.BATCH_OPTIONS),
)
if not self.signature_key:
raise ValueError('`signature_key` must be set.')
if self.tf_preprocessor and self.preprocessors:
Expand Down
24 changes: 24 additions & 0 deletions export/orbax/export/serving_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import jax.numpy as jnp
import numpy as np
from orbax.export import obm_configs
from orbax.export import serving_config
import tensorflow as tf

Expand All @@ -23,6 +24,29 @@

class ServingConfigTest(tf.test.TestCase):

def test_obm_kwargs_deprecation(self):
batch_opts = obm_configs.BatchOptions(
batch_component=obm_configs.BatchComponent.NO_BATCHING
)
with self.assertWarnsRegex(
DeprecationWarning, '`obm_kwargs` is deprecated'
):
sc = ServingConfig(
signature_key='f', obm_kwargs={'batch_options': batch_opts}
)
self.assertIsNotNone(sc.obm_export_options)
self.assertEqual(sc.obm_export_options.batch_options, batch_opts) # pytype: disable=attribute-error

def test_obm_kwargs_and_obm_export_options_set_raise_error(self):
with self.assertRaisesRegex(
ValueError, 'Both `obm_kwargs` and `obm_export_options` are set'
):
ServingConfig(
signature_key='f',
obm_kwargs={'batch_options': 1},
obm_export_options=obm_configs.ObmExportOptions(),
)

def test_bind_tf(self):
sc = ServingConfig(
signature_key='f',
Expand Down
Loading