diff --git a/detic/predictor.py b/detic/predictor.py index 318205a..7d116c7 100644 --- a/detic/predictor.py +++ b/detic/predictor.py @@ -5,6 +5,7 @@ from collections import deque import cv2 import torch +import random from detectron2.data import MetadataCatalog from detectron2.engine.defaults import DefaultPredictor @@ -37,7 +38,7 @@ def get_clip_embeddings(vocabulary, prompt='a '): } class VisualizationDemo(object): - def __init__(self, cfg, args, + def __init__(self, cfg, args, instance_mode=ColorMode.IMAGE, parallel=False): """ Args: @@ -50,10 +51,12 @@ def __init__(self, cfg, args, self.metadata = MetadataCatalog.get("__unused") self.metadata.thing_classes = args.custom_vocabulary.split(',') classifier = get_clip_embeddings(self.metadata.thing_classes) + self._default_vocabulary = None else: self.metadata = MetadataCatalog.get( BUILDIN_METADATA_PATH[args.vocabulary]) classifier = BUILDIN_CLASSIFIER[args.vocabulary] + self._default_vocabulary = args.vocabulary num_classes = len(self.metadata.thing_classes) self.cpu_device = torch.device("cpu") @@ -67,6 +70,26 @@ def __init__(self, cfg, args, self.predictor = DefaultPredictor(cfg) reset_cls_test(self.predictor.model, classifier, num_classes) + def change_vocabulary(self, vocab): + """ + Args: + vocab (str): The comma separated string of vocabulary + """ + self.metadata = MetadataCatalog.get("__unused+"+str(random.random())) + self.metadata.thing_classes = vocab.split(',') + classifier = get_clip_embeddings(self.metadata.thing_classes) + num_classes = len(self.metadata.thing_classes) + reset_cls_test(self.predictor.model, classifier, num_classes) + + def set_defalt_vocabulary(self): + if not self._default_vocabulary: + raise RuntimeError("The VisualizationDemo is not initalized with buildin vocabulary") + self.metadata = MetadataCatalog.get( + BUILDIN_METADATA_PATH[self._default_vocabulary]) + classifier = BUILDIN_CLASSIFIER[self._default_vocabulary] + num_classes = len(self.metadata.thing_classes) + reset_cls_test(self.predictor.model, classifier, num_classes) + def run_on_image(self, image): """ Args: