Skip to content

Introduce a somewhat usable "metatize" for TF helper functions #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions symbolic_pymc/tensorflow/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.protobuf.message import Message

from tensorflow.python.framework import tensor_util, op_def_registry, op_def_library, tensor_shape
from tensorflow.python.eager.context import graph_mode
from tensorflow.core.framework.op_def_pb2 import OpDef
from tensorflow.core.framework.node_def_pb2 import NodeDef

Expand Down Expand Up @@ -802,6 +803,65 @@ def find_opdef(cls, name):

return None

@staticmethod
def metatize_tf_function(func, *args, **kwargs):
"""Convert a TensorFlow function into a version that produces meta objects instead.

The approach used by this function is to turn all meta tensor arguments
into TensorFlow placeholders, execute the TensorFlow function using
those placeholders, then metatize the result and replace the
placeholders with the original arguments.

"""

func_sig = inspect.signature(func)
bound_args = func_sig.bind(*args, **kwargs)
bound_args.apply_defaults()

dummy_args = []
dummy_to_meta = {}
for arg_name, arg_val in bound_args.arguments.items():
if isinstance(arg_val, TFlowMetaTensor):
if isinstance(arg_val.dtype, tf.dtypes.DType):
ph_dtype_tf = arg_val.dtype
else:
ph_dtype_tf = tf.dtypes.variant
with graph_mode():
ph_arg_tf = tf.compat.v1.placeholder(ph_dtype_tf, name=f"dummy_{arg_name}")
dummy_to_meta[ph_arg_tf.name] = arg_val
dummy_args += [ph_arg_tf]
else:
dummy_args += [arg_val]

with graph_mode():
tf_res = func(*dummy_args)
mt_res = mt(tf_res)

def mt_arg_convert(obj):
nonlocal dummy_to_meta

# TODO: We can skip all this if we know that the object is reifiable
# (e.g. has a `.obj`).
if isinstance(obj, TFlowMetaConstant):
return obj
elif isinstance(obj, TFlowMetaTensor):
if str(obj.name) in dummy_to_meta:
return dummy_to_meta[str(obj.name)]

# obj_et = etuplize(obj, shallow=True)
# new_inputs = [mt_arg_convert(i) for i in obj_et[1:]]
# return etuple(obj_et[0], *new_inputs)

new_inputs = [mt_arg_convert(i) for i in obj.inputs]
return obj.op.op_def(*new_inputs)

elif isinstance(obj, (list, tuple)):
return type(obj)([mt_arg_convert(i) for i in obj])
else:
return obj

return mt_arg_convert(mt_res)

def __getattr__(self, obj):

ns_obj = next((getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None)
Expand Down
22 changes: 22 additions & 0 deletions tests/tensorflow/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,25 @@ class CustomClass(object):

with pytest.raises(ValueError):
mt(CustomClass())


@pytest.mark.usefixtures("run_with_tensorflow")
@run_in_graph_mode
def test_metatize_function():

# A_tf = tf.convert_to_tensor(np.c_[[1, 2], [3, 4]])
A_tf = tf.compat.v1.placeholder(tf.float64, name='A',
shape=tf.TensorShape([None, None]))

A_shape_tf = tf.shape(A_tf)
A_rows_tf = A_shape_tf[0]
I_A_tf = tf.eye(A_rows_tf)

# A_mt = mt(A_tf)
# A_shape_mt = mt.shape(A_mt)
# A_rows_tf = mt.StridedSlice(A_shape_mt, 0, 1, 1, shrink_axis_mask=1)
A_rows_mt = mt(A_rows_tf)

eye_mt = mt.metatize_tf_function(tf.eye, A_rows_mt)

assert mt(I_A_tf) == eye_mt