diff --git a/symbolic_pymc/tensorflow/meta.py b/symbolic_pymc/tensorflow/meta.py index 8834775..54d44ce 100644 --- a/symbolic_pymc/tensorflow/meta.py +++ b/symbolic_pymc/tensorflow/meta.py @@ -17,6 +17,7 @@ from google.protobuf.message import Message from tensorflow.python.framework import tensor_util, op_def_registry, op_def_library, tensor_shape +from tensorflow.python.eager.context import graph_mode from tensorflow.core.framework.op_def_pb2 import OpDef from tensorflow.core.framework.node_def_pb2 import NodeDef @@ -802,6 +803,65 @@ def find_opdef(cls, name): return None + @staticmethod + def metatize_tf_function(func, *args, **kwargs): + """Convert a TensorFlow function into a version that produces meta objects instead. + + The approach used by this function is to turn all meta tensor arguments + into TensorFlow placeholders, execute the TensorFlow function using + those placeholders, then metatize the result and replace the + placeholders with the original arguments. + + """ + + func_sig = inspect.signature(func) + bound_args = func_sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + dummy_args = [] + dummy_to_meta = {} + for arg_name, arg_val in bound_args.arguments.items(): + if isinstance(arg_val, TFlowMetaTensor): + if isinstance(arg_val.dtype, tf.dtypes.DType): + ph_dtype_tf = arg_val.dtype + else: + ph_dtype_tf = tf.dtypes.variant + with graph_mode(): + ph_arg_tf = tf.compat.v1.placeholder(ph_dtype_tf, name=f"dummy_{arg_name}") + dummy_to_meta[ph_arg_tf.name] = arg_val + dummy_args += [ph_arg_tf] + else: + dummy_args += [arg_val] + + with graph_mode(): + tf_res = func(*dummy_args) + mt_res = mt(tf_res) + + def mt_arg_convert(obj): + nonlocal dummy_to_meta + + # TODO: We can skip all this if we know that the object is reifiable + # (e.g. has a `.obj`). + if isinstance(obj, TFlowMetaConstant): + return obj + elif isinstance(obj, TFlowMetaTensor): + if str(obj.name) in dummy_to_meta: + return dummy_to_meta[str(obj.name)] + + # obj_et = etuplize(obj, shallow=True) + # new_inputs = [mt_arg_convert(i) for i in obj_et[1:]] + # return etuple(obj_et[0], *new_inputs) + + new_inputs = [mt_arg_convert(i) for i in obj.inputs] + return obj.op.op_def(*new_inputs) + + elif isinstance(obj, (list, tuple)): + return type(obj)([mt_arg_convert(i) for i in obj]) + else: + return obj + + return mt_arg_convert(mt_res) + def __getattr__(self, obj): ns_obj = next((getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None) diff --git a/tests/tensorflow/test_meta.py b/tests/tensorflow/test_meta.py index 41433ad..f4c7724 100644 --- a/tests/tensorflow/test_meta.py +++ b/tests/tensorflow/test_meta.py @@ -401,3 +401,25 @@ class CustomClass(object): with pytest.raises(ValueError): mt(CustomClass()) + + +@pytest.mark.usefixtures("run_with_tensorflow") +@run_in_graph_mode +def test_metatize_function(): + + # A_tf = tf.convert_to_tensor(np.c_[[1, 2], [3, 4]]) + A_tf = tf.compat.v1.placeholder(tf.float64, name='A', + shape=tf.TensorShape([None, None])) + + A_shape_tf = tf.shape(A_tf) + A_rows_tf = A_shape_tf[0] + I_A_tf = tf.eye(A_rows_tf) + + # A_mt = mt(A_tf) + # A_shape_mt = mt.shape(A_mt) + # A_rows_tf = mt.StridedSlice(A_shape_mt, 0, 1, 1, shrink_axis_mask=1) + A_rows_mt = mt(A_rows_tf) + + eye_mt = mt.metatize_tf_function(tf.eye, A_rows_mt) + + assert mt(I_A_tf) == eye_mt