Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:

# Load the temp file and the external data.
inferred_model = onnx.load(temp_inferred_file, load_external_data=False)
data_dir = Path(input_dir if args.temp_dir is None else args.data_dir)
data_dir = Path(input_dir if args.data_dir is None else args.data_dir)
onnx.load_external_data_for_model(inferred_model, str(data_dir))

# Remove the inferred shape file unless asked to keep it
Expand Down
40 changes: 39 additions & 1 deletion test/python/onnx_importer/command_line_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,29 @@ def linear_model() -> onnx.ModelProto:
return onnx_model


ALL_MODELS = [const_model, linear_model]
def path_based_shape_inference_model() -> onnx.ModelProto:
# Create a model with a serialized form that's large enough to require
# path-based shape inference.
large_tensor = numpy.random.rand(onnx.checker.MAXIMUM_PROTOBUF).astype(
numpy.float32
)
assert large_tensor.nbytes > onnx.checker.MAXIMUM_PROTOBUF
node1 = make_node(
"Constant",
[],
["large_const"],
value=numpy_helper.from_array(large_tensor, name="large_const"),
)
X = make_tensor_value_info(
"large_const", TensorProto.FLOAT, [onnx.checker.MAXIMUM_PROTOBUF]
)
graph = make_graph([node1], "large_const_graph", [], [X])
onnx_model = make_model(graph)
check_model(onnx_model)
return onnx_model


ALL_MODELS = [const_model, linear_model, path_based_shape_inference_model]


class CommandLineTest(unittest.TestCase):
Expand Down Expand Up @@ -141,6 +163,20 @@ def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str):
)
__main__.main(args)

def run_model_explicit_temp_implicit_data(
self, onnx_model: onnx.ModelProto, model_name: str
):
run_path = self.get_run_path(model_name)
model_file = run_path / f"{model_name}-explicit_temp_implicit_data.onnx"
mlir_file = run_path / f"{model_name}-explicit_temp_implicit_data.torch.mlir"
onnx.save(onnx_model, model_file)
temp_dir = run_path / "temp"
temp_dir.mkdir(exist_ok=True)
args = __main__.parse_arguments(
[str(model_file), "-o", str(mlir_file), "--temp-dir", str(temp_dir)]
)
__main__.main(args)

def test_all(self):
for model_func in ALL_MODELS:
model_name = model_func.__name__
Expand All @@ -150,6 +186,8 @@ def test_all(self):
self.run_model_intern(model, model_name)
with self.subTest("External data"):
self.run_model_extern(model, model_name)
with self.subTest("Explicit temp dir, implicit data dir"):
self.run_model_explicit_temp_implicit_data(model, model_name)


if __name__ == "__main__":
Expand Down