diff --git a/convert.py b/convert.py new file mode 100644 index 0000000..212e169 --- /dev/null +++ b/convert.py @@ -0,0 +1,419 @@ +import argparse +import random +from pathlib import Path +from types import MethodType + +import numpy as np +import onnx +import onnxruntime +import torch +import torch.nn.functional as F +import util.misc as utils +from models.ops.modules import MSDeformAttn +from onnxsim import simplify + +from models import build_model + + +def get_args_parser(): + parser = argparse.ArgumentParser("Deformable DETR Detector", add_help=False) + parser.add_argument("--lr", default=2e-4, type=float) + parser.add_argument( + "--lr_backbone_names", default=["backbone.0"], type=str, nargs="+" + ) + parser.add_argument("--lr_backbone", default=2e-5, type=float) + parser.add_argument( + "--lr_linear_proj_names", + default=["reference_points", "sampling_offsets"], + type=str, + nargs="+", + ) + parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float) + parser.add_argument("--batch_size", default=2, type=int) + parser.add_argument("--weight_decay", default=1e-4, type=float) + parser.add_argument("--epochs", default=50, type=int) + parser.add_argument("--lr_drop", default=40, type=int) + parser.add_argument("--lr_drop_epochs", default=None, type=int, nargs="+") + parser.add_argument( + "--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm" + ) + + parser.add_argument("--sgd", action="store_true") + + # Variants of Deformable DETR + parser.add_argument("--with_box_refine", default=False, action="store_true") + parser.add_argument("--two_stage", default=False, action="store_true") + + # Model parameters + parser.add_argument( + "--frozen_weights", + type=str, + default=None, + help="Path to the pretrained model. If set, only the mask head will be trained", + ) + + # * Backbone + parser.add_argument( + "--backbone", + default="resnet50", + type=str, + help="Name of the convolutional backbone to use", + ) + parser.add_argument( + "--dilation", + action="store_true", + help="If true, we replace stride with dilation in the last convolutional block (DC5)", + ) + parser.add_argument( + "--position_embedding", + default="sine", + type=str, + choices=("sine", "learned"), + help="Type of positional embedding to use on top of the image features", + ) + parser.add_argument( + "--position_embedding_scale", + default=2 * np.pi, + type=float, + help="position / size * scale", + ) + parser.add_argument( + "--num_feature_levels", default=5, type=int, help="number of feature levels" + ) + + # * Transformer + parser.add_argument( + "--enc_layers", + default=6, + type=int, + help="Number of encoding layers in the transformer", + ) + parser.add_argument( + "--dec_layers", + default=6, + type=int, + help="Number of decoding layers in the transformer", + ) + parser.add_argument( + "--dim_feedforward", + default=2048, + type=int, + help="Intermediate size of the feedforward layers in the transformer blocks", + ) + parser.add_argument( + "--hidden_dim", + default=256, + type=int, + help="Size of the embeddings (dimension of the transformer)", + ) + parser.add_argument( + "--dropout", default=0.1, type=float, help="Dropout applied in the transformer" + ) + parser.add_argument( + "--nheads", + default=8, + type=int, + help="Number of attention heads inside the transformer's attentions", + ) + parser.add_argument( + "--num_queries", default=900, type=int, help="Number of query slots" + ) + parser.add_argument("--dec_n_points", default=4, type=int) + parser.add_argument("--enc_n_points", default=4, type=int) + + # * Segmentation + parser.add_argument( + "--masks", + action="store_true", + help="Train segmentation head if the flag is provided", + ) + + # Loss + parser.add_argument( + "--no_aux_loss", + dest="aux_loss", + action="store_false", + help="Disables auxiliary decoding losses (loss at each layer)", + ) + + # * Matcher + parser.add_argument("--assign_first_stage", action="store_true") + parser.add_argument("--assign_second_stage", action="store_true") + parser.add_argument( + "--set_cost_class", + default=2, + type=float, + help="Class coefficient in the matching cost", + ) + parser.add_argument( + "--set_cost_bbox", + default=5, + type=float, + help="L1 box coefficient in the matching cost", + ) + parser.add_argument( + "--set_cost_giou", + default=2, + type=float, + help="giou box coefficient in the matching cost", + ) + + # * Loss coefficients + parser.add_argument("--mask_loss_coef", default=1, type=float) + parser.add_argument("--dice_loss_coef", default=1, type=float) + parser.add_argument("--cls_loss_coef", default=2, type=float) + parser.add_argument("--bbox_loss_coef", default=5, type=float) + parser.add_argument("--giou_loss_coef", default=2, type=float) + parser.add_argument("--focal_alpha", default=0.25, type=float) + + # dataset parameters + parser.add_argument("--dataset_file", default="coco") + parser.add_argument("--coco_path", default="./data/coco", type=str) + parser.add_argument("--coco_panoptic_path", type=str) + parser.add_argument("--remove_difficult", action="store_true") + parser.add_argument("--bigger", action="store_true") + + parser.add_argument( + "--output_dir", default="", help="path where to save, empty for no saving" + ) + parser.add_argument( + "--device", default="cuda", help="device to use for training / testing" + ) + parser.add_argument("--seed", default=42, type=int) + parser.add_argument( + "--resume", default="adet_2x_checkpoint0023.pth", help="resume from checkpoint" + ) + parser.add_argument( + "--finetune", + default="adet_2x_checkpoint0023.pth", + help="finetune from checkpoint", + ) + parser.add_argument( + "--start_epoch", default=0, type=int, metavar="N", help="start epoch" + ) + parser.add_argument("--eval", action="store_true") + parser.add_argument("--num_workers", default=2, type=int) + parser.add_argument( + "--cache_mode", + default=False, + action="store_true", + help="whether to cache images on memory", + ) + + # onnx parameters + parser.add_argument("--h", default=256, type=int) + parser.add_argument("--w", default=256, type=int) + parser.add_argument("--usegridsample", action="store_true") + parser.add_argument("--output", default="model.onnx", type=str) + + return parser + + +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: + + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = ( + value_list[level] + .flatten(2) + .transpose(1, 2) + .reshape(bs * num_heads, embed_dims, H_, W_) + ) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +def MSMHDA_onnx_export( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, +): + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2 + ) + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points + ) + attention_weights = F.softmax(attention_weights, -1).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points + ) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 + ) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets + / self.n_points + * reference_points[:, :, None, :, None, 2:] + * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + + output = multi_scale_deformable_attn_pytorch( + value, input_spatial_shapes, sampling_locations, attention_weights + ) + output = self.output_proj(output) + return output + + +def convert(args, model): + if args.resume: + if args.resume.startswith("https"): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location="cpu", check_hash=True + ) + else: + checkpoint = torch.load(args.resume, map_location="cpu") + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model"], strict=True + ) + unexpected_keys = [ + k + for k in unexpected_keys + if not (k.endswith("total_params") or k.endswith("total_ops")) + ] + if len(missing_keys) > 0: + print("Missing Keys: {}".format(missing_keys)) + if len(unexpected_keys) > 0: + print("Unexpected Keys: {}".format(unexpected_keys)) + + for module in model.modules(): + if isinstance(module, MSDeformAttn): + module.forward = MethodType(MSMHDA_onnx_export, module) + + model.eval() + + inputs = torch.randn(1, 3, args.h, args.w) + inputs = inputs + + torch.onnx.export( + model, + inputs, + args.output, + input_names=["input"], + output_names=["output"], + export_params=True, + do_constant_folding=True, + verbose=True, + opset_version=opset_version, + ) + + model_simple, check = simplify( + args.output, + ) + assert check, "Failed to simplify ONNX model." + + onnx.save(model_simple, args.output) + + +def compare(args, model): + x = torch.randn(1, 3, 256, 256).cuda() + + model.cuda() + torch_out = model(x) + + ort_session = onnxruntime.InferenceSession( + args.output, providers=["CUDAExecutionProvider"] + ) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + # output of onnxruntime + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} + ort_outs = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose( + to_numpy(torch_out["pred_logits"]), ort_outs[0], rtol=1e-03, atol=1e-05 + ) + np.testing.assert_allclose( + to_numpy(torch_out["pred_boxes"]), ort_outs[1], rtol=1e-03, atol=1e-05 + ) + + print("ONNX Successfully converted!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Deformable DETR training and evaluation script", parents=[get_args_parser()] + ) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + opset_version = 16 + + model, _, _ = build_model(args) + + convert(args, model) + compare(args, model) diff --git a/models/deformable_transformer.py b/models/deformable_transformer.py index e4cfd85..479e819 100644 --- a/models/deformable_transformer.py +++ b/models/deformable_transformer.py @@ -199,7 +199,7 @@ def forward(self, srcs, masks, pos_embeds, query_embed=None): # keep top Q/L indices for L levels q_per_l = topk // len(spatial_shapes) is_level_ordered = level_ids[keep_inds][None] == torch.arange(len(spatial_shapes), device=level_ids.device)[:,None] # LS - keep_inds_mask = is_level_ordered & (is_level_ordered.cumsum(1) <= q_per_l) # LS + keep_inds_mask = is_level_ordered & (is_level_ordered.float().cumsum(1) <= q_per_l) # LS keep_inds_mask = keep_inds_mask.any(0) # S # pad to Q indices (might let ones filtered from pre-nms sneak by... unlikely because we pick high conf anyways) diff --git a/models/position_encoding.py b/models/position_encoding.py index a92f0d3..bf64705 100644 --- a/models/position_encoding.py +++ b/models/position_encoding.py @@ -37,7 +37,7 @@ def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors mask = tensor_list.mask assert mask is not None - not_mask = ~mask + not_mask = (~mask).float() y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: diff --git a/requirements.txt b/requirements.txt index fd84672..09ede8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,6 @@ pycocotools tqdm cython scipy +onnx +onnxruntime +onnxsim \ No newline at end of file diff --git a/util/misc.py b/util/misc.py index ee3cf7b..004ab43 100644 --- a/util/misc.py +++ b/util/misc.py @@ -318,26 +318,6 @@ def _max_by_axis(the_list): return maxes -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - # TODO make this more general - if tensor_list[0].ndim == 3: - # TODO make it support different-sized images - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], :img.shape[2]] = False - else: - raise ValueError('not supported') - return NestedTensor(tensor, mask) - - class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors @@ -366,6 +346,62 @@ def __repr__(self): return str(self.tensors) +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + def setup_for_distributed(is_master): """ This function disables printing when not in master process