diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 5ec5c63f820e..6863c150bc7d 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -146,7 +146,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: # Load the temp file and the external data. inferred_model = onnx.load(temp_inferred_file, load_external_data=False) - data_dir = Path(input_dir if args.temp_dir is None else args.data_dir) + data_dir = Path(input_dir if args.data_dir is None else args.data_dir) onnx.load_external_data_for_model(inferred_model, str(data_dir)) # Remove the inferred shape file unless asked to keep it diff --git a/test/python/onnx_importer/command_line_test.py b/test/python/onnx_importer/command_line_test.py index 998e1098030f..38753352ca2d 100644 --- a/test/python/onnx_importer/command_line_test.py +++ b/test/python/onnx_importer/command_line_test.py @@ -87,7 +87,29 @@ def linear_model() -> onnx.ModelProto: return onnx_model -ALL_MODELS = [const_model, linear_model] +def path_based_shape_inference_model() -> onnx.ModelProto: + # Create a model with a serialized form that's large enough to require + # path-based shape inference. + large_tensor = numpy.random.rand(onnx.checker.MAXIMUM_PROTOBUF).astype( + numpy.float32 + ) + assert large_tensor.nbytes > onnx.checker.MAXIMUM_PROTOBUF + node1 = make_node( + "Constant", + [], + ["large_const"], + value=numpy_helper.from_array(large_tensor, name="large_const"), + ) + X = make_tensor_value_info( + "large_const", TensorProto.FLOAT, [onnx.checker.MAXIMUM_PROTOBUF] + ) + graph = make_graph([node1], "large_const_graph", [], [X]) + onnx_model = make_model(graph) + check_model(onnx_model) + return onnx_model + + +ALL_MODELS = [const_model, linear_model, path_based_shape_inference_model] class CommandLineTest(unittest.TestCase): @@ -141,6 +163,20 @@ def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str): ) __main__.main(args) + def run_model_explicit_temp_implicit_data( + self, onnx_model: onnx.ModelProto, model_name: str + ): + run_path = self.get_run_path(model_name) + model_file = run_path / f"{model_name}-explicit_temp_implicit_data.onnx" + mlir_file = run_path / f"{model_name}-explicit_temp_implicit_data.torch.mlir" + onnx.save(onnx_model, model_file) + temp_dir = run_path / "temp" + temp_dir.mkdir(exist_ok=True) + args = __main__.parse_arguments( + [str(model_file), "-o", str(mlir_file), "--temp-dir", str(temp_dir)] + ) + __main__.main(args) + def test_all(self): for model_func in ALL_MODELS: model_name = model_func.__name__ @@ -150,6 +186,8 @@ def test_all(self): self.run_model_intern(model, model_name) with self.subTest("External data"): self.run_model_extern(model, model_name) + with self.subTest("Explicit temp dir, implicit data dir"): + self.run_model_explicit_temp_implicit_data(model, model_name) if __name__ == "__main__":