diff --git a/python/sparkdl/__init__.py b/python/sparkdl/__init__.py index aa15059a..06b91bc8 100644 --- a/python/sparkdl/__init__.py +++ b/python/sparkdl/__init__.py @@ -13,15 +13,17 @@ # limitations under the License. # +from .graph.input import TFInputGraph from .image.imageIO import imageSchema, imageType, readImages from .transformers.keras_image import KerasImageFileTransformer from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer from .transformers.tf_image import TFImageTransformer +from .transformers.tf_tensor import TFTransformer from .transformers.utils import imageInputPlaceholder + __all__ = [ 'imageSchema', 'imageType', 'readImages', - 'TFImageTransformer', - 'DeepImagePredictor', 'DeepImageFeaturizer', - 'KerasImageFileTransformer', + 'TFImageTransformer', 'TFInputGraph', 'TFTransformer', + 'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'imageInputPlaceholder'] diff --git a/python/sparkdl/graph/builder.py b/python/sparkdl/graph/builder.py index 86c3b3ce..67510fc7 100644 --- a/python/sparkdl/graph/builder.py +++ b/python/sparkdl/graph/builder.py @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False): self.graph = graph or tf.Graph() self.sess = tf.Session(graph=self.graph) if using_keras: + self.using_keras = True self.keras_prev_sess = K.get_session() else: + self.using_keras = False self.keras_prev_sess = None def __enter__(self): - self.sess.as_default() self.sess.__enter__() - if self.keras_prev_sess is not None: + if self.using_keras: K.set_session(self.sess) return self def __exit__(self, *args): - if self.keras_prev_sess is not None: + if self.using_keras: K.set_session(self.keras_prev_sess) self.sess.__exit__(*args) @@ -268,4 +269,3 @@ def fromList(cls, functions): gfn = issn.asGraphFunction(first_inputs, last_outputs) return gfn - diff --git a/python/sparkdl/graph/input.py b/python/sparkdl/graph/input.py new file mode 100644 index 00000000..53b923c7 --- /dev/null +++ b/python/sparkdl/graph/input.py @@ -0,0 +1,254 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from tensorflow.core.protobuf import meta_graph_pb2 + +import sparkdl.graph.utils as tfx + +__all__ = ["TFInputGraph"] + +class TFInputGraph(object): + """ + An opaque serializable object containing TensorFlow graph. + + [WARNING] This class should not be called by any user code. + """ + def __init__(self): + raise NotImplementedError( + "Please do NOT construct TFInputGraph directly. Instead, use one of the helper functions") + + @classmethod + def _new_obj_internal(cls): + # pylint: disable=attribute-defined-outside-init + obj = object.__new__(cls) + # TODO: for (de-)serialization, the class should correspond to a ProtocolBuffer definition. + obj.graph_def = None + obj.input_tensor_name_from_signature = None + obj.output_tensor_name_from_signature = None + return obj + + def translateInputMapping(self, input_mapping): + assert self.input_tensor_name_from_signature is not None + _input_mapping = {} + if isinstance(input_mapping, dict): + input_mapping = list(input_mapping.items()) + assert isinstance(input_mapping, list) + for col_name, sig_key in input_mapping: + tnsr_name = self.input_tensor_name_from_signature[sig_key] + _input_mapping[col_name] = tnsr_name + return _input_mapping + + def translateOutputMapping(self, output_mapping): + assert self.output_tensor_name_from_signature is not None + _output_mapping = {} + if isinstance(output_mapping, dict): + output_mapping = list(output_mapping.items()) + assert isinstance(output_mapping, list) + for sig_key, col_name in output_mapping: + tnsr_name = self.output_tensor_name_from_signature[sig_key] + _output_mapping[tnsr_name] = col_name + return _output_mapping + + @classmethod + def fromGraph(cls, graph, sess, feed_names, fetch_names): + """ + Construct a TFInputGraphBuilder from a in memory tf.Graph object + """ + assert isinstance(graph, tf.Graph), \ + ('expect tf.Graph type but got', type(graph)) + + def import_graph_fn(_sess): + assert _sess == sess, 'must have the same session' + return _GinBuilderInfo() + + return _GinBuilder(import_graph_fn, sess, graph).build(feed_names, fetch_names) + + @classmethod + def fromGraphDef(cls, graph_def, feed_names, fetch_names): + """ + Construct a TFInputGraphBuilder from a tf.GraphDef object + """ + assert isinstance(graph_def, tf.GraphDef), \ + ('expect tf.GraphDef type but got', type(graph_def)) + + def import_graph_fn(sess): + with sess.as_default(): + tf.import_graph_def(graph_def, name='') + return _GinBuilderInfo() + + return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) + + @classmethod + def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names): + return cls._from_checkpoint_impl(checkpoint_dir, + signature_def_key=None, + feed_names=feed_names, fetch_names=fetch_names) + + @classmethod + def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): + assert signature_def_key is not None + return cls._from_checkpoint_impl(checkpoint_dir, + signature_def_key, + feed_names=None, fetch_names=None) + + @classmethod + def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names): + return cls._from_saved_model_impl(saved_model_dir, tag_set, + signature_def_key=None, + feed_names=feed_names, fetch_names=fetch_names) + + @classmethod + def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key): + assert signature_def_key is not None + return cls._from_saved_model_impl(saved_model_dir, tag_set, + signature_def_key=signature_def_key, + feed_names=None, fetch_names=None) + + @classmethod + def _from_checkpoint_impl(cls, + checkpoint_dir, + signature_def_key=None, + feed_names=None, + fetch_names=None): + """ + Construct a TFInputGraphBuilder from a model checkpoint + """ + assert (feed_names is None) == (fetch_names is None), \ + 'feed_names and fetch_names, if provided must appear together' + assert (feed_names is None) != (signature_def_key is None), \ + 'must either provide feed_names or singnature_def_key' + + def import_graph_fn(sess): + # Load checkpoint and import the graph + with sess.as_default(): + ckpt_path = tf.train.latest_checkpoint(checkpoint_dir) + + # NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def + # the current `import_graph_def` function seems to ignore + # any signature_def fields in a checkpoint's meta_graph_def. + meta_graph_def = meta_graph_pb2.MetaGraphDef() + with open("{}.meta".format(ckpt_path), 'rb') as fin: + meta_graph_def.ParseFromString(fin.read()) + + saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True) + saver.restore(sess, ckpt_path) + + sig_def = None + if signature_def_key is not None: + sig_def = meta_graph_def.signature_def[signature_def_key] + assert sig_def, 'singnature_def_key {} provided, '.format(signature_def_key) + \ + 'but failed to find it from the meta_graph_def ' + \ + 'from checkpoint {}'.format(checkpoint_dir) + + return _GinBuilderInfo(sig_def=sig_def) + + return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) + + @classmethod + def _from_saved_model_impl(cls, saved_model_dir, tag_set, + signature_def_key=None, + feed_names=None, + fetch_names=None): + """ + Construct a TFInputGraphBuilder from a SavedModel + """ + assert (feed_names is None) == (fetch_names is None), \ + 'feed_names and fetch_names, if provided must appear together' + assert (feed_names is None) != (signature_def_key is None), \ + 'must either provide feed_names or singnature_def_key' + + def import_graph_fn(sess): + tag_sets = tag_set.split(',') + meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir) + + sig_def = None + if signature_def_key is not None: + sig_def = tf.contrib.saved_model.get_signature_def_by_key( + meta_graph_def, signature_def_key) + + return _GinBuilderInfo(sig_def=sig_def) + + return _GinBuilder(import_graph_fn).build(feed_names, fetch_names) + + +class _GinBuilderInfo(object): + def __init__(self, sig_def=None): + self.sig_def = sig_def + self.feed_names = None + self.feed_mapping = None + self.fetch_names = None + self.fetch_mapping = None + + def extract_signatures(self): + assert self.sig_def is not None, \ + "ask to find sigdef mapping, but not found any" + + self.feed_mapping = {} + self.feed_names = [] + for sigdef_key, tnsr_info in self.sig_def.inputs.items(): + tnsr_name = tnsr_info.name + self.feed_mapping[sigdef_key] = tnsr_name + self.feed_names.append(tnsr_name) + + self.fetch_mapping = {} + self.fetch_names = [] + for sigdef_key, tnsr_info in self.sig_def.outputs.items(): + tnsr_name = tnsr_info.name + self.fetch_mapping[sigdef_key] = tnsr_name + self.fetch_names.append(tnsr_name) + +class _GinBuilder(object): + def __init__(self, import_graph_fn, sess=None, graph=None): + self.import_graph_fn = import_graph_fn + assert (sess is None) == (graph is None) + if sess is not None: + self.graph = graph + self.sess = sess + self._should_clean = False + else: + self.graph = tf.Graph() + self.sess = tf.Session(graph=self.graph) + self._should_clean = True + + def _build_impl(self, feed_names, fetch_names): + # pylint: disable=protected-access,attribute-defined-outside-init + gin = TFInputGraph._new_obj_internal() + assert (feed_names is None) == (fetch_names is None) + must_have_sig_def = fetch_names is None + # NOTE(phi-dbq): both have to be set to default + with self.sess.as_default(), self.graph.as_default(): + _ginfo = self.import_graph_fn(self.sess) + if must_have_sig_def: + _ginfo.extract_signatures() + feed_names = _ginfo.feed_names + fetch_names = _ginfo.fetch_names + gin.input_tensor_name_from_signature = _ginfo.feed_mapping + gin.output_tensor_name_from_signature = _ginfo.fetch_mapping + + for tnsr_name in feed_names: + assert tfx.get_op(self.graph, tnsr_name) + fetches = [tfx.get_tensor(self.graph, tnsr_name) for tnsr_name in fetch_names] + gin.graph_def = tfx.strip_and_freeze_until(fetches, self.graph, self.sess) + return gin + + def build(self, feed_names=None, fetch_names=None): + try: + gin = self._build_impl(feed_names, fetch_names) + finally: + if self._should_clean: + self.sess.close() + return gin diff --git a/python/sparkdl/graph/utils.py b/python/sparkdl/graph/utils.py index 45d8b065..75dec230 100644 --- a/python/sparkdl/graph/utils.py +++ b/python/sparkdl/graph/utils.py @@ -95,31 +95,49 @@ def get_tensor(graph, tfobj_or_name): 'cannot locate tensor {} in current graph'.format(_tensor_name) return tnsr -def as_tensor_name(name): +def as_tensor_name(tfobj_or_name): """ Derive tf.Tensor name from an op/tensor name. - We do not check if the tensor exist (as no graph parameter is passed in). + If the input is a name, we do not check if the tensor exist + (as no graph parameter is passed in). - :param name: op name or tensor name + :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ - assert isinstance(name, six.string_types) - name_parts = name.split(":") - assert len(name_parts) <= 2, name_parts - if len(name_parts) < 2: - name += ":0" - return name + if isinstance(tfobj_or_name, six.string_types): + # If input is a string, assume it is a name and infer the corresponding tensor name. + # WARNING: this depends on TensorFlow's tensor naming convention + name = tfobj_or_name + name_parts = name.split(":") + assert len(name_parts) <= 2, name_parts + if len(name_parts) < 2: + name += ":0" + return name + elif hasattr(tfobj_or_name, 'graph'): + tfobj = tfobj_or_name + return get_tensor(tfobj.graph, tfobj).name + else: + raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name))) -def as_op_name(name): +def as_op_name(tfobj_or_name): """ - Derive tf.Operation name from an op/tensor name - We do not check if the operation exist (as no graph parameter is passed in). + Derive tf.Operation name from an op/tensor name. + If the input is a name, we do not check if the operation exist + (as no graph parameter is passed in). - :param name: op name or tensor name + :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ - assert isinstance(name, six.string_types) - name_parts = name.split(":") - assert len(name_parts) <= 2, name_parts - return name_parts[0] + if isinstance(tfobj_or_name, six.string_types): + # If input is a string, assume it is a name and infer the corresponding operation name. + # WARNING: this depends on TensorFlow's operation naming convention + name = tfobj_or_name + name_parts = name.split(":") + assert len(name_parts) <= 2, name_parts + return name_parts[0] + elif hasattr(tfobj_or_name, 'graph'): + tfobj = tfobj_or_name + return get_op(tfobj.graph, tfobj).name + else: + raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name))) def op_name(graph, tfobj_or_name): """ diff --git a/python/sparkdl/param/__init__.py b/python/sparkdl/param/__init__.py index 98a8f7dd..ca1a9121 100644 --- a/python/sparkdl/param/__init__.py +++ b/python/sparkdl/param/__init__.py @@ -14,7 +14,11 @@ # from sparkdl.param.shared_params import ( - keyword_only, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel, - HasKerasLoss, HasKerasOptimizer, HasOutputNodeName, SparkDLTypeConverters) + keyword_only, HasInputCol, HasOutputCol, HasLabelCol, + # TFTransformer Params + HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams, + # Keras Estimator Params + HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName) +from sparkdl.param.converters import SparkDLTypeConverters from sparkdl.param.image_params import ( CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES) diff --git a/python/sparkdl/param/converters.py b/python/sparkdl/param/converters.py new file mode 100644 index 00000000..bea6bb3e --- /dev/null +++ b/python/sparkdl/param/converters.py @@ -0,0 +1,119 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import six + +import tensorflow as tf + +from pyspark.ml.param import TypeConverters + +import sparkdl.graph.utils as tfx +from sparkdl.graph.input import TFInputGraph +import sparkdl.utils.keras_model as kmutil + +__all__ = ['SparkDLTypeConverters'] + +def _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=True): + if isinstance(value, dict): + strs_pair_seq = [] + for k, v in value.items(): + try: + if is_key_tf_tensor: + _pair = (tfx.as_tensor_name(k), v) + else: + _pair = (k, tfx.as_tensor_name(v)) + except: + err_msg = "Can NOT convert {} (type {}) to tf.Tensor name" + _not_tf_op = k if is_key_tf_tensor else v + raise TypeError(err_msg.format(_not_tf_op, type(_not_tf_op))) + + str_val = v if is_key_tf_tensor else k + if not isinstance(str_val, six.string_types): + err_msg = 'expect string type for {}, but got {}' + raise TypeError(err_msg.format(str_val, type(str_val))) + + strs_pair_seq.append(_pair) + + return sorted(strs_pair_seq) + + if is_key_tf_tensor: + raise TypeError("Could not convert %s to tf.Tensor name to str mapping" % type(value)) + else: + raise TypeError("Could not convert %s to str to tf.Tensor name mapping" % type(value)) + + +class SparkDLTypeConverters(object): + @staticmethod + def toTFGraph(value): + if isinstance(value, tf.Graph): + return value + else: + raise TypeError("Could not convert %s to TensorFlow Graph" % type(value)) + + @staticmethod + def toTFInputGraph(value): + if isinstance(value, TFInputGraph): + return value + else: + raise TypeError("Could not convert %s to TFInputGraph" % type(value)) + + @staticmethod + def asColumnToTensorNameMap(value): + return _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=False) + + @staticmethod + def asTensorNameToColumnMap(value): + return _try_convert_tf_tensor_mapping(value, is_key_tf_tensor=True) + + @staticmethod + def toTFHParams(value): + if isinstance(value, tf.contrib.training.HParams): + return value + else: + raise TypeError("Could not convert %s to TensorFlow HParams" % type(value)) + + @staticmethod + def toStringOrTFTensor(value): + if isinstance(value, tf.Tensor): + return value + else: + try: + return TypeConverters.toString(value) + except TypeError: + raise TypeError("Could not convert %s to tensorflow.Tensor or str" % type(value)) + + @staticmethod + def supportedNameConverter(supportedList): + def converter(value): + if value in supportedList: + return value + else: + raise TypeError("%s %s is not in the supported list." % type(value), str(value)) + + return converter + + @staticmethod + def toKerasLoss(value): + if kmutil.is_valid_loss_function(value): + return value + raise ValueError( + "Named loss not supported in Keras: {} type({})".format(value, type(value))) + + @staticmethod + def toKerasOptimizer(value): + if kmutil.is_valid_optimizer(value): + return value + raise TypeError( + "Named optimizer not supported in Keras: {} type({})".format(value, type(value))) diff --git a/python/sparkdl/param/shared_params.py b/python/sparkdl/param/shared_params.py index e169e891..1d3c4882 100644 --- a/python/sparkdl/param/shared_params.py +++ b/python/sparkdl/param/shared_params.py @@ -20,14 +20,17 @@ """ from functools import wraps - -import tensorflow as tf +import six from pyspark.ml.param import Param, Params, TypeConverters -import sparkdl.utils.keras_model as kmutil +from sparkdl.graph.input import TFInputGraph +from sparkdl.param.converters import SparkDLTypeConverters -# From pyspark +######################################################## +# Copied from PySpark for backward compatibility. +# They first appeared in Apache Spark version 2.1.1. +######################################################## def keyword_only(func): """ @@ -36,12 +39,14 @@ def keyword_only(func): .. note:: Should only be used to wrap a method where first arg is `self` """ + @wraps(func) def wrapper(self, *args, **kwargs): if len(args) > 0: raise TypeError("Method %s forces keyword arguments." % func.__name__) self._input_kwargs = kwargs return func(self, **kwargs) + return wrapper @@ -50,10 +55,8 @@ class HasInputCol(Params): Mixin for param inputCol: input column name. """ - inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString) - - def __init__(self): - super(HasInputCol, self).__init__() + inputCol = Param( + Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString) def setInputCol(self, value): """ @@ -73,8 +76,8 @@ class HasOutputCol(Params): Mixin for param outputCol: output column name. """ - outputCol = Param(Params._dummy(), - "outputCol", "output column name.", typeConverter=TypeConverters.toString) + outputCol = Param( + Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString) def __init__(self): super(HasOutputCol, self).__init__() @@ -92,54 +95,9 @@ def getOutputCol(self): """ return self.getOrDefault(self.outputCol) -############################################ +######################################################## # New in sparkdl -############################################ - -class SparkDLTypeConverters(object): - - @staticmethod - def toStringOrTFTensor(value): - if isinstance(value, tf.Tensor): - return value - else: - try: - return TypeConverters.toString(value) - except TypeError: - raise TypeError("Could not convert %s to tensorflow.Tensor or str" % type(value)) - - @staticmethod - def toTFGraph(value): - # TODO: we may want to support tf.GraphDef in the future instead of tf.Graph since user - # is less likely to mess up using GraphDef vs Graph (e.g. constants vs variables). - if isinstance(value, tf.Graph): - return value - else: - raise TypeError("Could not convert %s to tensorflow.Graph type" % type(value)) - - @staticmethod - def supportedNameConverter(supportedList): - def converter(value): - if value in supportedList: - return value - else: - raise TypeError("%s %s is not in the supported list." % type(value), str(value)) - - return converter - - @staticmethod - def toKerasLoss(value): - if kmutil.is_valid_loss_function(value): - return value - raise ValueError( - "Named loss not supported in Keras: {} type({})".format(value, type(value))) - - @staticmethod - def toKerasOptimizer(value): - if kmutil.is_valid_optimizer(value): - return value - raise TypeError( - "Named optimizer not supported in Keras: {} type({})".format(value, type(value))) +######################################################## class HasOutputNodeName(Params): @@ -233,3 +191,83 @@ def seKerasLoss(self, value): def getKerasLoss(self): return self.getOrDefault(self.kerasLoss) + + +class HasOutputMapping(Params): + """ + Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs + """ + outputMapping = Param(Params._dummy(), + "outputMapping", + "Mapping output :class:`tf.Tensor` names to DataFrame column names", + typeConverter=SparkDLTypeConverters.asTensorNameToColumnMap) + + def setOutputMapping(self, value): + # NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the + # serializable TFInputGraph object once the inputMapping and outputMapping + # parameters are provided. + raise NotImplementedError( + "Please use the Transformer's constructor to assign `outputMapping` field.") + + def getOutputMapping(self): + return self.getOrDefault(self.outputMapping) + + +class HasInputMapping(Params): + """ + Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs + """ + inputMapping = Param(Params._dummy(), + "inputMapping", + "Mapping input DataFrame column names to :class:`tf.Tensor` names", + typeConverter=SparkDLTypeConverters.asColumnToTensorNameMap) + + def setInputMapping(self, value): + # NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the + # serializable TFInputGraph object once the inputMapping and outputMapping + # parameters are provided. + raise NotImplementedError( + "Please use the Transformer's constructor to assigne `inputMapping` field.") + + def getInputMapping(self): + return self.getOrDefault(self.inputMapping) + + +class HasTFInputGraph(Params): + """ + Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph. + """ + tfInputGraph = Param(Params._dummy(), + "tfInputGraph", + "A serializable object derived from a TensorFlow computation graph", + typeConverter=SparkDLTypeConverters.toTFInputGraph) + + def __init__(self): + super(HasTFInputGraph, self).__init__() + self._setDefault(tfInputGraph=None) + + def setTFInputGraph(self, value): + # NOTE(phi-dbq): due to the nature of TensorFlow import modes, we can only derive the + # serializable TFInputGraph object once the inputMapping and outputMapping + # parameters are provided. + raise NotImplementedError( + "Please use the Transformer's constructor to assign `tfInputGraph` field.") + + def getTFInputGraph(self): + return self.getOrDefault(self.tfInputGraph) + + +class HasTFHParams(Params): + """ + Mixin for TensorFlow model hyper-parameters + """ + tfHParams = Param(Params._dummy(), + "hparams", + "instance of :class:`tf.contrib.training.HParams`, a key-value map-like object", + typeConverter=SparkDLTypeConverters.toTFHParams) + + def setTFHParams(self, value): + return self._set(tfHParam=value) + + def getTFHParams(self): + return self.getOrDefault(self.tfHParams) diff --git a/python/sparkdl/transformers/tf_tensor.py b/python/sparkdl/transformers/tf_tensor.py new file mode 100644 index 00000000..8b11d9af --- /dev/null +++ b/python/sparkdl/transformers/tf_tensor.py @@ -0,0 +1,102 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import logging +import tensorflow as tf +from tensorflow.python.tools import optimize_for_inference_lib as infr_opt +import tensorframes as tfs + +from pyspark.ml import Transformer + +import sparkdl.graph.utils as tfx +from sparkdl.graph.input import TFInputGraph +from sparkdl.param import (keyword_only, SparkDLTypeConverters, HasInputMapping, + HasOutputMapping, HasTFInputGraph, HasTFHParams) + +__all__ = ['TFTransformer'] + +logger = logging.getLogger('sparkdl') + +class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping): + """ + Applies the TensorFlow graph to the array column in DataFrame. + + Restrictions of the current API: + + We assume that + - All graphs have a "minibatch" dimension (i.e. an unknown leading + dimension) in the tensor shapes. + - Input DataFrame has an array column where all elements have the same length + """ + + @keyword_only + def __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): + """ + __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) + """ + super(TFTransformer, self).__init__() + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): + """ + setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) + """ + super(TFTransformer, self).__init__() + kwargs = self._input_kwargs + # Further conanonicalization, e.g. converting dict to sorted str pairs happens here + return self._set(**kwargs) + + def _optimize_for_inference(self): + """ Optimize the graph for inference """ + gin = self.getTFInputGraph() + input_mapping = self.getInputMapping() + output_mapping = self.getOutputMapping() + input_node_names = [tfx.as_op_name(tnsr_name) for _, tnsr_name in input_mapping] + output_node_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping] + + # NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type + opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, + input_node_names, + output_node_names, + tf.float64.as_datatype_enum) + return opt_gdef + + def _transform(self, dataset): + graph_def = self._optimize_for_inference() + input_mapping = self.getInputMapping() + output_mapping = self.getOutputMapping() + + graph = tf.Graph() + with tf.Session(graph=graph): + analyzed_df = tfs.analyze(dataset) + + out_tnsr_op_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping] + tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names) + + feed_dict = dict((tfx.op_name(graph, tnsr_name), col_name) + for col_name, tnsr_name in input_mapping) + fetches = [tfx.get_tensor(graph, tnsr_op_name) for tnsr_op_name in out_tnsr_op_names] + + out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict) + + # We still have to rename output columns + for old_colname, new_colname in output_mapping: + if old_colname != new_colname: + out_df = out_df.withColumnRenamed(old_colname, new_colname) + + return out_df diff --git a/python/tests/param/__init__.py b/python/tests/param/__init__.py new file mode 100644 index 00000000..7084f22b --- /dev/null +++ b/python/tests/param/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/tests/param/params_test.py b/python/tests/param/params_test.py new file mode 100644 index 00000000..0c10411a --- /dev/null +++ b/python/tests/param/params_test.py @@ -0,0 +1,69 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from sparkdl.param.converters import SparkDLTypeConverters as conv + +class ParamsConverterTest(unittest.TestCase): + # pylint: disable=protected-access + + def test_tf_input_mapping_converter(self): + valid_tnsr_input = {'colA': 'tnsrOpA:0', + 'colB': 'tnsrOpB:0'} + valid_op_input = {'colA': 'tnsrOpA', + 'colB': 'tnsrOpB'} + valid_input_mapping_result = [('colA', 'tnsrOpA:0'), + ('colB', 'tnsrOpB:0')] + + for valid_input_mapping in [valid_op_input, valid_tnsr_input]: + res = conv.asColumnToTensorNameMap(valid_input_mapping) + self.assertEqual(valid_input_mapping_result, res) + + def test_tf_output_mapping_converter(self): + valid_tnsr_output = {'tnsrOpA:0': 'colA', + 'tnsrOpB:0': 'colB'} + valid_op_output = {'tnsrOpA': 'colA', + 'tnsrOpB': 'colB'} + valid_output_mapping_result = [('tnsrOpA:0', 'colA'), + ('tnsrOpB:0', 'colB')] + + for valid_output_mapping in [valid_tnsr_output, valid_op_output]: + res = conv.asTensorNameToColumnMap(valid_output_mapping) + self.assertEqual(valid_output_mapping_result, res) + + + def test_invalid_input_mapping(self): + for invalid in [['a1', 'b2'], ('c3', 'd4'), [('a', 1), ('b', 2)]]: + with self.assertRaises(TypeError): + conv.asColumnToTensorNameMap(invalid) + conv.asTensorNameToColumnMap(invalid) + + with self.assertRaises(TypeError): + # Wrong value type: must be string + conv.asTensorNameToColumnMap({1: 'a', 2.0: 'b'}) + conv.asColumnToTensorNameMap({'a': 1, 'b': 2.0}) + + # Wrong containter type: only accept dict + conv.asColumnToTensorNameMap([('colA', 'tnsrA:0'), ('colB', 'tnsrB:0')]) + conv.asTensorNameToColumnMap([('tnsrA:0', 'colA'), ('tnsrB:0', 'colB')]) diff --git a/python/tests/transformers/tf_tensor_test.py b/python/tests/transformers/tf_tensor_test.py new file mode 100644 index 00000000..c20a8e72 --- /dev/null +++ b/python/tests/transformers/tf_tensor_test.py @@ -0,0 +1,344 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +from contextlib import contextmanager +from glob import glob +import os +import shutil +import tempfile + +from keras.layers import Conv1D, Dense, Flatten, MaxPool1D +import numpy as np +import tensorflow as tf +import tensorframes as tfs + +from pyspark.sql.types import Row + +from sparkdl.graph.builder import IsolatedSession +from sparkdl.graph.input import * +import sparkdl.graph.utils as tfx +from sparkdl.transformers.tf_tensor import TFTransformer + +from ..tests import SparkDLTestCase + + +class TFTransformerTest(SparkDLTestCase): + + def setUp(self): + self.vec_size = 17 + self.num_vecs = 31 + + self.input_col = 'vec' + self.input_op_name = 'tnsrOpIn' + self.output_col = 'outputCol' + self.output_op_name = 'tnsrOpOut' + + self.feed_names = [] + self.fetch_names = [] + self.input_mapping = {} + self.output_mapping = {} + self.setup_iomap(replica=1) + + self.transformers = [] + self.test_case_results = [] + # Build a temporary directory, which might or might not be used by the test + self.model_output_root = tempfile.mkdtemp() + + + def tearDown(self): + shutil.rmtree(self.model_output_root, ignore_errors=True) + + def _build_default_session_tests(self, sess): + gin = TFInputGraph.fromGraph( + sess.graph, sess, self.feed_names, self.fetch_names) + self.build_standard_transformers(sess, gin) + + gin = TFInputGraph.fromGraphDef( + sess.graph.as_graph_def(), self.feed_names, self.fetch_names) + self.build_standard_transformers(sess, gin) + + def build_standard_transformers(self, sess, tf_input_graph): + def _add_transformer(imap, omap): + trnsfmr = TFTransformer( + tfInputGraph=tf_input_graph, inputMapping=imap, outputMapping=omap) + self.transformers.append(trnsfmr) + + imap = dict((col, tfx.tensor_name(sess.graph, op_name)) + for col, op_name in self.input_mapping.items()) + omap = dict((tfx.tensor_name(sess.graph, op_name), col) + for op_name, col in self.output_mapping.items()) + _add_transformer(imap, omap) + + def setup_iomap(self, replica=1): + self.input_mapping = {} + self.feed_names = [] + self.output_mapping = {} + self.fetch_names = [] + + if replica > 1: + for i in range(replica): + colname = '{}_replica{:03d}'.format(self.input_col, i) + tnsr_op_name = '{}_replica{:03d}'.format(self.input_op_name, i) + self.input_mapping[colname] = tnsr_op_name + self.feed_names.append(tnsr_op_name + ':0') + + colname = '{}_replica{:03d}'.format(self.output_col, i) + tnsr_op_name = '{}_replica{:03d}'.format(self.output_op_name, i) + self.output_mapping[tnsr_op_name] = colname + self.fetch_names.append(tnsr_op_name + ':0') + else: + self.input_mapping = {self.input_col: self.input_op_name} + self.feed_names = [self.input_op_name + ':0'] + self.output_mapping = {self.output_op_name: self.output_col} + self.fetch_names = [self.output_op_name + ':0'] + + @contextmanager + def _run_test_in_tf_session(self): + """ [THIS IS NOT A TEST]: encapsulate general test workflow """ + + # Build local features and DataFrame from it + local_features = [] + for idx in range(self.num_vecs): + _dict = {'idx': idx} + for colname, _ in self.input_mapping.items(): + _dict[colname] = np.random.randn(self.vec_size).tolist() + + local_features.append(Row(**_dict)) + + df = self.session.createDataFrame(local_features) + analyzed_df = tfs.analyze(df) + + # Build the TensorFlow graph + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + # Build test graph and transformers from here + yield sess + + # Get the reference data + _results = [] + for row in local_features: + fetches = [tfx.get_tensor(graph, tnsr_op_name) + for tnsr_op_name in self.output_mapping.keys()] + feed_dict = {} + for colname, tnsr_op_name in self.input_mapping.items(): + tnsr = tfx.get_tensor(graph, tnsr_op_name) + feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :] + + curr_res = sess.run(fetches, feed_dict=feed_dict) + _results.append(np.ravel(curr_res)) + + out_ref = np.hstack(_results) + + # Apply the transform + for transfomer in self.transformers: + out_df = transfomer.transform(analyzed_df) + out_colnames = [] + for old_colname, new_colname in self.output_mapping.items(): + out_colnames.append(new_colname) + if old_colname != new_colname: + out_df = out_df.withColumnRenamed(old_colname, new_colname) + + _results = [] + for row in out_df.select(out_colnames).collect(): + curr_res = [row[colname] for colname in out_colnames] + _results.append(np.ravel(curr_res)) + out_tgt = np.hstack(_results) + + err_msg = 'not close => {} != {}, max_diff {}' + self.assertTrue(np.allclose(out_ref, out_tgt), + msg=err_msg.format(out_ref.shape, out_tgt.shape, + np.max(np.abs(out_ref - out_tgt)))) + + + def test_build_from_tf_graph(self): + """ Build TFTransformer from tf.Graph """ + with self._run_test_in_tf_session() as sess: + # Begin building graph + x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name) + _ = tf.reduce_mean(x, axis=1, name=self.output_op_name) + # End building graph + + self._build_default_session_tests(sess) + + + def test_build_from_saved_model(self): + """ Build TFTransformer from saved model """ + # Setup saved model export directory + saved_model_root = self.model_output_root + saved_model_dir = os.path.join(saved_model_root, 'saved_model') + serving_tag = "serving_tag" + serving_sigdef_key = 'prediction_signature' + builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) + + with self._run_test_in_tf_session() as sess: + # Model definition: begin + x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name) + w = tf.Variable(tf.random_normal([self.vec_size], dtype=tf.float64), + dtype=tf.float64, name='varW') + z = tf.reduce_mean(x * w, axis=1, name=self.output_op_name) + # Model definition ends + + sess.run(w.initializer) + + sig_inputs = { + 'input_sig': tf.saved_model.utils.build_tensor_info(x)} + sig_outputs = { + 'output_sig': tf.saved_model.utils.build_tensor_info(z)} + + serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def( + inputs=sig_inputs, + outputs=sig_outputs) + + builder.add_meta_graph_and_variables(sess, + [serving_tag], + signature_def_map={ + serving_sigdef_key: serving_sigdef}) + builder.save() + + # Build the transformer from exported serving model + # We are using signaures, thus must provide the keys + tfInputGraph = TFInputGraph.fromSavedModelWithSignature( + saved_model_dir, serving_tag, serving_sigdef_key) + + inputMapping = tfInputGraph.translateInputMapping({ + self.input_col: 'input_sig' + }) + outputMapping = tfInputGraph.translateOutputMapping({ + 'output_sig': self.output_col + }) + trans_with_sig = TFTransformer(tfInputGraph=tfInputGraph, + inputMapping=inputMapping, + outputMapping=outputMapping) + self.transformers.append(trans_with_sig) + + # Build the transformer from exported serving model + # We are not using signatures, thus must provide tensor/operation names + gin = TFInputGraph.fromSavedModel( + saved_model_dir, serving_tag, self.feed_names, self.fetch_names) + self.build_standard_transformers(sess, gin) + + gin = TFInputGraph.fromGraph( + sess.graph, sess, self.feed_names, self.fetch_names) + self.build_standard_transformers(sess, gin) + + + def test_build_from_checkpoint(self): + """ Build TFTransformer from a model checkpoint """ + # Build the TensorFlow graph + model_ckpt_dir = self.model_output_root + ckpt_path_prefix = os.path.join(model_ckpt_dir, 'model_ckpt') + serving_sigdef_key = 'prediction_signature' + + with self._run_test_in_tf_session() as sess: + x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=self.input_op_name) + #x = tf.placeholder(tf.float64, shape=[None, vec_size], name=input_col) + w = tf.Variable(tf.random_normal([self.vec_size], dtype=tf.float64), + dtype=tf.float64, name='varW') + z = tf.reduce_mean(x * w, axis=1, name=self.output_op_name) + sess.run(w.initializer) + saver = tf.train.Saver(var_list=[w]) + _ = saver.save(sess, ckpt_path_prefix, global_step=2702) + + # Prepare the signature_def + serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def( + inputs={ + 'input_sig': tf.saved_model.utils.build_tensor_info(x) + }, + outputs={ + 'output_sig': tf.saved_model.utils.build_tensor_info(z) + }) + + # A rather contrived way to add signature def to a meta_graph + meta_graph_def = tf.train.export_meta_graph() + + # Find the meta_graph file (there should be only one) + _ckpt_meta_fpaths = glob('{}/*.meta'.format(model_ckpt_dir)) + self.assertEqual(len(_ckpt_meta_fpaths), 1, msg=','.join(_ckpt_meta_fpaths)) + ckpt_meta_fpath = _ckpt_meta_fpaths[0] + + # Add signature_def to the meta_graph and serialize it + # This will overwrite the existing meta_graph_def file + meta_graph_def.signature_def[serving_sigdef_key].CopyFrom(serving_sigdef) + with open(ckpt_meta_fpath, mode='wb') as fout: + fout.write(meta_graph_def.SerializeToString()) + + # Build the transformer from exported serving model + # We are using signaures, thus must provide the keys + tfInputGraph = TFInputGraph.fromCheckpointWithSignature( + model_ckpt_dir, serving_sigdef_key) + + inputMapping = tfInputGraph.translateInputMapping({ + self.input_col: 'input_sig' + }) + outputMapping = tfInputGraph.translateOutputMapping({ + 'output_sig': self.output_col + }) + trans_with_sig = TFTransformer(tfInputGraph=tfInputGraph, + inputMapping=inputMapping, + outputMapping=outputMapping) + self.transformers.append(trans_with_sig) + + # Transformer without using signature_def + gin = TFInputGraph.fromCheckpoint(model_ckpt_dir, self.feed_names, self.fetch_names) + self.build_standard_transformers(sess, gin) + + gin = TFInputGraph.fromGraph( + sess.graph, sess, self.feed_names, self.fetch_names) + self.build_standard_transformers(sess, gin) + + + def test_multi_io(self): + """ Build TFTransformer with multiple I/O tensors """ + self.setup_iomap(replica=3) + with self._run_test_in_tf_session() as sess: + xs = [] + for tnsr_op_name in self.input_mapping.values(): + x = tf.placeholder(tf.float64, shape=[None, self.vec_size], name=tnsr_op_name) + xs.append(x) + + zs = [] + for i, tnsr_op_name in enumerate(self.output_mapping.keys()): + z = tf.reduce_mean(xs[i], axis=1, name=tnsr_op_name) + zs.append(z) + + self._build_default_session_tests(sess) + + + def test_mixed_keras_graph(self): + """ Build mixed keras graph """ + with IsolatedSession(using_keras=True) as issn: + tnsr_in = tf.placeholder( + tf.double, shape=[None, self.vec_size], name=self.input_op_name) + inp = tf.expand_dims(tnsr_in, axis=2) + # Keras layers does not take tf.double + inp = tf.cast(inp, tf.float32) + conv = Conv1D(filters=4, kernel_size=2)(inp) + pool = MaxPool1D(pool_size=2)(conv) + flat = Flatten()(pool) + dense = Dense(1)(flat) + # We must keep the leading dimension of the output + redsum = tf.reduce_logsumexp(dense, axis=1) + tnsr_out = tf.cast(redsum, tf.double, name=self.output_op_name) + + # Initialize the variables + init_op = tf.global_variables_initializer() + issn.run(init_op) + # We could train the model ... but skip it here + gfn = issn.asGraphFunction([tnsr_in], [tnsr_out]) + + with self._run_test_in_tf_session() as sess: + tf.import_graph_def(gfn.graph_def, name='') + self._build_default_session_tests(sess)