Skip to content

Conversation

@SangbumChoi
Copy link
Contributor

Hi @jozhang97

I made ONNX conversion script for this ResNet50 DETA model

There are two slight modification in original code.

  1. onnx input process
  2. modify boolean type of input to float in cumsum operation

If you have time, please review and merge.
Also if you need further modification feel free to ask.

@SangbumChoi
Copy link
Contributor Author

FYI, to convert this model into tensorRT, it requires different conversion script (e.g. MultiScaleDeformableAttention)

@demuxin
Copy link

demuxin commented May 27, 2024

Hi @SangbumChoi , I tried to convert onnx to tensorrt engine. But there is error:

[E] 2: [myelinBuilderUtils.cpp::getMyelinSupportType::1270] Error Code 2: Internal Error (ForeignNode does not support data-dependent shape for now.)
[!] Invalid Engine. Please ensure the engine was built correctly
[E] FAILED | Runtime: 22.843s | Command: /home/osmagic/anaconda3/envs/pytorch/bin/polygraphy run --trt weights/out_nms_0527_sim_san.onnx

Do you know why there's this error and how to solve it?

I'm looking forward to your reply.

@SangbumChoi
Copy link
Contributor Author

@demuxin AFAIK, you need some custom kernel (deformable attention) to convert this model in to tensorRT. In my experience I have succeed to convert ONNX but not tensorRT.

I would happy to work and collaborate on converting tensorrt engine. Do you have any experience of writing CUDA programming?

@SangbumChoi
Copy link
Contributor Author

@demuxin Will you like to share your email and Slack message to discuss about this?

@xinlin-xiao
Copy link

@SangbumChoi Hi,Swin-L DETA model can be conver to onnx? I use the code conver Swin-L DETA to onnx :

deta.pt2onnx(img_size=(1440,832),weights='/mnt/data1/download_new/DETA-master/exps/public/deta_swin_ft_2024.4.3/best.pt')

def pt2onnx (self,weights,img_size,batch_size=1,device='cuda:0',export_nms=False,simplify=True):
       model=self.model
       print(img_size)
       img = torch.zeros(batch_size, 3, *img_size).to(device)
       try:
           import onnx
           
           print(f' starting export with onnx {onnx.__version__}...')
           f = weights.replace('.pt', '.onnx')  # filename
           torch.onnx.export(model, img, f, verbose=False, opset_version=11, input_names=['images'], output_names=['output'],
                               dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # size(1,3,640,640)
                                               'output': {0: 'batch', 2: 'y', 3: 'x'}} )
           # Checks
           model_onnx = onnx.load(f)  # load onnx model
           onnx.checker.check_model(model_onnx)  # check onnx model
           # print(onnx.helper.printable_graph(model_onnx.graph))  # print

           # Simplify
           if simplify:
               try:
                   # check_requirements(['onnx-simplifier'])
                   import onnxsim

                   print(f'simplifying with onnx-simplifier {onnxsim.__version__}...')
                   model_onnx, check = onnxsim.simplify(model_onnx,
                                                       input_shapes={'images': list(img.shape)} )
                   assert check, 'assert check failed'
                   onnx.save(model_onnx, f)
               except Exception as e:
                   print(f' simplifier failure: {e}')
           # print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
       except Exception as e:
           print(f' export failure: {e}')

It return :

/mnt/data1/download_new/DETA-master/models/deformable_detr.py:243: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
 export failure: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":607, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case.  Argument types: Tensor, bool, 

Candidates:
	aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
	aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> (Tensor(a!))

Do you have any suggestions for this mistake? thank you!!

@SangbumChoi
Copy link
Contributor Author

@xinlin-xiao Since Swin-L can be converted into ONNX I think overall the answer might be yes

@TheMattBin
Copy link

TheMattBin commented Sep 22, 2024

Hi, thanks for the great work! I'm also working on onnx export but with Huggingface optimum, but I encountered some issues when I try to export with optimum.

File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/base.py", line 306, in fix_dynamic_axes
    session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 480, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from deta_test/model.onnx failed:This is an invalid model. Type Error: Type 'tensor(bool)' of input parameter (/model/Equal_7_output_0) of operator (CumSum) in node (/model/CumSum_1) is invalid.

As I can see from your PR, it seems like you added NestedTensor function for onnx export. May I know the idea behind?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants