From 2cd264d61e00f9dfc2d3a9ec9c0c56ec64454cb1 Mon Sep 17 00:00:00 2001 From: zhangnd Date: Tue, 9 Oct 2018 11:03:35 +0800 Subject: [PATCH 1/2] add convert op BatchPermutation to Gather --- tools/convert_pkl_to_pb.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tools/convert_pkl_to_pb.py b/tools/convert_pkl_to_pb.py index e447adfbd..8e4572f69 100644 --- a/tools/convert_pkl_to_pb.py +++ b/tools/convert_pkl_to_pb.py @@ -268,6 +268,14 @@ def convert_upsample_nearest(op): height_scale=float(scale)) return resize_nearest_op + @op_filter(type='BatchPermutation') + def convert_batch_permutation(op): + gather_op = core.CreateOperator('Gather', + list(op.input), + list(op.output), + name=op.name) + return gather_op + @op_filter() def convert_rpn_rois(op): for j in range(len(op.input)): @@ -291,6 +299,7 @@ def convert_remove_op(op): # so run separately convert_op_in_proto(net, convert_remove_op) convert_op_in_proto(net, convert_upsample_nearest) + convert_op_in_proto(net, convert_batch_permutation) convert_op_in_proto(net, convert_python) convert_op_in_proto(net, convert_op_name) convert_op_in_proto(net, convert_rpn_rois) From d5000ee664dfcd7a0cdda20e73115dac5fc7963c Mon Sep 17 00:00:00 2001 From: zhangnd Date: Tue, 9 Oct 2018 14:03:59 +0800 Subject: [PATCH 2/2] keep logic simple --- detectron/utils/model_convert_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/detectron/utils/model_convert_utils.py b/detectron/utils/model_convert_utils.py index 17752db2b..a0d1e83ae 100644 --- a/detectron/utils/model_convert_utils.py +++ b/detectron/utils/model_convert_utils.py @@ -370,9 +370,9 @@ def compare_model(model1_func, model2_func, test_image, check_blobs): n1, n2 = cb1[idx], cb2[idx] r1 = res1[n1] if n1 in res1 else None r2 = res2[n2] if n2 in res2 else None - assert r1 is not None or r2 is None, \ + assert r1 is not None, \ "Blob {} in model1 is None".format(n1) - assert r2 is not None or r1 is None, \ + assert r2 is not None, \ "Blob {} in model2 is None".format(n2) assert r1.shape == r2.shape, \ "Blob {} and {} shape mismatched: {} vs {}".format(