diff --git a/mmdet/apis/ote/apis/detection/config_utils.py b/mmdet/apis/ote/apis/detection/config_utils.py index 38eb61e1825..8f5b88431b5 100644 --- a/mmdet/apis/ote/apis/detection/config_utils.py +++ b/mmdet/apis/ote/apis/detection/config_utils.py @@ -13,11 +13,13 @@ # and limitations under the License. import copy -import glob import os import tempfile from collections import defaultdict +from typing import Any, Optional, List + from mmcv import Config, ConfigDict +from mmcv.runner import master_only from sc_sdk.entities.datasets import Dataset from sc_sdk.entities.label import Label from sc_sdk.logging import logger_factory @@ -26,7 +28,8 @@ from .configuration import OTEDetectionConfig -logger = logger_factory.get_logger("OTEDetectionTask") + +logger = logger_factory.get_logger("OTEDetectionTask.config_utils") def patch_config(config: Config, work_dir: str, labels: List[Label], random_seed: Optional[int] = None): @@ -89,10 +92,10 @@ def prepare_for_testing(config: Config, dataset: Dataset) -> Config: def prepare_for_training(config: Config, train_dataset: Dataset, val_dataset: Dataset, - time_monitor: TimeMonitorCallback, learning_curves: defaultdict) -> Config: + round_id: Any, time_monitor: TimeMonitorCallback, learning_curves: defaultdict) -> Config: config = copy.deepcopy(config) - prepare_work_dir(config) + prepare_work_dir(config, round_id) # config.data.test.ote_dataset = dataset.get_subset(Subset.TESTING) config.data.val.ote_dataset = val_dataset @@ -124,7 +127,6 @@ def config_to_string(config: Config) -> str: config_copy.data.train.ote_dataset = None else: config_copy.data.train.dataset.ote_dataset = None - # config_copy.labels = [label.name for label in config.labels] return Config(config_copy).pretty_text @@ -141,6 +143,7 @@ def config_from_string(config_string: str) -> Config: return Config.fromfile(temp_file.name) +@master_only def save_config_to_file(config: Config): """ Dump the full config to a file. Filename is 'config.py', it is saved in the current work_dir. """ filepath = os.path.join(config.work_dir, 'config.py') @@ -149,16 +152,16 @@ def save_config_to_file(config: Config): f.write(config_string) -def prepare_work_dir(config: Config) -> str: +def prepare_work_dir(config: Config, round_id: Any = 0) -> str: base_work_dir = config.work_dir - checkpoint_dirs = glob.glob(os.path.join(base_work_dir, "checkpoints_round_*")) - train_round_checkpoint_dir = os.path.join(base_work_dir, f"checkpoints_round_{len(checkpoint_dirs)}") - os.makedirs(train_round_checkpoint_dir) + + train_round_checkpoint_dir = os.path.join(base_work_dir, f"checkpoints_round_{round_id}") + os.makedirs(train_round_checkpoint_dir, exist_ok=True) logger.info(f"Checkpoints and logs for this training run are stored in {train_round_checkpoint_dir}") config.work_dir = train_round_checkpoint_dir if 'meta' not in config.runner: config.runner.meta = ConfigDict() - config.runner.meta.exp_name = f"train_round_{len(checkpoint_dirs)}" + config.runner.meta.exp_name = f"train_round_{round_id}" # Save training config for debugging. It is saved in the checkpoint dir for this training round save_config_to_file(config) return train_round_checkpoint_dir @@ -184,8 +187,6 @@ def set_data_classes(config: Config, label_names: List[str]): config.model.roi_head.bbox_head.num_classes = num_classes elif 'bbox_head' in config.model: config.model.bbox_head.num_classes = num_classes - # FIXME. ? - # self.config.model.CLASSES = label_names def patch_datasets(config: Config): diff --git a/mmdet/apis/ote/apis/detection/configuration.py b/mmdet/apis/ote/apis/detection/configuration.py index 70bbf7aa29e..e54d75f790a 100644 --- a/mmdet/apis/ote/apis/detection/configuration.py +++ b/mmdet/apis/ote/apis/detection/configuration.py @@ -115,17 +115,5 @@ class __Postprocessing(ParameterGroup): affects_outcome_of=ModelLifecycle.INFERENCE ) - @attrs - class __AlgoBackend(ParameterGroup): - header = string_attribute("Internal Algo Backend parameters") - description = header - visible_in_ui = boolean_attribute(False) - - template = string_attribute("template.yaml") - model = string_attribute("model.py") - model_name = string_attribute("object detection model") - data_pipeline = string_attribute("ote_data_pipeline.py") - learning_parameters = add_parameter_group(__LearningParameters) - algo_backend = add_parameter_group(__AlgoBackend) postprocessing = add_parameter_group(__Postprocessing) diff --git a/mmdet/apis/ote/apis/detection/openvino_task.py b/mmdet/apis/ote/apis/detection/openvino_task.py index ffa05a7dccf..e910d29223f 100644 --- a/mmdet/apis/ote/apis/detection/openvino_task.py +++ b/mmdet/apis/ote/apis/detection/openvino_task.py @@ -33,6 +33,8 @@ from sc_sdk.usecases.tasks.interfaces.evaluate_interface import IEvaluationTask from sc_sdk.usecases.tasks.interfaces.inference_interface import IInferenceTask +from mmcv.runner import master_only + from .configuration import OTEDetectionConfig @@ -165,6 +167,7 @@ def __init__(self, task_environment: TaskEnvironment): self.model = self.task_environment.model self.inferencer = self.load_inferencer() + @master_only def load_inferencer(self) -> OpenVINODetectionInferencer: labels = self.task_environment.label_schema.get_labels(include_empty=False) return OpenVINODetectionInferencer(self.hparams, @@ -172,12 +175,14 @@ def load_inferencer(self) -> OpenVINODetectionInferencer: self.model.get_data("openvino.xml"), self.model.get_data("openvino.bin")) + @master_only def infer(self, dataset: Dataset, inference_parameters: Optional[InferenceParameters] = None) -> Dataset: from tqdm import tqdm for dataset_item in tqdm(dataset): dataset_item.annotation_scene = self.inferencer.predict(dataset_item.numpy) return dataset + @master_only def evaluate(self, output_result_set: ResultSet, evaluation_metric: Optional[str] = None): diff --git a/mmdet/apis/ote/apis/detection/task.py b/mmdet/apis/ote/apis/detection/task.py index 558c30ea9fa..cea6930c00c 100644 --- a/mmdet/apis/ote/apis/detection/task.py +++ b/mmdet/apis/ote/apis/detection/task.py @@ -19,17 +19,18 @@ import shutil import tempfile import torch +import torch.distributed as dist +import torch.multiprocessing as mp import warnings from collections import defaultdict -from mmcv.parallel import MMDataParallel -from mmcv.runner import load_checkpoint +from mmcv.parallel import MMDistributedDataParallel +from mmcv.runner import load_checkpoint, get_dist_info, init_dist, master_only from mmcv.utils import Config from ote_sdk.configuration.helper.utils import ids_to_strings from ote_sdk.entities.inference_parameters import InferenceParameters from ote_sdk.entities.label import ScoredLabel -from ote_sdk.entities.metrics import (CurveMetric, InfoMetric, LineChartInfo, - MetricsGroup, Performance, ScoreMetric, - VisualizationInfo, VisualizationType) +from ote_sdk.entities.metrics import (CurveMetric, LineChartInfo, + MetricsGroup, Performance, ScoreMetric) from ote_sdk.entities.shapes.box import Box from ote_sdk.entities.train_parameters import TrainParameters from sc_sdk.configuration import cfg_helper @@ -50,7 +51,7 @@ from sc_sdk.usecases.tasks.interfaces.unload_interface import IUnload from typing import List, Optional, Tuple -from mmdet.apis import export_model, single_gpu_test, train_detector +from mmdet.apis import export_model, multi_gpu_test, single_gpu_test, train_detector from mmdet.apis.ote.apis.detection.config_utils import (patch_config, prepare_for_testing, prepare_for_training, @@ -64,6 +65,15 @@ logger = logger_factory.get_logger("OTEDetectionTask") +def init_dist_cpu(launcher, backend, **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + dist.init_process_group(backend=backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + class OTEDetectionTask(ITrainingTask, IInferenceTask, IExportTask, IEvaluationTask, IUnload): task_environment: TaskEnvironment @@ -74,28 +84,42 @@ def __init__(self, task_environment: TaskEnvironment): """ logger.info(f"Loading OTEDetectionTask.") + self._scratch_space = tempfile.mkdtemp(prefix="ote-det-scratch-") logger.info(f"Scratch space created at {self._scratch_space}") self._task_environment = task_environment self._hyperparams = hyperparams = task_environment.get_hyper_parameters(OTEDetectionConfig) - self._model_name = hyperparams.algo_backend.model_name self._labels = task_environment.get_labels(False) + if not torch.distributed.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("RANK", "0") + if torch.cuda.is_available(): + init_dist(launcher='pytorch') + else: + init_dist_cpu(launcher='pytorch', backend="gloo") + self._rank, world_size = get_dist_info() + logger.warning(f'World size {world_size}, rank {self._rank}') + template_file_path = task_environment.model_template.model_template_path # Get and prepare mmdet config. base_dir = os.path.abspath(os.path.dirname(template_file_path)) - config_file_path = os.path.join(base_dir, hyperparams.algo_backend.model) + config_file_path = os.path.join(base_dir, 'model.py') self._config = Config.fromfile(config_file_path) patch_config(self._config, self._scratch_space, self._labels, random_seed=42) set_hyperparams(self._config, hyperparams) + self._config.gpu_ids = range(world_size) # Create and initialize PyTorch model. self._model = self._load_model(task_environment.model) # Extra control variables. + self._training_round_id = 0 self._training_work_dir = None self._is_training = False self._should_stop = False @@ -113,7 +137,6 @@ def _load_model(self, model: Model): try: model.load_state_dict(model_data['model']) logger.info(f"Loaded model weights from Task Environment") - logger.info(f"Model architecture: {self._model_name}") except BaseException as ex: raise ValueError("Could not load the saved model. The model file structure is invalid.") \ from ex @@ -121,8 +144,7 @@ def _load_model(self, model: Model): # If there is no trained model yet, create model with pretrained weights as defined in the model config # file. model = self._create_model(self._config, from_scratch=False) - logger.info(f"No trained model in project yet. Created new model with '{self._model_name}' " - f"architecture and general-purpose pretrained weights.") + logger.info(f"No trained model in project yet. Created new model with general-purpose pretrained weights.") return model @staticmethod @@ -163,38 +185,37 @@ def infer(self, dataset: Dataset, inference_parameters: Optional[InferenceParame prediction_results, _ = self._infer_detector(self._model, self._config, dataset, False) - # Loop over dataset again to assign predictions. Convert from MMDetection format to OTE format - for dataset_item, output in zip(dataset, prediction_results): - width = dataset_item.width - height = dataset_item.height + if self._rank == 0: + # Loop over dataset again to assign predictions. Convert from MMDetection format to OTE format + for dataset_item, output in zip(dataset, prediction_results): + width = dataset_item.width + height = dataset_item.height - shapes = [] - for label_idx, detections in enumerate(output): - for i in range(detections.shape[0]): - probability = float(detections[i, 4]) - coords = detections[i, :4].astype(float).copy() - coords /= np.array([width, height, width, height], dtype=float) - coords = np.clip(coords, 0, 1) + shapes = [] + for label_idx, detections in enumerate(output): + for i in range(detections.shape[0]): + probability = float(detections[i, 4]) + coords = detections[i, :4].astype(float).copy() + coords /= np.array([width, height, width, height], dtype=float) + coords = np.clip(coords, 0, 1) - if probability < confidence_threshold: - continue + if probability < confidence_threshold: + continue - assigned_label = [ScoredLabel(self._labels[label_idx], - probability=probability)] - if coords[3] - coords[1] <= 0 or coords[2] - coords[0] <= 0: - continue + assigned_label = [ScoredLabel(self._labels[label_idx], probability=probability)] + if coords[3] - coords[1] <= 0 or coords[2] - coords[0] <= 0: + continue - shapes.append(Annotation( - Box(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3]), - labels=assigned_label)) + shapes.append(Annotation( + Box(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3]), + labels=assigned_label)) - dataset_item.append_annotations(shapes) + dataset_item.append_annotations(shapes) return dataset - @staticmethod - def _infer_detector(model: torch.nn.Module, config: Config, dataset: Dataset, + def _infer_detector(self, model: torch.nn.Module, config: Config, dataset: Dataset, eval: Optional[bool] = False, metric_name: Optional[str] = 'mAP') -> Tuple[List, float]: model.eval() test_config = prepare_for_testing(config, dataset) @@ -204,22 +225,26 @@ def _infer_detector(model: torch.nn.Module, config: Config, dataset: Dataset, samples_per_gpu=batch_size, workers_per_gpu=test_config.data.workers_per_gpu, num_gpus=1, - dist=False, + dist=True, shuffle=False) + if torch.cuda.is_available(): - eval_model = MMDataParallel(model.cuda(test_config.gpu_ids[0]), - device_ids=test_config.gpu_ids) + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + eval_predictions = multi_gpu_test(model, mm_val_dataloader) else: - eval_model = MMDataCPU(model) - # Use a single gpu for testing. Set in both mm_val_dataloader and eval_model - eval_predictions = single_gpu_test(eval_model, mm_val_dataloader, show=False) + model = MMDataCPU(model) + eval_predictions = single_gpu_test(model, mm_val_dataloader, show=False) metric = None - if eval: + if eval and self._rank == 0: metric = mm_val_dataset.evaluate(eval_predictions, metric=metric_name)[metric_name] return eval_predictions, metric + @master_only def evaluate(self, output_result_set: ResultSet, evaluation_metric: Optional[str] = None): @@ -253,6 +278,7 @@ def train(self, dataset: Dataset, output_model: Model, train_parameters: Optiona """ Trains a model on a dataset """ set_hyperparams(self._config, self._hyperparams) + self._training_round_id += 1 train_dataset = dataset.get_subset(Subset.TRAINING) val_dataset = dataset.get_subset(Subset.VALIDATION) @@ -282,12 +308,13 @@ def train(self, dataset: Dataset, output_model: Model, train_parameters: Optiona # Run training. self._time_monitor = TimeMonitorCallback(0, 0, 0, 0, update_progress_callback=lambda _: None) learning_curves = defaultdict(OTELoggerHook.Curve) - training_config = prepare_for_training(config, train_dataset, val_dataset, self._time_monitor, learning_curves) + training_config = prepare_for_training(config, train_dataset, val_dataset, + self._training_round_id, self._time_monitor, learning_curves) self._training_work_dir = training_config.work_dir mm_train_dataset = build_dataset(training_config.data.train) self._is_training = True self._model.train() - train_detector(model=self._model, dataset=mm_train_dataset, cfg=training_config, validate=True) + train_detector(model=self._model, dataset=mm_train_dataset, cfg=training_config, distributed=True, validate=True) # Check for stop signal when training has stopped. If should_stop is true, training was cancelled and no new # model should be returned. Old train model is restored. @@ -296,15 +323,10 @@ def train(self, dataset: Dataset, output_model: Model, train_parameters: Optiona self._model = old_model self._should_stop = False self._is_training = False + self._training_work_dir = None self._time_monitor = None return - # Load the best weights and check if model has improved. - training_metrics = self._generate_training_metrics_group(learning_curves) - best_checkpoint_path = os.path.join(training_config.work_dir, 'latest.pth') - best_checkpoint = torch.load(best_checkpoint_path) - self._model.load_state_dict(best_checkpoint['state_dict']) - # Evaluate model performance after training. _, final_performance = self._infer_detector(self._model, config, val_dataset, True) improved = final_performance > initial_performance @@ -316,6 +338,7 @@ def train(self, dataset: Dataset, output_model: Model, train_parameters: Optiona else: logger.info("First training round, saving the model.") # Add mAP metric and loss curves + training_metrics = self._generate_training_metrics_group(learning_curves) performance = Performance(score=ScoreMetric(value=final_performance, name="mAP"), dashboard_metrics=training_metrics) logger.info('FINAL MODEL PERFORMANCE\n' + str(performance)) @@ -328,9 +351,11 @@ def train(self, dataset: Dataset, output_model: Model, train_parameters: Optiona self._model = old_model self._is_training = False + self._training_work_dir = None self._time_monitor = None + @master_only def save_model(self, output_model: Model): buffer = io.BytesIO() hyperparams = self._task_environment.get_hyper_parameters(OTEDetectionConfig) @@ -352,6 +377,7 @@ def get_training_progress(self) -> float: return -1.0 + @master_only def cancel_training(self): """ Sends a cancel training signal to gracefully stop the optimizer. The signal consists of creating a @@ -373,13 +399,6 @@ def _generate_training_metrics_group(self, learning_curves) -> Optional[List[Met """ output: List[MetricsGroup] = [] - # Model architecture - architecture = InfoMetric(name='Model architecture', value=self._model_name) - visualization_info_architecture = VisualizationInfo(name="Model architecture", - visualisation_type=VisualizationType.TEXT) - output.append(MetricsGroup(metrics=[architecture], - visualization_info=visualization_info_architecture)) - # Learning curves for key, curve in learning_curves.items(): metric_curve = CurveMetric(xs=curve.x, ys=curve.y, name=key) @@ -436,20 +455,21 @@ def unload(self): ctypes.string_at(0) else: logger.warning("Got unload request, but not on Docker. Only clearing CUDA cache") - torch.cuda.empty_cache() - logger.warning(f"Done unloading. " - f"Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.warning(f"CUDA cache is cleared. " + "Torch is still occupying {torch.cuda.memory_allocated()} bytes of GPU memory") + logger.warning("Done unloading.") + @master_only def export(self, export_type: ExportType, output_model: OptimizedModel): assert export_type == ExportType.OPENVINO optimized_model_precision = ModelPrecision.FP32 - with tempfile.TemporaryDirectory() as tempdir: - optimized_model_dir = os.path.join(tempdir, "export") - logger.info(f'Optimized model will be temporarily saved to "{optimized_model_dir}"') - os.makedirs(optimized_model_dir, exist_ok=True) + with tempfile.TemporaryDirectory(prefix="ote-det-export-") as tempdir: + logger.info(f'Optimized model will be temporarily saved to "{tempdir}"') try: from torch.jit._trace import TracerWarning warnings.filterwarnings("ignore", category=TracerWarning) @@ -470,6 +490,7 @@ def export(self, raise RuntimeError("Optimization was unsuccessful.") from ex + @master_only def _delete_scratch_space(self): """ Remove model checkpoints and mmdet logs diff --git a/mmdet/apis/ote/sample/sample.py b/mmdet/apis/ote/sample/sample.py index d04973b7b5c..a8e478d85e9 100644 --- a/mmdet/apis/ote/sample/sample.py +++ b/mmdet/apis/ote/sample/sample.py @@ -15,6 +15,12 @@ import argparse import os.path as osp import sys +import warnings + +warnings.filterwarnings('ignore', category=DeprecationWarning, message='.*cElementTree is deprecated.*') +warnings.filterwarnings('ignore', category=UserWarning, message='.*Nevergrad package could not be imported.*') +warnings.filterwarnings('ignore', category=UserWarning, message='.*This overload of nonzero is deprecated.*') + from ote_sdk.configuration.helper import create from ote_sdk.entities.inference_parameters import InferenceParameters from sc_sdk.entities.dataset_storage import NullDatasetStorage @@ -37,7 +43,9 @@ reload_hyper_parameters) from mmdet.apis.ote.extension.datasets.mmdataset import MMDatasetAdapter -logger = logger_factory.get_logger('Sample') +logger = logger_factory.get_logger('OTEDetectionSample') +import logging +logger.setLevel(logging.INFO) def parse_args(): diff --git a/tests/test_ote_api.py b/tests/test_ote_api.py index 17787a68e64..7302fa4794e 100644 --- a/tests/test_ote_api.py +++ b/tests/test_ote_api.py @@ -70,7 +70,6 @@ def test_configuration_yaml(): configuration_yaml_converted = yaml.safe_load(configuration_yaml_str) with open(osp.join('mmdet', 'apis', 'ote', 'apis', 'detection', 'configuration.yaml')) as read_file: configuration_yaml_loaded = yaml.safe_load(read_file) - del configuration_yaml_converted['algo_backend'] assert configuration_yaml_converted == configuration_yaml_loaded def test_set_values_as_default():