Skip to content
This repository was archived by the owner on Aug 20, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 125 additions & 75 deletions tfx_addons/predictions_to_bigquery/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import datetime
import os
import re
from collections.abc import Mapping, Sequence
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import apache_beam as beam
import numpy as np
Expand All @@ -31,7 +30,7 @@
from tensorflow_serving.apis import prediction_log_pb2
from tfx import types
from tfx.dsl.components.base import base_beam_executor
from tfx.types import artifact_utils
from tfx.types import Artifact, artifact_utils

# TODO(cezequiel): Move relevant functions in utils module here.
from tfx_addons.predictions_to_bigquery import utils
Expand All @@ -41,79 +40,126 @@
_DEFAULT_TIMESTRING_FORMAT = '%Y%m%d_%H%M%S'
_REQUIRED_EXEC_PROPERTIES = (
'bq_table_name',
'bq_dataset',
'filter_threshold',
'gcp_project',
'gcs_temp_dir',
'vocab_label_file',
)
_REGEX_CHARS_TO_REPLACE = re.compile(r'[^a-zA-Z0-9_]')
_REGEX_BQ_TABLE_NAME = re.compile(r'^[\w-]*:?[\w_]+\.[\w_]+$')


def _check_exec_properties(exec_properties: Mapping[str, Any]) -> None:
def _check_exec_properties(exec_properties: dict[str, Any]) -> None:
for key in _REQUIRED_EXEC_PROPERTIES:
if exec_properties[key] is None:
raise ValueError(f'{key} must be set in exec_properties')


def _get_labels(transform_output_uri: str, vocab_file: str) -> Sequence[str]:
tf_transform_output = tft.TFTransformOutput(transform_output_uri)
tft_vocab = tf_transform_output.vocabulary_by_name(vocab_filename=vocab_file)
def _get_prediction_log_path(inference_results: list[Artifact]) -> str:
inference_results_uri = artifact_utils.get_single_uri(inference_results)
return f'{inference_results_uri}/*.gz'


def _get_tft_output(
transform_graph: Optional[list[Artifact]] = None
) -> Optional[tft.TFTransformOutput]:
if transform_graph is None:
return None

transform_graph_uri = artifact_utils.get_single_uri(transform_graph)
return tft.TFTransformOutput(transform_graph_uri)


def _get_labels(tft_output: tft.TFTransformOutput,
vocab_file: str) -> list[str]:
tft_vocab = tft_output.vocabulary_by_name(vocab_filename=vocab_file)
return [label.decode() for label in tft_vocab]


def _get_bq_table_name(
basename: str,
timestamp: Optional[datetime.datetime] = None,
timestring_format: Optional[str] = None,
) -> str:
def _check_bq_table_name(bq_table_name: str) -> None:
if _REGEX_BQ_TABLE_NAME.match(bq_table_name) is None:
raise ValueError('Invalid BigQuery table name.'
' Specify in either `PROJECT:DATASET.TABLE` or'
' `DATASET.TABLE` format.')


def _add_bq_table_name_suffix(basename: str,
timestamp: Optional[datetime.datetime] = None,
timestring_format: Optional[str] = None) -> str:
if timestamp is not None:
timestring_format = timestring_format or _DEFAULT_TIMESTRING_FORMAT
return basename + '_' + timestamp.strftime(timestring_format)
return basename


def _get_additional_bq_parameters(
expiration_days: Optional[int] = None,
table_partitioning: bool = False,
) -> Mapping[str, Any]:
table_expiration_days: Optional[int] = None,
table_partitioning: Optional[bool] = False,
) -> dict[str, Any]:
output = {}
if table_partitioning:
time_partitioning = {'type': 'DAY'}
logging.info('BigQuery table time partitioning set to DAY')
if expiration_days:
expiration_time_delta = datetime.timedelta(days=expiration_days)
if table_expiration_days:
expiration_time_delta = datetime.timedelta(days=table_expiration_days)
expiration_milliseconds = expiration_time_delta.total_seconds() * 1000
logging.info(
f'BigQuery table partition expiration time set to {expiration_days}'
' days')
f'BigQuery table expiration set to {table_expiration_days} days.')
time_partitioning['expirationMs'] = expiration_milliseconds
output['timePartitioning'] = time_partitioning
return output


def _get_features(
*,
schema_uri: Optional[str] = None,
# TODO(cezequiel): Move to a separate module with called functions.
# pylint: disable=protected-access
def _parse_features_from_prediction_results(
prediction_log_path: str) -> dict[str, Any]:
filepath = tf.io.gfile.glob(prediction_log_path)[0]
compression_type = utils._get_compress_type(filepath)
dataset = tf.data.TFRecordDataset([filepath],
compression_type=compression_type)

for bytes_record in dataset.take(1):
prediction_log = prediction_log_pb2.PredictionLog.FromString(
bytes_record.numpy())

example_bytes = (
prediction_log.predict_log.request.inputs['examples'].string_val[0])
example = tf.train.Example.FromString(example_bytes)
features = {}

for name, feature_proto in example.features.feature.items():
feature_dtype = utils._get_feature_type(feature=feature_proto)
feature = tf.io.VarLenFeature(dtype=feature_dtype)
features[name] = feature

return features


def _get_schema_features(
schema: Optional[list[Artifact]] = None,
tft_output: Optional[tft.TFTransformOutput] = None,
prediction_log_path: Optional[str] = None,
) -> Mapping[str, Any]:
if schema_uri:
) -> dict[str, Any]:
if schema is not None:
schema_uri = artifact_utils.get_single_uri(schema)
schema_file = os.path.join(schema_uri, _SCHEMA_FILE_NAME)
return utils.load_schema(schema_file)

if not prediction_log_path:
raise ValueError('Specify one of `schema_uri` or `prediction_log_path`.')
if tft_output is not None:
return tft_output.raw_feature_spec()

return utils.parse_schema(prediction_log_path)
if prediction_log_path is None:
raise ValueError(
'Specify one of `schema`, `tft_output` or `prediction_log_path`.')

return _parse_features_from_prediction_results(prediction_log_path)


def _get_bq_field_name_from_key(key: str) -> str:
field_name = _REGEX_CHARS_TO_REPLACE.sub('_', key)
return re.sub('_+', '_', field_name).strip('_')


def _features_to_bq_schema(features: Mapping[str, Any],
required: bool = False):
def _features_to_bq_schema(features: dict[str, Any], required: bool = False):
bq_schema_fields_ = utils.feature_to_bq_schema(features, required=required)
bq_schema_fields = []
for field in bq_schema_fields_:
Expand All @@ -128,8 +174,7 @@ def _features_to_bq_schema(features: Mapping[str, Any],


def _tensor_to_native_python_value(
tensor: Union[tf.Tensor, tf.sparse.SparseTensor]
) -> Optional[Union[int, float, str]]:
tensor: Union[tf.Tensor, tf.sparse.SparseTensor]) -> Optional[Any]:
"""Converts a TF Tensor to a native Python value."""
if isinstance(tensor, tf.sparse.SparseTensor):
values = tensor.values.numpy()
Expand All @@ -139,42 +184,43 @@ def _tensor_to_native_python_value(
return None
values = np.squeeze(values) # Removes extra dimension, e.g. shape (n, 1).
values = values.item() # Converts to native Python type
if isinstance(values, Sequence) and isinstance(values[0], bytes):
if isinstance(values, list) and isinstance(values[0], bytes):
return [v.decode('utf-8') for v in values]
if isinstance(values, bytes):
return values.decode('utf-8')
return values


@beam.typehints.with_input_types(str)
@beam.typehints.with_output_types(beam.typehints.Iterable[Tuple[str, str,
@beam.typehints.with_output_types(beam.typehints.Iterable[tuple[str, str,
Any]])
class FilterPredictionToDictFn(beam.DoFn):
"""Converts a PredictionLog proto to a dict."""
def __init__(
self,
labels: List,
features: Any,
features: dict[str, tf.io.FixedLenFeature],
timestamp: datetime.datetime,
filter_threshold: float,
labels: Optional[list[str]] = None,
score_multiplier: float = 1.,
):
super().__init__()
self._labels = labels
self._features = features
self._timestamp = timestamp
self._filter_threshold = filter_threshold
self._labels = labels
self._score_multiplier = score_multiplier
self._timestamp = timestamp

def _parse_prediction(self, predictions: npt.ArrayLike):
def _parse_prediction(
self, predictions: npt.ArrayLike) -> tuple[Optional[str], float]:
prediction_id = np.argmax(predictions)
logging.debug("Prediction id: %s", prediction_id)
logging.debug("Predictions: %s", predictions)
label = self._labels[prediction_id]
label = self._labels[prediction_id] if self._labels is not None else None
score = predictions[0][prediction_id]
return label, score

def _parse_example(self, serialized: bytes) -> Mapping[str, Any]:
def _parse_example(self, serialized: bytes) -> dict[str, Any]:
parsed_example = tf.io.parse_example(serialized, self._features)
output = {}
for key, tensor in parsed_example.items():
Expand All @@ -191,17 +237,18 @@ def process(self, element, *args, **kwargs): # pylint: disable=missing-function
del args, kwargs # unused

parsed_prediction_scores = tf.make_ndarray(
element.predict_log.response.outputs["outputs"])
element.predict_log.response.outputs['outputs'])
label, score = self._parse_prediction(parsed_prediction_scores)
if score >= self._filter_threshold:
output = {
"category_label": label,
# Workaround to issue with the score value having additional non-zero values
# in higher decimal places.
# e.g. 0.8 -> 0.800000011920929
"score": round(score * self._score_multiplier, _DECIMAL_PLACES),
"datetime": self._timestamp,
'score': round(score * self._score_multiplier, _DECIMAL_PLACES),
'datetime': self._timestamp,
}
if label is not None:
output['category_label'] = label
output.update(
self._parse_example(
element.predict_log.request.inputs['examples'].string_val))
Expand All @@ -212,9 +259,9 @@ class Executor(base_beam_executor.BaseBeamExecutor):
"""Implements predictions-to-bigquery component logic."""
def Do(
self,
input_dict: Mapping[str, List[types.Artifact]],
output_dict: Mapping[str, List[types.Artifact]],
exec_properties: Mapping[str, Any],
input_dict: dict[str, list[types.Artifact]],
output_dict: dict[str, list[types.Artifact]],
exec_properties: dict[str, Any],
) -> None:
"""Do function for predictions_to_bq executor."""

Expand All @@ -223,36 +270,41 @@ def Do(
# Check required keys set in exec_properties
_check_exec_properties(exec_properties)

# get labels from tf transform generated vocab file
labels = _get_labels(
artifact_utils.get_single_uri(input_dict['transform_graph']),
exec_properties['vocab_label_file'],
)
logging.info(f"found the following labels from TFT vocab: {labels}")

# set BigQuery table name and timestamp suffix if specified.
bq_table_name = _get_bq_table_name(exec_properties['bq_table_name'],
timestamp,
exec_properties['table_suffix'])

# set prediction result file path and decoder
inference_results_uri = artifact_utils.get_single_uri(
input_dict["inference_results"])
prediction_log_path = f"{inference_results_uri}/*.gz"
# Get prediction log file path and decoder
prediction_log_path = _get_prediction_log_path(
input_dict['inference_results'])
prediction_log_decoder = beam.coders.ProtoCoder(
prediction_log_pb2.PredictionLog)

tft_output = _get_tft_output(input_dict.get('transform_graph'))

# get schema features
features = _get_features(schema_uri=artifact_utils.get_single_uri(
input_dict["schema"]),
prediction_log_path=prediction_log_path)
features = _get_schema_features(
schema=input_dict.get('schema'),
tft_output=tft_output,
prediction_log_path=prediction_log_path,
)

# get label names from TFTransformOutput object, if applicable
if tft_output is not None and 'vocab_label_file' in exec_properties:
labels = _get_labels(tft_output, exec_properties['vocab_label_file'])
logging.info(f'Found the following labels from TFT vocab: {labels}.')
else:
labels = None
logging.info('No TFTransform output given; no labels parsed.')

# set BigQuery table name and timestamp suffix if specified.
_check_bq_table_name(exec_properties['bq_table_name'])
bq_table_name = _add_bq_table_name_suffix(
exec_properties['bq_table_name'], timestamp,
exec_properties['table_time_suffix'])

# generate bigquery schema from tf transform features
bq_schema = _features_to_bq_schema(features)
logging.info(f'generated bq_schema: {bq_schema}.')

additional_bq_parameters = _get_additional_bq_parameters(
exec_properties.get('expiration_time_delta'),
exec_properties.get('table_expiration_days'),
exec_properties.get('table_partitioning'))

# run the Beam pipeline to write the inference data to bigquery
Expand All @@ -262,14 +314,12 @@ def Do(
prediction_log_path, coder=prediction_log_decoder)
| 'Filter and Convert to Dict' >> beam.ParDo(
FilterPredictionToDictFn(
labels=labels,
features=features,
timestamp=timestamp,
filter_threshold=exec_properties['filter_threshold']))
| 'Write Dict to BQ' >> beam.io.gcp.bigquery.WriteToBigQuery(
filter_threshold=exec_properties['filter_threshold'],
labels=labels))
| 'Write Dict to BQ' >> beam.io.WriteToBigQuery(
table=bq_table_name,
dataset=exec_properties['bq_dataset'],
project=exec_properties['gcp_project'],
schema=bq_schema,
additional_bq_parameters=additional_bq_parameters,
create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
Expand Down
Loading