diff --git a/python/docs/sparkdl.rst b/python/docs/sparkdl.rst index c92e60cc..bf0c86f8 100644 --- a/python/docs/sparkdl.rst +++ b/python/docs/sparkdl.rst @@ -6,8 +6,10 @@ Subpackages .. toctree:: + sparkdl.estimators sparkdl.graph sparkdl.image + sparkdl.param sparkdl.transformers sparkdl.udf sparkdl.utils diff --git a/python/requirements.txt b/python/requirements.txt index a98a4d17..9d2133fa 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -4,6 +4,7 @@ h5py>=2.7.0 keras==2.0.4 # NOTE: this package has only been tested with keras 2.0.4 and may not work with other releases nose>=1.3.7 # for testing numpy>=1.11.2 +parameterized>=0.6.1 # for testing pillow>=4.1.1,<4.2 pygments>=2.2.0 tensorflow==1.3.0 diff --git a/python/sparkdl/__init__.py b/python/sparkdl/__init__.py index aa15059a..06b91bc8 100644 --- a/python/sparkdl/__init__.py +++ b/python/sparkdl/__init__.py @@ -13,15 +13,17 @@ # limitations under the License. # +from .graph.input import TFInputGraph from .image.imageIO import imageSchema, imageType, readImages from .transformers.keras_image import KerasImageFileTransformer from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer from .transformers.tf_image import TFImageTransformer +from .transformers.tf_tensor import TFTransformer from .transformers.utils import imageInputPlaceholder + __all__ = [ 'imageSchema', 'imageType', 'readImages', - 'TFImageTransformer', - 'DeepImagePredictor', 'DeepImageFeaturizer', - 'KerasImageFileTransformer', + 'TFImageTransformer', 'TFInputGraph', 'TFTransformer', + 'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'imageInputPlaceholder'] diff --git a/python/sparkdl/graph/builder.py b/python/sparkdl/graph/builder.py index 86c3b3ce..a7d7122f 100644 --- a/python/sparkdl/graph/builder.py +++ b/python/sparkdl/graph/builder.py @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False): self.graph = graph or tf.Graph() self.sess = tf.Session(graph=self.graph) if using_keras: + self.using_keras = True self.keras_prev_sess = K.get_session() else: + self.using_keras = False self.keras_prev_sess = None def __enter__(self): - self.sess.as_default() self.sess.__enter__() - if self.keras_prev_sess is not None: + if self.using_keras: K.set_session(self.sess) return self def __exit__(self, *args): - if self.keras_prev_sess is not None: + if self.using_keras: K.set_session(self.keras_prev_sess) self.sess.__exit__(*args) @@ -87,8 +88,8 @@ def asGraphFunction(self, inputs, outputs, strip_and_freeze=True): else: gdef = self.graph.as_graph_def(add_shapes=True) return GraphFunction(graph_def=gdef, - input_names=[tfx.validated_input(self.graph, elem) for elem in inputs], - output_names=[tfx.validated_output(self.graph, elem) for elem in outputs]) + input_names=[tfx.validated_input(elem, self.graph) for elem in inputs], + output_names=[tfx.validated_output(elem, self.graph) for elem in outputs]) def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs): """ @@ -130,8 +131,8 @@ def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_k return_elements=gfn.output_names, name=scope_name, **gdef_kargs) - feeds = [tfx.get_tensor(self.graph, name) for name in input_names] - fetches = [tfx.get_tensor(self.graph, name) for name in output_names] + feeds = [tfx.get_tensor(name, self.graph) for name in input_names] + fetches = [tfx.get_tensor(name, self.graph) for name in output_names] return (feeds, fetches) @@ -233,7 +234,7 @@ def fromList(cls, functions): _, first_gfn = functions[0] feeds, _ = issn.importGraphFunction(first_gfn, prefix='') for tnsr in feeds: - name = tfx.op_name(issn.graph, tnsr) + name = tfx.op_name(tnsr, issn.graph) first_input_info.append((tnsr.dtype, tnsr.shape, name)) # TODO: make sure that this graph is not reused to prevent name conflict # Report error if the graph is not manipulated by anyone else @@ -268,4 +269,3 @@ def fromList(cls, functions): gfn = issn.asGraphFunction(first_inputs, last_outputs) return gfn - diff --git a/python/sparkdl/graph/input.py b/python/sparkdl/graph/input.py new file mode 100644 index 00000000..67ab1119 --- /dev/null +++ b/python/sparkdl/graph/input.py @@ -0,0 +1,355 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=no-name-in-module + +import sparkdl.graph.utils as tfx + +__all__ = ["TFInputGraph"] + +# pylint: disable=invalid-name,wrong-spelling-in-comment,wrong-spelling-in-docstring + +class TFInputGraph(object): + """ + An opaque object containing TensorFlow graph. + This object can be serialized. + + .. note:: We recommend constructing this object using one of the class constructor methods. + + - :py:meth:`fromGraph` + - :py:meth:`fromGraphDef` + - :py:meth:`fromCheckpoint` + - :py:meth:`fromCheckpointWithSignature` + - :py:meth:`fromSavedModel` + - :py:meth:`fromSavedModelWithSignature` + + + When the graph contains serving signatures in which a set of well-known names are associated + with their corresponding raw tensor names in the graph, we extract and store them here. + For example, the TensorFlow saved model may contain the following structure, + so that end users can retrieve the the input tensor via `well_known_input_sig` and + the output tensor via `well_known_output_sig` without knowing the actual tensor names a priori. + + .. code-block:: python + + sigdef: {'well_known_prediction_signature': + inputs { key: "well_known_input_sig" + value { + name: "tnsrIn:0" + dtype: DT_DOUBLE + tensor_shape { dim { size: -1 } dim { size: 17 } } + } + } + outputs { key: "well_known_output_sig" + value { + name: "tnsrOut:0" + dtype: DT_DOUBLE + tensor_shape { dim { size: -1 } } + } + }} + + + In this case, the class will internally store the mapping from signature names to tensor names. + + .. code-block:: python + + {'well_known_input_sig': 'tnsrIn:0'} + {'well_known_output_sig': 'tnsrOut:0'} + + + :param graph_def: :py:obj:`tf.GraphDef`, a serializable object containing the topology and + computation units of the TensorFlow graph. The graph object is prepared for + inference, i.e. the variables are converted to constants and operations like + BatchNormalization_ are converted to be independent of input batch. + + .. _BatchNormalization: https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization + + :param input_tensor_name_from_signature: dict, signature key names mapped to tensor names. + Please see the example above. + :param output_tensor_name_from_signature: dict, signature key names mapped to tensor names + Please see the example above. + """ + + + def __init__(self, graph_def, input_tensor_name_from_signature, + output_tensor_name_from_signature): + self.graph_def = graph_def + self.input_tensor_name_from_signature = input_tensor_name_from_signature + self.output_tensor_name_from_signature = output_tensor_name_from_signature + + def translateInputMapping(self, input_mapping): + """ + When the meta_graph contains signature_def, we expect users to provide + input and output mapping with respect to the tensor reference keys + embedded in the `signature_def`. + + This function translates the input_mapping into the canonical format, + which maps input DataFrame column names to tensor names. + + :param input_mapping: dict, DataFrame column name to tensor reference names + defined in the signature_def key. + """ + assert self.input_tensor_name_from_signature is not None + _input_mapping = {} + if isinstance(input_mapping, dict): + input_mapping = list(input_mapping.items()) + assert isinstance(input_mapping, list) + for col_name, sig_key in input_mapping: + tnsr_name = self.input_tensor_name_from_signature[sig_key] + _input_mapping[col_name] = tnsr_name + return _input_mapping + + def translateOutputMapping(self, output_mapping): + """ + When the meta_graph contains signature_def, we expect users to provide + input and output mapping with respect to the tensor reference keys + embedded in the `signature_def`. + + This function translates the output_mapping into the canonical format, + which maps tensor names into input DataFrame column names. + + :param output_mapping: dict, tensor reference names defined in the signature_def keys + into the output DataFrame column names. + """ + assert self.output_tensor_name_from_signature is not None + _output_mapping = {} + if isinstance(output_mapping, dict): + output_mapping = list(output_mapping.items()) + assert isinstance(output_mapping, list) + for sig_key, col_name in output_mapping: + tnsr_name = self.output_tensor_name_from_signature[sig_key] + _output_mapping[tnsr_name] = col_name + return _output_mapping + + @classmethod + def fromGraph(cls, graph, sess, feed_names, fetch_names): + """ + Construct a TFInputGraph from a in memory `tf.Graph` object. + The graph might contain variables that are maintained in the provided session. + Thus we need an active session in which the graph's variables are initialized or + restored. We do not close the session. As a result, this constructor can be used + inside a standard TensorFlow session context. + + .. code-block:: python + + with tf.Session() as sess: + graph = import_my_tensorflow_graph(...) + input = TFInputGraph.fromGraph(graph, sess, ...) + + :param graph: a :py:class:`tf.Graph` object containing the topology and computation units of + the TensorFlow graph. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + + @classmethod + def fromGraphDef(cls, graph_def, feed_names, fetch_names): + """ + Construct a TFInputGraph from a tf.GraphDef object. + + :param graph_def: :py:class:`tf.GraphDef`, a serializable object containing the topology and + computation units of the TensorFlow graph. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + assert isinstance(graph_def, tf.GraphDef), \ + ('expect tf.GraphDef type but got', type(graph_def)) + + graph = tf.Graph() + with tf.Session(graph=graph) as sess: + tf.import_graph_def(graph_def, name='') + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + + @classmethod + def fromCheckpoint(cls, checkpoint_dir, feed_names, fetch_names): + """ + Construct a TFInputGraph object from a checkpoint, ignore the embedded + signature_def, if there is any. + + :param checkpoint_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + return _from_checkpoint_impl(checkpoint_dir, signature_def_key=None, feed_names=feed_names, + fetch_names=fetch_names) + + @classmethod + def fromCheckpointWithSignature(cls, checkpoint_dir, signature_def_key): + """ + Construct a TFInputGraph object from a checkpoint, using the embedded + signature_def. Throw an error if we cannot find an entry with the `signature_def_key` + inside the `signature_def`. + + :param checkpoint_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param signature_def_key: str, key (name) of the signature_def to use. It should be in + the list of `signature_def` structures saved with the checkpoint. + """ + assert signature_def_key is not None + return _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names=None, + fetch_names=None) + + @classmethod + def fromSavedModel(cls, saved_model_dir, tag_set, feed_names, fetch_names): + """ + Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory. + Ignore the the embedded signature_def, if there is any. + + :param saved_model_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param tag_set: str, name of the graph stored in this meta_graph of the saved model + that we are interested in using. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=None, + feed_names=feed_names, fetch_names=fetch_names) + + @classmethod + def fromSavedModelWithSignature(cls, saved_model_dir, tag_set, signature_def_key): + """ + Construct a TFInputGraph object from a saved model (`tf.SavedModel`) directory, + using the embedded signature_def. Throw error if we cannot find an entry with + the `signature_def_key` inside the `signature_def`. + + :param saved_model_dir: str, name of the directory containing the TensorFlow graph + training checkpoint. + :param tag_set: str, name of the graph stored in this meta_graph of the saved model + that we are interested in using. + :param signature_def_key: str, key (name) of the signature_def to use. It should be in + the list of `signature_def` structures saved with the + TensorFlow `SavedModel`. + """ + assert signature_def_key is not None + return _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key=signature_def_key, + feed_names=None, fetch_names=None) + + +def _from_checkpoint_impl(checkpoint_dir, signature_def_key, feed_names, fetch_names): + """ + Construct a TFInputGraph from a model checkpoint. + Notice that one should either provide the `signature_def_key` or provide both + `feed_names` and `fetch_names`. Please set the unprovided values to None. + + :param signature_def_key: str, name of the mapping contained inside the `signature_def` + from which we retrieve the signature key to tensor names mapping. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + assert (feed_names is None) == (fetch_names is None), \ + 'feed_names and fetch_names, if provided must be both non-None.' + assert (feed_names is None) != (signature_def_key is None), \ + 'must either provide feed_names or singnature_def_key' + + graph = tf.Graph() + with tf.Session(graph=graph) as sess: + # Load checkpoint and import the graph + ckpt_path = tf.train.latest_checkpoint(checkpoint_dir) + + # NOTE(phi-dbq): we must manually load meta_graph_def to get the signature_def + # the current `import_graph_def` function seems to ignore + # any signature_def fields in a checkpoint's meta_graph_def. + meta_graph_def = meta_graph_pb2.MetaGraphDef() + with open("{}.meta".format(ckpt_path), 'rb') as fin: + meta_graph_def.ParseFromString(fin.read()) + + saver = tf.train.import_meta_graph(meta_graph_def, clear_devices=True) + saver.restore(sess, ckpt_path) + + if signature_def_key is not None: + sig_def = meta_graph_def.signature_def[signature_def_key] + return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) + else: + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + +def _from_saved_model_impl(saved_model_dir, tag_set, signature_def_key, feed_names, fetch_names): + """ + Construct a TFInputGraph from a SavedModel. + Notice that one should either provide the `signature_def_key` or provide both + `feed_names` and `fetch_names`. Please set the unprovided values to None. + + :param signature_def_key: str, name of the mapping contained inside the `signature_def` + from which we retrieve the signature key to tensor names mapping. + :param feed_names: list, names of the input tensors. + :param fetch_names: list, names of the output tensors. + """ + assert (feed_names is None) == (fetch_names is None), \ + 'feed_names and fetch_names, if provided must be both non-None.' + assert (feed_names is None) != (signature_def_key is None), \ + 'must either provide feed_names or singnature_def_key' + + graph = tf.Graph() + with tf.Session(graph=graph) as sess: + tag_sets = tag_set.split(',') + meta_graph_def = tf.saved_model.loader.load(sess, tag_sets, saved_model_dir) + + if signature_def_key is not None: + sig_def = tf.contrib.saved_model.get_signature_def_by_key(meta_graph_def, + signature_def_key) + return _build_with_sig_def(sess=sess, graph=graph, sig_def=sig_def) + else: + return _build_with_feeds_fetches(sess=sess, graph=graph, feed_names=feed_names, + fetch_names=fetch_names) + + +def _build_with_sig_def(sess, graph, sig_def): + # pylint: disable=protected-access + assert sig_def, 'signature_def must not be None' + + with sess.as_default(), graph.as_default(): + feed_mapping = {} + feed_names = [] + for sigdef_key, tnsr_info in sig_def.inputs.items(): + tnsr_name = tnsr_info.name + feed_mapping[sigdef_key] = tnsr_name + feed_names.append(tnsr_name) + + fetch_mapping = {} + fetch_names = [] + for sigdef_key, tnsr_info in sig_def.outputs.items(): + tnsr_name = tnsr_info.name + fetch_mapping[sigdef_key] = tnsr_name + fetch_names.append(tnsr_name) + + for tnsr_name in feed_names: + assert tfx.get_op(tnsr_name, graph), \ + 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) + fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] + graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) + + return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping, + output_tensor_name_from_signature=fetch_mapping) + + +def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names): + assert feed_names is not None, "must provide feed_names" + assert fetch_names is not None, "must provide fetch names" + + with sess.as_default(), graph.as_default(): + for tnsr_name in feed_names: + assert tfx.get_op(tnsr_name, graph), \ + 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) + fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] + graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) + + return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None, + output_tensor_name_from_signature=None) diff --git a/python/sparkdl/graph/tensorframes_udf.py b/python/sparkdl/graph/tensorframes_udf.py index 54027b8d..aa1531b4 100644 --- a/python/sparkdl/graph/tensorframes_udf.py +++ b/python/sparkdl/graph/tensorframes_udf.py @@ -33,7 +33,7 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal .. code-block:: python from sparkdl.graph.tensorframes_udf import makeUDF - + with IsolatedSession() as issn: x = tf.placeholder(tf.double, shape=[], name="input_x") z = tf.add(x, 3, name='z') @@ -45,7 +45,7 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal df = spark.createDataFrame([Row(xCol=float(x)) for x in range(100)]) df.createOrReplaceTempView("my_float_table") - spark.sql("select my_tensorflow_udf(xCol) as zCol from my_float_table").show() + spark.sql("select my_tensorflow_udf(xCol) as zCol from my_float_table").show() :param graph: :py:class:`tf.Graph`, a TensorFlow Graph :param udf_name: str, name of the SQL UDF @@ -77,18 +77,18 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal tfs.core._add_graph(graph, jvm_builder) # Obtain the fetches and their shapes - fetch_names = [tfx.tensor_name(graph, fetch) for fetch in fetches] - fetch_shapes = [tfx.get_shape(graph, fetch) for fetch in fetches] + fetch_names = [tfx.tensor_name(fetch, graph) for fetch in fetches] + fetch_shapes = [tfx.get_shape(fetch, graph) for fetch in fetches] # Traverse the graph nodes and obtain all the placeholders and their shapes placeholder_names = [] placeholder_shapes = [] for node in graph.as_graph_def(add_shapes=True).node: if len(node.input) == 0 and str(node.op) == 'Placeholder': - tnsr_name = tfx.tensor_name(graph, node.name) + tnsr_name = tfx.tensor_name(node.name, graph) tnsr = graph.get_tensor_by_name(tnsr_name) try: - tnsr_shape = tfx.get_shape(graph, tnsr) + tnsr_shape = tfx.get_shape(tnsr, graph) placeholder_names.append(tnsr_name) placeholder_shapes.append(tnsr_shape) except ValueError: @@ -98,7 +98,7 @@ def makeGraphUDF(graph, udf_name, fetches, feeds_to_fields_map=None, blocked=Fal jvm_builder.shape(fetch_names + placeholder_names, fetch_shapes + placeholder_shapes) jvm_builder.fetches(fetch_names) # Passing feeds to TensorFrames - placeholder_op_names = [tfx.op_name(graph, name) for name in placeholder_names] + placeholder_op_names = [tfx.op_name(name, graph) for name in placeholder_names] # Passing the graph input to DataFrame column mapping and additional placeholder names tfs.core._add_inputs(jvm_builder, feeds_to_fields_map, placeholder_op_names) diff --git a/python/sparkdl/graph/utils.py b/python/sparkdl/graph/utils.py index 45d8b065..64e093fe 100644 --- a/python/sparkdl/graph/utils.py +++ b/python/sparkdl/graph/utils.py @@ -16,8 +16,6 @@ import logging import six -import webbrowser -from tempfile import NamedTemporaryFile import tensorflow as tf @@ -35,14 +33,15 @@ def validated_graph(graph): """ - Check if the input is a valid tf.Graph + Check if the input is a valid :py:class:`tf.Graph` and return it. + Raise an error otherwise. - :param graph: tf.Graph, a TensorFlow Graph object + :param graph: :py:class:`tf.Graph`, a TensorFlow Graph object """ assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph)) return graph -def get_shape(graph, tfobj_or_name): +def get_shape(tfobj_or_name, graph): """ Return the shape of the tensor as a list @@ -50,38 +49,44 @@ def get_shape(graph, tfobj_or_name): :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ graph = validated_graph(graph) - _shape = get_tensor(graph, tfobj_or_name).get_shape().as_list() + _shape = get_tensor(tfobj_or_name, graph).get_shape().as_list() return [-1 if x is None else x for x in _shape] -def get_op(graph, tfobj_or_name): +def get_op(tfobj_or_name, graph): """ - Get a tf.Operation object + Get a :py:class:`tf.Operation` object. - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the operation. + By default the graph we don't require this argument to be provided. """ graph = validated_graph(graph) + _assert_same_graph(tfobj_or_name, graph) if isinstance(tfobj_or_name, tf.Operation): return tfobj_or_name name = tfobj_or_name if isinstance(tfobj_or_name, tf.Tensor): name = tfobj_or_name.name if not isinstance(name, six.string_types): - raise TypeError('invalid op request for {} of {}'.format(name, type(name))) - _op_name = as_op_name(name) + raise TypeError('invalid op request for [type {}] {}'.format(type(name), name)) + _op_name = op_name(name, graph=None) op = graph.get_operation_by_name(_op_name) - assert op is not None, \ - 'cannot locate op {} in current graph'.format(_op_name) + err_msg = 'cannot locate op {} in the current graph, got [type {}] {}' + assert isinstance(op, tf.Operation), err_msg.format(_op_name, type(op), op) return op -def get_tensor(graph, tfobj_or_name): +def get_tensor(tfobj_or_name, graph): """ - Get a tf.Tensor object + Get a :py:class:`tf.Tensor` object - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the tensor. + By default the graph we don't require this argument to be provided. """ graph = validated_graph(graph) + _assert_same_graph(tfobj_or_name, graph) if isinstance(tfobj_or_name, tf.Tensor): return tfobj_or_name name = tfobj_or_name @@ -89,59 +94,71 @@ def get_tensor(graph, tfobj_or_name): name = tfobj_or_name.name if not isinstance(name, six.string_types): raise TypeError('invalid tensor request for {} of {}'.format(name, type(name))) - _tensor_name = as_tensor_name(name) + _tensor_name = tensor_name(name, graph=None) tnsr = graph.get_tensor_by_name(_tensor_name) - assert tnsr is not None, \ - 'cannot locate tensor {} in current graph'.format(_tensor_name) + err_msg = 'cannot locate tensor {} in the current graph, got [type {}] {}' + assert isinstance(tnsr, tf.Tensor), err_msg.format(_tensor_name, type(tnsr), tnsr) return tnsr -def as_tensor_name(name): - """ - Derive tf.Tensor name from an op/tensor name. - We do not check if the tensor exist (as no graph parameter is passed in). - - :param name: op name or tensor name - """ - assert isinstance(name, six.string_types) - name_parts = name.split(":") - assert len(name_parts) <= 2, name_parts - if len(name_parts) < 2: - name += ":0" - return name - -def as_op_name(name): - """ - Derive tf.Operation name from an op/tensor name - We do not check if the operation exist (as no graph parameter is passed in). - - :param name: op name or tensor name - """ - assert isinstance(name, six.string_types) - name_parts = name.split(":") - assert len(name_parts) <= 2, name_parts - return name_parts[0] - -def op_name(graph, tfobj_or_name): - """ - Get the name of a tf.Operation - - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either - """ - graph = validated_graph(graph) - return get_op(graph, tfobj_or_name).name - -def tensor_name(graph, tfobj_or_name): - """ - Get the name of a tf.Tensor - - :param graph: tf.Graph, a TensorFlow Graph object - :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either - """ - graph = validated_graph(graph) - return get_tensor(graph, tfobj_or_name).name +def tensor_name(tfobj_or_name, graph=None): + """ + Derive the :py:class:`tf.Tensor` name from a :py:class:`tf.Operation` or :py:class:`tf.Tensor` + object, or its name. + If a name is provided and the graph is not, we will derive the tensor name based on + TensorFlow's naming convention. + If the input is a TensorFlow object, or the graph is given, we also check that + the tensor exists in the associated graph. + + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the tensor. + By default the graph we don't require this argument to be provided. + """ + if graph is not None: + return get_tensor(tfobj_or_name, graph).name + if isinstance(tfobj_or_name, six.string_types): + # If input is a string, assume it is a name and infer the corresponding tensor name. + # WARNING: this depends on TensorFlow's tensor naming convention + name = tfobj_or_name + name_parts = name.split(":") + assert len(name_parts) <= 2, name_parts + if len(name_parts) < 2: + name += ":0" + return name + elif hasattr(tfobj_or_name, 'graph'): + return get_tensor(tfobj_or_name, tfobj_or_name.graph).name + else: + raise TypeError('invalid tf.Tensor name query type {}'.format(type(tfobj_or_name))) + +def op_name(tfobj_or_name, graph=None): + """ + Derive the :py:class:`tf.Operation` name from a :py:class:`tf.Operation` or + :py:class:`tf.Tensor` object, or its name. + If a name is provided and the graph is not, we will derive the operation name based on + TensorFlow's naming convention. + If the input is a TensorFlow object, or the graph is given, we also check that + the operation exists in the associated graph. + + :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or + a name to either. + :param graph: a :py:class:`tf.Graph` object containing the operation. + By default the graph we don't require this argument to be provided. + """ + if graph is not None: + return get_op(tfobj_or_name, graph).name + if isinstance(tfobj_or_name, six.string_types): + # If input is a string, assume it is a name and infer the corresponding operation name. + # WARNING: this depends on TensorFlow's operation naming convention + name = tfobj_or_name + name_parts = name.split(":") + assert len(name_parts) <= 2, name_parts + return name_parts[0] + elif hasattr(tfobj_or_name, 'graph'): + return get_op(tfobj_or_name, tfobj_or_name.graph).name + else: + raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name))) -def validated_output(graph, tfobj_or_name): +def validated_output(tfobj_or_name, graph): """ Validate and return the output names useable GraphFunction @@ -149,9 +166,9 @@ def validated_output(graph, tfobj_or_name): :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ graph = validated_graph(graph) - return op_name(graph, tfobj_or_name) + return op_name(tfobj_or_name, graph) -def validated_input(graph, tfobj_or_name): +def validated_input(tfobj_or_name, graph): """ Validate and return the input names useable GraphFunction @@ -159,7 +176,7 @@ def validated_input(graph, tfobj_or_name): :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either """ graph = validated_graph(graph) - name = op_name(graph, tfobj_or_name) + name = op_name(tfobj_or_name, graph) op = graph.get_operation_by_name(name) assert 'Placeholder' == op.type, \ ('input must be Placeholder, but get', op.type) @@ -186,7 +203,7 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): gdef_frozen = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(add_shapes=True), - [op_name(graph, tnsr) for tnsr in fetches]) + [op_name(tnsr, graph) for tnsr in fetches]) if should_close_session: sess.close() @@ -198,3 +215,9 @@ def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): return g else: return gdef_frozen + + +def _assert_same_graph(tfobj, graph): + if graph is not None and hasattr(tfobj, 'graph'): + err_msg = 'the graph of TensorFlow element {} != graph {}' + assert tfobj.graph == graph, err_msg.format(tfobj, graph) diff --git a/python/sparkdl/param/__init__.py b/python/sparkdl/param/__init__.py index 98a8f7dd..ca1a9121 100644 --- a/python/sparkdl/param/__init__.py +++ b/python/sparkdl/param/__init__.py @@ -14,7 +14,11 @@ # from sparkdl.param.shared_params import ( - keyword_only, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel, - HasKerasLoss, HasKerasOptimizer, HasOutputNodeName, SparkDLTypeConverters) + keyword_only, HasInputCol, HasOutputCol, HasLabelCol, + # TFTransformer Params + HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams, + # Keras Estimator Params + HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName) +from sparkdl.param.converters import SparkDLTypeConverters from sparkdl.param.image_params import ( CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES) diff --git a/python/sparkdl/param/converters.py b/python/sparkdl/param/converters.py new file mode 100644 index 00000000..25a2e3a1 --- /dev/null +++ b/python/sparkdl/param/converters.py @@ -0,0 +1,199 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=invalid-name,import-error + +""" SparkDLTypeConverters + +Type conversion utilities for defining MLlib `Params` used in Spark Deep Learning Pipelines. + +.. note:: We follow the convention of MLlib to name these utilities "converters", + but most of them act as type checkers that return the argument if it is + the desired type and raise `TypeError` otherwise. +""" + +import six + +import tensorflow as tf + +from pyspark.ml.param import TypeConverters + +from sparkdl.graph.input import * +import sparkdl.utils.keras_model as kmutil + +__all__ = ['SparkDLTypeConverters'] + +class SparkDLTypeConverters(object): + """ + .. note:: DeveloperApi + + Methods for type conversion functions for :py:func:`Param.typeConverter`. + These methods are similar to :py:class:`spark.ml.param.TypeConverters`. + They provide support for the `Params` types introduced in Spark Deep Learning Pipelines. + """ + + @staticmethod + def toTFGraph(value): + """ + Convert a value to a :py:obj:`tf.Graph` object, if possible. + """ + if not isinstance(value, tf.Graph): + raise TypeError("Could not convert %s to tf.Graph" % type(value)) + return value + + @staticmethod + def toTFInputGraph(value): + if isinstance(value, TFInputGraph): + return value + else: + raise TypeError("Could not convert %s to TFInputGraph" % type(value)) + + @staticmethod + def asColumnToTensorNameMap(value): + """ + Convert a value to a column name to :py:class:`tf.Tensor` name mapping + as a sorted list (in lexicographical order) of string pairs, if possible. + """ + if not isinstance(value, dict): + err_msg = "Could not convert [type {}] {} to column name to tf.Tensor name mapping" + raise TypeError(err_msg.format(type(value), value)) + + strs_pair_seq = [] + for _maybe_col_name, _maybe_tnsr_name in value.items(): + _check_is_str(_maybe_col_name) + _check_is_tensor_name(_maybe_tnsr_name) + strs_pair_seq.append((_maybe_col_name, _maybe_tnsr_name)) + + return sorted(strs_pair_seq) + + @staticmethod + def asTensorNameToColumnMap(value): + """ + Convert a value to a :py:class:`tf.Tensor` name to column name mapping + as a sorted list (in lexicographical order) of string pairs, if possible. + """ + if not isinstance(value, dict): + err_msg = "Could not convert [type {}] {} to tf.Tensor name to column name mapping" + raise TypeError(err_msg.format(type(value), value)) + + strs_pair_seq = [] + for _maybe_tnsr_name, _maybe_col_name in value.items(): + _check_is_str(_maybe_col_name) + _check_is_tensor_name(_maybe_tnsr_name) + strs_pair_seq.append((_maybe_tnsr_name, _maybe_col_name)) + + return sorted(strs_pair_seq) + + @staticmethod + def toTFHParams(value): + """ + Check that the given value is a :py:class:`tf.contrib.training.HParams` object, + and return it. Raise an error otherwise. + """ + if not isinstance(value, tf.contrib.training.HParams): + raise TypeError("Could not convert %s to TensorFlow HParams" % type(value)) + + return value + + @staticmethod + def toTFTensorName(value): + """ + Check if a value is a valid :py:class:`tf.Tensor` name and return it. + Raise an error otherwise. + """ + if isinstance(value, tf.Tensor): + return value.name + try: + _check_is_tensor_name(value) + return value + except Exception as exc: + err_msg = "Could not convert [type {}] {} to tf.Tensor name. {}" + raise TypeError(err_msg.format(type(value), value, exc)) + + @staticmethod + def buildSupportedItemConverter(supportedList): + """ + Create a "converter" that try to check if a value is part of the supported list of values. + + :param supportedList: list, containing supported objects. + :return: a converter that try to check if a value is part of the `supportedList` and return it. + Raise an error otherwise. + """ + + def converter(value): + """ Implementing the conversion logic """ + if value not in supportedList: + err_msg = "[type {}] {} is not in the supported list: {}" + raise TypeError(err_msg.format(type(value), str(value), supportedList)) + + return value + + return converter + + @staticmethod + def toKerasLoss(value): + """ + Check if a value is a valid Keras loss function name and return it. + Otherwise raise an error. + """ + # return early in for clarify as well as less indentation + if not kmutil.is_valid_loss_function(value): + err_msg = "Named loss not supported in Keras: [type {}] {}" + raise ValueError(err_msg.format(type(value), value)) + + return value + + @staticmethod + def toKerasOptimizer(value): + """ + Check if a value is a valid name of Keras optimizer and return it. + Otherwise raise an error. + """ + if not kmutil.is_valid_optimizer(value): + err_msg = "Named optimizer not supported in Keras: [type {}] {}" + raise TypeError(err_msg.format(type(value), value)) + + return value + + +def _check_is_tensor_name(_maybe_tnsr_name): + """ Check if the input is a valid tensor name or raise a `TypeError` otherwise. """ + if not isinstance(_maybe_tnsr_name, six.string_types): + err_msg = "expect tensor name to be of string type, but got [type {}]" + raise TypeError(err_msg.format(type(_maybe_tnsr_name))) + + # The check is taken from TensorFlow's NodeDef protocol buffer. + # Each input is "node:src_output" with "node" being a string name and + # "src_output" indicating which output tensor to use from "node". If + # "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + # may optionally be followed by control inputs that have the format + # "^node". + # Reference: + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/node_def.proto + # https://stackoverflow.com/questions/36150834/how-does-tensorflow-name-tensors + try: + _, src_idx = _maybe_tnsr_name.split(":") + _ = int(src_idx) + except Exception: + err_msg = "Tensor name must be of type :, but got {}" + raise TypeError(err_msg.format(_maybe_tnsr_name)) + + +def _check_is_str(_maybe_str): + """ Check if the value is a valid string type or raise a `TypeError` otherwise. """ + # We only check if the column name candidate is a string type + if not isinstance(_maybe_str, six.string_types): + err_msg = 'expect string type but got type {} for {}' + raise TypeError(err_msg.format(type(_maybe_str), _maybe_str)) diff --git a/python/sparkdl/param/image_params.py b/python/sparkdl/param/image_params.py index 6807ce2a..6ca2ff6d 100644 --- a/python/sparkdl/param/image_params.py +++ b/python/sparkdl/param/image_params.py @@ -107,7 +107,7 @@ class HasOutputMode(Params): "How the output column should be formatted. 'vector' for a 1-d MLlib " + "Vector of floats. 'image' to format the output to work with the image " + "tools in this package.", - typeConverter=SparkDLTypeConverters.supportedNameConverter(OUTPUT_MODES)) + typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(OUTPUT_MODES)) def setOutputMode(self, value): return self._set(outputMode=value) diff --git a/python/sparkdl/param/shared_params.py b/python/sparkdl/param/shared_params.py index e169e891..1116aa54 100644 --- a/python/sparkdl/param/shared_params.py +++ b/python/sparkdl/param/shared_params.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # - """ Some parts are copied from pyspark.ml.param.shared and some are complementary to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being private APIs. """ - +import textwrap from functools import wraps - -import tensorflow as tf +import six from pyspark.ml.param import Param, Params, TypeConverters -import sparkdl.utils.keras_model as kmutil +from sparkdl.graph.input import TFInputGraph +from sparkdl.param.converters import SparkDLTypeConverters + +######################################################## +# Copied from PySpark for backward compatibility. +# They first appeared in Apache Spark version 2.1.1. +######################################################## -# From pyspark def keyword_only(func): """ @@ -36,12 +39,14 @@ def keyword_only(func): .. note:: Should only be used to wrap a method where first arg is `self` """ + @wraps(func) def wrapper(self, *args, **kwargs): if len(args) > 0: raise TypeError("Method %s forces keyword arguments." % func.__name__) self._input_kwargs = kwargs return func(self, **kwargs) + return wrapper @@ -50,10 +55,8 @@ class HasInputCol(Params): Mixin for param inputCol: input column name. """ - inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString) - - def __init__(self): - super(HasInputCol, self).__init__() + inputCol = Param(Params._dummy(), "inputCol", "input column name.", + typeConverter=TypeConverters.toString) def setInputCol(self, value): """ @@ -73,8 +76,8 @@ class HasOutputCol(Params): Mixin for param outputCol: output column name. """ - outputCol = Param(Params._dummy(), - "outputCol", "output column name.", typeConverter=TypeConverters.toString) + outputCol = Param(Params._dummy(), "outputCol", "output column name.", + typeConverter=TypeConverters.toString) def __init__(self): super(HasOutputCol, self).__init__() @@ -92,54 +95,10 @@ def getOutputCol(self): """ return self.getOrDefault(self.outputCol) -############################################ + +######################################################## # New in sparkdl -############################################ - -class SparkDLTypeConverters(object): - - @staticmethod - def toStringOrTFTensor(value): - if isinstance(value, tf.Tensor): - return value - else: - try: - return TypeConverters.toString(value) - except TypeError: - raise TypeError("Could not convert %s to tensorflow.Tensor or str" % type(value)) - - @staticmethod - def toTFGraph(value): - # TODO: we may want to support tf.GraphDef in the future instead of tf.Graph since user - # is less likely to mess up using GraphDef vs Graph (e.g. constants vs variables). - if isinstance(value, tf.Graph): - return value - else: - raise TypeError("Could not convert %s to tensorflow.Graph type" % type(value)) - - @staticmethod - def supportedNameConverter(supportedList): - def converter(value): - if value in supportedList: - return value - else: - raise TypeError("%s %s is not in the supported list." % type(value), str(value)) - - return converter - - @staticmethod - def toKerasLoss(value): - if kmutil.is_valid_loss_function(value): - return value - raise ValueError( - "Named loss not supported in Keras: {} type({})".format(value, type(value))) - - @staticmethod - def toKerasOptimizer(value): - if kmutil.is_valid_optimizer(value): - return value - raise TypeError( - "Named optimizer not supported in Keras: {} type({})".format(value, type(value))) +######################################################## class HasOutputNodeName(Params): @@ -233,3 +192,76 @@ def seKerasLoss(self, value): def getKerasLoss(self): return self.getOrDefault(self.kerasLoss) + + +class HasOutputMapping(Params): + """ + Mixin for param outputMapping: ordered list of ('outputTensorOpName', 'outputColName') pairs + """ + outputMapping = Param(Params._dummy(), + "outputMapping", + "Mapping output :class:`tf.Tensor` names to DataFrame column names", + typeConverter=SparkDLTypeConverters.asTensorNameToColumnMap) + + def setOutputMapping(self, value): + return self._set(outputMapping=value) + + def getOutputMapping(self): + return self.getOrDefault(self.outputMapping) + + +class HasInputMapping(Params): + """ + Mixin for param inputMapping: ordered list of ('inputColName', 'inputTensorOpName') pairs + """ + inputMapping = Param(Params._dummy(), + "inputMapping", + "Mapping input DataFrame column names to :class:`tf.Tensor` names", + typeConverter=SparkDLTypeConverters.asColumnToTensorNameMap) + + def setInputMapping(self, value): + return self._set(inputMapping=value) + + def getInputMapping(self): + return self.getOrDefault(self.inputMapping) + + +class HasTFInputGraph(Params): + """ + Mixin for param tfInputGraph: a serializable object derived from a TensorFlow computation graph. + """ + tfInputGraph = Param(Params._dummy(), + "tfInputGraph", + "A serializable object derived from a TensorFlow computation graph", + typeConverter=SparkDLTypeConverters.toTFInputGraph) + + def __init__(self): + super(HasTFInputGraph, self).__init__() + self._setDefault(tfInputGraph=None) + + def setTFInputGraph(self, value): + return self._set(tfInputGraph=value) + + def getTFInputGraph(self): + return self.getOrDefault(self.tfInputGraph) + + +class HasTFHParams(Params): + """ + Mixin for TensorFlow model hyper-parameters + """ + tfHParams = Param(Params._dummy(), "hparams", + textwrap.dedent("""\ + instance of :class:`tf.contrib.training.HParams`, a namespace-like + key-value object, storing parameters to be used to define the final + TensorFlow graph for the Transformer. + + Currently used values are: + - `batch_size`: number of samples evaluated together in inference steps"""), + typeConverter=SparkDLTypeConverters.toTFHParams) + + def setTFHParams(self, value): + return self._set(tfHParam=value) + + def getTFHParams(self): + return self.getOrDefault(self.tfHParams) diff --git a/python/sparkdl/transformers/keras_applications.py b/python/sparkdl/transformers/keras_applications.py index 733ac654..50c30d4f 100644 --- a/python/sparkdl/transformers/keras_applications.py +++ b/python/sparkdl/transformers/keras_applications.py @@ -109,4 +109,3 @@ def _testKerasModel(self, include_top): "InceptionV3": InceptionV3Model, "Xception": XceptionModel } - diff --git a/python/sparkdl/transformers/keras_image.py b/python/sparkdl/transformers/keras_image.py index de10fc87..3c2762d9 100644 --- a/python/sparkdl/transformers/keras_image.py +++ b/python/sparkdl/transformers/keras_image.py @@ -76,14 +76,14 @@ def _transform(self, dataset): return transformer.transform(image_df).drop(self._loadedImageCol()) def _loadTFGraph(self): - with KSessionWrap() as (sess, g): + with KSessionWrap() as (sess, graph): assert K.backend() == "tensorflow", \ "Keras backend is not tensorflow but KerasImageTransformer only supports " + \ "tensorflow-backed Keras models." - with g.as_default(): + with graph.as_default(): K.set_learning_phase(0) # Testing phase model = load_model(self.getModelFile()) - out_op_name = tfx.op_name(g, model.output) + out_op_name = tfx.op_name(model.output, graph) self._inputTensor = model.input.name self._outputTensor = model.output.name - return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True) + return tfx.strip_and_freeze_until([out_op_name], graph, sess, return_graph=True) diff --git a/python/sparkdl/transformers/named_image.py b/python/sparkdl/transformers/named_image.py index 156c4e1e..f3139b7a 100644 --- a/python/sparkdl/transformers/named_image.py +++ b/python/sparkdl/transformers/named_image.py @@ -40,7 +40,7 @@ class DeepImagePredictor(Transformer, HasInputCol, HasOutputCol): """ modelName = Param(Params._dummy(), "modelName", "A deep learning model name", - typeConverter=SparkDLTypeConverters.supportedNameConverter(SUPPORTED_MODELS)) + typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(SUPPORTED_MODELS)) decodePredictions = Param(Params._dummy(), "decodePredictions", "If true, output predictions in the (class, description, probability) format", typeConverter=TypeConverters.toBoolean) @@ -125,7 +125,7 @@ class DeepImageFeaturizer(Transformer, HasInputCol, HasOutputCol): """ modelName = Param(Params._dummy(), "modelName", "A deep learning model name", - typeConverter=SparkDLTypeConverters.supportedNameConverter(SUPPORTED_MODELS)) + typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(SUPPORTED_MODELS)) @keyword_only def __init__(self, inputCol=None, outputCol=None, modelName=None): @@ -169,7 +169,7 @@ class _NamedImageTransformer(Transformer, HasInputCol, HasOutputCol): """ modelName = Param(Params._dummy(), "modelName", "A deep learning model name", - typeConverter=SparkDLTypeConverters.supportedNameConverter(SUPPORTED_MODELS)) + typeConverter=SparkDLTypeConverters.buildSupportedItemConverter(SUPPORTED_MODELS)) featurize = Param(Params._dummy(), "featurize", "If true, output features. If false, output predictions. Either way the output is a vector.", typeConverter=TypeConverters.toBoolean) diff --git a/python/sparkdl/transformers/tf_image.py b/python/sparkdl/transformers/tf_image.py index da37fcad..152a7fea 100644 --- a/python/sparkdl/transformers/tf_image.py +++ b/python/sparkdl/transformers/tf_image.py @@ -28,6 +28,12 @@ import sparkdl.utils.jvmapi as JVMAPI import sparkdl.graph.utils as tfx +__all__ = ['TFImageTransformer'] + +IMAGE_INPUT_TENSOR_NAME = tfx.tensor_name(utils.IMAGE_INPUT_PLACEHOLDER_NAME) +USER_GRAPH_NAMESPACE = 'given' +NEW_OUTPUT_PREFIX = 'sdl_flattened' + class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode): """ Applies the Tensorflow graph to the image column in DataFrame. @@ -47,42 +53,39 @@ class TFImageTransformer(Transformer, HasInputCol, HasOutputCol, HasOutputMode): since a new session is created inside this transformer. """ - USER_GRAPH_NAMESPACE = 'given' - NEW_OUTPUT_PREFIX = 'sdl_flattened' - graph = Param(Params._dummy(), "graph", "A TensorFlow computation graph", typeConverter=SparkDLTypeConverters.toTFGraph) inputTensor = Param(Params._dummy(), "inputTensor", "A TensorFlow tensor object or name representing the input image", - typeConverter=SparkDLTypeConverters.toStringOrTFTensor) + typeConverter=SparkDLTypeConverters.toTFTensorName) outputTensor = Param(Params._dummy(), "outputTensor", "A TensorFlow tensor object or name representing the output", - typeConverter=SparkDLTypeConverters.toStringOrTFTensor) + typeConverter=SparkDLTypeConverters.toTFTensorName) @keyword_only def __init__(self, inputCol=None, outputCol=None, graph=None, - inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None, + inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector"): """ __init__(self, inputCol=None, outputCol=None, graph=None, - inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None, + inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector") """ super(TFImageTransformer, self).__init__() - self._setDefault(inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME) - self._setDefault(outputMode="vector") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, inputCol=None, outputCol=None, graph=None, - inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None, + inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector"): """ setParams(self, inputCol=None, outputCol=None, graph=None, - inputTensor=utils.IMAGE_INPUT_PLACEHOLDER_NAME, outputTensor=None, + inputTensor=IMAGE_INPUT_TENSOR_NAME, outputTensor=None, outputMode="vector") """ + self._setDefault(inputTensor=IMAGE_INPUT_TENSOR_NAME) + self._setDefault(outputMode="vector") kwargs = self._input_kwargs return self._set(**kwargs) @@ -99,18 +102,12 @@ def getGraph(self): return self.getOrDefault(self.graph) def getInputTensor(self): - tensor_or_name = self.getOrDefault(self.inputTensor) - if isinstance(tensor_or_name, tf.Tensor): - return tensor_or_name - else: - return self.getGraph().get_tensor_by_name(tensor_or_name) + tensor_name = self.getOrDefault(self.inputTensor) + return self.getGraph().get_tensor_by_name(tensor_name) def getOutputTensor(self): - tensor_or_name = self.getOrDefault(self.outputTensor) - if isinstance(tensor_or_name, tf.Tensor): - return tensor_or_name - else: - return self.getGraph().get_tensor_by_name(tensor_or_name) + tensor_name = self.getOrDefault(self.outputTensor) + return self.getGraph().get_tensor_by_name(tensor_name) def _transform(self, dataset): graph = self.getGraph() @@ -139,7 +136,7 @@ def _transform(self, dataset): "__sdl_image_data") ) - tfs_output_name = tfx.op_name(final_graph, output_tensor) + tfs_output_name = tfx.op_name(output_tensor, final_graph) original_output_name = self._getOriginalOutputTensorName() output_shape = final_graph.get_tensor_by_name(original_output_name).shape output_mode = self.getOrDefault(self.outputMode) @@ -185,7 +182,7 @@ def _addReshapeLayers(self, tf_graph, dtype="uint8"): # Add on the original graph tf.import_graph_def(gdef, input_map={input_tensor_name: image_reshaped_expanded}, return_elements=[self.getOutputTensor().name], - name=self.USER_GRAPH_NAMESPACE) + name=USER_GRAPH_NAMESPACE) # Flatten the output for tensorframes output_node = g.get_tensor_by_name(self._getOriginalOutputTensorName()) @@ -204,13 +201,13 @@ def _stripGraph(self, tf_graph): return g def _getOriginalOutputTensorName(self): - return self.USER_GRAPH_NAMESPACE + '/' + self.getOutputTensor().name + return USER_GRAPH_NAMESPACE + '/' + self.getOutputTensor().name def _getFinalOutputTensorName(self): - return self.NEW_OUTPUT_PREFIX + '_' + self.getOutputTensor().name + return NEW_OUTPUT_PREFIX + '_' + self.getOutputTensor().name def _getFinalOutputOpName(self): - return tfx.as_op_name(self._getFinalOutputTensorName()) + return tfx.op_name(self._getFinalOutputTensorName()) def _convertOutputToImage(self, df, tfs_output_col, output_shape): assert len(output_shape) == 4, str(output_shape) + " does not have 4 dimensions" diff --git a/python/sparkdl/transformers/tf_tensor.py b/python/sparkdl/transformers/tf_tensor.py new file mode 100644 index 00000000..7207f5f1 --- /dev/null +++ b/python/sparkdl/transformers/tf_tensor.py @@ -0,0 +1,105 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import logging +import tensorflow as tf +from tensorflow.python.tools import optimize_for_inference_lib as infr_opt +import tensorframes as tfs + +from pyspark.ml import Transformer + +import sparkdl.graph.utils as tfx +from sparkdl.param import (keyword_only, HasInputMapping, HasOutputMapping, + HasTFInputGraph, HasTFHParams) + +__all__ = ['TFTransformer'] + +logger = logging.getLogger('sparkdl') + +class TFTransformer(Transformer, HasTFInputGraph, HasTFHParams, HasInputMapping, HasOutputMapping): + """ + Applies the TensorFlow graph to the array column in DataFrame. + + Restrictions of the current API: + + We assume that + - All the inputs of the graphs have a "minibatch" dimension (i.e. an unknown leading + dimension) in the tensor shapes. + - Input DataFrame has an array column where all elements have the same length + - The transformer is expected to work on blocks of data at the same time. + """ + + @keyword_only + def __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): + """ + __init__(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) + """ + super(TFTransformer, self).__init__() + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None): + """ + setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tfHParms=None) + """ + super(TFTransformer, self).__init__() + kwargs = self._input_kwargs + # Further conanonicalization, e.g. converting dict to sorted str pairs happens here + return self._set(**kwargs) + + def _optimize_for_inference(self): + """ Optimize the graph for inference """ + gin = self.getTFInputGraph() + input_mapping = self.getInputMapping() + output_mapping = self.getOutputMapping() + input_node_names = [tfx.op_name(tnsr_name) for _, tnsr_name in input_mapping] + output_node_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] + + # NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type + opt_gdef = infr_opt.optimize_for_inference(gin.graph_def, + input_node_names, + output_node_names, + # TODO: below is the place to change for + # the `float64` data type issue. + tf.float64.as_datatype_enum) + return opt_gdef + + def _transform(self, dataset): + graph_def = self._optimize_for_inference() + input_mapping = self.getInputMapping() + output_mapping = self.getOutputMapping() + + graph = tf.Graph() + with tf.Session(graph=graph): + analyzed_df = tfs.analyze(dataset) + + out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping] + tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names) + + feed_dict = dict((tfx.op_name(tnsr_name, graph), col_name) + for col_name, tnsr_name in input_mapping) + fetches = [tfx.get_tensor(tnsr_op_name, graph) for tnsr_op_name in out_tnsr_op_names] + + out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict) + + # We still have to rename output columns + for tnsr_name, new_colname in output_mapping: + old_colname = tfx.op_name(tnsr_name, graph) + if old_colname != new_colname: + out_df = out_df.withColumnRenamed(old_colname, new_colname) + + return out_df diff --git a/python/tests/graph/test_builder.py b/python/tests/graph/test_builder.py index b0736896..93b3c9f5 100644 --- a/python/tests/graph/test_builder.py +++ b/python/tests/graph/test_builder.py @@ -78,15 +78,15 @@ def test_get_graph_elements(self): z = tf.add(x, 3, name='z') g = issn.graph - self.assertEqual(tfx.get_tensor(g, z), z) - self.assertEqual(tfx.get_tensor(g, x), x) - self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x)) - self.assertEqual("x:0", tfx.tensor_name(g, x)) - self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x)) - self.assertEqual("x", tfx.op_name(g, x)) - self.assertEqual("z", tfx.op_name(g, z)) - self.assertEqual(tfx.tensor_name(g, z), "z:0") - self.assertEqual(tfx.tensor_name(g, x), "x:0") + self.assertEqual(tfx.get_tensor(z, g), z) + self.assertEqual(tfx.get_tensor(x, g), x) + self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(x, g)) + self.assertEqual("x:0", tfx.tensor_name(x, g)) + self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(x, g)) + self.assertEqual("x", tfx.op_name(x, g)) + self.assertEqual("z", tfx.op_name(z, g)) + self.assertEqual(tfx.tensor_name(z, g), "z:0") + self.assertEqual(tfx.tensor_name(x, g), "x:0") def test_import_export_graph_function(self): """ Function import and export must be consistent """ diff --git a/python/tests/graph/test_import.py b/python/tests/graph/test_import.py new file mode 100644 index 00000000..36501568 --- /dev/null +++ b/python/tests/graph/test_import.py @@ -0,0 +1,322 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import contextlib +import shutil +import numpy as np +import os +import tensorflow as tf +import tempfile +import glob + +import sparkdl.graph.utils as tfx +from sparkdl.graph.input import TFInputGraph + + +class TestGraphImport(object): + def test_graph_novar(self): + gin = _build_graph_input(lambda session: + TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], + [_tensor_output_name])) + _check_input_novar(gin) + + def test_graphdef_novar(self): + gin = _build_graph_input(lambda session: + TFInputGraph.fromGraphDef(session.graph.as_graph_def(), + [_tensor_input_name], [_tensor_output_name])) + _check_input_novar(gin) + + def test_saved_model_novar(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + + def gin_fun(session): + _build_saved_model(session, saved_model_dir) + # Build the transformer from exported serving model + # We are using signatures, thus must provide the keys + return TFInputGraph.fromSavedModelWithSignature(saved_model_dir, _serving_tag, + _serving_sigdef_key) + + gin = _build_graph_input(gin_fun) + _check_input_novar(gin) + + def test_saved_model_iomap(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + _build_graph() + _build_saved_model(sess, saved_model_dir) + # Build the transformer from exported serving model + # We are using signatures, thus must provide the keys + gin = TFInputGraph.fromSavedModelWithSignature(saved_model_dir, _serving_tag, + _serving_sigdef_key) + + _input_mapping_with_sigdef = {'inputCol': _tensor_input_signature} + # Input mapping for the Transformer + _translated_input_mapping = gin.translateInputMapping(_input_mapping_with_sigdef) + _expected_input_mapping = {'inputCol': tfx.tensor_name(_tensor_input_name)} + # Output mapping for the Transformer + _output_mapping_with_sigdef = {_tensor_output_signature: 'outputCol'} + _translated_output_mapping = gin.translateOutputMapping(_output_mapping_with_sigdef) + _expected_output_mapping = {tfx.tensor_name(_tensor_output_name): 'outputCol'} + + err_msg = "signature based input mapping {} and output mapping {} " + \ + "must be translated correctly into tensor name based mappings" + assert _translated_input_mapping == _expected_input_mapping \ + and _translated_output_mapping == _expected_output_mapping, \ + err_msg.format(_translated_input_mapping, _translated_output_mapping) + + + def test_saved_graph_novar(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + + def gin_fun(session): + _build_saved_model(session, saved_model_dir) + return TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input(gin_fun) + _check_input_novar(gin) + + def test_checkpoint_sig_var(self): + with _make_temp_directory() as tmp_dir: + def gin_fun(session): + _build_checkpointed_model(session, tmp_dir) + return TFInputGraph.fromCheckpointWithSignature(tmp_dir, _serving_sigdef_key) + + gin = _build_graph_input_var(gin_fun) + _check_input_novar(gin) + + def test_checkpoint_nosig_var(self): + with _make_temp_directory() as tmp_dir: + def gin_fun(session): + _build_checkpointed_model(session, tmp_dir) + return TFInputGraph.fromCheckpoint(tmp_dir, + [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input_var(gin_fun) + _check_input_novar(gin) + + def test_checkpoint_graph_var(self): + with _make_temp_directory() as tmp_dir: + def gin_fun(session): + _build_checkpointed_model(session, tmp_dir) + return TFInputGraph.fromGraph(session.graph, session, + [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input_var(gin_fun) + _check_input_novar(gin) + + def test_graphdef_novar_2(self): + gin = _build_graph_input_2(lambda session: + TFInputGraph.fromGraphDef(session.graph.as_graph_def(), + [_tensor_input_name], [_tensor_output_name])) + _check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1) + + def test_saved_graph_novar_2(self): + with _make_temp_directory() as tmp_dir: + saved_model_dir = os.path.join(tmp_dir, 'saved_model') + + def gin_fun(session): + _build_saved_model(session, saved_model_dir) + return TFInputGraph.fromGraph(session.graph, session, [_tensor_input_name], [_tensor_output_name]) + + gin = _build_graph_input_2(gin_fun) + _check_output_2(gin, np.array([1, 2, 3]), np.array([2, 2, 2]), 1) + +_serving_tag = "serving_tag" +_serving_sigdef_key = 'prediction_signature' +# The name of the input tensor +_tensor_input_name = "input_tensor" +# For testing graphs with 2 inputs +_tensor_input_name_2 = "input_tensor_2" +# The name of the output tensor (scalar) +_tensor_output_name = "output_tensor" +# Input signature name +_tensor_input_signature = 'well_known_input_sig' +# Output signature name +_tensor_output_signature = 'well_known_output_sig' +# The name of the variable +_tensor_var_name = "variable" +# The size of the input tensor +_tensor_size = 3 + + +def _build_checkpointed_model(session, tmp_dir): + """ + Writes a model checkpoint in the given directory. The graph is assumed to be generated + with _build_graph_var. + """ + ckpt_path_prefix = os.path.join(tmp_dir, 'model_ckpt') + input_tensor = tfx.get_tensor(_tensor_input_name, session.graph) + output_tensor = tfx.get_tensor(_tensor_output_name, session.graph) + w = tfx.get_tensor(_tensor_var_name, session.graph) + saver = tf.train.Saver(var_list=[w]) + _ = saver.save(session, ckpt_path_prefix, global_step=2702) + sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)} + sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)} + serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def( + inputs=sig_inputs, outputs=sig_outputs) + + # A rather contrived way to add signature def to a meta_graph + meta_graph_def = tf.train.export_meta_graph() + + # Find the meta_graph file (there should be only one) + _ckpt_meta_fpaths = glob.glob('{}/*.meta'.format(tmp_dir)) + assert len(_ckpt_meta_fpaths) == 1, \ + 'expected only one meta graph, but got {}'.format(','.join(_ckpt_meta_fpaths)) + ckpt_meta_fpath = _ckpt_meta_fpaths[0] + + # Add signature_def to the meta_graph and serialize it + # This will overwrite the existing meta_graph_def file + meta_graph_def.signature_def[_serving_sigdef_key].CopyFrom(serving_sigdef) + with open(ckpt_meta_fpath, mode='wb') as fout: + fout.write(meta_graph_def.SerializeToString()) + + +def _build_saved_model(session, saved_model_dir): + """ + Saves a model in a file. The graph is assumed to be generated with _build_graph_novar. + """ + builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) + input_tensor = tfx.get_tensor(_tensor_input_name, session.graph) + output_tensor = tfx.get_tensor(_tensor_output_name, session.graph) + sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)} + sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)} + serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def( + inputs=sig_inputs, outputs=sig_outputs) + + builder.add_meta_graph_and_variables( + session, [_serving_tag], signature_def_map={_serving_sigdef_key: serving_sigdef}) + builder.save() + + +@contextlib.contextmanager +def _make_temp_directory(): + temp_dir = tempfile.mkdtemp() + try: + yield temp_dir + finally: + shutil.rmtree(temp_dir) + + +def _build_graph_input(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it, and then calls + gin_function(session) to return the graph input object + """ + graph = tf.Graph() + with tf.Session(graph=graph) as s, graph.as_default(): + _build_graph() + return gin_function(s) + + +def _build_graph_input_2(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it (graph_2), and then calls + gin_function(session) to return the graph input object + """ + graph = tf.Graph() + with tf.Session(graph=graph) as s, graph.as_default(): + _build_graph_2() + return gin_function(s) + + +def _build_graph_input_var(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it that contains a variable, + and then calls gin_function(session) to return the graph input object + """ + graph = tf.Graph() + with tf.Session(graph=graph) as s, graph.as_default(): + _build_graph_var(s) + return gin_function(s) + + +def _build_graph(): + """ + Given a session (implicitly), adds nodes of computations + + It takes a vector input, with vec_size columns and returns an int32 scalar. + """ + x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) + _ = tf.reduce_max(x, name=_tensor_output_name) + + +def _build_graph_2(): + """ + Given a session (implicitly), adds nodes of computations with two inputs. + + It takes a vector input, with vec_size columns and returns an int32 scalar. + """ + x1 = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) + x2 = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name_2) + # Make sure that the inputs are not used in a symmetric manner. + _ = tf.reduce_max(x1 - x2, name=_tensor_output_name) + + +def _build_graph_var(session): + """ + Given a session, adds nodes that include one variable. + """ + x = tf.placeholder(tf.int32, shape=[_tensor_size], name=_tensor_input_name) + w = tf.Variable(tf.ones(shape=[_tensor_size], dtype=tf.int32), name=_tensor_var_name) + _ = tf.reduce_max(x * w, name=_tensor_output_name) + session.run(w.initializer) + + +def _check_input_novar(gin): + """ + Tests that the graph from _build_graph has been serialized in the InputGraph object. + """ + _check_output(gin, np.array([1, 2, 3]), 3) + + +def _check_output(gin, tf_input, expected): + """ + Takes a TFInputGraph object (assumed to have the input and outputs of the given + names above) and compares the outcome against some expected outcome. + """ + graph = tf.Graph() + graph_def = gin.graph_def + with tf.Session(graph=graph) as sess: + tf.import_graph_def(graph_def, name="") + tgt_feed = tfx.get_tensor(_tensor_input_name, graph) + tgt_fetch = tfx.get_tensor(_tensor_output_name, graph) + # Run on the testing target + tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input}) + # Working on integers, the calculation should be exact + assert np.all(tgt_out == expected), (tgt_out, expected) + + +# TODO: we could factorize with _check_output, but this is not worth the time doing it. +def _check_output_2(gin, tf_input1, tf_input2, expected): + """ + Takes a TFInputGraph object (assumed to have the input and outputs of the given + names above) and compares the outcome against some expected outcome. + """ + graph = tf.Graph() + graph_def = gin.graph_def + with tf.Session(graph=graph) as sess: + tf.import_graph_def(graph_def, name="") + tgt_feed1 = tfx.get_tensor(_tensor_input_name, graph) + tgt_feed2 = tfx.get_tensor(_tensor_input_name_2, graph) + tgt_fetch = tfx.get_tensor(_tensor_output_name, graph) + # Run on the testing target + tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed1: tf_input1, tgt_feed2: tf_input2}) + # Working on integers, the calculation should be exact + assert np.all(tgt_out == expected), (tgt_out, expected) diff --git a/python/tests/graph/test_pieces.py b/python/tests/graph/test_pieces.py index 1497d137..9d659265 100644 --- a/python/tests/graph/test_pieces.py +++ b/python/tests/graph/test_pieces.py @@ -55,7 +55,7 @@ def exec_gfn_spimg_decode(spimg_dict, img_dtype): gfn = gfac.buildSpImageConverter(img_dtype) with IsolatedSession() as issn: feeds, fetches = issn.importGraphFunction(gfn, prefix="") - feed_dict = dict((tnsr, spimg_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) + feed_dict = dict((tnsr, spimg_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) img_out = issn.run(fetches[0], feed_dict=feed_dict) return img_out @@ -159,7 +159,7 @@ def test_pipeline(self): with IsolatedSession() as issn: # Need blank import scope name so that spimg fields match the input names feeds, fetches = issn.importGraphFunction(piped_model, prefix="") - feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) + feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(tnsr, issn.graph)]) for tnsr in feeds) preds_tgt = issn.run(fetches[0], feed_dict=feed_dict) # Uncomment the line below to see the graph # tfx.write_visualization_html(issn.graph, diff --git a/python/tests/graph/test_utils.py b/python/tests/graph/test_utils.py new file mode 100644 index 00000000..4847c9b1 --- /dev/null +++ b/python/tests/graph/test_utils.py @@ -0,0 +1,174 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +from collections import namedtuple +# Use this to create parameterized test cases +from parameterized import parameterized + +import tensorflow as tf + +import sparkdl.graph.utils as tfx + +from ..tests import PythonUnitTestCase + +TestCase = namedtuple('TestCase', ['data', 'description']) + + +def _gen_tensor_op_string_input_tests(): + op_name = 'someOp' + for tnsr_idx in [0, 1, 2, 3, 5, 8, 15, 17]: + tnsr_name = '{}:{}'.format(op_name, tnsr_idx) + yield TestCase(data=(op_name, tfx.op_name(tnsr_name)), + description='test tensor name to op name') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), + description='test tensor name to tensor name') + + +def _gen_invalid_tensor_or_op_input_with_wrong_types(): + for wrong_val in [7, 1.2, tf.Graph()]: + yield TestCase(data=wrong_val, description='wrong type {}'.format(type(wrong_val))) + + +def _gen_invalid_tensor_or_op_with_graph_pairing(): + tnsr = tf.constant(1427.08, name='someConstOp') + other_graph = tf.Graph() + op_name = tnsr.op.name + + # Test get_tensor and get_op with non-associated tensor/op and graph inputs + _comm_suffix = ' with wrong graph' + yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph), + description='test get_op from tensor' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph), + description='test get_tensor from tensor' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph), + description='test get_op from tensor name' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph), + description='test get_tensor from tensor name' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph), + description='test get_op from op' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph), + description='test get_tensor from op' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_op(op_name, other_graph), + description='test get_op from op name' + _comm_suffix) + yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph), + description='test get_tensor from op name' + _comm_suffix) + + +def _gen_valid_tensor_op_input_combos(): + op_name = 'someConstOp' + tnsr_name = '{}:0'.format(op_name) + tnsr = tf.constant(1427.08, name=op_name) + graph = tnsr.graph + + # Test for op_name + yield TestCase(data=(op_name, tfx.op_name(tnsr)), + description='get op name from tensor (no graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)), + description='get op name from tensor (with graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr_name)), + description='get op name from tensor name (no graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)), + description='get op name from tensor name (with graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr.op)), + description='get op name from op (no graph)') + yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)), + description='get op name from op (with graph)') + yield TestCase(data=(op_name, tfx.op_name(op_name)), + description='get op name from op name (no graph)') + yield TestCase(data=(op_name, tfx.op_name(op_name, graph)), + description='get op name from op name (with graph)') + + # Test for tensor_name + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)), + description='get tensor name from tensor (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)), + description='get tensor name from tensor (with graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), + description='get tensor name from tensor name (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)), + description='get tensor name from tensor name (with graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)), + description='get tensor name from op (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)), + description='get tensor name from op (with graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), + description='get tensor name from op name (no graph)') + yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)), + description='get tensor name from op name (with graph)') + + # Test for get_tensor + yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)), + description='get tensor from tensor') + yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)), + description='get tensor from tensor name') + yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)), + description='get tensor from op') + yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)), + description='get tensor from op name') + + # Test for get_op + yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)), + description='get op from tensor') + yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)), + description='get op from tensor name') + yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)), + description='get op from op') + yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)), + description='test op from op name') + + +class TFeXtensionGraphUtilsTest(PythonUnitTestCase): + @parameterized.expand(_gen_tensor_op_string_input_tests) + def test_valid_tensor_op_name_inputs(self, data, description): + """ Must get correct names from valid graph element names """ + name_a, name_b = data + self.assertEqual(name_a, name_b, msg=description) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_tensor_name_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.tensor_name(data) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_op_name_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.op_name(data) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_op_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.get_op(data, tf.Graph()) + + @parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types) + def test_invalid_tensor_inputs_with_wrong_types(self, data, description): + """ Must fail when provided wrong types """ + with self.assertRaises(TypeError, msg=description): + tfx.get_tensor(data, tf.Graph()) + + @parameterized.expand(_gen_valid_tensor_op_input_combos) + def test_valid_tensor_op_object_inputs(self, data, description): + """ Must get correct graph elements from valid graph elements or their names """ + tfobj_or_name_a, tfobj_or_name_b = data + self.assertEqual(tfobj_or_name_a, tfobj_or_name_b, msg=description) + + @parameterized.expand(_gen_invalid_tensor_or_op_with_graph_pairing) + def test_invalid_tensor_op_object_graph_pairing(self, data, description): + """ Must fail with non-associated tensor/op and graph inputs """ + with self.assertRaises((KeyError, AssertionError, TypeError), msg=description): + data() diff --git a/python/tests/param/__init__.py b/python/tests/param/__init__.py new file mode 100644 index 00000000..7084f22b --- /dev/null +++ b/python/tests/param/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/tests/param/params_test.py b/python/tests/param/params_test.py new file mode 100644 index 00000000..f7479385 --- /dev/null +++ b/python/tests/param/params_test.py @@ -0,0 +1,75 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +from collections import namedtuple +# Use this to create parameterized test cases +from parameterized import parameterized + +from sparkdl.param.converters import SparkDLTypeConverters + +from ..tests import PythonUnitTestCase + +TestCase = namedtuple('TestCase', ['data', 'description']) + +_shared_invalid_test_cases = [ + TestCase(data=['a1', 'b2'], description='required pair but got single element'), + TestCase(data=('c3', 'd4'), description='required pair but got single element'), + TestCase(data=[('a', 1), ('b', 2)], description='only accept dict, but got list'), + TestCase(data={1: 'a', 2.0: 'b'}, description='wrong mapping type'), + TestCase(data={'a': 1.0, 'b': 2}, description='wrong mapping type'), +] +_col2tnsr_test_cases = _shared_invalid_test_cases + [ + TestCase(data={'colA': 'tnsrOpA', 'colB': 'tnsrOpB'}, + description='tensor name required'), +] +_tnsr2col_test_cases = _shared_invalid_test_cases + [ + TestCase(data={'tnsrOpA': 'colA', 'tnsrOpB': 'colB'}, + description='tensor name required'), +] + +class ParamsConverterTest(PythonUnitTestCase): + """ + Test MLlib Params introduced in Spark Deep Learning Pipeline + Additional test cases are attached via the meta class `TestGenMeta`. + """ + + def test_tf_input_mapping_converter(self): + """ Test valid input mapping conversion """ + valid_tnsr_input = {'colA': 'tnsrOpA:0', 'colB': 'tnsrOpB:0'} + valid_input_mapping_result = [('colA', 'tnsrOpA:0'), ('colB', 'tnsrOpB:0')] + + res = SparkDLTypeConverters.asColumnToTensorNameMap(valid_tnsr_input) + self.assertEqual(valid_input_mapping_result, res) + + def test_tf_output_mapping_converter(self): + """ Test valid output mapping conversion """ + valid_tnsr_output = {'tnsrOpA:0': 'colA', 'tnsrOpB:0': 'colB'} + valid_output_mapping_result = [('tnsrOpA:0', 'colA'), ('tnsrOpB:0', 'colB')] + + res = SparkDLTypeConverters.asTensorNameToColumnMap(valid_tnsr_output) + self.assertEqual(valid_output_mapping_result, res) + + @parameterized.expand(_col2tnsr_test_cases) + def test_invalid_input_mapping(self, data, description): + """ Test invalid column name to tensor name mapping """ + with self.assertRaises(TypeError, msg=description): + SparkDLTypeConverters.asColumnToTensorNameMap(data) + + @parameterized.expand(_tnsr2col_test_cases) + def test_invalid_output_mapping(self, data, description): + """ Test invalid tensor name to column name mapping """ + with self.assertRaises(TypeError, msg=description): + SparkDLTypeConverters.asTensorNameToColumnMap(data) diff --git a/python/tests/tests.py b/python/tests/tests.py index d93b31a8..4bf9d65d 100644 --- a/python/tests/tests.py +++ b/python/tests/tests.py @@ -29,22 +29,37 @@ from pyspark.sql import SQLContext from pyspark.sql import SparkSession +class PythonUnitTestCase(unittest.TestCase): + # We try to use unittest2 for python 2.6 or earlier + # This class is created to avoid replicating this logic in various places. + pass -class SparkDLTestCase(unittest.TestCase): +class TestSparkContext(object): @classmethod - def setUpClass(cls): + def setup_env(cls): cls.sc = SparkContext('local[*]', cls.__name__) cls.sql = SQLContext(cls.sc) cls.session = SparkSession.builder.getOrCreate() @classmethod - def tearDownClass(cls): + def tear_down_env(cls): cls.session.stop() cls.session = None cls.sc.stop() cls.sc = None cls.sql = None + +class SparkDLTestCase(TestSparkContext, unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.setup_env() + + @classmethod + def tearDownClass(cls): + cls.tear_down_env() + def assertDfHasCols(self, df, cols = []): map(lambda c: self.assertIn(c, df.columns), cols) diff --git a/python/tests/transformers/tf_transformer_test.py b/python/tests/transformers/tf_transformer_test.py new file mode 100644 index 00000000..849a84d7 --- /dev/null +++ b/python/tests/transformers/tf_transformer_test.py @@ -0,0 +1,146 @@ +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import, division, print_function + +import numpy as np +import tensorflow as tf + +from pyspark.sql.types import Row + +import tensorframes as tfs + +import sparkdl.graph.utils as tfx +from sparkdl.graph.input import TFInputGraph +from sparkdl.transformers.tf_tensor import TFTransformer + +from ..tests import SparkDLTestCase + +class TFTransformerTests(SparkDLTestCase): + def test_graph_novar(self): + transformer = _build_transformer(lambda session: + TFInputGraph.fromGraph(session.graph, session, + [_tensor_input_name], + [_tensor_output_name])) + gin = transformer.getTFInputGraph() + local_features = _build_local_features() + expected = _get_expected_result(gin, local_features) + dataset = self.session.createDataFrame(local_features) + _check_transformer_output(transformer, dataset, expected) + + +# The name of the input tensor +_tensor_input_name = "input_tensor" +# The name of the output tensor (scalar) +_tensor_output_name = "output_tensor" +# The size of the input tensor +_tensor_size = 3 +# Input mapping for the Transformer +_input_mapping = {'inputCol': tfx.tensor_name(_tensor_input_name)} +# Output mapping for the Transformer +_output_mapping = {tfx.tensor_name(_tensor_output_name): 'outputCol'} +# Numerical threshold +_all_close_tolerance = 1e-5 + + +def _build_transformer(gin_function): + """ + Makes a session and a default graph, loads the simple graph into it, and then calls + gin_function(session) to build the :py:obj:`TFInputGraph` object. + Return the :py:obj:`TFTransformer` created from it. + """ + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + _build_graph(sess) + gin = gin_function(sess) + + return TFTransformer(tfInputGraph=gin, + inputMapping=_input_mapping, + outputMapping=_output_mapping) + + +def _build_graph(sess): + """ + Given a session (implicitly), adds nodes of computations + + It takes a vector input, with `_tensor_size` columns and returns an float64 scalar. + """ + x = tf.placeholder(tf.float64, shape=[None, _tensor_size], name=_tensor_input_name) + _ = tf.reduce_max(x, axis=1, name=_tensor_output_name) + +def _build_local_features(): + """ + Build numpy array (i.e. local) features. + """ + # Build local features and DataFrame from it + local_features = [] + np.random.seed(997) + for idx in range(100): + _dict = {'idx': idx} + for colname, _ in _input_mapping.items(): + _dict[colname] = np.random.randn(_tensor_size).tolist() + + local_features.append(Row(**_dict)) + + return local_features + +def _get_expected_result(gin, local_features): + """ + Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results. + :param: gin, a :py:obj:`TFInputGraph` + :return: expected results in NumPy array + """ + graph = tf.Graph() + with tf.Session(graph=graph) as sess, graph.as_default(): + # Build test graph and transformers from here + tf.import_graph_def(gin.graph_def, name='') + + # Build the results + _results = [] + for row in local_features: + fetches = [tfx.get_tensor(tnsr_name, graph) + for tnsr_name, _ in _output_mapping.items()] + feed_dict = {} + for colname, tnsr_name in _input_mapping.items(): + tnsr = tfx.get_tensor(tnsr_name, graph) + feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :] + + curr_res = sess.run(fetches, feed_dict=feed_dict) + _results.append(np.ravel(curr_res)) + + expected = np.hstack(_results) + + return expected + +def _check_transformer_output(transformer, dataset, expected): + """ + Given a transformer and a spark dataset, check if the transformer + produces the expected results. + """ + analyzed_df = tfs.analyze(dataset) + out_df = transformer.transform(analyzed_df) + + # Collect transformed values + out_colnames = list(_output_mapping.values()) + _results = [] + for row in out_df.select(out_colnames).collect(): + curr_res = [row[colname] for colname in out_colnames] + _results.append(np.ravel(curr_res)) + out_tgt = np.hstack(_results) + + _err_msg = 'not close => shape {} != {}, max_diff {} > {}' + max_diff = np.max(np.abs(expected - out_tgt)) + err_msg = _err_msg.format(expected.shape, out_tgt.shape, + max_diff, _all_close_tolerance) + assert np.allclose(expected, out_tgt, atol=_all_close_tolerance), err_msg diff --git a/python/tests/udf/keras_sql_udf_test.py b/python/tests/udf/keras_sql_udf_test.py index d1473b3c..5c67c854 100644 --- a/python/tests/udf/keras_sql_udf_test.py +++ b/python/tests/udf/keras_sql_udf_test.py @@ -66,7 +66,7 @@ def test_simple_keras_udf(self): makeGraphUDF(issn.graph, 'my_keras_model_udf', model.outputs, - {tfx.op_name(issn.graph, model.inputs[0]): 'image_col'}) + {tfx.op_name(model.inputs[0], issn.graph): 'image_col'}) # Run the training procedure # Export the graph in this IsolatedSession as a GraphFunction # gfn = issn.asGraphFunction(model.inputs, model.outputs) @@ -168,4 +168,3 @@ def test_map_blocks_sql_1(self): data2 = df2.collect() assert len(data2) == 5, data2 assert data2[0].z == 3.0, data2 -