From 9c839f736eb77dbcfd28d2058db156d3634abc38 Mon Sep 17 00:00:00 2001 From: Kirill Dubovikov Date: Tue, 9 Apr 2024 13:55:50 +0000 Subject: [PATCH] onnx export points dimension fix --- export_to_onnx.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/export_to_onnx.py b/export_to_onnx.py index ba6c13f..b07e414 100644 --- a/export_to_onnx.py +++ b/export_to_onnx.py @@ -38,16 +38,16 @@ def export_onnx_esam(model, output): onnx_model = onnx_models.OnnxEfficientSam(model=model) dynamic_axes = { "batched_images": {0: "batch", 2: "height", 3: "width"}, - "batched_point_coords": {2: "num_points"}, - "batched_point_labels": {2: "num_points"}, + "batched_point_coords": {1: "num_points"}, + "batched_point_labels": {1: "num_points"}, } dummy_inputs = { "batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float), "batched_point_coords": torch.randint( - low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float + low=0, high=1080, size=(1, 5, 1, 2), dtype=torch.float ), "batched_point_labels": torch.randint( - low=0, high=4, size=(1, 1, 5), dtype=torch.float + low=0, high=4, size=(1, 5, 1), dtype=torch.float ), } output_names = ["output_masks", "iou_predictions"] @@ -82,16 +82,16 @@ def export_onnx_esam_decoder(model, output): onnx_model = onnx_models.OnnxEfficientSamDecoder(model=model) dynamic_axes = { "image_embeddings": {0: "batch"}, - "batched_point_coords": {2: "num_points"}, - "batched_point_labels": {2: "num_points"}, + "batched_point_coords": {1: "num_points"}, + "batched_point_labels": {1: "num_points"}, } dummy_inputs = { "image_embeddings": torch.randn(1, 256, 64, 64, dtype=torch.float), "batched_point_coords": torch.randint( - low=0, high=1080, size=(1, 1, 5, 2), dtype=torch.float + low=0, high=1080, size=(1, 5, 1, 2), dtype=torch.float ), "batched_point_labels": torch.randint( - low=0, high=4, size=(1, 1, 5), dtype=torch.float + low=0, high=4, size=(1, 5, 1), dtype=torch.float ), "orig_im_size": torch.tensor([1080, 1920], dtype=torch.long), }